rig/integrations/
cli_chatbot.rs

1use crate::{
2    agent::{Agent, MultiTurnStreamItem, Text},
3    completion::{Chat, CompletionError, CompletionModel, PromptError, Usage},
4    message::Message,
5    streaming::{StreamedAssistantContent, StreamingPrompt},
6    wasm_compat::WasmCompatSend,
7};
8use futures::StreamExt;
9use std::io::{self, Write};
10
11pub struct NoImplProvided;
12
13pub struct ChatImpl<T>(T)
14where
15    T: Chat;
16
17pub struct AgentImpl<M>
18where
19    M: CompletionModel + 'static,
20{
21    agent: Agent<M>,
22    multi_turn_depth: usize,
23    show_usage: bool,
24    usage: Usage,
25}
26
27pub struct ChatBotBuilder<T>(T);
28
29pub struct ChatBot<T>(T);
30
31/// Trait to abstract message behavior away from cli_chat/`run` loop
32#[allow(private_interfaces)]
33trait CliChat {
34    async fn request(&mut self, prompt: &str, history: Vec<Message>)
35    -> Result<String, PromptError>;
36
37    fn show_usage(&self) -> bool {
38        false
39    }
40
41    fn usage(&self) -> Option<Usage> {
42        None
43    }
44}
45
46impl<T> CliChat for ChatImpl<T>
47where
48    T: Chat,
49{
50    async fn request(
51        &mut self,
52        prompt: &str,
53        history: Vec<Message>,
54    ) -> Result<String, PromptError> {
55        let res = self.0.chat(prompt, history).await?;
56        println!("{res}");
57
58        Ok(res)
59    }
60}
61
62impl<M> CliChat for AgentImpl<M>
63where
64    M: CompletionModel + WasmCompatSend + 'static,
65{
66    async fn request(
67        &mut self,
68        prompt: &str,
69        history: Vec<Message>,
70    ) -> Result<String, PromptError> {
71        let mut response_stream = self
72            .agent
73            .stream_prompt(prompt)
74            .with_history(history)
75            .multi_turn(self.multi_turn_depth)
76            .await;
77
78        let mut acc = String::new();
79
80        loop {
81            let Some(chunk) = response_stream.next().await else {
82                break Ok(acc);
83            };
84
85            match chunk {
86                Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text {
87                    text,
88                }))) => {
89                    print!("{}", text);
90                    acc.push_str(&text);
91                }
92                Ok(MultiTurnStreamItem::FinalResponse(final_response)) => {
93                    self.usage = final_response.usage();
94                }
95                Err(e) => {
96                    break Err(PromptError::CompletionError(
97                        CompletionError::ResponseError(e.to_string()),
98                    ));
99                }
100                _ => continue,
101            }
102        }
103    }
104
105    fn show_usage(&self) -> bool {
106        self.show_usage
107    }
108
109    fn usage(&self) -> Option<Usage> {
110        Some(self.usage)
111    }
112}
113
114impl Default for ChatBotBuilder<NoImplProvided> {
115    fn default() -> Self {
116        Self(NoImplProvided)
117    }
118}
119
120impl ChatBotBuilder<NoImplProvided> {
121    pub fn new() -> Self {
122        Self::default()
123    }
124
125    pub fn agent<M: CompletionModel + 'static>(
126        self,
127        agent: Agent<M>,
128    ) -> ChatBotBuilder<AgentImpl<M>> {
129        ChatBotBuilder(AgentImpl {
130            agent,
131            multi_turn_depth: 1,
132            show_usage: false,
133            usage: Usage::default(),
134        })
135    }
136
137    pub fn chat<T: Chat>(self, chatbot: T) -> ChatBotBuilder<ChatImpl<T>> {
138        ChatBotBuilder(ChatImpl(chatbot))
139    }
140}
141
142impl<T> ChatBotBuilder<ChatImpl<T>>
143where
144    T: Chat,
145{
146    pub fn build(self) -> ChatBot<ChatImpl<T>> {
147        let ChatBotBuilder(chat_impl) = self;
148        ChatBot(chat_impl)
149    }
150}
151
152impl<M> ChatBotBuilder<AgentImpl<M>>
153where
154    M: CompletionModel + 'static,
155{
156    pub fn multi_turn_depth(self, multi_turn_depth: usize) -> Self {
157        ChatBotBuilder(AgentImpl {
158            multi_turn_depth,
159            ..self.0
160        })
161    }
162
163    pub fn show_usage(self) -> Self {
164        ChatBotBuilder(AgentImpl {
165            show_usage: true,
166            ..self.0
167        })
168    }
169
170    pub fn build(self) -> ChatBot<AgentImpl<M>> {
171        ChatBot(self.0)
172    }
173}
174
175#[allow(private_bounds)]
176impl<T> ChatBot<T>
177where
178    T: CliChat,
179{
180    pub async fn run(mut self) -> Result<(), PromptError> {
181        let stdin = io::stdin();
182        let mut stdout = io::stdout();
183        let mut history = vec![];
184
185        loop {
186            print!("> ");
187            stdout.flush().unwrap();
188
189            let mut input = String::new();
190            match stdin.read_line(&mut input) {
191                Ok(_) => {
192                    let input = input.trim();
193                    if input == "exit" {
194                        break;
195                    }
196
197                    tracing::info!("Prompt:\n{input}\n");
198
199                    println!();
200                    println!("========================== Response ============================");
201
202                    let response = self.0.request(input, history.clone()).await?;
203                    history.push(Message::user(input));
204                    history.push(Message::assistant(response));
205
206                    println!("================================================================");
207                    println!();
208
209                    if self.0.show_usage() {
210                        let Usage {
211                            input_tokens,
212                            output_tokens,
213                            ..
214                        } = self.0.usage().unwrap();
215                        println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens");
216                    }
217                }
218                Err(e) => println!("Error reading request: {e}"),
219            }
220        }
221
222        Ok(())
223    }
224}