Skip to main content

rig_core/integrations/
cli_chatbot.rs

1use crate::{
2    agent::{Agent, MultiTurnStreamItem, Text},
3    completion::{Chat, CompletionError, CompletionModel, PromptError, Usage},
4    markers::{Missing, Provided},
5    message::Message,
6    streaming::{StreamedAssistantContent, StreamingPrompt},
7    wasm_compat::WasmCompatSend,
8};
9use futures::StreamExt;
10use std::io::{self, Write};
11
12pub struct ChatImpl<T>(T)
13where
14    T: Chat;
15
16pub struct AgentImpl<M>
17where
18    M: CompletionModel + 'static,
19{
20    agent: Agent<M>,
21    max_turns: usize,
22    show_usage: bool,
23    usage: Usage,
24}
25
26pub struct ChatBotBuilder<T = Missing>(T);
27
28pub struct ChatBot<T>(T);
29
30/// Trait to abstract message behavior away from cli_chat/`run` loop
31#[allow(private_interfaces)]
32trait CliChat {
33    async fn request(
34        &mut self,
35        prompt: &str,
36        history: &mut Vec<Message>,
37    ) -> Result<String, PromptError>;
38
39    fn show_usage(&self) -> bool {
40        false
41    }
42
43    fn usage(&self) -> Option<Usage> {
44        None
45    }
46}
47
48impl<T> CliChat for ChatImpl<T>
49where
50    T: Chat,
51{
52    async fn request(
53        &mut self,
54        prompt: &str,
55        history: &mut Vec<Message>,
56    ) -> Result<String, PromptError> {
57        let res = self.0.chat(prompt, history).await?;
58        println!("{res}");
59
60        Ok(res)
61    }
62}
63
64impl<M> CliChat for AgentImpl<M>
65where
66    M: CompletionModel + WasmCompatSend + 'static,
67{
68    async fn request(
69        &mut self,
70        prompt: &str,
71        history: &mut Vec<Message>,
72    ) -> Result<String, PromptError> {
73        let mut response_stream = self
74            .agent
75            .stream_prompt(prompt)
76            .with_history(history.clone())
77            .multi_turn(self.max_turns)
78            .await;
79
80        let mut acc = String::new();
81        let mut messages = None;
82
83        let result = loop {
84            let Some(chunk) = response_stream.next().await else {
85                println!();
86                break Ok(acc);
87            };
88
89            match chunk {
90                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
91                    Text { text },
92                ))) => {
93                    print!("{}", text);
94                    acc.push_str(&text);
95                }
96                Ok(MultiTurnStreamItem::FinalResponse(final_response)) => {
97                    self.usage = final_response.usage();
98                    messages = final_response.history().map(|history| history.to_vec());
99                }
100                Err(e) => {
101                    break Err(PromptError::CompletionError(
102                        CompletionError::ResponseError(e.to_string()),
103                    ));
104                }
105                _ => continue,
106            }
107        };
108
109        if let Ok(response) = &result {
110            if let Some(messages) = messages {
111                history.extend(messages);
112            } else {
113                history.push(Message::user(prompt));
114                history.push(Message::assistant(response.as_str()));
115            }
116        }
117
118        result
119    }
120
121    fn show_usage(&self) -> bool {
122        self.show_usage
123    }
124
125    fn usage(&self) -> Option<Usage> {
126        Some(self.usage)
127    }
128}
129
130impl Default for ChatBotBuilder<Missing> {
131    fn default() -> Self {
132        Self(Missing)
133    }
134}
135
136impl ChatBotBuilder<Missing> {
137    pub fn new() -> Self {
138        Self::default()
139    }
140
141    pub fn agent<M: CompletionModel + 'static>(
142        self,
143        agent: Agent<M>,
144    ) -> ChatBotBuilder<Provided<AgentImpl<M>>> {
145        ChatBotBuilder(Provided(AgentImpl {
146            agent,
147            max_turns: 1,
148            show_usage: false,
149            usage: Usage::default(),
150        }))
151    }
152
153    pub fn chat<T: Chat>(self, chatbot: T) -> ChatBotBuilder<Provided<ChatImpl<T>>> {
154        ChatBotBuilder(Provided(ChatImpl(chatbot)))
155    }
156}
157
158impl<T> ChatBotBuilder<Provided<ChatImpl<T>>>
159where
160    T: Chat,
161{
162    pub fn build(self) -> ChatBot<ChatImpl<T>> {
163        ChatBot(self.0.0)
164    }
165}
166
167impl<M> ChatBotBuilder<Provided<AgentImpl<M>>>
168where
169    M: CompletionModel + 'static,
170{
171    pub fn max_turns(self, max_turns: usize) -> Self {
172        ChatBotBuilder(Provided(AgentImpl {
173            max_turns,
174            ..self.0.0
175        }))
176    }
177
178    pub fn show_usage(self) -> Self {
179        ChatBotBuilder(Provided(AgentImpl {
180            show_usage: true,
181            ..self.0.0
182        }))
183    }
184
185    pub fn build(self) -> ChatBot<AgentImpl<M>> {
186        ChatBot(self.0.0)
187    }
188}
189
190#[allow(private_bounds)]
191impl<T> ChatBot<T>
192where
193    T: CliChat,
194{
195    pub async fn run(mut self) -> Result<(), PromptError> {
196        let stdin = io::stdin();
197        let mut stdout = io::stdout();
198        let mut history = vec![];
199
200        loop {
201            print!("> ");
202            stdout.flush().map_err(|e| {
203                PromptError::CompletionError(CompletionError::ResponseError(format!(
204                    "failed to flush stdout: {e}"
205                )))
206            })?;
207
208            let mut input = String::new();
209            match stdin.read_line(&mut input) {
210                Ok(_) => {
211                    let input = input.trim();
212                    if input == "exit" {
213                        break;
214                    }
215
216                    tracing::info!("Prompt:\n{input}\n");
217
218                    println!();
219                    println!("========================== Response ============================");
220
221                    self.0.request(input, &mut history).await?;
222
223                    println!("================================================================");
224                    println!();
225
226                    if self.0.show_usage()
227                        && let Some(Usage {
228                            input_tokens,
229                            output_tokens,
230                            ..
231                        }) = self.0.usage()
232                    {
233                        println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens");
234                    }
235                }
236                Err(e) => println!("Error reading request: {e}"),
237            }
238        }
239
240        Ok(())
241    }
242}