Skip to main content

mistralai_client/v1/
tool.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::{any::Any, collections::HashMap, fmt::Debug};
4
5// -----------------------------------------------------------------------------
6// Definitions
7
8#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
9pub struct ToolCall {
10    pub function: ToolCallFunction,
11}
12
13#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
14pub struct ToolCallFunction {
15    pub name: String,
16    pub arguments: String,
17}
18
19#[derive(Clone, Debug, Deserialize, Serialize)]
20pub struct Tool {
21    pub r#type: ToolType,
22    pub function: ToolFunction,
23}
24impl Tool {
25    pub fn new(
26        function_name: String,
27        function_description: String,
28        function_parameters: Vec<ToolFunctionParameter>,
29    ) -> Self {
30        let properties: HashMap<String, ToolFunctionParameterProperty> = function_parameters
31            .into_iter()
32            .map(|param| {
33                (
34                    param.name,
35                    ToolFunctionParameterProperty {
36                        r#type: param.r#type,
37                        description: param.description,
38                    },
39                )
40            })
41            .collect();
42        let property_names = properties.keys().cloned().collect();
43
44        let parameters = ToolFunctionParameters {
45            r#type: ToolFunctionParametersType::Object,
46            properties,
47            required: property_names,
48        };
49
50        Self {
51            r#type: ToolType::Function,
52            function: ToolFunction {
53                name: function_name,
54                description: function_description,
55                parameters,
56            },
57        }
58    }
59}
60
61// -----------------------------------------------------------------------------
62// Request
63
64#[derive(Clone, Debug, Deserialize, Serialize)]
65pub struct ToolFunction {
66    name: String,
67    description: String,
68    parameters: ToolFunctionParameters,
69}
70
71#[derive(Clone, Debug, Deserialize, Serialize)]
72pub struct ToolFunctionParameter {
73    name: String,
74    description: String,
75    r#type: ToolFunctionParameterType,
76}
77impl ToolFunctionParameter {
78    pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self {
79        Self {
80            name,
81            r#type,
82            description,
83        }
84    }
85}
86
87#[derive(Clone, Debug, Deserialize, Serialize)]
88pub struct ToolFunctionParameters {
89    r#type: ToolFunctionParametersType,
90    properties: HashMap<String, ToolFunctionParameterProperty>,
91    required: Vec<String>,
92}
93
94#[derive(Clone, Debug, Deserialize, Serialize)]
95pub struct ToolFunctionParameterProperty {
96    r#type: ToolFunctionParameterType,
97    description: String,
98}
99
100#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
101pub enum ToolFunctionParametersType {
102    #[serde(rename = "object")]
103    Object,
104}
105
106#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
107pub enum ToolFunctionParameterType {
108    #[serde(rename = "string")]
109    String,
110}
111
112#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
113pub enum ToolType {
114    #[serde(rename = "function")]
115    Function,
116}
117
118/// An enum representing how functions should be called.
119#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
120pub enum ToolChoice {
121    /// The model is forced to call a function.
122    #[serde(rename = "any")]
123    Any,
124    /// The model can choose to either generate a message or call a function.
125    #[serde(rename = "auto")]
126    Auto,
127    /// The model won't call a function and will generate a message instead.
128    #[serde(rename = "none")]
129    None,
130}
131
132// -----------------------------------------------------------------------------
133// Custom
134
135#[async_trait]
136pub trait Function: Send {
137    async fn execute(&self, arguments: String) -> Box<dyn Any + Send>;
138}
139
140impl Debug for dyn Function {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        write!(f, "Function()")
143    }
144}