asimov_runner/
provider.rs

1// This is free and unencumbered software released into the public domain.
2
3use crate::{Runner, RunnerError};
4use async_trait::async_trait;
5use std::{ffi::OsStr, io::Read, process::Stdio};
6
7pub use asimov_patterns::{Prompt, PromptMessage, PromptRole, ProviderOptions};
8
9pub type ProviderResult = std::result::Result<String, RunnerError>;
10
11/// LLM inference provider. Consumes text input, produces text output.
12#[derive(Debug)]
13pub struct Provider {
14    runner: Runner,
15    #[allow(unused)]
16    options: ProviderOptions,
17}
18
19impl Provider {
20    pub fn new(program: impl AsRef<OsStr>, options: ProviderOptions) -> Self {
21        let mut runner = Runner::new(program);
22
23        runner
24            .command()
25            .stdin(Stdio::piped())
26            .stdout(Stdio::piped())
27            .stderr(Stdio::piped());
28
29        Self { runner, options }
30    }
31}
32
33impl asimov_patterns::Provider<String, RunnerError> for Provider {}
34
35#[async_trait]
36impl asimov_patterns::Execute<String, RunnerError> for Provider {
37    async fn execute(&mut self) -> ProviderResult {
38        let mut process = self.runner.spawn().await?;
39
40        let prompt = self.options.prompt.clone();
41        let mut stdin = process.stdin.take().expect("Failed to capture stdin");
42        tokio::spawn(async move {
43            use tokio::io::AsyncWriteExt;
44            stdin
45                .write_all(prompt.to_string().as_bytes())
46                .await
47                .expect("Failed to write to stdin");
48        });
49
50        let mut stdout = self.runner.wait(process).await?;
51        let mut result = String::new();
52        stdout.read_to_string(&mut result)?;
53
54        Ok(result)
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61    use asimov_patterns::Execute;
62
63    #[tokio::test]
64    async fn test_execute() {
65        let mut runner = Provider::new(
66            "cat",
67            ProviderOptions {
68                prompt: "Hello, world!".into(),
69            },
70        );
71        let result = runner.execute().await;
72        assert!(result.is_ok());
73        assert_eq!(result.unwrap(), String::from("Hello, world!"));
74    }
75}