Skip to main content

custom_tool_call/
main.rs

1use anyhow::Result;
2use mistralrs::{
3    CalledFunction, IsqType, RequestBuilder, SearchResult, TextMessageRole, TextMessages,
4    TextModelBuilder, Tool, ToolChoice, ToolType,
5};
6use std::fs;
7use std::sync::Arc;
8use walkdir::WalkDir;
9
10fn local_search(query: &str) -> Result<Vec<SearchResult>> {
11    let mut results = Vec::new();
12    for entry in WalkDir::new(".") {
13        let entry = entry?;
14        if entry.file_type().is_file() {
15            let name = entry.file_name().to_string_lossy();
16            if name.contains(query) {
17                let path = entry.path().display().to_string();
18                let content = fs::read_to_string(entry.path()).unwrap_or_default();
19                results.push(SearchResult {
20                    title: name.into_owned(),
21                    description: path.clone(),
22                    url: path,
23                    content,
24                });
25            }
26        }
27    }
28    results.sort_by_key(|r| r.title.clone());
29    results.reverse();
30    Ok(results)
31}
32
33#[tokio::main]
34async fn main() -> Result<()> {
35    // Build the model and register the *tool callback*.
36    let model = TextModelBuilder::new("NousResearch/Hermes-3-Llama-3.1-8B")
37        .with_isq(IsqType::Q4K)
38        .with_logging()
39        .with_tool_callback(
40            "local_search",
41            Arc::new(|f: &CalledFunction| {
42                let args: serde_json::Value = serde_json::from_str(&f.arguments)?;
43                let query = args["query"].as_str().unwrap_or("");
44                Ok(serde_json::to_string(&local_search(query)?)?)
45            }),
46        )
47        .build()
48        .await?;
49
50    // Define the JSON schema for the tool the model can call.
51    let parameters = std::collections::HashMap::from([(
52        "query".to_string(),
53        serde_json::json!({"type": "string", "description": "Query"}),
54    )]);
55    let tool = Tool {
56        tp: ToolType::Function,
57        function: mistralrs::Function {
58            description: Some("Local filesystem search".to_string()),
59            name: "local_search".to_string(),
60            parameters: Some(parameters),
61        },
62    };
63
64    // Ask the user question and allow the model to call the tool automatically.
65    let messages =
66        TextMessages::new().add_message(TextMessageRole::User, "Where is Cargo.toml in this repo?");
67    let messages = RequestBuilder::from(messages)
68        .set_tools(vec![tool])
69        .set_tool_choice(ToolChoice::Auto);
70
71    let response = model.send_chat_request(messages).await?;
72    println!("{}", response.choices[0].message.content.as_ref().unwrap());
73    Ok(())
74}