1use anyhow::Result;
15use mistralrs::{
16 tool, AgentBuilder, AgentEvent, AgentStopReason, IsqType, PagedAttentionMetaBuilder,
17 TextModelBuilder,
18};
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use std::io::Write;
22
23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25struct WeatherInfo {
26 temperature: f32,
28 conditions: String,
30 humidity: u8,
32}
33
34#[derive(Debug, Serialize, Deserialize, JsonSchema)]
36struct SearchResult {
37 title: String,
39 snippet: String,
41}
42
43#[tool(description = "Get the current weather for a location")]
45fn get_weather(
46 #[description = "The city name to get weather for"] city: String,
47 #[description = "Temperature unit: 'celsius' or 'fahrenheit'"]
48 #[default = "celsius"]
49 unit: Option<String>,
50) -> Result<WeatherInfo> {
51 std::thread::sleep(std::time::Duration::from_millis(100));
54
55 let temp = match unit.as_deref() {
56 Some("fahrenheit") => 72.5,
57 _ => 22.5,
58 };
59
60 Ok(WeatherInfo {
61 temperature: temp,
62 conditions: format!("Sunny with clear skies in {}", city),
63 humidity: 45,
64 })
65}
66
67#[tool(description = "Search the web for information on a topic")]
69async fn web_search(
70 #[description = "The search query"] query: String,
71 #[description = "Maximum number of results to return"]
72 #[default = 3u32]
73 max_results: Option<u32>,
74) -> Result<Vec<SearchResult>> {
75 tokio::time::sleep(std::time::Duration::from_millis(150)).await;
78
79 let num_results = max_results.unwrap_or(3) as usize;
80
81 let results: Vec<SearchResult> = (0..num_results)
82 .map(|i| SearchResult {
83 title: format!("Result {} for: {}", i + 1, query),
84 snippet: format!(
85 "This is a snippet of information about '{}' from result {}.",
86 query,
87 i + 1
88 ),
89 })
90 .collect();
91
92 Ok(results)
93}
94
95#[tokio::main]
96async fn main() -> Result<()> {
97 let model = TextModelBuilder::new("../hf_models/qwen3_4b")
100 .with_isq(IsqType::Q4K)
101 .with_logging()
102 .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
103 .build()
104 .await?;
105
106 let agent = AgentBuilder::new(model)
111 .with_system_prompt(
112 "You are a helpful assistant with access to weather and web search tools. \
113 Use them when needed to answer user questions accurately.",
114 )
115 .with_max_iterations(5)
116 .with_parallel_tool_execution(true) .register_tool(get_weather_tool_with_callback())
118 .register_tool(web_search_tool_with_callback())
119 .build();
120
121 println!("=== Agent with Streaming Output ===\n");
122 println!(
123 "User: What's the weather like in Boston, and can you find me some good restaurants there?\n"
124 );
125 print!("Assistant: ");
126
127 let mut stream = agent
129 .run_stream(
130 "What's the weather like in Boston, and can you find me some good restaurants there?",
131 )
132 .await?;
133
134 let stdout = std::io::stdout();
135 let mut handle = stdout.lock();
136
137 while let Some(event) = stream.next().await {
139 match event {
140 AgentEvent::TextDelta(text) => {
141 write!(handle, "{}", text)?;
143 handle.flush()?;
144 }
145 AgentEvent::ToolCallsStart(calls) => {
146 writeln!(handle, "\n\n[Calling {} tool(s)...]", calls.len())?;
148 for call in &calls {
149 writeln!(
150 handle,
151 " - {}: {}",
152 call.function.name, call.function.arguments
153 )?;
154 }
155 }
156 AgentEvent::ToolResult(result) => {
157 let status = if result.result.is_ok() { "OK" } else { "ERROR" };
159 writeln!(
160 handle,
161 " [Tool {} completed: {}]",
162 result.tool_name, status
163 )?;
164 }
165 AgentEvent::ToolCallsComplete => {
166 writeln!(handle, "[All tools completed, continuing...]\n")?;
168 write!(handle, "Assistant: ")?;
169 handle.flush()?;
170 }
171 AgentEvent::Complete(response) => {
172 writeln!(handle, "\n\n=== Agent Execution Summary ===")?;
174 writeln!(handle, "Completed in {} iteration(s)", response.iterations)?;
175 writeln!(handle, "Stop reason: {:?}", response.stop_reason)?;
176 writeln!(handle, "Steps taken: {}", response.steps.len())?;
177
178 match response.stop_reason {
179 AgentStopReason::TextResponse => {
180 writeln!(handle, "Final response delivered successfully.")?;
181 }
182 AgentStopReason::MaxIterations => {
183 writeln!(
184 handle,
185 "Agent reached maximum iterations without producing a final response."
186 )?;
187 }
188 AgentStopReason::NoAction => {
189 writeln!(handle, "Agent produced no response.")?;
190 }
191 AgentStopReason::Error(e) => {
192 writeln!(handle, "Agent encountered an error: {}", e)?;
193 }
194 }
195 }
196 }
197 }
198
199 Ok(())
200}