Skip to main content

agent_streaming/
main.rs

1//! Example demonstrating the agentic loop with streaming output
2//!
3//! This example shows how to:
4//! 1. Define sync and async tools using the `#[tool]` macro
5//! 2. Create an agent with registered tools
6//! 3. Run the agentic loop with streaming output
7//! 4. Process streaming events in real-time
8//! 5. Execute tools in parallel
9//!
10//! For non-streaming output, see the `agent` example.
11//!
12//! Run with: `cargo run --release --example agent_streaming -p mistralrs`
13
14use anyhow::Result;
15use mistralrs::{
16    tool, AgentBuilder, AgentEvent, AgentStopReason, IsqType, PagedAttentionMetaBuilder,
17    TextModelBuilder,
18};
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use std::io::Write;
22
23/// Weather information returned by the get_weather tool
24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25struct WeatherInfo {
26    /// Temperature in the requested unit
27    temperature: f32,
28    /// Weather conditions description
29    conditions: String,
30    /// Humidity percentage
31    humidity: u8,
32}
33
34/// Search result from web search
35#[derive(Debug, Serialize, Deserialize, JsonSchema)]
36struct SearchResult {
37    /// Title of the search result
38    title: String,
39    /// Snippet from the result
40    snippet: String,
41}
42
43/// Get the current weather for a location (sync tool)
44#[tool(description = "Get the current weather for a location")]
45fn get_weather(
46    #[description = "The city name to get weather for"] city: String,
47    #[description = "Temperature unit: 'celsius' or 'fahrenheit'"]
48    #[default = "celsius"]
49    unit: Option<String>,
50) -> Result<WeatherInfo> {
51    // Mock implementation - in a real application, this would call a weather API
52    // Simulate some work
53    std::thread::sleep(std::time::Duration::from_millis(100));
54
55    let temp = match unit.as_deref() {
56        Some("fahrenheit") => 72.5,
57        _ => 22.5,
58    };
59
60    Ok(WeatherInfo {
61        temperature: temp,
62        conditions: format!("Sunny with clear skies in {}", city),
63        humidity: 45,
64    })
65}
66
67/// Search the web for information (async tool - demonstrates native async support)
68#[tool(description = "Search the web for information on a topic")]
69async fn web_search(
70    #[description = "The search query"] query: String,
71    #[description = "Maximum number of results to return"]
72    #[default = 3u32]
73    max_results: Option<u32>,
74) -> Result<Vec<SearchResult>> {
75    // Mock implementation - in a real application, this would call a search API
76    // Simulate async I/O with tokio::time::sleep
77    tokio::time::sleep(std::time::Duration::from_millis(150)).await;
78
79    let num_results = max_results.unwrap_or(3) as usize;
80
81    let results: Vec<SearchResult> = (0..num_results)
82        .map(|i| SearchResult {
83            title: format!("Result {} for: {}", i + 1, query),
84            snippet: format!(
85                "This is a snippet of information about '{}' from result {}.",
86                query,
87                i + 1
88            ),
89        })
90        .collect();
91
92    Ok(results)
93}
94
95#[tokio::main]
96async fn main() -> Result<()> {
97    // Build the model
98    // Using a model that supports tool calling (e.g., Llama 3.1, Qwen, Mistral)
99    let model = TextModelBuilder::new("../hf_models/qwen3_4b")
100        .with_isq(IsqType::Q4K)
101        .with_logging()
102        .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
103        .build()
104        .await?;
105
106    // Create the agent with registered tools
107    // - get_weather is a sync tool (runs in spawn_blocking)
108    // - web_search is an async tool (runs natively async)
109    // Both can execute in parallel when the model calls multiple tools
110    let agent = AgentBuilder::new(model)
111        .with_system_prompt(
112            "You are a helpful assistant with access to weather and web search tools. \
113             Use them when needed to answer user questions accurately.",
114        )
115        .with_max_iterations(5)
116        .with_parallel_tool_execution(true) // Enable parallel tool execution (default)
117        .register_tool(get_weather_tool_with_callback())
118        .register_tool(web_search_tool_with_callback())
119        .build();
120
121    println!("=== Agent with Streaming Output ===\n");
122    println!(
123        "User: What's the weather like in Boston, and can you find me some good restaurants there?\n"
124    );
125    print!("Assistant: ");
126
127    // Run the agent with streaming output
128    let mut stream = agent
129        .run_stream(
130            "What's the weather like in Boston, and can you find me some good restaurants there?",
131        )
132        .await?;
133
134    let stdout = std::io::stdout();
135    let mut handle = stdout.lock();
136
137    // Process streaming events
138    while let Some(event) = stream.next().await {
139        match event {
140            AgentEvent::TextDelta(text) => {
141                // Print text as it streams - this gives real-time output
142                write!(handle, "{}", text)?;
143                handle.flush()?;
144            }
145            AgentEvent::ToolCallsStart(calls) => {
146                // Model is about to call tools
147                writeln!(handle, "\n\n[Calling {} tool(s)...]", calls.len())?;
148                for call in &calls {
149                    writeln!(
150                        handle,
151                        "  - {}: {}",
152                        call.function.name, call.function.arguments
153                    )?;
154                }
155            }
156            AgentEvent::ToolResult(result) => {
157                // A single tool finished execution
158                let status = if result.result.is_ok() { "OK" } else { "ERROR" };
159                writeln!(
160                    handle,
161                    "  [Tool {} completed: {}]",
162                    result.tool_name, status
163                )?;
164            }
165            AgentEvent::ToolCallsComplete => {
166                // All tools finished, model will continue generating
167                writeln!(handle, "[All tools completed, continuing...]\n")?;
168                write!(handle, "Assistant: ")?;
169                handle.flush()?;
170            }
171            AgentEvent::Complete(response) => {
172                // Agent finished executing
173                writeln!(handle, "\n\n=== Agent Execution Summary ===")?;
174                writeln!(handle, "Completed in {} iteration(s)", response.iterations)?;
175                writeln!(handle, "Stop reason: {:?}", response.stop_reason)?;
176                writeln!(handle, "Steps taken: {}", response.steps.len())?;
177
178                match response.stop_reason {
179                    AgentStopReason::TextResponse => {
180                        writeln!(handle, "Final response delivered successfully.")?;
181                    }
182                    AgentStopReason::MaxIterations => {
183                        writeln!(
184                            handle,
185                            "Agent reached maximum iterations without producing a final response."
186                        )?;
187                    }
188                    AgentStopReason::NoAction => {
189                        writeln!(handle, "Agent produced no response.")?;
190                    }
191                    AgentStopReason::Error(e) => {
192                        writeln!(handle, "Agent encountered an error: {}", e)?;
193                    }
194                }
195            }
196        }
197    }
198
199    Ok(())
200}