use futures::StreamExt;
use serde_json::json;
use stakai::{ContentPart, GenerateRequest, Inference, Message, Model, StreamEvent, Tool};
use std::collections::HashMap;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Inference::new();
let weather_tool = Tool::function("get_weather", "Get the current weather for a location")
.parameters(json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature unit"
}
},
"required": ["city"]
}));
let time_tool =
Tool::function("get_time", "Get the current time for a location").parameters(json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"]
}));
let mut request = GenerateRequest::new(
Model::custom("gpt-4o-mini", "openai"),
vec![Message::new(
stakai::Role::User,
"What's the weather and time in Paris?",
)],
);
request.options = request
.options
.add_tool(weather_tool.clone())
.add_tool(time_tool.clone());
println!("--- Streaming tool calls from model\n");
let mut stream = client.stream(&request).await?;
let mut tool_calls: HashMap<String, ToolCallBuilder> = HashMap::new();
let mut text_content = String::new();
while let Some(event) = stream.next().await {
match event? {
StreamEvent::Start { id } => {
println!("Stream started: {}", id);
}
StreamEvent::TextDelta { delta, .. } => {
print!("{}", delta);
text_content.push_str(&delta);
}
StreamEvent::ReasoningDelta { delta, .. } => {
print!("[Reasoning: {}]", delta);
}
StreamEvent::ToolCallStart { id, name } => {
println!("\n\n🔧 Tool call started:");
println!(" ID: {}", id);
println!(" Function: {}", name);
tool_calls.insert(
id.clone(),
ToolCallBuilder {
id: id.clone(),
name: name.clone(),
arguments: String::new(),
},
);
}
StreamEvent::ToolCallDelta { id, delta } => {
if let Some(builder) = tool_calls.get_mut(&id) {
builder.arguments.push_str(&delta);
}
}
StreamEvent::ToolCallEnd {
id,
name,
arguments,
..
} => {
println!("\n✅ Tool call completed:");
println!(" ID: {}", id);
println!(" Function: {}", name);
println!(" Arguments: {}", arguments);
tool_calls.insert(
id.clone(),
ToolCallBuilder {
id: id.clone(),
name: name.clone(),
arguments: arguments.to_string(),
},
);
}
StreamEvent::Finish { usage, reason } => {
println!("\n\n--- Stream finished");
println!("Reason: {:?}", reason);
println!(
"Usage: {} prompt + {} completion = {} total tokens",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
}
StreamEvent::Error { message } => {
eprintln!("Error: {}", message);
}
}
}
if !tool_calls.is_empty() {
println!("\n\n--- Executing tools");
let mut tool_results = Vec::new();
for (_, builder) in tool_calls.iter() {
let result = execute_tool(&builder.name, &builder.arguments)?;
println!("\n🔨 Executed: {}", builder.name);
println!(" Result: {}", result);
tool_results.push((builder.id.clone(), result));
}
let mut messages = request.messages.clone();
for (_, builder) in tool_calls.iter() {
messages.push(Message::new(
stakai::Role::Assistant,
vec![ContentPart::tool_call(
builder.id.clone(),
builder.name.clone(),
serde_json::from_str(&builder.arguments).unwrap_or(serde_json::json!({})),
)],
));
}
for (call_id, result) in tool_results {
messages.push(Message::new(
stakai::Role::Tool,
vec![ContentPart::tool_result(call_id, result)],
));
}
let mut follow_up = GenerateRequest::new(Model::custom("gpt-4o-mini", "openai"), messages);
follow_up.options = follow_up.options.add_tool(weather_tool).add_tool(time_tool);
println!("\n\n--- Getting final response with tool results\n");
let mut final_stream = client.stream(&follow_up).await?;
while let Some(event) = final_stream.next().await {
if let StreamEvent::TextDelta { delta, .. } = event? {
print!("{}", delta);
}
}
println!("\n");
}
Ok(())
}
struct ToolCallBuilder {
id: String,
name: String,
arguments: String,
}
fn execute_tool(
name: &str,
arguments: &str,
) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
let args: serde_json::Value = serde_json::from_str(arguments)?;
match name {
"get_weather" => {
let city = args["city"].as_str().unwrap_or("Unknown");
Ok(json!({
"city": city,
"temperature": 22.5,
"condition": "Sunny",
"humidity": 65
}))
}
"get_time" => {
let city = args["city"].as_str().unwrap_or("Unknown");
Ok(json!({
"city": city,
"time": "14:30",
"timezone": "CET"
}))
}
_ => Ok(json!({
"error": format!("Unknown tool: {}", name)
})),
}
}