rig/
cli_chatbot.rs

1use std::io::{self, Write};
2
3use futures::StreamExt;
4
5use crate::{
6    agent::{Agent, prompt_request::streaming::MultiTurnStreamItem},
7    completion::{Chat, CompletionError, CompletionModel, Message, PromptError},
8    streaming::StreamingPrompt,
9};
10
11/// Type-state representing an empty `agent` field in `ChatbotBuilder`
12pub struct AgentNotSet;
13
14/// Builder pattern for CLI chatbots.
15///
16/// # Example
17/// ```rust
18/// let chatbot = ChatbotBuilder::new().agent(my_agent).show_usage().build();
19///
20/// chatbot.run().await?;
21pub struct ChatbotBuilder<A> {
22    agent: A,
23    multi_turn_depth: usize,
24    show_usage: bool,
25}
26
27impl Default for ChatbotBuilder<AgentNotSet> {
28    fn default() -> Self {
29        ChatbotBuilder {
30            agent: AgentNotSet,
31            multi_turn_depth: 0,
32            show_usage: false,
33        }
34    }
35}
36
37impl ChatbotBuilder<AgentNotSet> {
38    pub fn new() -> Self {
39        Default::default()
40    }
41
42    /// Sets the agent that will be used to drive the CLI interface
43    pub fn agent<M>(self, agent: Agent<M>) -> ChatbotBuilder<Agent<M>>
44    where
45        M: CompletionModel + 'static,
46    {
47        ChatbotBuilder {
48            agent,
49            multi_turn_depth: self.multi_turn_depth,
50            show_usage: self.show_usage,
51        }
52    }
53}
54
55impl<A> ChatbotBuilder<A> {
56    /// Sets the `show_usage` flag, so that after a request the number of tokens
57    /// in the input and output will be printed
58    pub fn show_usage(self) -> Self {
59        Self {
60            show_usage: true,
61            ..self
62        }
63    }
64
65    /// Sets the maximum depth for multi-turn, i.e. toolcalls
66    pub fn multi_turn_depth(self, multi_turn_depth: usize) -> Self {
67        Self {
68            multi_turn_depth,
69            ..self
70        }
71    }
72}
73
74impl<M> ChatbotBuilder<Agent<M>>
75where
76    M: CompletionModel + 'static,
77{
78    /// Consumes the `ChatbotBuilder`, returning a `Chatbot` which can be run
79    pub fn build(self) -> Chatbot<M> {
80        Chatbot {
81            agent: self.agent,
82            multi_turn_depth: self.multi_turn_depth,
83            show_usage: self.show_usage,
84        }
85    }
86}
87
88/// A CLI chatbot.
89/// Only takes [Agent] types unlike [cli_chatbot] which takes any `impl Chat` type.
90///
91/// # Example
92/// ```rust
93/// let chatbot = ChatbotBuilder::new().agent(my_agent).show_usage().build();
94///
95/// chatbot.run().await?;
96pub struct Chatbot<M>
97where
98    M: CompletionModel + 'static,
99{
100    agent: Agent<M>,
101    multi_turn_depth: usize,
102    show_usage: bool,
103}
104
105impl<M> Chatbot<M>
106where
107    M: CompletionModel + 'static,
108{
109    pub async fn run(self) -> Result<(), PromptError> {
110        let stdin = io::stdin();
111        let mut stdout = io::stdout();
112        let mut chat_log = vec![];
113
114        println!("Welcome to the chatbot! Type 'exit' to quit.");
115
116        loop {
117            print!("> ");
118            // Flush stdout to ensure the prompt appears before input
119            stdout.flush().unwrap();
120
121            let mut input = String::new();
122            match stdin.read_line(&mut input) {
123                Ok(_) => {
124                    // Remove the newline character from the input
125                    let input = input.trim();
126
127                    if input.is_empty() {
128                        continue;
129                    }
130
131                    // Check for a command to exit
132                    if input == "exit" {
133                        break;
134                    }
135
136                    tracing::info!("Prompt:\n{}\n", input);
137
138                    let mut usage = None;
139                    let mut response = String::new();
140
141                    println!();
142                    println!("========================== Response ============================");
143
144                    let mut stream_response = self
145                        .agent
146                        .stream_prompt(input)
147                        .with_history(chat_log.clone())
148                        .multi_turn(self.multi_turn_depth)
149                        .await;
150
151                    while let Some(chunk) = stream_response.next().await {
152                        match chunk {
153                            Ok(MultiTurnStreamItem::Text(s)) => {
154                                let text = s.text.as_str();
155                                print!("{text}");
156                                response.push_str(text);
157                            }
158                            Ok(MultiTurnStreamItem::FinalResponse(r)) => {
159                                if self.show_usage {
160                                    usage = Some(r.usage());
161                                }
162                            }
163
164                            Err(e) => {
165                                return Err(PromptError::CompletionError(
166                                    CompletionError::ResponseError(e.to_string()),
167                                ));
168                            }
169                        }
170                    }
171
172                    println!("================================================================");
173                    println!();
174
175                    // `with_history` does not push to history, we have handle that
176                    chat_log.push(Message::user(input));
177                    chat_log.push(Message::assistant(response.clone()));
178
179                    if let Some(usage) = usage {
180                        println!(
181                            "Input: {} tokens\nOutput: {} tokens",
182                            usage.input_tokens, usage.output_tokens
183                        )
184                    }
185
186                    tracing::info!("Response:\n{}\n", response);
187                }
188                Err(error) => println!("Error reading input: {error}"),
189            }
190        }
191
192        Ok(())
193    }
194}
195
196/// Utility function to create a simple REPL CLI chatbot from a type that implements the
197/// `Chat` trait.
198///
199/// Where the [Chatbot] type takes an agent, this takes any type that implements the [Chat] trait.
200pub async fn cli_chatbot(chatbot: impl Chat) -> Result<(), PromptError> {
201    let stdin = io::stdin();
202    let mut stdout = io::stdout();
203    let mut chat_log = vec![];
204
205    println!("Welcome to the chatbot! Type 'exit' to quit.");
206    loop {
207        print!("> ");
208        // Flush stdout to ensure the prompt appears before input
209        stdout.flush().unwrap();
210
211        let mut input = String::new();
212        match stdin.read_line(&mut input) {
213            Ok(_) => {
214                // Remove the newline character from the input
215                let input = input.trim();
216                // Check for a command to exit
217                if input == "exit" {
218                    break;
219                }
220                tracing::info!("Prompt:\n{}\n", input);
221
222                let response = chatbot.chat(input, chat_log.clone()).await?;
223                chat_log.push(Message::user(input));
224                chat_log.push(Message::assistant(response.clone()));
225
226                println!("========================== Response ============================");
227                println!("{response}");
228                println!("================================================================\n\n");
229
230                tracing::info!("Response:\n{}\n", response);
231            }
232            Err(error) => println!("Error reading input: {error}"),
233        }
234    }
235
236    Ok(())
237}