rig_core/integrations/
cli_chatbot.rs1use crate::{
2 agent::{Agent, MultiTurnStreamItem, Text},
3 completion::{Chat, CompletionError, CompletionModel, PromptError, Usage},
4 markers::{Missing, Provided},
5 message::Message,
6 streaming::{StreamedAssistantContent, StreamingPrompt},
7 wasm_compat::WasmCompatSend,
8};
9use futures::StreamExt;
10use std::io::{self, Write};
11
12pub struct ChatImpl<T>(T)
13where
14 T: Chat;
15
16pub struct AgentImpl<M>
17where
18 M: CompletionModel + 'static,
19{
20 agent: Agent<M>,
21 max_turns: usize,
22 show_usage: bool,
23 usage: Usage,
24}
25
26pub struct ChatBotBuilder<T = Missing>(T);
27
28pub struct ChatBot<T>(T);
29
30#[allow(private_interfaces)]
32trait CliChat {
33 async fn request(
34 &mut self,
35 prompt: &str,
36 history: &mut Vec<Message>,
37 ) -> Result<String, PromptError>;
38
39 fn show_usage(&self) -> bool {
40 false
41 }
42
43 fn usage(&self) -> Option<Usage> {
44 None
45 }
46}
47
48impl<T> CliChat for ChatImpl<T>
49where
50 T: Chat,
51{
52 async fn request(
53 &mut self,
54 prompt: &str,
55 history: &mut Vec<Message>,
56 ) -> Result<String, PromptError> {
57 let res = self.0.chat(prompt, history).await?;
58 println!("{res}");
59
60 Ok(res)
61 }
62}
63
64impl<M> CliChat for AgentImpl<M>
65where
66 M: CompletionModel + WasmCompatSend + 'static,
67{
68 async fn request(
69 &mut self,
70 prompt: &str,
71 history: &mut Vec<Message>,
72 ) -> Result<String, PromptError> {
73 let mut response_stream = self
74 .agent
75 .stream_prompt(prompt)
76 .with_history(history.clone())
77 .multi_turn(self.max_turns)
78 .await;
79
80 let mut acc = String::new();
81 let mut messages = None;
82
83 let result = loop {
84 let Some(chunk) = response_stream.next().await else {
85 println!();
86 break Ok(acc);
87 };
88
89 match chunk {
90 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
91 Text { text },
92 ))) => {
93 print!("{}", text);
94 acc.push_str(&text);
95 }
96 Ok(MultiTurnStreamItem::FinalResponse(final_response)) => {
97 self.usage = final_response.usage();
98 messages = final_response.history().map(|history| history.to_vec());
99 }
100 Err(e) => {
101 break Err(PromptError::CompletionError(
102 CompletionError::ResponseError(e.to_string()),
103 ));
104 }
105 _ => continue,
106 }
107 };
108
109 if let Ok(response) = &result {
110 if let Some(messages) = messages {
111 history.extend(messages);
112 } else {
113 history.push(Message::user(prompt));
114 history.push(Message::assistant(response.as_str()));
115 }
116 }
117
118 result
119 }
120
121 fn show_usage(&self) -> bool {
122 self.show_usage
123 }
124
125 fn usage(&self) -> Option<Usage> {
126 Some(self.usage)
127 }
128}
129
130impl Default for ChatBotBuilder<Missing> {
131 fn default() -> Self {
132 Self(Missing)
133 }
134}
135
136impl ChatBotBuilder<Missing> {
137 pub fn new() -> Self {
138 Self::default()
139 }
140
141 pub fn agent<M: CompletionModel + 'static>(
142 self,
143 agent: Agent<M>,
144 ) -> ChatBotBuilder<Provided<AgentImpl<M>>> {
145 ChatBotBuilder(Provided(AgentImpl {
146 agent,
147 max_turns: 1,
148 show_usage: false,
149 usage: Usage::default(),
150 }))
151 }
152
153 pub fn chat<T: Chat>(self, chatbot: T) -> ChatBotBuilder<Provided<ChatImpl<T>>> {
154 ChatBotBuilder(Provided(ChatImpl(chatbot)))
155 }
156}
157
158impl<T> ChatBotBuilder<Provided<ChatImpl<T>>>
159where
160 T: Chat,
161{
162 pub fn build(self) -> ChatBot<ChatImpl<T>> {
163 ChatBot(self.0.0)
164 }
165}
166
167impl<M> ChatBotBuilder<Provided<AgentImpl<M>>>
168where
169 M: CompletionModel + 'static,
170{
171 pub fn max_turns(self, max_turns: usize) -> Self {
172 ChatBotBuilder(Provided(AgentImpl {
173 max_turns,
174 ..self.0.0
175 }))
176 }
177
178 pub fn show_usage(self) -> Self {
179 ChatBotBuilder(Provided(AgentImpl {
180 show_usage: true,
181 ..self.0.0
182 }))
183 }
184
185 pub fn build(self) -> ChatBot<AgentImpl<M>> {
186 ChatBot(self.0.0)
187 }
188}
189
190#[allow(private_bounds)]
191impl<T> ChatBot<T>
192where
193 T: CliChat,
194{
195 pub async fn run(mut self) -> Result<(), PromptError> {
196 let stdin = io::stdin();
197 let mut stdout = io::stdout();
198 let mut history = vec![];
199
200 loop {
201 print!("> ");
202 stdout.flush().map_err(|e| {
203 PromptError::CompletionError(CompletionError::ResponseError(format!(
204 "failed to flush stdout: {e}"
205 )))
206 })?;
207
208 let mut input = String::new();
209 match stdin.read_line(&mut input) {
210 Ok(_) => {
211 let input = input.trim();
212 if input == "exit" {
213 break;
214 }
215
216 tracing::info!("Prompt:\n{input}\n");
217
218 println!();
219 println!("========================== Response ============================");
220
221 self.0.request(input, &mut history).await?;
222
223 println!("================================================================");
224 println!();
225
226 if self.0.show_usage()
227 && let Some(Usage {
228 input_tokens,
229 output_tokens,
230 ..
231 }) = self.0.usage()
232 {
233 println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens");
234 }
235 }
236 Err(e) => println!("Error reading request: {e}"),
237 }
238 }
239
240 Ok(())
241 }
242}