use anyhow::Result;
use mistralrs::{
CalledFunction, IsqBits, ModelBuilder, RequestBuilder, SearchResult, TextMessageRole,
TextMessages, Tool, ToolChoice, ToolType,
};
use std::fs;
use std::sync::Arc;
use walkdir::WalkDir;
fn local_search(query: &str) -> Result<Vec<SearchResult>> {
let mut results = Vec::new();
for entry in WalkDir::new(".") {
let entry = entry?;
if entry.file_type().is_file() {
let name = entry.file_name().to_string_lossy();
if name.contains(query) {
let path = entry.path().display().to_string();
let content = fs::read_to_string(entry.path()).unwrap_or_default();
results.push(SearchResult {
title: name.into_owned(),
description: path.clone(),
url: path,
content,
});
}
}
}
results.sort_by_key(|r| r.title.clone());
results.reverse();
Ok(results)
}
#[tokio::main]
async fn main() -> Result<()> {
let model = ModelBuilder::new("google/gemma-4-E4B-it")
.with_auto_isq(IsqBits::Four)
.with_logging()
.with_tool_callback(
"local_search",
Arc::new(|f: &CalledFunction| {
let args: serde_json::Value = serde_json::from_str(&f.arguments)?;
let query = args["query"].as_str().unwrap_or("");
Ok(serde_json::to_string(&local_search(query)?)?)
}),
)
.build()
.await?;
let parameters = std::collections::HashMap::from([(
"query".to_string(),
serde_json::json!({"type": "string", "description": "Query"}),
)]);
let tool = Tool {
tp: ToolType::Function,
function: mistralrs::Function {
description: Some("Local filesystem search".to_string()),
name: "local_search".to_string(),
parameters: Some(parameters),
},
};
let messages =
TextMessages::new().add_message(TextMessageRole::User, "Where is Cargo.toml in this repo?");
let messages = RequestBuilder::from(messages)
.set_tools(vec![tool])
.set_tool_choice(ToolChoice::Auto);
let response = model.send_chat_request(messages).await?;
println!("{}", response.choices[0].message.content.as_ref().unwrap());
Ok(())
}