rig/
cli_chatbot.rs

1use std::io::{self, Write};
2
3use futures::StreamExt;
4
5use crate::{
6    agent::{Agent, Text, prompt_request::streaming::MultiTurnStreamItem},
7    completion::{Chat, CompletionError, CompletionModel, Message, PromptError},
8    streaming::{StreamedAssistantContent, StreamingPrompt},
9};
10
11/// Type-state representing an empty `agent` field in `ChatbotBuilder`
12pub struct AgentNotSet;
13
14/// Builder pattern for CLI chatbots.
15///
16/// # Example
17/// ```rust
18/// let chatbot = ChatbotBuilder::new().agent(my_agent).show_usage().build();
19///
20/// chatbot.run().await?;
21pub struct ChatbotBuilder<A> {
22    agent: A,
23    multi_turn_depth: usize,
24    show_usage: bool,
25}
26
27impl Default for ChatbotBuilder<AgentNotSet> {
28    fn default() -> Self {
29        ChatbotBuilder {
30            agent: AgentNotSet,
31            multi_turn_depth: 0,
32            show_usage: false,
33        }
34    }
35}
36
37impl ChatbotBuilder<AgentNotSet> {
38    pub fn new() -> Self {
39        Default::default()
40    }
41
42    /// Sets the agent that will be used to drive the CLI interface
43    pub fn agent<M>(self, agent: Agent<M>) -> ChatbotBuilder<Agent<M>>
44    where
45        M: CompletionModel + 'static,
46    {
47        ChatbotBuilder {
48            agent,
49            multi_turn_depth: self.multi_turn_depth,
50            show_usage: self.show_usage,
51        }
52    }
53}
54
55impl<A> ChatbotBuilder<A> {
56    /// Sets the `show_usage` flag, so that after a request the number of tokens
57    /// in the input and output will be printed
58    pub fn show_usage(self) -> Self {
59        Self {
60            show_usage: true,
61            ..self
62        }
63    }
64
65    /// Sets the maximum depth for multi-turn, i.e. toolcalls
66    pub fn multi_turn_depth(self, multi_turn_depth: usize) -> Self {
67        Self {
68            multi_turn_depth,
69            ..self
70        }
71    }
72}
73
74impl<M> ChatbotBuilder<Agent<M>>
75where
76    M: CompletionModel + 'static,
77{
78    /// Consumes the `ChatbotBuilder`, returning a `Chatbot` which can be run
79    pub fn build(self) -> Chatbot<M> {
80        Chatbot {
81            agent: self.agent,
82            multi_turn_depth: self.multi_turn_depth,
83            show_usage: self.show_usage,
84        }
85    }
86}
87
88/// A CLI chatbot.
89/// Only takes [Agent] types unlike [cli_chatbot] which takes any `impl Chat` type.
90///
91/// # Example
92/// ```rust
93/// let chatbot = ChatbotBuilder::new().agent(my_agent).show_usage().build();
94///
95/// chatbot.run().await?;
96pub struct Chatbot<M>
97where
98    M: CompletionModel + 'static,
99{
100    agent: Agent<M>,
101    multi_turn_depth: usize,
102    show_usage: bool,
103}
104
105impl<M> Chatbot<M>
106where
107    M: CompletionModel + 'static,
108{
109    pub async fn run(self) -> Result<(), PromptError> {
110        let stdin = io::stdin();
111        let mut stdout = io::stdout();
112        let mut chat_log = vec![];
113
114        println!("Welcome to the chatbot! Type 'exit' to quit.");
115
116        loop {
117            print!("> ");
118            // Flush stdout to ensure the prompt appears before input
119            stdout.flush().unwrap();
120
121            let mut input = String::new();
122            match stdin.read_line(&mut input) {
123                Ok(_) => {
124                    // Remove the newline character from the input
125                    let input = input.trim();
126
127                    if input.is_empty() {
128                        continue;
129                    }
130
131                    // Check for a command to exit
132                    if input == "exit" {
133                        break;
134                    }
135
136                    tracing::info!("Prompt:\n{}\n", input);
137
138                    let mut usage = None;
139                    let mut response = String::new();
140
141                    println!();
142                    println!("========================== Response ============================");
143
144                    let mut stream_response = self
145                        .agent
146                        .stream_prompt(input)
147                        .with_history(chat_log.clone())
148                        .multi_turn(self.multi_turn_depth)
149                        .await;
150
151                    while let Some(chunk) = stream_response.next().await {
152                        match chunk {
153                            Ok(MultiTurnStreamItem::StreamItem(
154                                StreamedAssistantContent::Text(Text { text }),
155                            )) => {
156                                print!("{text}");
157                                response.push_str(&text);
158                            }
159                            Ok(MultiTurnStreamItem::FinalResponse(r)) => {
160                                if self.show_usage {
161                                    usage = Some(r.usage());
162                                }
163                            }
164                            Err(e) => {
165                                return Err(PromptError::CompletionError(
166                                    CompletionError::ResponseError(e.to_string()),
167                                ));
168                            }
169                            _ => {}
170                        }
171                    }
172
173                    println!("================================================================");
174                    println!();
175
176                    // `with_history` does not push to history, we have handle that
177                    chat_log.push(Message::user(input));
178                    chat_log.push(Message::assistant(response.clone()));
179
180                    if let Some(usage) = usage {
181                        println!(
182                            "Input: {} tokens\nOutput: {} tokens",
183                            usage.input_tokens, usage.output_tokens
184                        )
185                    }
186
187                    tracing::info!("Response:\n{}\n", response);
188                }
189                Err(error) => println!("Error reading input: {error}"),
190            }
191        }
192
193        Ok(())
194    }
195}
196
197/// Utility function to create a simple REPL CLI chatbot from a type that implements the
198/// `Chat` trait.
199///
200/// Where the [Chatbot] type takes an agent, this takes any type that implements the [Chat] trait.
201pub async fn cli_chatbot(chatbot: impl Chat) -> Result<(), PromptError> {
202    let stdin = io::stdin();
203    let mut stdout = io::stdout();
204    let mut chat_log = vec![];
205
206    println!("Welcome to the chatbot! Type 'exit' to quit.");
207    loop {
208        print!("> ");
209        // Flush stdout to ensure the prompt appears before input
210        stdout.flush().unwrap();
211
212        let mut input = String::new();
213        match stdin.read_line(&mut input) {
214            Ok(_) => {
215                // Remove the newline character from the input
216                let input = input.trim();
217                // Check for a command to exit
218                if input == "exit" {
219                    break;
220                }
221                tracing::info!("Prompt:\n{}\n", input);
222
223                let response = chatbot.chat(input, chat_log.clone()).await?;
224                chat_log.push(Message::user(input));
225                chat_log.push(Message::assistant(response.clone()));
226
227                println!("========================== Response ============================");
228                println!("{response}");
229                println!("================================================================\n\n");
230
231                tracing::info!("Response:\n{}\n", response);
232            }
233            Err(error) => println!("Error reading input: {error}"),
234        }
235    }
236
237    Ok(())
238}