rig/integrations/
cli_chatbot.rs1use crate::{
2 agent::{Agent, MultiTurnStreamItem, Text},
3 completion::{Chat, CompletionError, CompletionModel, PromptError, Usage},
4 message::Message,
5 streaming::{StreamedAssistantContent, StreamingPrompt},
6 wasm_compat::WasmCompatSend,
7};
8use futures::StreamExt;
9use std::io::{self, Write};
10
11pub struct NoImplProvided;
12
13pub struct ChatImpl<T>(T)
14where
15 T: Chat;
16
17pub struct AgentImpl<M>
18where
19 M: CompletionModel + 'static,
20{
21 agent: Agent<M>,
22 multi_turn_depth: usize,
23 show_usage: bool,
24 usage: Usage,
25}
26
27pub struct ChatBotBuilder<T>(T);
28
29pub struct ChatBot<T>(T);
30
31#[allow(private_interfaces)]
33trait CliChat {
34 async fn request(&mut self, prompt: &str, history: Vec<Message>)
35 -> Result<String, PromptError>;
36
37 fn show_usage(&self) -> bool {
38 false
39 }
40
41 fn usage(&self) -> Option<Usage> {
42 None
43 }
44}
45
46impl<T> CliChat for ChatImpl<T>
47where
48 T: Chat,
49{
50 async fn request(
51 &mut self,
52 prompt: &str,
53 history: Vec<Message>,
54 ) -> Result<String, PromptError> {
55 let res = self.0.chat(prompt, history).await?;
56 println!("{res}");
57
58 Ok(res)
59 }
60}
61
62impl<M> CliChat for AgentImpl<M>
63where
64 M: CompletionModel + WasmCompatSend + 'static,
65{
66 async fn request(
67 &mut self,
68 prompt: &str,
69 history: Vec<Message>,
70 ) -> Result<String, PromptError> {
71 let mut response_stream = self
72 .agent
73 .stream_prompt(prompt)
74 .with_history(history)
75 .multi_turn(self.multi_turn_depth)
76 .await;
77
78 let mut acc = String::new();
79
80 loop {
81 let Some(chunk) = response_stream.next().await else {
82 break Ok(acc);
83 };
84
85 match chunk {
86 Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text {
87 text,
88 }))) => {
89 print!("{}", text);
90 acc.push_str(&text);
91 }
92 Ok(MultiTurnStreamItem::FinalResponse(final_response)) => {
93 self.usage = final_response.usage();
94 }
95 Err(e) => {
96 break Err(PromptError::CompletionError(
97 CompletionError::ResponseError(e.to_string()),
98 ));
99 }
100 _ => continue,
101 }
102 }
103 }
104
105 fn show_usage(&self) -> bool {
106 self.show_usage
107 }
108
109 fn usage(&self) -> Option<Usage> {
110 Some(self.usage)
111 }
112}
113
114impl Default for ChatBotBuilder<NoImplProvided> {
115 fn default() -> Self {
116 Self(NoImplProvided)
117 }
118}
119
120impl ChatBotBuilder<NoImplProvided> {
121 pub fn new() -> Self {
122 Self::default()
123 }
124
125 pub fn agent<M: CompletionModel + 'static>(
126 self,
127 agent: Agent<M>,
128 ) -> ChatBotBuilder<AgentImpl<M>> {
129 ChatBotBuilder(AgentImpl {
130 agent,
131 multi_turn_depth: 1,
132 show_usage: false,
133 usage: Usage::default(),
134 })
135 }
136
137 pub fn chat<T: Chat>(self, chatbot: T) -> ChatBotBuilder<ChatImpl<T>> {
138 ChatBotBuilder(ChatImpl(chatbot))
139 }
140}
141
142impl<T> ChatBotBuilder<ChatImpl<T>>
143where
144 T: Chat,
145{
146 pub fn build(self) -> ChatBot<ChatImpl<T>> {
147 let ChatBotBuilder(chat_impl) = self;
148 ChatBot(chat_impl)
149 }
150}
151
152impl<M> ChatBotBuilder<AgentImpl<M>>
153where
154 M: CompletionModel + 'static,
155{
156 pub fn multi_turn_depth(self, multi_turn_depth: usize) -> Self {
157 ChatBotBuilder(AgentImpl {
158 multi_turn_depth,
159 ..self.0
160 })
161 }
162
163 pub fn show_usage(self) -> Self {
164 ChatBotBuilder(AgentImpl {
165 show_usage: true,
166 ..self.0
167 })
168 }
169
170 pub fn build(self) -> ChatBot<AgentImpl<M>> {
171 ChatBot(self.0)
172 }
173}
174
175#[allow(private_bounds)]
176impl<T> ChatBot<T>
177where
178 T: CliChat,
179{
180 pub async fn run(mut self) -> Result<(), PromptError> {
181 let stdin = io::stdin();
182 let mut stdout = io::stdout();
183 let mut history = vec![];
184
185 loop {
186 print!("> ");
187 stdout.flush().unwrap();
188
189 let mut input = String::new();
190 match stdin.read_line(&mut input) {
191 Ok(_) => {
192 let input = input.trim();
193 if input == "exit" {
194 break;
195 }
196
197 tracing::info!("Prompt:\n{input}\n");
198
199 println!();
200 println!("========================== Response ============================");
201
202 let response = self.0.request(input, history.clone()).await?;
203 history.push(Message::user(input));
204 history.push(Message::assistant(response));
205
206 println!("================================================================");
207 println!();
208
209 if self.0.show_usage() {
210 let Usage {
211 input_tokens,
212 output_tokens,
213 ..
214 } = self.0.usage().unwrap();
215 println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens");
216 }
217 }
218 Err(e) => println!("Error reading request: {e}"),
219 }
220 }
221
222 Ok(())
223 }
224}