1use schemars::{generate::SchemaSettings, JsonSchema, SchemaGenerator};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use snafu::{ResultExt, Snafu};
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8#[serde(untagged)]
9pub enum Tool {
10 Function {
12 function_declarations: Vec<FunctionDeclaration>,
14 },
15 GoogleSearch {
17 google_search: GoogleSearchConfig,
19 },
20 URLContext {
21 url_context: URLContextConfig,
22 },
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub struct GoogleSearchConfig {}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct URLContextConfig {}
32
33impl Tool {
34 pub fn new(function_declaration: FunctionDeclaration) -> Self {
36 Self::Function {
37 function_declarations: vec![function_declaration],
38 }
39 }
40
41 pub fn with_functions(function_declarations: Vec<FunctionDeclaration>) -> Self {
43 Self::Function {
44 function_declarations,
45 }
46 }
47
48 pub fn google_search() -> Self {
50 Self::GoogleSearch {
51 google_search: GoogleSearchConfig {},
52 }
53 }
54
55 pub fn url_context() -> Self {
57 Self::URLContext {
58 url_context: URLContextConfig {},
59 }
60 }
61}
62
63#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
65#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
66pub enum Behavior {
67 #[default]
70 Blocking,
71 NonBlocking,
75}
76
77#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
79pub struct FunctionDeclaration {
80 pub name: String,
82 pub description: String,
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub behavior: Option<Behavior>,
87 #[serde(skip_serializing_if = "Option::is_none")]
89 pub(crate) parameters: Option<Value>,
90 #[serde(skip_serializing_if = "Option::is_none")]
94 pub(crate) response: Option<Value>,
95}
96
97fn generate_parameters_schema<Parameters>() -> Value
99where
100 Parameters: JsonSchema + Serialize,
101{
102 let schema_generator = SchemaGenerator::new(SchemaSettings::openapi3().with(|s| {
104 s.inline_subschemas = true;
105 s.meta_schema = None;
106 }));
107
108 let mut schema = schema_generator.into_root_schema_for::<Parameters>();
109
110 schema.remove("title");
112 schema.to_value()
113}
114
115impl FunctionDeclaration {
116 pub fn new(
118 name: impl Into<String>,
119 description: impl Into<String>,
120 behavior: Option<Behavior>,
121 ) -> Self {
122 Self {
123 name: name.into(),
124 description: description.into(),
125 behavior,
126 ..Default::default()
127 }
128 }
129
130 pub fn with_parameters<Parameters>(mut self) -> Self
132 where
133 Parameters: JsonSchema + Serialize,
134 {
135 self.parameters = Some(generate_parameters_schema::<Parameters>());
136 self
137 }
138
139 pub fn with_response<Response>(mut self) -> Self
141 where
142 Response: JsonSchema + Serialize,
143 {
144 self.response = Some(generate_parameters_schema::<Response>());
145 self
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
151pub struct FunctionCall {
152 pub name: String,
154 pub args: serde_json::Value,
156 #[serde(skip_serializing_if = "Option::is_none")]
158 pub thought_signature: Option<String>,
159}
160
161#[derive(Debug, Snafu)]
162pub enum FunctionCallError {
163 #[snafu(display("failed to deserialize parameter '{key}'"))]
164 Deserialization {
165 source: serde_json::Error,
166 key: String,
167 },
168
169 #[snafu(display("parameter '{key}' is missing in arguments '{args}'"))]
170 MissingParameter {
171 key: String,
172 args: serde_json::Value,
173 },
174
175 #[snafu(display("arguments should be an object; actual: {actual}"))]
176 ArgumentTypeMismatch { actual: String },
177}
178
179impl FunctionCall {
180 pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
182 Self {
183 name: name.into(),
184 args,
185 thought_signature: None,
186 }
187 }
188
189 pub fn with_thought_signature(
191 name: impl Into<String>,
192 args: serde_json::Value,
193 thought_signature: impl Into<String>,
194 ) -> Self {
195 Self {
196 name: name.into(),
197 args,
198 thought_signature: Some(thought_signature.into()),
199 }
200 }
201
202 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T, FunctionCallError> {
204 match &self.args {
205 serde_json::Value::Object(obj) => {
206 if let Some(value) = obj.get(key) {
207 serde_json::from_value(value.clone()).with_context(|_| DeserializationSnafu {
208 key: key.to_string(),
209 })
210 } else {
211 Err(MissingParameterSnafu {
212 key: key.to_string(),
213 args: self.args.clone(),
214 }
215 .build())
216 }
217 }
218 _ => Err(ArgumentTypeMismatchSnafu {
219 actual: self.args.to_string(),
220 }
221 .build()),
222 }
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
228pub struct FunctionResponse {
229 pub name: String,
231 #[serde(skip_serializing_if = "Option::is_none")]
234 pub response: Option<serde_json::Value>,
235}
236
237impl FunctionResponse {
238 pub fn new(name: impl Into<String>, response: serde_json::Value) -> Self {
240 Self {
241 name: name.into(),
242 response: Some(response),
243 }
244 }
245
246 pub fn from_schema<Response>(
248 name: impl Into<String>,
249 response: Response,
250 ) -> Result<Self, serde_json::Error>
251 where
252 Response: JsonSchema + Serialize,
253 {
254 let json = serde_json::to_value(&response)?;
255 Ok(Self {
256 name: name.into(),
257 response: Some(json),
258 })
259 }
260
261 pub fn from_str(
263 name: impl Into<String>,
264 response: impl Into<String>,
265 ) -> Result<Self, serde_json::Error> {
266 let json = serde_json::from_str(&response.into())?;
267 Ok(Self {
268 name: name.into(),
269 response: Some(json),
270 })
271 }
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
276pub struct ToolConfig {
277 #[serde(skip_serializing_if = "Option::is_none")]
279 pub function_calling_config: Option<FunctionCallingConfig>,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
284pub struct FunctionCallingConfig {
285 pub mode: FunctionCallingMode,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
291#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
292pub enum FunctionCallingMode {
293 Auto,
295 Any,
297 None,
299}