1use serde_json::Value;
18use std::collections::HashMap;
19
20use crate::proto::{
22 self, CodeExecution as ProtoCodeExecution, CollectionsSearch as ProtoCollectionsSearch,
23 DocumentSearch as ProtoDocumentSearch, Function as ProtoFunction,
24 FunctionCall as ProtoFunctionCall, Mcp as ProtoMcp, ToolCall as ProtoToolCall, ToolCallStatus,
25 ToolCallType, ToolChoice as ProtoToolChoice, ToolMode, WebSearch as ProtoWebSearch,
26 XSearch as ProtoXSearch,
27};
28
29#[derive(Clone, Debug)]
55pub enum Tool {
56 Function(FunctionTool),
58 WebSearch(WebSearchTool),
60 XSearch(XSearchTool),
62 CodeExecution,
64 CollectionsSearch(CollectionsSearchTool),
66 Mcp(McpTool),
68 DocumentSearch(DocumentSearchTool),
70}
71
72impl Tool {
73 pub fn to_proto(&self) -> proto::Tool {
75 let tool = match self {
76 Tool::Function(f) => proto::tool::Tool::Function(f.to_proto()),
77 Tool::WebSearch(w) => proto::tool::Tool::WebSearch(w.to_proto()),
78 Tool::XSearch(x) => proto::tool::Tool::XSearch(x.to_proto()),
79 Tool::CodeExecution => proto::tool::Tool::CodeExecution(ProtoCodeExecution {}),
80 Tool::CollectionsSearch(c) => proto::tool::Tool::CollectionsSearch(c.to_proto()),
81 Tool::Mcp(m) => proto::tool::Tool::Mcp(m.to_proto()),
82 Tool::DocumentSearch(d) => proto::tool::Tool::DocumentSearch(d.to_proto()),
83 };
84
85 proto::Tool { tool: Some(tool) }
86 }
87}
88
89#[derive(Clone, Debug)]
91pub struct FunctionTool {
92 pub name: String,
94 pub description: String,
96 pub parameters: Value,
98}
99
100impl FunctionTool {
101 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
103 Self {
104 name: name.into(),
105 description: description.into(),
106 parameters: serde_json::json!({
107 "type": "object",
108 "properties": {},
109 }),
110 }
111 }
112
113 pub fn with_parameters(mut self, parameters: Value) -> Self {
115 self.parameters = parameters;
116 self
117 }
118
119 fn to_proto(&self) -> ProtoFunction {
120 ProtoFunction {
121 name: self.name.clone(),
122 description: self.description.clone(),
123 strict: false,
124 parameters: self.parameters.to_string(),
125 }
126 }
127}
128
129#[derive(Clone, Debug, Default)]
131pub struct WebSearchTool {
132 pub excluded_domains: Vec<String>,
134 pub allowed_domains: Vec<String>,
136 pub enable_image_understanding: Option<bool>,
138}
139
140impl WebSearchTool {
141 pub fn new() -> Self {
143 Self::default()
144 }
145
146 pub fn with_excluded_domains(mut self, domains: Vec<String>) -> Self {
148 self.excluded_domains = domains;
149 self
150 }
151
152 pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
154 self.allowed_domains = domains;
155 self
156 }
157
158 pub fn with_image_understanding(mut self, enable: bool) -> Self {
160 self.enable_image_understanding = Some(enable);
161 self
162 }
163
164 fn to_proto(&self) -> ProtoWebSearch {
165 ProtoWebSearch {
166 excluded_domains: self.excluded_domains.clone(),
167 allowed_domains: self.allowed_domains.clone(),
168 enable_image_understanding: self.enable_image_understanding,
169 }
170 }
171}
172
173#[derive(Clone, Debug, Default)]
175pub struct XSearchTool {
176 pub from_date: Option<prost_types::Timestamp>,
178 pub to_date: Option<prost_types::Timestamp>,
180 pub allowed_x_handles: Vec<String>,
182 pub excluded_x_handles: Vec<String>,
184 pub enable_image_understanding: Option<bool>,
186 pub enable_video_understanding: Option<bool>,
188}
189
190impl XSearchTool {
191 pub fn new() -> Self {
193 Self::default()
194 }
195
196 pub fn with_date_range(
198 mut self,
199 from: Option<prost_types::Timestamp>,
200 to: Option<prost_types::Timestamp>,
201 ) -> Self {
202 self.from_date = from;
203 self.to_date = to;
204 self
205 }
206
207 pub fn with_allowed_handles(mut self, handles: Vec<String>) -> Self {
209 self.allowed_x_handles = handles;
210 self
211 }
212
213 pub fn with_excluded_handles(mut self, handles: Vec<String>) -> Self {
215 self.excluded_x_handles = handles;
216 self
217 }
218
219 pub fn with_media_understanding(mut self, images: bool, videos: bool) -> Self {
221 self.enable_image_understanding = Some(images);
222 self.enable_video_understanding = Some(videos);
223 self
224 }
225
226 fn to_proto(&self) -> ProtoXSearch {
227 ProtoXSearch {
228 from_date: self.from_date,
229 to_date: self.to_date,
230 allowed_x_handles: self.allowed_x_handles.clone(),
231 excluded_x_handles: self.excluded_x_handles.clone(),
232 enable_image_understanding: self.enable_image_understanding,
233 enable_video_understanding: self.enable_video_understanding,
234 }
235 }
236}
237
238#[derive(Clone, Debug)]
240pub struct CollectionsSearchTool {
241 pub collection_ids: Vec<String>,
243 pub limit: Option<i32>,
245}
246
247impl CollectionsSearchTool {
248 pub fn new(collection_ids: Vec<String>) -> Self {
250 Self {
251 collection_ids,
252 limit: None,
253 }
254 }
255
256 pub fn with_limit(mut self, limit: i32) -> Self {
258 self.limit = Some(limit);
259 self
260 }
261
262 fn to_proto(&self) -> ProtoCollectionsSearch {
263 ProtoCollectionsSearch {
264 collection_ids: self.collection_ids.clone(),
265 limit: self.limit,
266 }
267 }
268}
269
270#[derive(Clone, Debug)]
272pub struct McpTool {
273 pub server_label: String,
275 pub server_description: String,
277 pub server_url: String,
279 pub allowed_tool_names: Vec<String>,
281 pub authorization: Option<String>,
283 pub extra_headers: HashMap<String, String>,
285}
286
287impl McpTool {
288 pub fn new(server_url: impl Into<String>) -> Self {
290 Self {
291 server_label: String::new(),
292 server_description: String::new(),
293 server_url: server_url.into(),
294 allowed_tool_names: Vec::new(),
295 authorization: None,
296 extra_headers: HashMap::new(),
297 }
298 }
299
300 pub fn with_label(mut self, label: impl Into<String>) -> Self {
302 self.server_label = label.into();
303 self
304 }
305
306 pub fn with_description(mut self, description: impl Into<String>) -> Self {
308 self.server_description = description.into();
309 self
310 }
311
312 pub fn with_allowed_tools(mut self, tools: Vec<String>) -> Self {
314 self.allowed_tool_names = tools;
315 self
316 }
317
318 pub fn with_authorization(mut self, token: impl Into<String>) -> Self {
320 self.authorization = Some(token.into());
321 self
322 }
323
324 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
326 self.extra_headers = headers;
327 self
328 }
329
330 fn to_proto(&self) -> ProtoMcp {
331 ProtoMcp {
332 server_label: self.server_label.clone(),
333 server_description: self.server_description.clone(),
334 server_url: self.server_url.clone(),
335 allowed_tool_names: self.allowed_tool_names.clone(),
336 authorization: self.authorization.clone(),
337 extra_headers: self.extra_headers.clone(),
338 }
339 }
340}
341
342#[derive(Clone, Debug, Default)]
344pub struct DocumentSearchTool {
345 pub limit: Option<i32>,
347}
348
349impl DocumentSearchTool {
350 pub fn new() -> Self {
352 Self::default()
353 }
354
355 pub fn with_limit(mut self, limit: i32) -> Self {
357 self.limit = Some(limit);
358 self
359 }
360
361 fn to_proto(&self) -> ProtoDocumentSearch {
362 ProtoDocumentSearch { limit: self.limit }
363 }
364}
365
366#[derive(Clone, Debug)]
371pub enum ToolChoice {
372 Auto,
374 Required,
376 Function(String),
378}
379
380impl ToolChoice {
381 pub fn to_proto(&self) -> ProtoToolChoice {
383 let tool_choice = match self {
384 ToolChoice::Auto => proto::tool_choice::ToolChoice::Mode(ToolMode::Auto as i32),
385 ToolChoice::Required => proto::tool_choice::ToolChoice::Mode(ToolMode::Required as i32),
386 ToolChoice::Function(name) => {
387 proto::tool_choice::ToolChoice::FunctionName(name.clone())
388 }
389 };
390
391 ProtoToolChoice {
392 tool_choice: Some(tool_choice),
393 }
394 }
395}
396
397#[derive(Clone, Debug)]
402pub struct ToolCall {
403 pub id: String,
405 pub call_type: ToolCallKind,
407 pub status: ToolCallStatusKind,
409 pub error_message: Option<String>,
411 pub function: FunctionCall,
413}
414
415impl ToolCall {
416 pub fn from_proto(proto: ProtoToolCall) -> Option<Self> {
418 let function = match proto.tool? {
419 proto::tool_call::Tool::Function(f) => FunctionCall {
420 name: f.name,
421 arguments: f.arguments,
422 },
423 };
424
425 Some(Self {
426 id: proto.id,
427 call_type: ToolCallKind::from_proto(proto.r#type),
428 status: ToolCallStatusKind::from_proto(proto.status),
429 error_message: proto.error_message,
430 function,
431 })
432 }
433
434 pub fn to_proto(&self) -> ProtoToolCall {
436 ProtoToolCall {
437 id: self.id.clone(),
438 r#type: self.call_type.to_proto() as i32,
439 status: self.status.to_proto() as i32,
440 error_message: self.error_message.clone(),
441 tool: Some(proto::tool_call::Tool::Function(ProtoFunctionCall {
442 name: self.function.name.clone(),
443 arguments: self.function.arguments.clone(),
444 })),
445 }
446 }
447}
448
449#[derive(Clone, Debug, PartialEq, Eq)]
451pub enum ToolCallKind {
452 ClientSideTool,
454 WebSearchTool,
456 XSearchTool,
458 CodeExecutionTool,
460 CollectionsSearchTool,
462 McpTool,
464 Unknown,
466}
467
468impl ToolCallKind {
469 fn from_proto(value: i32) -> Self {
470 match value {
471 x if x == ToolCallType::ClientSideTool as i32 => ToolCallKind::ClientSideTool,
472 x if x == ToolCallType::WebSearchTool as i32 => ToolCallKind::WebSearchTool,
473 x if x == ToolCallType::XSearchTool as i32 => ToolCallKind::XSearchTool,
474 x if x == ToolCallType::CodeExecutionTool as i32 => ToolCallKind::CodeExecutionTool,
475 x if x == ToolCallType::CollectionsSearchTool as i32 => {
476 ToolCallKind::CollectionsSearchTool
477 }
478 x if x == ToolCallType::McpTool as i32 => ToolCallKind::McpTool,
479 _ => ToolCallKind::Unknown,
480 }
481 }
482
483 fn to_proto(&self) -> ToolCallType {
484 match self {
485 ToolCallKind::ClientSideTool => ToolCallType::ClientSideTool,
486 ToolCallKind::WebSearchTool => ToolCallType::WebSearchTool,
487 ToolCallKind::XSearchTool => ToolCallType::XSearchTool,
488 ToolCallKind::CodeExecutionTool => ToolCallType::CodeExecutionTool,
489 ToolCallKind::CollectionsSearchTool => ToolCallType::CollectionsSearchTool,
490 ToolCallKind::McpTool => ToolCallType::McpTool,
491 ToolCallKind::Unknown => ToolCallType::Invalid,
492 }
493 }
494}
495
496#[derive(Clone, Debug, PartialEq, Eq)]
498pub enum ToolCallStatusKind {
499 InProgress,
501 Completed,
503 Incomplete,
505 Failed,
507}
508
509impl ToolCallStatusKind {
510 fn from_proto(value: i32) -> Self {
511 match value {
512 x if x == ToolCallStatus::InProgress as i32 => ToolCallStatusKind::InProgress,
513 x if x == ToolCallStatus::Completed as i32 => ToolCallStatusKind::Completed,
514 x if x == ToolCallStatus::Incomplete as i32 => ToolCallStatusKind::Incomplete,
515 x if x == ToolCallStatus::Failed as i32 => ToolCallStatusKind::Failed,
516 _ => ToolCallStatusKind::InProgress, }
518 }
519
520 fn to_proto(&self) -> ToolCallStatus {
521 match self {
522 ToolCallStatusKind::InProgress => ToolCallStatus::InProgress,
523 ToolCallStatusKind::Completed => ToolCallStatus::Completed,
524 ToolCallStatusKind::Incomplete => ToolCallStatus::Incomplete,
525 ToolCallStatusKind::Failed => ToolCallStatus::Failed,
526 }
527 }
528}
529
530#[derive(Clone, Debug)]
532pub struct FunctionCall {
533 pub name: String,
535 pub arguments: String,
537}
538
539impl FunctionCall {
540 pub fn parse_arguments<T: serde::de::DeserializeOwned>(&self) -> serde_json::Result<T> {
542 serde_json::from_str(&self.arguments)
543 }
544
545 pub fn arguments_json(&self) -> serde_json::Result<Value> {
547 serde_json::from_str(&self.arguments)
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554 use serde_json::json;
555
556 #[test]
557 fn test_function_tool_creation() {
558 let tool = FunctionTool::new("get_weather", "Get current weather");
559 assert_eq!(tool.name, "get_weather");
560 assert_eq!(tool.description, "Get current weather");
561 }
562
563 #[test]
564 fn test_function_tool_with_parameters() {
565 let params = json!({
566 "type": "object",
567 "properties": {
568 "location": {"type": "string"}
569 }
570 });
571
572 let tool = FunctionTool::new("get_weather", "Get weather").with_parameters(params.clone());
573
574 assert_eq!(tool.parameters, params);
575 }
576
577 #[test]
578 fn test_web_search_tool() {
579 let tool = WebSearchTool::new().with_excluded_domains(vec!["spam.com".to_string()]);
580 assert_eq!(tool.excluded_domains.len(), 1);
581 }
582
583 #[test]
584 fn test_x_search_tool() {
585 let tool = XSearchTool::new().with_allowed_handles(vec!["@rustlang".to_string()]);
586 assert_eq!(tool.allowed_x_handles.len(), 1);
587 }
588
589 #[test]
590 fn test_tool_choice_auto() {
591 let choice = ToolChoice::Auto;
592 assert!(matches!(choice, ToolChoice::Auto));
593 }
594
595 #[test]
596 fn test_tool_choice_required() {
597 let choice = ToolChoice::Required;
598 assert!(matches!(choice, ToolChoice::Required));
599 }
600
601 #[test]
602 fn test_tool_choice_function() {
603 let choice = ToolChoice::Function("my_function".to_string());
604 match choice {
605 ToolChoice::Function(name) => assert_eq!(name, "my_function"),
606 _ => panic!("Expected Function variant"),
607 }
608 }
609
610 #[test]
611 fn test_function_call_parse_arguments() {
612 let call = FunctionCall {
613 name: "test_fn".to_string(),
614 arguments: r#"{"param": "value"}"#.to_string(),
615 };
616
617 let json = call.arguments_json().unwrap();
618 assert_eq!(json["param"], "value");
619 }
620
621 #[test]
622 fn test_mcp_tool() {
623 let tool = McpTool::new("https://example.com/mcp").with_label("My MCP Server");
624 assert_eq!(tool.server_url, "https://example.com/mcp");
625 assert_eq!(tool.server_label, "My MCP Server");
626 }
627
628 #[test]
629 fn test_collections_search_tool() {
630 let tool = CollectionsSearchTool::new(vec!["coll_1".to_string()]).with_limit(10);
631
632 assert_eq!(tool.collection_ids.len(), 1);
633 assert_eq!(tool.limit, Some(10));
634 }
635
636 #[test]
637 fn test_document_search_tool() {
638 let tool = DocumentSearchTool::new().with_limit(20);
639
640 assert_eq!(tool.limit, Some(20));
641 }
642}