rig/
cli_chatbot.rs

1use crate::{
2    agent::{Agent, MultiTurnStreamItem, Text},
3    completion::{Chat, CompletionError, CompletionModel, PromptError, Usage},
4    message::Message,
5    streaming::{StreamedAssistantContent, StreamingPrompt},
6};
7use futures::StreamExt;
8use std::io::{self, Write};
9
10pub struct NoImplProvided;
11
12pub struct ChatImpl<T>(T)
13where
14    T: Chat;
15
16pub struct AgentImpl<M>
17where
18    M: CompletionModel + 'static,
19{
20    agent: Agent<M>,
21    multi_turn_depth: usize,
22    show_usage: bool,
23    usage: Usage,
24}
25
26pub struct ChatBotBuilder<T>(T);
27
28pub struct ChatBot<T>(T);
29
30/// Trait to abstract message behavior away from cli_chat/`run` loop
31#[allow(private_interfaces)]
32trait CliChat {
33    async fn request(&mut self, prompt: &str, history: Vec<Message>)
34    -> Result<String, PromptError>;
35
36    fn show_usage(&self) -> bool {
37        false
38    }
39
40    fn usage(&self) -> Option<Usage> {
41        None
42    }
43}
44
45impl<T> CliChat for ChatImpl<T>
46where
47    T: Chat,
48{
49    async fn request(
50        &mut self,
51        prompt: &str,
52        history: Vec<Message>,
53    ) -> Result<String, PromptError> {
54        let res = self.0.chat(prompt, history).await?;
55        println!("{res}");
56
57        Ok(res)
58    }
59}
60
61impl<M> CliChat for AgentImpl<M>
62where
63    M: CompletionModel + 'static,
64{
65    async fn request(
66        &mut self,
67        prompt: &str,
68        history: Vec<Message>,
69    ) -> Result<String, PromptError> {
70        let mut response_stream = self
71            .agent
72            .stream_prompt(prompt)
73            .with_history(history)
74            .multi_turn(self.multi_turn_depth)
75            .await;
76
77        let mut acc = String::new();
78
79        loop {
80            let Some(chunk) = response_stream.next().await else {
81                break Ok(acc);
82            };
83
84            match chunk {
85                Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text {
86                    text,
87                }))) => {
88                    print!("{}", text);
89                    acc.push_str(&text);
90                }
91                Ok(MultiTurnStreamItem::FinalResponse(final_response)) => {
92                    self.usage = final_response.usage();
93                }
94                Err(e) => {
95                    break Err(PromptError::CompletionError(
96                        CompletionError::ResponseError(e.to_string()),
97                    ));
98                }
99                _ => continue,
100            }
101        }
102    }
103
104    fn show_usage(&self) -> bool {
105        self.show_usage
106    }
107
108    fn usage(&self) -> Option<Usage> {
109        Some(self.usage)
110    }
111}
112
113impl Default for ChatBotBuilder<NoImplProvided> {
114    fn default() -> Self {
115        Self(NoImplProvided)
116    }
117}
118
119impl ChatBotBuilder<NoImplProvided> {
120    pub fn new() -> Self {
121        Self::default()
122    }
123
124    pub fn agent<M: CompletionModel + 'static>(
125        self,
126        agent: Agent<M>,
127    ) -> ChatBotBuilder<AgentImpl<M>> {
128        ChatBotBuilder(AgentImpl {
129            agent,
130            multi_turn_depth: 1,
131            show_usage: false,
132            usage: Usage::default(),
133        })
134    }
135
136    pub fn chat<T: Chat>(self, chatbot: T) -> ChatBotBuilder<ChatImpl<T>> {
137        ChatBotBuilder(ChatImpl(chatbot))
138    }
139}
140
141impl<T> ChatBotBuilder<ChatImpl<T>>
142where
143    T: Chat,
144{
145    pub fn build(self) -> ChatBot<ChatImpl<T>> {
146        let ChatBotBuilder(chat_impl) = self;
147        ChatBot(chat_impl)
148    }
149}
150
151impl<M> ChatBotBuilder<AgentImpl<M>>
152where
153    M: CompletionModel + 'static,
154{
155    pub fn multi_turn_depth(self, multi_turn_depth: usize) -> Self {
156        ChatBotBuilder(AgentImpl {
157            multi_turn_depth,
158            ..self.0
159        })
160    }
161
162    pub fn show_usage(self) -> Self {
163        ChatBotBuilder(AgentImpl {
164            show_usage: true,
165            ..self.0
166        })
167    }
168
169    pub fn build(self) -> ChatBot<AgentImpl<M>> {
170        ChatBot(self.0)
171    }
172}
173
174#[allow(private_bounds)]
175impl<T> ChatBot<T>
176where
177    T: CliChat,
178{
179    pub async fn run(mut self) -> Result<(), PromptError> {
180        let stdin = io::stdin();
181        let mut stdout = io::stdout();
182        let mut history = vec![];
183
184        loop {
185            print!("> ");
186            stdout.flush().unwrap();
187
188            let mut input = String::new();
189            match stdin.read_line(&mut input) {
190                Ok(_) => {
191                    let input = input.trim();
192                    if input == "exit" {
193                        break;
194                    }
195
196                    tracing::info!("Prompt:\n{input}\n");
197
198                    println!();
199                    println!("========================== Response ============================");
200
201                    let response = self.0.request(input, history.clone()).await?;
202                    history.push(Message::user(input));
203                    history.push(Message::assistant(response));
204
205                    println!("================================================================");
206                    println!();
207
208                    if self.0.show_usage() {
209                        let Usage {
210                            input_tokens,
211                            output_tokens,
212                            ..
213                        } = self.0.usage().unwrap();
214                        println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens");
215                    }
216                }
217                Err(e) => println!("Error reading request: {e}"),
218            }
219        }
220
221        Ok(())
222    }
223}