use_mistral_tools/
use_mistral_tools.rs1use anyhow::Result;
2use schemars::JsonSchema;
3use serde::Deserialize;
4use serde::Serialize;
5
6use allms::{
7 llm::{
8 tools::{LLMTools, MistralCodeInterpreterConfig, MistralWebSearchConfig},
9 MistralModels,
10 },
11 Completions,
12};
13
14#[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)]
16struct AINewsArticles {
17 pub articles: Vec<AINewsArticle>,
18}
19
20#[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)]
21struct AINewsArticle {
22 pub title: String,
23 pub url: String,
24 pub description: String,
25}
26
27#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
29pub struct CodeInterpreterResponse {
30 pub problem: String,
31 pub code: String,
32 pub solution: String,
33}
34
35#[tokio::main]
36async fn main() -> Result<()> {
37 env_logger::init();
38
39 let mistral_api_key: String =
40 std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
41
42 let web_search_tool = LLMTools::MistralWebSearch(MistralWebSearchConfig::new());
44 let mistral_responses = Completions::new(
45 MistralModels::MistralMedium3_1,
46 &mistral_api_key,
47 None,
48 None,
49 )
50 .add_tool(web_search_tool);
51
52 match mistral_responses
53 .get_answer::<AINewsArticles>("Find up to 5 most recent news items about Artificial Intelligence, Generative AI, and Large Language Models.
54 For each news item, provide the title, url, and a short description.")
55 .await
56 {
57 Ok(response) => println!("AI news articles:\n{:#?}", response),
58 Err(e) => eprintln!("Error: {:?}", e),
59 }
60
61 let code_interpreter_tool =
63 LLMTools::MistralCodeInterpreter(MistralCodeInterpreterConfig::new());
64 let mistral_responses = Completions::new(
65 MistralModels::MistralMedium3_1,
66 &mistral_api_key,
67 None,
68 None,
69 )
70 .add_tool(code_interpreter_tool);
71
72 match mistral_responses
73 .get_answer::<CodeInterpreterResponse>(
74 "Calculate the mean and standard deviation of [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]",
75 )
76 .await
77 {
78 Ok(response) => println!("Code interpreter response:\n{:#?}", response),
79 Err(e) => eprintln!("Error: {:?}", e),
80 }
81
82 Ok(())
83}