1use schemars::{generate::SchemaSettings, JsonSchema, SchemaGenerator};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use snafu::{ResultExt, Snafu};
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8#[serde(untagged)]
9pub enum Tool {
10 Function {
12 function_declarations: Vec<FunctionDeclaration>,
14 },
15 GoogleSearch {
17 google_search: GoogleSearchConfig,
19 },
20 URLContext {
21 url_context: URLContextConfig,
22 },
23 GoogleMaps {
25 google_maps: GoogleMapsConfig,
27 },
28 CodeExecution {
30 #[serde(rename = "codeExecution")]
31 code_execution: CodeExecutionConfig,
32 },
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub struct GoogleSearchConfig {}
38
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41pub struct URLContextConfig {}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
45#[serde(rename_all = "camelCase")]
46pub struct GoogleMapsConfig {
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub enable_widget: Option<bool>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
54pub struct CodeExecutionConfig {}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
60#[serde(rename_all = "camelCase")]
61pub struct ExecutableCode {
62 pub language: CodeLanguage,
64 pub code: String,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
70#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
71pub enum CodeLanguage {
72 Python,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80#[serde(rename_all = "camelCase")]
81pub struct CodeExecutionResult {
82 pub outcome: CodeExecutionOutcome,
84 pub output: String,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
90#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
91pub enum CodeExecutionOutcome {
92 OutcomeOk,
94 OutcomeFailed,
96 OutcomeDeadlineExceeded,
98}
99
100impl Tool {
101 pub fn new(function_declaration: FunctionDeclaration) -> Self {
103 Self::Function {
104 function_declarations: vec![function_declaration],
105 }
106 }
107
108 pub fn with_functions(function_declarations: Vec<FunctionDeclaration>) -> Self {
110 Self::Function {
111 function_declarations,
112 }
113 }
114
115 pub fn google_search() -> Self {
117 Self::GoogleSearch {
118 google_search: GoogleSearchConfig {},
119 }
120 }
121
122 pub fn url_context() -> Self {
124 Self::URLContext {
125 url_context: URLContextConfig {},
126 }
127 }
128
129 pub fn google_maps(enable_widget: Option<bool>) -> Self {
131 Self::GoogleMaps {
132 google_maps: GoogleMapsConfig { enable_widget },
133 }
134 }
135
136 pub fn code_execution() -> Self {
141 Self::CodeExecution {
142 code_execution: CodeExecutionConfig {},
143 }
144 }
145}
146
147#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
149#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
150pub enum Behavior {
151 #[default]
154 Blocking,
155 NonBlocking,
159}
160
161#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
163pub struct FunctionDeclaration {
164 pub name: String,
166 pub description: String,
168 #[serde(skip_serializing_if = "Option::is_none")]
170 pub behavior: Option<Behavior>,
171 #[serde(skip_serializing_if = "Option::is_none")]
173 pub(crate) parameters: Option<Value>,
174 #[serde(skip_serializing_if = "Option::is_none")]
178 pub(crate) response: Option<Value>,
179}
180
181fn generate_parameters_schema<Parameters>() -> Value
183where
184 Parameters: JsonSchema + Serialize,
185{
186 let schema_generator = SchemaGenerator::new(SchemaSettings::openapi3().with(|s| {
188 s.inline_subschemas = true;
189 s.meta_schema = None;
190 }));
191
192 let mut schema = schema_generator.into_root_schema_for::<Parameters>();
193
194 schema.remove("title");
196 schema.to_value()
197}
198
199impl FunctionDeclaration {
200 pub fn new(
202 name: impl Into<String>,
203 description: impl Into<String>,
204 behavior: Option<Behavior>,
205 ) -> Self {
206 Self {
207 name: name.into(),
208 description: description.into(),
209 behavior,
210 ..Default::default()
211 }
212 }
213
214 pub fn with_parameters<Parameters>(mut self) -> Self
216 where
217 Parameters: JsonSchema + Serialize,
218 {
219 self.parameters = Some(generate_parameters_schema::<Parameters>());
220 self
221 }
222
223 pub fn with_response<Response>(mut self) -> Self
225 where
226 Response: JsonSchema + Serialize,
227 {
228 self.response = Some(generate_parameters_schema::<Response>());
229 self
230 }
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
235pub struct FunctionCall {
236 pub name: String,
238 pub args: serde_json::Value,
240 #[serde(skip_serializing_if = "Option::is_none")]
242 pub thought_signature: Option<String>,
243}
244
245#[derive(Debug, Snafu)]
246pub enum FunctionCallError {
247 #[snafu(display("failed to deserialize parameter '{key}'"))]
248 Deserialization {
249 source: serde_json::Error,
250 key: String,
251 },
252
253 #[snafu(display("parameter '{key}' is missing in arguments '{args}'"))]
254 MissingParameter {
255 key: String,
256 args: serde_json::Value,
257 },
258
259 #[snafu(display("arguments should be an object; actual: {actual}"))]
260 ArgumentTypeMismatch { actual: String },
261}
262
263impl FunctionCall {
264 pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
266 Self {
267 name: name.into(),
268 args,
269 thought_signature: None,
270 }
271 }
272
273 pub fn with_thought_signature(
275 name: impl Into<String>,
276 args: serde_json::Value,
277 thought_signature: impl Into<String>,
278 ) -> Self {
279 Self {
280 name: name.into(),
281 args,
282 thought_signature: Some(thought_signature.into()),
283 }
284 }
285
286 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T, FunctionCallError> {
288 match &self.args {
289 serde_json::Value::Object(obj) => {
290 if let Some(value) = obj.get(key) {
291 serde_json::from_value(value.clone()).with_context(|_| DeserializationSnafu {
292 key: key.to_string(),
293 })
294 } else {
295 Err(MissingParameterSnafu {
296 key: key.to_string(),
297 args: self.args.clone(),
298 }
299 .build())
300 }
301 }
302 _ => Err(ArgumentTypeMismatchSnafu {
303 actual: self.args.to_string(),
304 }
305 .build()),
306 }
307 }
308}
309
310#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
312pub struct FunctionResponse {
313 pub name: String,
315 #[serde(skip_serializing_if = "Option::is_none")]
318 pub response: Option<serde_json::Value>,
319}
320
321impl FunctionResponse {
322 pub fn new(name: impl Into<String>, response: serde_json::Value) -> Self {
324 Self {
325 name: name.into(),
326 response: Some(response),
327 }
328 }
329
330 pub fn from_schema<Response>(
332 name: impl Into<String>,
333 response: Response,
334 ) -> Result<Self, serde_json::Error>
335 where
336 Response: JsonSchema + Serialize,
337 {
338 let json = serde_json::to_value(&response)?;
339 Ok(Self {
340 name: name.into(),
341 response: Some(json),
342 })
343 }
344
345 pub fn from_str(
347 name: impl Into<String>,
348 response: impl Into<String>,
349 ) -> Result<Self, serde_json::Error> {
350 let json = serde_json::from_str(&response.into())?;
351 Ok(Self {
352 name: name.into(),
353 response: Some(json),
354 })
355 }
356}
357
358#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
360pub struct ToolConfig {
361 #[serde(skip_serializing_if = "Option::is_none")]
363 pub function_calling_config: Option<FunctionCallingConfig>,
364 #[serde(skip_serializing_if = "Option::is_none")]
366 pub retrieval_config: Option<RetrievalConfig>,
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
371pub struct FunctionCallingConfig {
372 pub mode: FunctionCallingMode,
374}
375
376#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
378#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
379pub enum FunctionCallingMode {
380 Auto,
382 Any,
384 None,
386}
387
388#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
390#[serde(rename_all = "camelCase")]
391pub struct RetrievalConfig {
392 #[serde(skip_serializing_if = "Option::is_none")]
394 pub lat_lng: Option<LatLng>,
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
399pub struct LatLng {
400 pub latitude: f64,
402 pub longitude: f64,
404}
405
406impl LatLng {
407 pub fn new(latitude: f64, longitude: f64) -> Self {
409 Self {
410 latitude,
411 longitude,
412 }
413 }
414}