Skip to main content

agent/
main.rs

1//! Example demonstrating the agentic loop with tool calling
2//!
3//! This example shows how to:
4//! 1. Define sync and async tools using the `#[tool]` macro
5//! 2. Create an agent with registered tools
6//! 3. Run the agentic loop (non-streaming)
7//! 4. Execute tools in parallel
8//!
9//! For streaming output, see the `agent_streaming` example.
10//!
11//! Run with: `cargo run --release --example agent -p mistralrs`
12
13use anyhow::Result;
14use mistralrs::{
15    tool, AgentBuilder, AgentStopReason, IsqType, PagedAttentionMetaBuilder, TextModelBuilder,
16};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19
20/// Weather information returned by the get_weather tool
21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
22struct WeatherInfo {
23    /// Temperature in the requested unit
24    temperature: f32,
25    /// Weather conditions description
26    conditions: String,
27    /// Humidity percentage
28    humidity: u8,
29}
30
31/// Search result from web search
32#[derive(Debug, Serialize, Deserialize, JsonSchema)]
33struct SearchResult {
34    /// Title of the search result
35    title: String,
36    /// Snippet from the result
37    snippet: String,
38}
39
40/// Get the current weather for a location (sync tool)
41#[tool(description = "Get the current weather for a location")]
42fn get_weather(
43    #[description = "The city name to get weather for"] city: String,
44    #[description = "Temperature unit: 'celsius' or 'fahrenheit'"]
45    #[default = "celsius"]
46    unit: Option<String>,
47) -> Result<WeatherInfo> {
48    // Mock implementation - in a real application, this would call a weather API
49    // Simulate some work
50    std::thread::sleep(std::time::Duration::from_millis(100));
51
52    let temp = match unit.as_deref() {
53        Some("fahrenheit") => 72.5,
54        _ => 22.5,
55    };
56
57    Ok(WeatherInfo {
58        temperature: temp,
59        conditions: format!("Sunny with clear skies in {}", city),
60        humidity: 45,
61    })
62}
63
64/// Search the web for information (async tool - demonstrates native async support)
65#[tool(description = "Search the web for information on a topic")]
66async fn web_search(
67    #[description = "The search query"] query: String,
68    #[description = "Maximum number of results to return"]
69    #[default = 3u32]
70    max_results: Option<u32>,
71) -> Result<Vec<SearchResult>> {
72    // Mock implementation - in a real application, this would call a search API
73    // Simulate async I/O with tokio::time::sleep
74    tokio::time::sleep(std::time::Duration::from_millis(150)).await;
75
76    let num_results = max_results.unwrap_or(3) as usize;
77
78    let results: Vec<SearchResult> = (0..num_results)
79        .map(|i| SearchResult {
80            title: format!("Result {} for: {}", i + 1, query),
81            snippet: format!(
82                "This is a snippet of information about '{}' from result {}.",
83                query,
84                i + 1
85            ),
86        })
87        .collect();
88
89    Ok(results)
90}
91
92#[tokio::main]
93async fn main() -> Result<()> {
94    // Build the model
95    // Using a model that supports tool calling (e.g., Llama 3.1, Qwen, Mistral)
96    let model = TextModelBuilder::new("../hf_models/qwen3_4b")
97        .with_isq(IsqType::Q4K)
98        .with_logging()
99        .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
100        .build()
101        .await?;
102
103    // Create the agent with registered tools
104    // - get_weather is a sync tool (runs in spawn_blocking)
105    // - web_search is an async tool (runs natively async)
106    // Both can execute in parallel when the model calls multiple tools
107    let agent = AgentBuilder::new(model)
108        .with_system_prompt(
109            "You are a helpful assistant with access to weather and web search tools. \
110             Use them when needed to answer user questions accurately.",
111        )
112        .with_max_iterations(5)
113        .with_parallel_tool_execution(true) // Enable parallel tool execution (default)
114        .register_tool(get_weather_tool_with_callback())
115        .register_tool(web_search_tool_with_callback())
116        .build();
117
118    println!("=== Agent Example (Non-Streaming) ===\n");
119
120    let user_message =
121        "What's the weather like in Boston, and can you find me some good restaurants there?";
122    println!("User: {}\n", user_message);
123
124    // Run the agent (waits for complete response)
125    let response = agent.run(user_message).await?;
126
127    // Print the final response
128    if let Some(text) = &response.final_response {
129        println!("Assistant: {}\n", text);
130    }
131
132    // Print execution summary
133    println!("=== Execution Summary ===");
134    println!("Completed in {} iteration(s)", response.iterations);
135    println!("Stop reason: {:?}", response.stop_reason);
136    println!("Steps taken: {}", response.steps.len());
137
138    // Print details of each step
139    for (i, step) in response.steps.iter().enumerate() {
140        println!("\n--- Step {} ---", i + 1);
141        if !step.tool_calls.is_empty() {
142            println!("Tool calls:");
143            for call in &step.tool_calls {
144                println!("  - {}: {}", call.function.name, call.function.arguments);
145            }
146            println!("Tool results:");
147            for result in &step.tool_results {
148                let status = if result.result.is_ok() { "OK" } else { "ERROR" };
149                println!("  - {}: {}", result.tool_name, status);
150            }
151        }
152    }
153
154    match response.stop_reason {
155        AgentStopReason::TextResponse => {
156            println!("\nFinal response delivered successfully.");
157        }
158        AgentStopReason::MaxIterations => {
159            println!("\nAgent reached maximum iterations without producing a final response.");
160        }
161        AgentStopReason::NoAction => {
162            println!("\nAgent produced no response.");
163        }
164        AgentStopReason::Error(e) => {
165            println!("\nAgent encountered an error: {}", e);
166        }
167    }
168
169    Ok(())
170}