1use crate::{
2 agent::{Agent, MultiTurnStreamItem, Text},
3 completion::{Chat, CompletionError, CompletionModel, PromptError, Usage},
4 message::Message,
5 streaming::{StreamedAssistantContent, StreamingPrompt},
6};
7use futures::StreamExt;
8use std::io::{self, Write};
9
10pub struct NoImplProvided;
11
12pub struct ChatImpl<T>(T)
13where
14 T: Chat;
15
16pub struct AgentImpl<M>
17where
18 M: CompletionModel + 'static,
19{
20 agent: Agent<M>,
21 multi_turn_depth: usize,
22 show_usage: bool,
23 usage: Usage,
24}
25
26pub struct ChatBotBuilder<T>(T);
27
28pub struct ChatBot<T>(T);
29
30#[allow(private_interfaces)]
32trait CliChat {
33 async fn request(&mut self, prompt: &str, history: Vec<Message>)
34 -> Result<String, PromptError>;
35
36 fn show_usage(&self) -> bool {
37 false
38 }
39
40 fn usage(&self) -> Option<Usage> {
41 None
42 }
43}
44
45impl<T> CliChat for ChatImpl<T>
46where
47 T: Chat,
48{
49 async fn request(
50 &mut self,
51 prompt: &str,
52 history: Vec<Message>,
53 ) -> Result<String, PromptError> {
54 let res = self.0.chat(prompt, history).await?;
55 println!("{res}");
56
57 Ok(res)
58 }
59}
60
61impl<M> CliChat for AgentImpl<M>
62where
63 M: CompletionModel + 'static,
64{
65 async fn request(
66 &mut self,
67 prompt: &str,
68 history: Vec<Message>,
69 ) -> Result<String, PromptError> {
70 let mut response_stream = self
71 .agent
72 .stream_prompt(prompt)
73 .with_history(history)
74 .multi_turn(self.multi_turn_depth)
75 .await;
76
77 let mut acc = String::new();
78
79 loop {
80 let Some(chunk) = response_stream.next().await else {
81 break Ok(acc);
82 };
83
84 match chunk {
85 Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text {
86 text,
87 }))) => {
88 print!("{}", text);
89 acc.push_str(&text);
90 }
91 Ok(MultiTurnStreamItem::FinalResponse(final_response)) => {
92 self.usage = final_response.usage();
93 }
94 Err(e) => {
95 break Err(PromptError::CompletionError(
96 CompletionError::ResponseError(e.to_string()),
97 ));
98 }
99 _ => continue,
100 }
101 }
102 }
103
104 fn show_usage(&self) -> bool {
105 self.show_usage
106 }
107
108 fn usage(&self) -> Option<Usage> {
109 Some(self.usage)
110 }
111}
112
113impl Default for ChatBotBuilder<NoImplProvided> {
114 fn default() -> Self {
115 Self(NoImplProvided)
116 }
117}
118
119impl ChatBotBuilder<NoImplProvided> {
120 pub fn new() -> Self {
121 Self::default()
122 }
123
124 pub fn agent<M: CompletionModel + 'static>(
125 self,
126 agent: Agent<M>,
127 ) -> ChatBotBuilder<AgentImpl<M>> {
128 ChatBotBuilder(AgentImpl {
129 agent,
130 multi_turn_depth: 1,
131 show_usage: false,
132 usage: Usage::default(),
133 })
134 }
135
136 pub fn chat<T: Chat>(self, chatbot: T) -> ChatBotBuilder<ChatImpl<T>> {
137 ChatBotBuilder(ChatImpl(chatbot))
138 }
139}
140
141impl<T> ChatBotBuilder<ChatImpl<T>>
142where
143 T: Chat,
144{
145 pub fn build(self) -> ChatBot<ChatImpl<T>> {
146 let ChatBotBuilder(chat_impl) = self;
147 ChatBot(chat_impl)
148 }
149}
150
151impl<M> ChatBotBuilder<AgentImpl<M>>
152where
153 M: CompletionModel + 'static,
154{
155 pub fn multi_turn_depth(self, multi_turn_depth: usize) -> Self {
156 ChatBotBuilder(AgentImpl {
157 multi_turn_depth,
158 ..self.0
159 })
160 }
161
162 pub fn show_usage(self) -> Self {
163 ChatBotBuilder(AgentImpl {
164 show_usage: true,
165 ..self.0
166 })
167 }
168
169 pub fn build(self) -> ChatBot<AgentImpl<M>> {
170 ChatBot(self.0)
171 }
172}
173
174#[allow(private_bounds)]
175impl<T> ChatBot<T>
176where
177 T: CliChat,
178{
179 pub async fn run(mut self) -> Result<(), PromptError> {
180 let stdin = io::stdin();
181 let mut stdout = io::stdout();
182 let mut history = vec![];
183
184 loop {
185 print!("> ");
186 stdout.flush().unwrap();
187
188 let mut input = String::new();
189 match stdin.read_line(&mut input) {
190 Ok(_) => {
191 let input = input.trim();
192 if input == "exit" {
193 break;
194 }
195
196 tracing::info!("Prompt:\n{input}\n");
197
198 println!();
199 println!("========================== Response ============================");
200
201 let response = self.0.request(input, history.clone()).await?;
202 history.push(Message::user(input));
203 history.push(Message::assistant(response));
204
205 println!("================================================================");
206 println!();
207
208 if self.0.show_usage() {
209 let Usage {
210 input_tokens,
211 output_tokens,
212 ..
213 } = self.0.usage().unwrap();
214 println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens");
215 }
216 }
217 Err(e) => println!("Error reading request: {e}"),
218 }
219 }
220
221 Ok(())
222 }
223}