use anyhow::Result;
use mistralrs::{
tool, AgentBuilder, AgentEvent, AgentStopReason, IsqBits, ModelBuilder,
PagedAttentionMetaBuilder,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::io::Write;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct WeatherInfo {
temperature: f32,
conditions: String,
humidity: u8,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct SearchResult {
title: String,
snippet: String,
}
#[tool(description = "Get the current weather for a location")]
fn get_weather(
#[description = "The city name to get weather for"] city: String,
#[description = "Temperature unit: 'celsius' or 'fahrenheit'"]
#[default = "celsius"]
unit: Option<String>,
) -> Result<WeatherInfo> {
std::thread::sleep(std::time::Duration::from_millis(100));
let temp = match unit.as_deref() {
Some("fahrenheit") => 72.5,
_ => 22.5,
};
Ok(WeatherInfo {
temperature: temp,
conditions: format!("Sunny with clear skies in {}", city),
humidity: 45,
})
}
#[tool(description = "Search the web for information on a topic")]
async fn web_search(
#[description = "The search query"] query: String,
#[description = "Maximum number of results to return"]
#[default = 3u32]
max_results: Option<u32>,
) -> Result<Vec<SearchResult>> {
tokio::time::sleep(std::time::Duration::from_millis(150)).await;
let num_results = max_results.unwrap_or(3) as usize;
let results: Vec<SearchResult> = (0..num_results)
.map(|i| SearchResult {
title: format!("Result {} for: {}", i + 1, query),
snippet: format!(
"This is a snippet of information about '{}' from result {}.",
query,
i + 1
),
})
.collect();
Ok(results)
}
#[tokio::main]
async fn main() -> Result<()> {
let model = ModelBuilder::new("google/gemma-4-E4B-it")
.with_auto_isq(IsqBits::Four)
.with_logging()
.with_paged_attn(PagedAttentionMetaBuilder::default().build()?)
.build()
.await?;
let agent = AgentBuilder::new(model)
.with_system_prompt(
"You are a helpful assistant with access to weather and web search tools. \
Use them when needed to answer user questions accurately.",
)
.with_max_iterations(5)
.with_parallel_tool_execution(true) .register_tool(get_weather_tool_with_callback())
.register_tool(web_search_tool_with_callback())
.build();
println!("=== Agent with Streaming Output ===\n");
println!(
"User: What's the weather like in Boston, and can you find me some good restaurants there?\n"
);
print!("Assistant: ");
let mut stream = agent
.run_stream(
"What's the weather like in Boston, and can you find me some good restaurants there?",
)
.await?;
let stdout = std::io::stdout();
let mut handle = stdout.lock();
while let Some(event) = stream.next().await {
match event {
AgentEvent::TextDelta(text) => {
write!(handle, "{}", text)?;
handle.flush()?;
}
AgentEvent::ToolCallsStart(calls) => {
writeln!(handle, "\n\n[Calling {} tool(s)...]", calls.len())?;
for call in &calls {
writeln!(
handle,
" - {}: {}",
call.function.name, call.function.arguments
)?;
}
}
AgentEvent::ToolResult(result) => {
let status = if result.result.is_ok() { "OK" } else { "ERROR" };
writeln!(
handle,
" [Tool {} completed: {}]",
result.tool_name, status
)?;
}
AgentEvent::ToolCallsComplete => {
writeln!(handle, "[All tools completed, continuing...]\n")?;
write!(handle, "Assistant: ")?;
handle.flush()?;
}
AgentEvent::Complete(response) => {
writeln!(handle, "\n\n=== Agent Execution Summary ===")?;
writeln!(handle, "Completed in {} iteration(s)", response.iterations)?;
writeln!(handle, "Stop reason: {:?}", response.stop_reason)?;
writeln!(handle, "Steps taken: {}", response.steps.len())?;
match response.stop_reason {
AgentStopReason::TextResponse => {
writeln!(handle, "Final response delivered successfully.")?;
}
AgentStopReason::MaxIterations => {
writeln!(
handle,
"Agent reached maximum iterations without producing a final response."
)?;
}
AgentStopReason::NoAction => {
writeln!(handle, "Agent produced no response.")?;
}
AgentStopReason::Error(e) => {
writeln!(handle, "Agent encountered an error: {}", e)?;
}
}
}
}
}
Ok(())
}