Skip to main content

tools/
main.rs

1use std::collections::HashMap;
2
3use anyhow::Result;
4use mistralrs::{
5    Function, IsqType, RequestBuilder, TextMessageRole, TextModelBuilder, Tool, ToolChoice,
6    ToolType,
7};
8use serde_json::{json, Value};
9
10#[derive(serde::Deserialize, Debug, Clone)]
11struct GetWeatherInput {
12    place: String,
13}
14
15fn get_weather(input: GetWeatherInput) -> String {
16    format!("Weather in {}: Temperature: 25C. Wind: calm. Dew point: 10C. Precipitiation: 5cm of rain expected.", input.place)
17}
18
19#[tokio::main]
20async fn main() -> Result<()> {
21    let model = TextModelBuilder::new("meta-llama/Meta-Llama-3.1-8B-Instruct")
22        .with_logging()
23        .with_isq(IsqType::Q8_0)
24        .build()
25        .await?;
26
27    let parameters: HashMap<String, Value> = serde_json::from_value(json!({
28        "type": "object",
29        "properties": {
30            "place": {
31                "type": "string",
32                "description": "The place to get the weather for.",
33            },
34        },
35        "required": ["place"],
36    }))?;
37
38    let tools = vec![Tool {
39        tp: ToolType::Function,
40        function: Function {
41            description: Some("Get the weather for a certain city.".to_string()),
42            name: "get_weather".to_string(),
43            parameters: Some(parameters),
44        },
45    }];
46
47    // We will keep all the messages here
48    let mut messages = RequestBuilder::new()
49        .add_message(TextMessageRole::User, "What is the weather in Boston?")
50        .set_tools(tools)
51        .set_tool_choice(ToolChoice::Auto);
52
53    let response = model.send_chat_request(messages.clone()).await?;
54
55    let message = &response.choices[0].message;
56
57    if let Some(tool_calls) = &message.tool_calls {
58        let called = &tool_calls[0];
59        if called.function.name == "get_weather" {
60            let input: GetWeatherInput = serde_json::from_str(&called.function.arguments)?;
61            println!("Called tool `get_weather` with arguments {input:?}");
62
63            let result = get_weather(input);
64            println!("Output of tool call: {result}");
65
66            // Add tool call message from assistant so it knows what it called
67            // Then, add message from the tool
68            messages = messages
69                .add_message_with_tool_call(
70                    TextMessageRole::Assistant,
71                    String::new(),
72                    vec![called.clone()],
73                )
74                .add_tool_message(result, called.id.clone())
75                .set_tool_choice(ToolChoice::None);
76
77            let response = model.send_chat_request(messages.clone()).await?;
78
79            let message = &response.choices[0].message;
80            println!("Output of model: {:?}", message.content);
81        }
82    }
83
84    Ok(())
85}