rig_onchain_kit/
reasoning_loop.rs

1use anyhow::Result;
2use futures::StreamExt;
3use rig::agent::Agent;
4use rig::completion::AssistantContent;
5use rig::completion::Message;
6use rig::message::{ToolResultContent, UserContent};
7use rig::providers::anthropic::completion::CompletionModel;
8use rig::streaming::{StreamingChat, StreamingChoice};
9use rig::OneOrMany;
10use std::io::Write;
11use std::sync::Arc;
12use tokio::sync::mpsc::Sender;
13
14pub enum LoopResponse {
15    Message(String),
16    ToolCall { name: String, result: String },
17}
18
19pub struct ReasoningLoop {
20    agent: Arc<Agent<CompletionModel>>,
21    stdout: bool,
22}
23
24impl ReasoningLoop {
25    pub fn new(agent: Arc<Agent<CompletionModel>>) -> Self {
26        Self {
27            agent,
28            stdout: true,
29        }
30    }
31
32    pub async fn stream(
33        &self,
34        messages: Vec<Message>,
35        tx: Option<Sender<LoopResponse>>,
36    ) -> Result<Vec<Message>> {
37        if tx.is_none() && !self.stdout {
38            panic!("enable stdout or provide tx channel");
39        }
40
41        let mut current_messages = messages;
42        let agent = self.agent.clone();
43        let stdout = self.stdout;
44
45        'outer: loop {
46            let mut current_response = String::new();
47
48            let mut stream =
49                agent.stream_chat(" ", current_messages.clone()).await?;
50
51            while let Some(chunk) = stream.next().await {
52                match chunk? {
53                    StreamingChoice::Message(text) => {
54                        if stdout {
55                            print!("{}", text);
56                            std::io::stdout().flush()?;
57                        } else if let Some(tx) = &tx {
58                            tx.send(LoopResponse::Message(text.clone()))
59                                .await?;
60                        }
61                        current_response.push_str(&text);
62                    }
63                    StreamingChoice::ToolCall(name, tool_id, params) => {
64                        // Add the assistant's response up to this point with the tool call
65                        if !current_response.is_empty() {
66                            current_messages.push(Message::Assistant {
67                                content: OneOrMany::one(
68                                    AssistantContent::text(
69                                        current_response.clone(),
70                                    ),
71                                ),
72                            });
73                            current_response.clear();
74                        }
75
76                        // Add the tool use message from the assistant
77                        current_messages.push(Message::Assistant {
78                            content: OneOrMany::one(
79                                AssistantContent::tool_call(
80                                    tool_id.clone(),
81                                    name.clone(),
82                                    params.clone(),
83                                ),
84                            ),
85                        });
86
87                        // Call the tool and get result
88                        let result = self
89                            .agent
90                            .tools
91                            .call(&name, params.to_string())
92                            .await;
93
94                        if stdout {
95                            println!("Tool result: {:?}", result);
96                        }
97
98                        // Add the tool result as a user message
99                        current_messages.push(Message::User {
100                            content: OneOrMany::one(
101                                UserContent::tool_result(
102                                    tool_id,
103                                    OneOrMany::one(ToolResultContent::text(
104                                        match &result {
105                                            Ok(content) => {
106                                                content.to_string()
107                                            }
108                                            Err(err) => err.to_string(),
109                                        },
110                                    )),
111                                ),
112                            ),
113                        });
114
115                        if let Some(tx) = &tx {
116                            tx.send(LoopResponse::ToolCall {
117                                name,
118                                result: match &result {
119                                    Ok(content) => content.to_string(),
120                                    Err(err) => err.to_string(),
121                                },
122                            })
123                            .await?;
124                        }
125
126                        continue 'outer;
127                    }
128                }
129            }
130
131            // Add any remaining response to messages
132            if !current_response.is_empty() {
133                current_messages.push(Message::Assistant {
134                    content: OneOrMany::one(AssistantContent::text(
135                        current_response,
136                    )),
137                });
138            }
139
140            // If we get here, there were no tool calls in this iteration
141            break;
142        }
143
144        Ok(current_messages)
145    }
146
147    pub fn with_stdout(mut self, enabled: bool) -> Self {
148        self.stdout = enabled;
149        self
150    }
151}