rig/integrations/
cli_chatbot.rs1use crate::{
2 agent::{Agent, MultiTurnStreamItem, Text},
3 completion::{Chat, CompletionError, CompletionModel, PromptError, Usage},
4 markers::{Missing, Provided},
5 message::Message,
6 streaming::{StreamedAssistantContent, StreamingPrompt},
7 wasm_compat::WasmCompatSend,
8};
9use futures::StreamExt;
10use std::io::{self, Write};
11
12pub struct ChatImpl<T>(T)
13where
14 T: Chat;
15
16pub struct AgentImpl<M>
17where
18 M: CompletionModel + 'static,
19{
20 agent: Agent<M>,
21 max_turns: usize,
22 show_usage: bool,
23 usage: Usage,
24}
25
26pub struct ChatBotBuilder<T = Missing>(T);
27
28pub struct ChatBot<T>(T);
29
30#[allow(private_interfaces)]
32trait CliChat {
33 async fn request(&mut self, prompt: &str, history: Vec<Message>)
34 -> Result<String, PromptError>;
35
36 fn show_usage(&self) -> bool {
37 false
38 }
39
40 fn usage(&self) -> Option<Usage> {
41 None
42 }
43}
44
45impl<T> CliChat for ChatImpl<T>
46where
47 T: Chat,
48{
49 async fn request(
50 &mut self,
51 prompt: &str,
52 history: Vec<Message>,
53 ) -> Result<String, PromptError> {
54 let res = self.0.chat(prompt, &history).await?;
55 println!("{res}");
56
57 Ok(res)
58 }
59}
60
61impl<M> CliChat for AgentImpl<M>
62where
63 M: CompletionModel + WasmCompatSend + 'static,
64{
65 async fn request(
66 &mut self,
67 prompt: &str,
68 history: Vec<Message>,
69 ) -> Result<String, PromptError> {
70 let mut response_stream = self
71 .agent
72 .stream_prompt(prompt)
73 .with_history(&history)
74 .multi_turn(self.max_turns)
75 .await;
76
77 let mut acc = String::new();
78
79 loop {
80 let Some(chunk) = response_stream.next().await else {
81 break Ok(acc);
82 };
83
84 match chunk {
85 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
86 Text { text },
87 ))) => {
88 print!("{}", text);
89 acc.push_str(&text);
90 }
91 Ok(MultiTurnStreamItem::FinalResponse(final_response)) => {
92 self.usage = final_response.usage();
93 }
94 Err(e) => {
95 break Err(PromptError::CompletionError(
96 CompletionError::ResponseError(e.to_string()),
97 ));
98 }
99 _ => continue,
100 }
101 }
102 }
103
104 fn show_usage(&self) -> bool {
105 self.show_usage
106 }
107
108 fn usage(&self) -> Option<Usage> {
109 Some(self.usage)
110 }
111}
112
113impl Default for ChatBotBuilder<Missing> {
114 fn default() -> Self {
115 Self(Missing)
116 }
117}
118
119impl ChatBotBuilder<Missing> {
120 pub fn new() -> Self {
121 Self::default()
122 }
123
124 pub fn agent<M: CompletionModel + 'static>(
125 self,
126 agent: Agent<M>,
127 ) -> ChatBotBuilder<Provided<AgentImpl<M>>> {
128 ChatBotBuilder(Provided(AgentImpl {
129 agent,
130 max_turns: 1,
131 show_usage: false,
132 usage: Usage::default(),
133 }))
134 }
135
136 pub fn chat<T: Chat>(self, chatbot: T) -> ChatBotBuilder<Provided<ChatImpl<T>>> {
137 ChatBotBuilder(Provided(ChatImpl(chatbot)))
138 }
139}
140
141impl<T> ChatBotBuilder<Provided<ChatImpl<T>>>
142where
143 T: Chat,
144{
145 pub fn build(self) -> ChatBot<ChatImpl<T>> {
146 ChatBot(self.0.0)
147 }
148}
149
150impl<M> ChatBotBuilder<Provided<AgentImpl<M>>>
151where
152 M: CompletionModel + 'static,
153{
154 pub fn max_turns(self, max_turns: usize) -> Self {
155 ChatBotBuilder(Provided(AgentImpl {
156 max_turns,
157 ..self.0.0
158 }))
159 }
160
161 pub fn show_usage(self) -> Self {
162 ChatBotBuilder(Provided(AgentImpl {
163 show_usage: true,
164 ..self.0.0
165 }))
166 }
167
168 pub fn build(self) -> ChatBot<AgentImpl<M>> {
169 ChatBot(self.0.0)
170 }
171}
172
173#[allow(private_bounds)]
174impl<T> ChatBot<T>
175where
176 T: CliChat,
177{
178 pub async fn run(mut self) -> Result<(), PromptError> {
179 let stdin = io::stdin();
180 let mut stdout = io::stdout();
181 let mut history = vec![];
182
183 loop {
184 print!("> ");
185 stdout.flush().map_err(|e| {
186 PromptError::CompletionError(CompletionError::ResponseError(format!(
187 "failed to flush stdout: {e}"
188 )))
189 })?;
190
191 let mut input = String::new();
192 match stdin.read_line(&mut input) {
193 Ok(_) => {
194 let input = input.trim();
195 if input == "exit" {
196 break;
197 }
198
199 tracing::info!("Prompt:\n{input}\n");
200
201 println!();
202 println!("========================== Response ============================");
203
204 let response = self.0.request(input, history.clone()).await?;
205 history.push(Message::user(input));
206 history.push(Message::assistant(response));
207
208 println!("================================================================");
209 println!();
210
211 if self.0.show_usage()
212 && let Some(Usage {
213 input_tokens,
214 output_tokens,
215 ..
216 }) = self.0.usage()
217 {
218 println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens");
219 }
220 }
221 Err(e) => println!("Error reading request: {e}"),
222 }
223 }
224
225 Ok(())
226 }
227}