Skip to main content

rig/integrations/
cli_chatbot.rs

1use crate::{
2    agent::{Agent, MultiTurnStreamItem, Text},
3    completion::{Chat, CompletionError, CompletionModel, PromptError, Usage},
4    markers::{Missing, Provided},
5    message::Message,
6    streaming::{StreamedAssistantContent, StreamingPrompt},
7    wasm_compat::WasmCompatSend,
8};
9use futures::StreamExt;
10use std::io::{self, Write};
11
12pub struct ChatImpl<T>(T)
13where
14    T: Chat;
15
16pub struct AgentImpl<M>
17where
18    M: CompletionModel + 'static,
19{
20    agent: Agent<M>,
21    max_turns: usize,
22    show_usage: bool,
23    usage: Usage,
24}
25
26pub struct ChatBotBuilder<T = Missing>(T);
27
28pub struct ChatBot<T>(T);
29
30/// Trait to abstract message behavior away from cli_chat/`run` loop
31#[allow(private_interfaces)]
32trait CliChat {
33    async fn request(&mut self, prompt: &str, history: Vec<Message>)
34    -> Result<String, PromptError>;
35
36    fn show_usage(&self) -> bool {
37        false
38    }
39
40    fn usage(&self) -> Option<Usage> {
41        None
42    }
43}
44
45impl<T> CliChat for ChatImpl<T>
46where
47    T: Chat,
48{
49    async fn request(
50        &mut self,
51        prompt: &str,
52        history: Vec<Message>,
53    ) -> Result<String, PromptError> {
54        let res = self.0.chat(prompt, &history).await?;
55        println!("{res}");
56
57        Ok(res)
58    }
59}
60
61impl<M> CliChat for AgentImpl<M>
62where
63    M: CompletionModel + WasmCompatSend + 'static,
64{
65    async fn request(
66        &mut self,
67        prompt: &str,
68        history: Vec<Message>,
69    ) -> Result<String, PromptError> {
70        let mut response_stream = self
71            .agent
72            .stream_prompt(prompt)
73            .with_history(&history)
74            .multi_turn(self.max_turns)
75            .await;
76
77        let mut acc = String::new();
78
79        loop {
80            let Some(chunk) = response_stream.next().await else {
81                break Ok(acc);
82            };
83
84            match chunk {
85                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
86                    Text { text },
87                ))) => {
88                    print!("{}", text);
89                    acc.push_str(&text);
90                }
91                Ok(MultiTurnStreamItem::FinalResponse(final_response)) => {
92                    self.usage = final_response.usage();
93                }
94                Err(e) => {
95                    break Err(PromptError::CompletionError(
96                        CompletionError::ResponseError(e.to_string()),
97                    ));
98                }
99                _ => continue,
100            }
101        }
102    }
103
104    fn show_usage(&self) -> bool {
105        self.show_usage
106    }
107
108    fn usage(&self) -> Option<Usage> {
109        Some(self.usage)
110    }
111}
112
113impl Default for ChatBotBuilder<Missing> {
114    fn default() -> Self {
115        Self(Missing)
116    }
117}
118
119impl ChatBotBuilder<Missing> {
120    pub fn new() -> Self {
121        Self::default()
122    }
123
124    pub fn agent<M: CompletionModel + 'static>(
125        self,
126        agent: Agent<M>,
127    ) -> ChatBotBuilder<Provided<AgentImpl<M>>> {
128        ChatBotBuilder(Provided(AgentImpl {
129            agent,
130            max_turns: 1,
131            show_usage: false,
132            usage: Usage::default(),
133        }))
134    }
135
136    pub fn chat<T: Chat>(self, chatbot: T) -> ChatBotBuilder<Provided<ChatImpl<T>>> {
137        ChatBotBuilder(Provided(ChatImpl(chatbot)))
138    }
139}
140
141impl<T> ChatBotBuilder<Provided<ChatImpl<T>>>
142where
143    T: Chat,
144{
145    pub fn build(self) -> ChatBot<ChatImpl<T>> {
146        ChatBot(self.0.0)
147    }
148}
149
150impl<M> ChatBotBuilder<Provided<AgentImpl<M>>>
151where
152    M: CompletionModel + 'static,
153{
154    pub fn max_turns(self, max_turns: usize) -> Self {
155        ChatBotBuilder(Provided(AgentImpl {
156            max_turns,
157            ..self.0.0
158        }))
159    }
160
161    pub fn show_usage(self) -> Self {
162        ChatBotBuilder(Provided(AgentImpl {
163            show_usage: true,
164            ..self.0.0
165        }))
166    }
167
168    pub fn build(self) -> ChatBot<AgentImpl<M>> {
169        ChatBot(self.0.0)
170    }
171}
172
173#[allow(private_bounds)]
174impl<T> ChatBot<T>
175where
176    T: CliChat,
177{
178    pub async fn run(mut self) -> Result<(), PromptError> {
179        let stdin = io::stdin();
180        let mut stdout = io::stdout();
181        let mut history = vec![];
182
183        loop {
184            print!("> ");
185            stdout.flush().map_err(|e| {
186                PromptError::CompletionError(CompletionError::ResponseError(format!(
187                    "failed to flush stdout: {e}"
188                )))
189            })?;
190
191            let mut input = String::new();
192            match stdin.read_line(&mut input) {
193                Ok(_) => {
194                    let input = input.trim();
195                    if input == "exit" {
196                        break;
197                    }
198
199                    tracing::info!("Prompt:\n{input}\n");
200
201                    println!();
202                    println!("========================== Response ============================");
203
204                    let response = self.0.request(input, history.clone()).await?;
205                    history.push(Message::user(input));
206                    history.push(Message::assistant(response));
207
208                    println!("================================================================");
209                    println!();
210
211                    if self.0.show_usage()
212                        && let Some(Usage {
213                            input_tokens,
214                            output_tokens,
215                            ..
216                        }) = self.0.usage()
217                    {
218                        println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens");
219                    }
220                }
221                Err(e) => println!("Error reading request: {e}"),
222            }
223        }
224
225        Ok(())
226    }
227}