use llm::{
builder::{FunctionBuilder, LLMBackend, LLMBuilder, ParamBuilder},
chat::{ChatMessage, ToolChoice},
FunctionCall, LLMProvider, ToolCall,
};
use serde_json::{json, Value};
use std::env;
use std::error::Error;
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let args: Vec<String> = env::args().collect();
let provider = if args.len() > 1 { &args[1] } else { "openai" };
let scenario = if args.len() > 2 { &args[2] } else { "simple" };
let llm = create_llm(provider)?;
println!("=== Unified Tool Calling Example ===");
println!("Provider: {provider}");
println!("Scenario: {scenario}");
println!("=================================\n");
match scenario {
"simple" => run_simple_scenario(llm.as_ref()).await?,
"multi" => run_multi_turn_scenario(llm.as_ref()).await?,
"choice" => run_tool_choice_scenario(llm.as_ref()).await?,
_ => {
println!("Unknown scenario: {scenario}. Available scenarios: simple, multi, choice");
println!("Example: cargo run --example unified_tool_calling_example -- openai multi");
}
}
Ok(())
}
fn create_llm(provider_name: &str) -> Result<Box<dyn LLMProvider>, Box<dyn Error>> {
let backend = match provider_name.to_lowercase().as_str() {
"openai" => LLMBackend::OpenAI,
"anthropic" => LLMBackend::Anthropic,
"google" => LLMBackend::Google,
"ollama" => LLMBackend::Ollama,
_ => {
return Err(format!(
"Unsupported provider: {provider_name}. Use 'openai', 'anthropic', or 'google'"
)
.into());
}
};
let api_key = match backend {
LLMBackend::OpenAI => {
env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY environment variable not set")
}
LLMBackend::Anthropic => {
env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY environment variable not set")
}
LLMBackend::Google => {
env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY environment variable not set")
}
LLMBackend::Ollama => env::var("OLLAMA_API_KEY").unwrap_or("ollama".into()),
_ => unreachable!(),
};
let model = match backend {
LLMBackend::OpenAI => "gpt-4o-mini",
LLMBackend::Anthropic => "claude-3-5-haiku-latest",
LLMBackend::Google => "gemini-1.5-flash",
LLMBackend::Ollama => "llama3.1:latest",
_ => unreachable!(),
};
let llm = LLMBuilder::new()
.backend(backend)
.api_key(api_key)
.model(model)
.max_tokens(1024)
.temperature(0.7)
.function(
FunctionBuilder::new("get_weather")
.description("Get the current weather in a specific location")
.param(ParamBuilder::new("location").type_of("string").description(
"The city and state/country, e.g., 'San Francisco, CA' or 'Tokyo, Japan'",
))
.required(vec!["location".to_string()]),
)
.function(
FunctionBuilder::new("get_current_time")
.description("Get the current time in a specific time zone")
.param(
ParamBuilder::new("timezone")
.type_of("string")
.description("The timezone, e.g., 'EST', 'PST', 'UTC', 'JST', etc."),
)
.required(vec!["timezone".to_string()]),
)
.function(
FunctionBuilder::new("search_restaurants")
.description("Search for restaurants in a specific location")
.param(
ParamBuilder::new("location")
.type_of("string")
.description("The city or neighborhood to search in"),
)
.param(
ParamBuilder::new("cuisine")
.type_of("string")
.description("Type of cuisine, e.g., 'Italian', 'Japanese', etc."),
)
.required(vec!["location".to_string()]),
)
.build()?;
Ok(llm)
}
async fn run_simple_scenario(llm: &dyn LLMProvider) -> Result<(), Box<dyn Error>> {
println!("SCENARIO: Simple Tool Calling");
println!("This demonstrates a single query that triggers tool use\n");
let messages = vec![ChatMessage::user()
.content("What's the current weather in Tokyo?")
.build()];
println!("User: What's the current weather in Tokyo?\n");
println!("Sending request to model...");
let response = llm.chat_with_tools(&messages, llm.tools()).await?;
if let Some(tool_calls) = response.tool_calls() {
println!("\nModel is using tools: {}", tool_calls.len());
for call in &tool_calls {
println!("Tool call: {}", call.function.name);
println!("Arguments: {}\n", call.function.arguments);
let result = process_tool_call(call)?;
println!(
"Tool response: {}\n",
serde_json::to_string_pretty(&result)?
);
let mut follow_up = messages.clone();
follow_up.push(
ChatMessage::assistant()
.tool_use(tool_calls.clone())
.content("")
.build(),
);
let tool_results = vec![ToolCall {
id: call.id.clone(),
call_type: "function".to_string(),
function: FunctionCall {
name: call.function.name.clone(),
arguments: serde_json::to_string(&result)?,
},
}];
follow_up.push(
ChatMessage::user()
.tool_result(tool_results)
.content("")
.build(),
);
println!("Getting final response with tool results...");
let final_response = llm.chat_with_tools(&follow_up, llm.tools()).await?;
println!("\nFinal response: {final_response}");
}
} else {
println!("\nModel provided a direct response (no tools used):\n{response}");
}
Ok(())
}
async fn run_multi_turn_scenario(llm: &dyn LLMProvider) -> Result<(), Box<dyn Error>> {
println!("SCENARIO: Multi-turn Conversation with Tool Calling");
println!("This demonstrates maintaining context across multiple turns with tool use\n");
let mut conversation = Vec::new();
let user_query = "I'm planning a trip to Tokyo. What's the weather like there?";
println!("User: {user_query}\n");
await_tool_response(llm, &mut conversation, user_query).await?;
let user_query = "What time is it there right now?";
println!("\nUser: {user_query}\n");
await_tool_response(llm, &mut conversation, user_query).await?;
let user_query = "Can you recommend some good sushi restaurants in Tokyo?";
println!("\nUser: {user_query}\n");
await_tool_response(llm, &mut conversation, user_query).await?;
let user_query =
"Based on the weather and time, when would be a good time to visit those restaurants?";
println!("\nUser: {user_query}\n");
await_tool_response(llm, &mut conversation, user_query).await?;
Ok(())
}
async fn run_tool_choice_scenario(llm: &dyn LLMProvider) -> Result<(), Box<dyn Error>> {
println!("SCENARIO: Tool Choice Options");
println!("This demonstrates controlling how the model uses tools\n");
let query = "What's the weather like in Tokyo and what time is it there?";
test_tool_choice(llm, ToolChoice::Auto, query).await?;
test_tool_choice(llm, ToolChoice::Any, query).await?;
test_tool_choice(llm, ToolChoice::Tool("get_weather".to_string()), query).await?;
test_tool_choice(llm, ToolChoice::None, query).await?;
Ok(())
}
async fn test_tool_choice(
llm: &dyn LLMProvider,
tool_choice: ToolChoice,
query: &str,
) -> Result<(), Box<dyn Error>> {
println!("\n--- Testing {tool_choice:?} ---");
let mut builder = LLMBuilder::new();
if let Some(tools) = llm.tools() {
for tool in tools {
builder = builder.function(
FunctionBuilder::new(&tool.function.name).description(&tool.function.description), );
}
}
let custom_llm = builder
.backend(match llm {
_ if std::any::type_name::<OpenAI>().contains("OpenAI") => LLMBackend::OpenAI,
_ if std::any::type_name::<Anthropic>().contains("Anthropic") => LLMBackend::Anthropic,
_ if std::any::type_name::<Google>().contains("Google") => LLMBackend::Google,
_ => LLMBackend::OpenAI, })
.api_key(std::env::var("OPENAI_API_KEY").unwrap_or("key".to_string()))
.model("gpt-4")
.max_tokens(1024)
.tool_choice(tool_choice.clone())
.build()?;
let messages = vec![ChatMessage::user().content(query).build()];
println!("User: {query}\n");
let response = custom_llm
.chat_with_tools(&messages, custom_llm.tools())
.await?;
if let Some(tool_calls) = response.tool_calls() {
println!("Tools called:");
for call in tool_calls {
println!("- {}", call.function.name);
}
} else {
println!("No tools called");
}
println!("\nResponse: {response}");
Ok(())
}
async fn await_tool_response(
llm: &dyn LLMProvider,
conversation: &mut Vec<ChatMessage>,
user_query: &str,
) -> Result<(), Box<dyn Error>> {
conversation.push(ChatMessage::user().content(user_query).build());
let response = llm.chat_with_tools(conversation, llm.tools()).await?;
if let Some(tool_calls) = response.tool_calls() {
println!("Model is using tools: {}", tool_calls.len());
conversation.push(
ChatMessage::assistant()
.tool_use(tool_calls.clone())
.content(response.text().unwrap_or_default())
.build(),
);
let mut tool_results = Vec::new();
for call in &tool_calls {
println!("Tool call: {}", call.function.name);
println!("Arguments: {}", call.function.arguments);
let result = process_tool_call(call)?;
println!("Tool response: {}", serde_json::to_string_pretty(&result)?);
tool_results.push(ToolCall {
id: call.id.clone(),
call_type: "function".to_string(),
function: FunctionCall {
name: call.function.name.clone(),
arguments: serde_json::to_string(&result)?,
},
});
}
conversation.push(ChatMessage::user().tool_result(tool_results).build());
println!("Getting final response with tool results...");
let final_response = llm.chat_with_tools(conversation, llm.tools()).await?;
println!("\nAssistant: {}", final_response.text().unwrap_or_default());
conversation.push(
ChatMessage::assistant()
.content(final_response.text().unwrap_or_default())
.build(),
);
} else {
let response_text = response.text().unwrap_or_default();
println!("\nAssistant: {response_text}");
conversation.push(ChatMessage::assistant().content(response_text).build());
}
Ok(())
}
fn process_tool_call(tool_call: &ToolCall) -> Result<Value, Box<dyn Error>> {
let args: Value = serde_json::from_str(&tool_call.function.arguments)?;
match tool_call.function.name.as_str() {
"get_weather" => {
let location = args["location"].as_str().unwrap_or("unknown location");
Ok(json!({
"location": location,
"temperature": 22,
"units": "celsius",
"conditions": "Partly cloudy",
"humidity": "65%",
"forecast": "Clear skies expected later today"
}))
}
"get_current_time" => {
let timezone = args["timezone"].as_str().unwrap_or("UTC");
Ok(json!({
"timezone": timezone,
"current_time": "14:30",
"date": "April 2, 2025",
"day_of_week": "Wednesday"
}))
}
"search_restaurants" => {
let location = args["location"].as_str().unwrap_or("unknown location");
let cuisine = args["cuisine"].as_str().unwrap_or("any");
Ok(json!({
"location": location,
"cuisine": cuisine,
"restaurants": [
{
"name": "Sushi Dai",
"rating": 4.8,
"price_range": "$$$",
"specialty": "Omakase"
},
{
"name": "Tsukiji Sushisay",
"rating": 4.6,
"price_range": "$$",
"specialty": "Market-fresh sushi"
},
{
"name": "Sukiyabashi Jiro",
"rating": 4.9,
"price_range": "$$$$",
"specialty": "Premium sushi experience"
}
]
}))
}
_ => Ok(json!({
"error": "Unknown function",
"function": tool_call.function.name
})),
}
}
type OpenAI = llm::backends::openai::OpenAI;
type Anthropic = llm::backends::anthropic::Anthropic;
type Google = llm::backends::google::Google;