erniebot_rs/chat/
function.rs

1use schemars::schema::RootSchema;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4
5use super::Role;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
8#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
9/// Definition of the function structure that some models(like Erniebot) can select and call.
10pub struct Function {
11    /// The name of the function.
12    pub name: String,
13    /// The description of the function.
14    pub description: String,
15    /// The format of parameters of the function, following the JSON schema format.
16    pub parameters: RootSchema,
17    /// The format of the response of the function, following the JSON schema format.
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub response: Option<RootSchema>,
20    /// The examples of the function. each instance of the outer vector represents a round of conversation, and each instance of the inner vector represents a message in the round of conversation. More details can be found in the example of chat_with_function.rs.
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub examples: Option<Vec<Vec<Example>>>,
23}
24
25/// Example of a message involved in a function calling process
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
27#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
28pub struct Example {
29    /// Same as the role in Message, can be "user", "assistant", or "function".
30    pub role: Role,
31    /// Dialog content instructions:
32
33    /// (1) If the current message contains a function_call and the role is "assistant", the message can be empty. However, in other scenarios, it cannot be empty.
34
35    /// (2) The content corresponding to the last message cannot be a blank character, including spaces, "\n", "\r", r"\f", etc.
36    pub content: Option<String>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    /// The "author" of the message. the This member is required when the role value is "function", and in this case is should be the name in the function_call in the response content
39    pub name: Option<String>,
40    /// this is function calling result of last round of function call, serving as chat history.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub function_call: Option<FunctionCall>,
43}
44
45/// This is function calling result of last round of function call
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
47#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
48pub struct FunctionCall {
49    /// name of a function
50    pub name: String,
51    /// arguments of a function call that LLM model outputs, following the JSON format.
52    pub arguments: String,
53    /// The thinking process of the model
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub thoughts: Option<String>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
59#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
60/// In the context of function calls, prompt the large model to select a specific function (not mandatory). Note: The specified function name must exist within the list of functions.
61pub struct ToolChoice {
62    pub r#type: String, //only one valid value: "function"
63    pub function: Value,
64}
65
66impl ToolChoice {
67    pub fn new(function: Function) -> Self {
68        Self {
69            r#type: "function".to_string(),
70            function: serde_json::json!(
71                {
72                    "name": function.name,
73                }
74            ),
75        }
76    }
77    pub fn from_function_name(name: String) -> Self {
78        Self {
79            r#type: "function".to_string(),
80            function: serde_json::json!(
81                {
82                    "name": name,
83                }
84            ),
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use schemars::{schema::RootSchema, schema_for, JsonSchema};
92    use serde::{Deserialize, Serialize};
93    #[test]
94    fn test_schema() {
95        #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
96        #[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
97        struct TestStruct {
98            pub date: String,
99            pub place: String,
100        }
101        let schema = schema_for!(TestStruct);
102        println!("{:?}", serde_json::to_string(&schema).unwrap());
103        let default_schema = RootSchema::default();
104        println!("{:?}", serde_json::to_string(&default_schema).unwrap());
105    }
106
107    #[test]
108    fn test_tool_choice() {
109        use super::Function;
110        let function = Function {
111            name: "test".to_string(),
112            description: "test".to_string(),
113            ..Default::default()
114        };
115        let tool_choice = super::ToolChoice::new(function);
116        println!("{:?}", serde_json::to_string(&tool_choice).unwrap());
117    }
118}