1use std::io::{self, Write};
2
3use futures::StreamExt;
4
5use crate::{
6 agent::{Agent, prompt_request::streaming::MultiTurnStreamItem},
7 completion::{Chat, CompletionError, CompletionModel, Message, PromptError},
8 streaming::StreamingPrompt,
9};
10
11pub struct AgentNotSet;
13
14pub struct ChatbotBuilder<A> {
22 agent: A,
23 multi_turn_depth: usize,
24 show_usage: bool,
25}
26
27impl Default for ChatbotBuilder<AgentNotSet> {
28 fn default() -> Self {
29 ChatbotBuilder {
30 agent: AgentNotSet,
31 multi_turn_depth: 0,
32 show_usage: false,
33 }
34 }
35}
36
37impl ChatbotBuilder<AgentNotSet> {
38 pub fn new() -> Self {
39 Default::default()
40 }
41
42 pub fn agent<M>(self, agent: Agent<M>) -> ChatbotBuilder<Agent<M>>
44 where
45 M: CompletionModel + 'static,
46 {
47 ChatbotBuilder {
48 agent,
49 multi_turn_depth: self.multi_turn_depth,
50 show_usage: self.show_usage,
51 }
52 }
53}
54
55impl<A> ChatbotBuilder<A> {
56 pub fn show_usage(self) -> Self {
59 Self {
60 show_usage: true,
61 ..self
62 }
63 }
64
65 pub fn multi_turn_depth(self, multi_turn_depth: usize) -> Self {
67 Self {
68 multi_turn_depth,
69 ..self
70 }
71 }
72}
73
74impl<M> ChatbotBuilder<Agent<M>>
75where
76 M: CompletionModel + 'static,
77{
78 pub fn build(self) -> Chatbot<M> {
80 Chatbot {
81 agent: self.agent,
82 multi_turn_depth: self.multi_turn_depth,
83 show_usage: self.show_usage,
84 }
85 }
86}
87
88pub struct Chatbot<M>
97where
98 M: CompletionModel + 'static,
99{
100 agent: Agent<M>,
101 multi_turn_depth: usize,
102 show_usage: bool,
103}
104
105impl<M> Chatbot<M>
106where
107 M: CompletionModel + 'static,
108{
109 pub async fn run(self) -> Result<(), PromptError> {
110 let stdin = io::stdin();
111 let mut stdout = io::stdout();
112 let mut chat_log = vec![];
113
114 println!("Welcome to the chatbot! Type 'exit' to quit.");
115
116 loop {
117 print!("> ");
118 stdout.flush().unwrap();
120
121 let mut input = String::new();
122 match stdin.read_line(&mut input) {
123 Ok(_) => {
124 let input = input.trim();
126
127 if input.is_empty() {
128 continue;
129 }
130
131 if input == "exit" {
133 break;
134 }
135
136 tracing::info!("Prompt:\n{}\n", input);
137
138 let mut usage = None;
139 let mut response = String::new();
140
141 println!();
142 println!("========================== Response ============================");
143
144 let mut stream_response = self
145 .agent
146 .stream_prompt(input)
147 .with_history(chat_log.clone())
148 .multi_turn(self.multi_turn_depth)
149 .await;
150
151 while let Some(chunk) = stream_response.next().await {
152 match chunk {
153 Ok(MultiTurnStreamItem::Text(s)) => {
154 let text = s.text.as_str();
155 print!("{text}");
156 response.push_str(text);
157 }
158 Ok(MultiTurnStreamItem::FinalResponse(r)) => {
159 if self.show_usage {
160 usage = Some(r.usage());
161 }
162 }
163
164 Err(e) => {
165 return Err(PromptError::CompletionError(
166 CompletionError::ResponseError(e.to_string()),
167 ));
168 }
169 }
170 }
171
172 println!("================================================================");
173 println!();
174
175 chat_log.push(Message::user(input));
177 chat_log.push(Message::assistant(response.clone()));
178
179 if let Some(usage) = usage {
180 println!(
181 "Input: {} tokens\nOutput: {} tokens",
182 usage.input_tokens, usage.output_tokens
183 )
184 }
185
186 tracing::info!("Response:\n{}\n", response);
187 }
188 Err(error) => println!("Error reading input: {error}"),
189 }
190 }
191
192 Ok(())
193 }
194}
195
196pub async fn cli_chatbot(chatbot: impl Chat) -> Result<(), PromptError> {
201 let stdin = io::stdin();
202 let mut stdout = io::stdout();
203 let mut chat_log = vec![];
204
205 println!("Welcome to the chatbot! Type 'exit' to quit.");
206 loop {
207 print!("> ");
208 stdout.flush().unwrap();
210
211 let mut input = String::new();
212 match stdin.read_line(&mut input) {
213 Ok(_) => {
214 let input = input.trim();
216 if input == "exit" {
218 break;
219 }
220 tracing::info!("Prompt:\n{}\n", input);
221
222 let response = chatbot.chat(input, chat_log.clone()).await?;
223 chat_log.push(Message::user(input));
224 chat_log.push(Message::assistant(response.clone()));
225
226 println!("========================== Response ============================");
227 println!("{response}");
228 println!("================================================================\n\n");
229
230 tracing::info!("Response:\n{}\n", response);
231 }
232 Err(error) => println!("Error reading input: {error}"),
233 }
234 }
235
236 Ok(())
237}