Skip to main content

asimov_runner/programs/
prompter.rs

1// This is free and unencumbered software released into the public domain.
2
3use crate::{Executor, ExecutorError, Input, TextOutput};
4use async_trait::async_trait;
5use derive_more::Debug;
6use std::{
7    ffi::OsStr,
8    io::{Cursor, Read},
9    process::Stdio,
10};
11use tokio::io::{AsyncRead, AsyncWrite};
12
13pub use asimov_patterns::PrompterOptions;
14pub use asimov_prompt::{Prompt, PromptMessage, PromptRole};
15
16pub type PrompterResult = std::result::Result<String, ExecutorError>;
17
18/// See: https://asimov-specs.github.io/program-patterns/#prompter
19#[allow(unused)]
20#[derive(Debug)]
21pub struct Prompter {
22    executor: Executor,
23    options: PrompterOptions,
24    input: Prompt,
25    output: TextOutput,
26}
27
28impl Prompter {
29    pub fn new(
30        program: impl AsRef<OsStr>,
31        input: Prompt,
32        output: TextOutput,
33        options: PrompterOptions,
34    ) -> Self {
35        let mut executor = Executor::new(program);
36
37        executor
38            .command()
39            .args(if let Some(ref input) = options.input {
40                vec![format!("--input={}", input)]
41            } else {
42                vec![]
43            })
44            .args(if let Some(ref output) = options.output {
45                vec![format!("--output={}", output)]
46            } else {
47                vec![]
48            })
49            .args(if let Some(ref model) = options.model {
50                vec![format!("--model={}", model)]
51            } else {
52                vec![]
53            })
54            .args(&options.other)
55            .stdin(Stdio::piped())
56            .stdout(Stdio::piped())
57            .stderr(Stdio::piped());
58
59        Self {
60            executor,
61            options,
62            input,
63            output,
64        }
65    }
66
67    pub async fn execute(&mut self) -> PrompterResult {
68        let mut process = self.executor.spawn().await?;
69
70        let prompt = self.input.clone();
71        let mut stdin = process.stdin.take().expect("should capture stdin");
72        tokio::spawn(async move {
73            use tokio::io::AsyncWriteExt;
74            stdin
75                .write_all(prompt.to_string().as_bytes())
76                .await
77                .expect("should write to stdin");
78        });
79
80        let mut stdout = self.executor.wait(process).await?;
81        let mut result = String::new();
82        stdout.read_to_string(&mut result)?;
83
84        Ok(result)
85    }
86}
87
88impl asimov_patterns::Prompter<String, ExecutorError> for Prompter {}
89
90#[async_trait]
91impl asimov_patterns::Execute<String, ExecutorError> for Prompter {
92    async fn execute(&mut self) -> PrompterResult {
93        self.execute().await
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use asimov_patterns::Execute;
101
102    #[tokio::test]
103    async fn test_execute() {
104        let mut prompter = Prompter::new(
105            "cat",
106            Prompt::builder()
107                .messages(vec![PromptMessage(
108                    PromptRole::User,
109                    "Hello, world!".into(),
110                )])
111                .build(),
112            TextOutput::Ignored,
113            PrompterOptions::default(),
114        );
115        let result = prompter.execute().await;
116        assert!(result.is_ok());
117        assert_eq!(result.unwrap(), String::from("user: Hello, world!\n"));
118    }
119}