erniebot_rs/chat/
function.rs1use schemars::schema::RootSchema;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4
5use super::Role;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
8#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
9pub struct Function {
11 pub name: String,
13 pub description: String,
15 pub parameters: RootSchema,
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub response: Option<RootSchema>,
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub examples: Option<Vec<Vec<Example>>>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
27#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
28pub struct Example {
29 pub role: Role,
31 pub content: Option<String>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub name: Option<String>,
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub function_call: Option<FunctionCall>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
47#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
48pub struct FunctionCall {
49 pub name: String,
51 pub arguments: String,
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub thoughts: Option<String>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
59#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
60pub struct ToolChoice {
62 pub r#type: String, pub function: Value,
64}
65
66impl ToolChoice {
67 pub fn new(function: Function) -> Self {
68 Self {
69 r#type: "function".to_string(),
70 function: serde_json::json!(
71 {
72 "name": function.name,
73 }
74 ),
75 }
76 }
77 pub fn from_function_name(name: String) -> Self {
78 Self {
79 r#type: "function".to_string(),
80 function: serde_json::json!(
81 {
82 "name": name,
83 }
84 ),
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use schemars::{schema::RootSchema, schema_for, JsonSchema};
92 use serde::{Deserialize, Serialize};
93 #[test]
94 fn test_schema() {
95 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
96 #[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
97 struct TestStruct {
98 pub date: String,
99 pub place: String,
100 }
101 let schema = schema_for!(TestStruct);
102 println!("{:?}", serde_json::to_string(&schema).unwrap());
103 let default_schema = RootSchema::default();
104 println!("{:?}", serde_json::to_string(&default_schema).unwrap());
105 }
106
107 #[test]
108 fn test_tool_choice() {
109 use super::Function;
110 let function = Function {
111 name: "test".to_string(),
112 description: "test".to_string(),
113 ..Default::default()
114 };
115 let tool_choice = super::ToolChoice::new(function);
116 println!("{:?}", serde_json::to_string(&tool_choice).unwrap());
117 }
118}