1use std::collections::HashMap;
17
18use serde::Serialize;
19use validator::*;
20
21use super::model_validate::validate_json_schema_value;
22use crate::tool::web_search::request::{ContentSize, SearchEngine, SearchRecencyFilter};
23
24#[derive(Debug, Clone, Serialize)]
54pub struct ThinkingType {
55 #[serde(rename = "type")]
57 pub mode: ThinkingMode,
58
59 #[serde(skip_serializing_if = "Option::is_none")]
65 pub clear_thinking: Option<bool>,
66}
67
68#[derive(Debug, Clone, Serialize)]
70#[serde(rename_all = "lowercase")]
71pub enum ThinkingMode {
72 Enabled,
73 Disabled,
74}
75
76impl ThinkingType {
77 pub fn enabled() -> Self {
79 Self {
80 mode: ThinkingMode::Enabled,
81 clear_thinking: None,
82 }
83 }
84
85 pub fn disabled() -> Self {
87 Self {
88 mode: ThinkingMode::Disabled,
89 clear_thinking: None,
90 }
91 }
92
93 pub fn with_clear_thinking(mut self, clear: bool) -> Self {
98 self.clear_thinking = Some(clear);
99 self
100 }
101}
102
103#[derive(Debug, Clone, Serialize)]
144#[serde(tag = "type")]
145#[serde(rename_all = "snake_case")]
146pub enum Tools {
147 Function { function: Function },
153
154 Retrieval { retrieval: Retrieval },
159
160 WebSearch { web_search: WebSearch },
165
166 #[serde(rename = "mcp")]
171 MCP { mcp: MCP },
172}
173
174#[derive(Debug, Clone, Serialize, Validate)]
184pub struct Function {
185 #[validate(length(min = 1, max = 64))]
187 pub name: String,
188
189 pub description: String,
191
192 #[serde(skip_serializing_if = "Option::is_none")]
196 #[validate(custom(function = "validate_json_schema_value"))]
197 pub parameters: Option<serde_json::Value>,
198}
199
200impl Function {
201 pub fn new(
223 name: impl Into<String>,
224 description: impl Into<String>,
225 parameters: serde_json::Value,
226 ) -> Self {
227 Self {
228 name: name.into(),
229 description: description.into(),
230 parameters: Some(parameters),
231 }
232 }
233}
234
235#[derive(Debug, Clone, Serialize)]
240pub struct Retrieval {
241 knowledge_id: String,
242 #[serde(skip_serializing_if = "Option::is_none")]
243 prompt_template: Option<String>,
244}
245
246impl Retrieval {
247 pub fn new(knowledge_id: impl Into<String>, prompt_template: Option<String>) -> Self {
249 Self {
250 knowledge_id: knowledge_id.into(),
251 prompt_template,
252 }
253 }
254}
255
256#[derive(Debug, Clone, Serialize, PartialEq)]
260#[serde(rename_all = "snake_case")]
261pub enum ResultSequence {
262 Before,
263 After,
264}
265
266#[derive(Debug, Clone, Serialize, Validate)]
269pub struct WebSearch {
270 pub search_engine: SearchEngine,
273
274 #[serde(skip_serializing_if = "Option::is_none")]
276 pub enable: Option<bool>,
277
278 #[serde(skip_serializing_if = "Option::is_none")]
280 pub search_query: Option<String>,
281
282 #[serde(skip_serializing_if = "Option::is_none")]
285 pub search_intent: Option<bool>,
286
287 #[serde(skip_serializing_if = "Option::is_none")]
289 #[validate(range(min = 1, max = 50))]
290 pub count: Option<u32>,
291
292 #[serde(skip_serializing_if = "Option::is_none")]
294 pub search_domain_filter: Option<String>,
295
296 #[serde(skip_serializing_if = "Option::is_none")]
298 pub search_recency_filter: Option<SearchRecencyFilter>,
299
300 #[serde(skip_serializing_if = "Option::is_none")]
302 pub content_size: Option<ContentSize>,
303
304 #[serde(skip_serializing_if = "Option::is_none")]
306 pub result_sequence: Option<ResultSequence>,
307
308 #[serde(skip_serializing_if = "Option::is_none")]
310 pub search_result: Option<bool>,
311
312 #[serde(skip_serializing_if = "Option::is_none")]
314 pub require_search: Option<bool>,
315
316 #[serde(skip_serializing_if = "Option::is_none")]
318 pub search_prompt: Option<String>,
319}
320
321impl WebSearch {
322 pub fn new(search_engine: SearchEngine) -> Self {
325 Self {
326 search_engine,
327 enable: None,
328 search_query: None,
329 search_intent: None,
330 count: None,
331 search_domain_filter: None,
332 search_recency_filter: None,
333 content_size: None,
334 result_sequence: None,
335 search_result: None,
336 require_search: None,
337 search_prompt: None,
338 }
339 }
340
341 pub fn with_enable(mut self, enable: bool) -> Self {
343 self.enable = Some(enable);
344 self
345 }
346 pub fn with_search_query(mut self, query: impl Into<String>) -> Self {
348 self.search_query = Some(query.into());
349 self
350 }
351 pub fn with_search_intent(mut self, search_intent: bool) -> Self {
353 self.search_intent = Some(search_intent);
354 self
355 }
356 pub fn with_count(mut self, count: u32) -> Self {
358 self.count = Some(count);
359 self
360 }
361 pub fn with_search_domain_filter(mut self, domain: impl Into<String>) -> Self {
363 self.search_domain_filter = Some(domain.into());
364 self
365 }
366 pub fn with_search_recency_filter(mut self, filter: SearchRecencyFilter) -> Self {
368 self.search_recency_filter = Some(filter);
369 self
370 }
371 pub fn with_content_size(mut self, size: ContentSize) -> Self {
373 self.content_size = Some(size);
374 self
375 }
376 pub fn with_result_sequence(mut self, seq: ResultSequence) -> Self {
378 self.result_sequence = Some(seq);
379 self
380 }
381 pub fn with_search_result(mut self, enable: bool) -> Self {
383 self.search_result = Some(enable);
384 self
385 }
386 pub fn with_require_search(mut self, require: bool) -> Self {
388 self.require_search = Some(require);
389 self
390 }
391 pub fn with_search_prompt(mut self, prompt: impl Into<String>) -> Self {
393 self.search_prompt = Some(prompt.into());
394 self
395 }
396}
397#[derive(Debug, Clone, Serialize, Validate)]
401pub struct MCP {
402 #[validate(length(min = 1))]
405 pub server_label: String,
406
407 #[serde(skip_serializing_if = "Option::is_none")]
409 #[validate(url)]
410 pub server_url: Option<String>,
411
412 #[serde(skip_serializing_if = "Option::is_none")]
414 pub transport_type: Option<MCPTransportType>,
415
416 #[serde(skip_serializing_if = "Vec::is_empty")]
418 pub allowed_tools: Vec<String>,
419
420 #[serde(skip_serializing_if = "Option::is_none")]
422 pub headers: Option<HashMap<String, String>>,
423}
424
425impl MCP {
426 pub fn new(server_label: impl Into<String>) -> Self {
429 Self {
430 server_label: server_label.into(),
431 server_url: None,
432 transport_type: Some(MCPTransportType::StreamableHttp),
433 allowed_tools: Vec::new(),
434 headers: None,
435 }
436 }
437
438 pub fn with_server_url(mut self, url: impl Into<String>) -> Self {
440 self.server_url = Some(url.into());
441 self
442 }
443 pub fn with_transport_type(mut self, transport: MCPTransportType) -> Self {
445 self.transport_type = Some(transport);
446 self
447 }
448 pub fn with_allowed_tools(mut self, tools: impl Into<Vec<String>>) -> Self {
450 self.allowed_tools = tools.into();
451 self
452 }
453 pub fn add_allowed_tool(mut self, tool: impl Into<String>) -> Self {
455 self.allowed_tools.push(tool.into());
456 self
457 }
458 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
460 self.headers = Some(headers);
461 self
462 }
463 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
465 let mut map = self.headers.unwrap_or_default();
466 map.insert(key.into(), value.into());
467 self.headers = Some(map);
468 self
469 }
470}
471
472#[derive(Debug, Clone, Serialize, PartialEq)]
474#[serde(rename_all = "kebab-case")]
475pub enum MCPTransportType {
476 Sse,
477 StreamableHttp,
478}
479
480#[derive(Debug, Clone, Copy, Serialize)]
490#[serde(rename_all = "snake_case")]
491#[serde(tag = "type")]
492pub enum ResponseFormat {
493 Text,
495 JsonObject,
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502
503 #[test]
505 fn test_thinking_type_enabled_serialization() {
506 let thinking = ThinkingType::enabled();
507 let json = serde_json::to_string(&thinking).unwrap();
508 assert!(json.contains("\"type\":\"enabled\""));
509 assert!(!json.contains("clear_thinking"));
510 }
511
512 #[test]
513 fn test_thinking_type_disabled_serialization() {
514 let thinking = ThinkingType::disabled();
515 let json = serde_json::to_string(&thinking).unwrap();
516 assert!(json.contains("\"type\":\"disabled\""));
517 assert!(!json.contains("clear_thinking"));
518 }
519
520 #[test]
521 fn test_thinking_type_with_clear_thinking_serialization() {
522 let thinking = ThinkingType::enabled().with_clear_thinking(false);
523 let json = serde_json::to_string(&thinking).unwrap();
524 assert!(json.contains("\"type\":\"enabled\""));
525 assert!(json.contains("\"clear_thinking\":false"));
526 }
527
528 #[test]
529 fn test_thinking_type_disabled_with_clear_thinking() {
530 let thinking = ThinkingType::disabled().with_clear_thinking(true);
531 let json = serde_json::to_string(&thinking).unwrap();
532 assert!(json.contains("\"type\":\"disabled\""));
533 assert!(json.contains("\"clear_thinking\":true"));
534 }
535
536 #[test]
538 fn test_function_new() {
539 let params = serde_json::json!({
540 "type": "object",
541 "properties": {
542 "name": {"type": "string"}
543 }
544 });
545 let func = Function::new("test_func", "A test function", params);
546
547 assert_eq!(func.name, "test_func");
548 assert_eq!(func.description, "A test function");
549 assert!(func.parameters.is_some());
550 }
551
552 #[test]
553 fn test_function_serialization() {
554 let params = serde_json::json!({
555 "type": "object",
556 "properties": {
557 "value": {"type": "number"}
558 }
559 });
560 let func = Function::new("test_func", "A test function", params);
561 let json = serde_json::to_string(&func).unwrap();
562
563 assert!(json.contains("\"name\":\"test_func\""));
564 assert!(json.contains("\"description\":\"A test function\""));
565 assert!(json.contains("\"properties\""));
566 }
567
568 #[test]
569 fn test_function_validation() {
570 let params = serde_json::json!({
571 "type": "object",
572 "properties": {}
573 });
574 let func = Function::new("valid_name", "Description", params.clone());
575
576 assert!(func.validate().is_ok());
578
579 let invalid_name = Function::new("", "Description", params.clone());
580 assert!(invalid_name.validate().is_err());
581
582 let long_name = Function::new("a".repeat(65), "Description", params);
583 assert!(long_name.validate().is_err());
584 }
585
586 #[test]
588 fn test_retrieval_new() {
589 let retrieval = Retrieval::new("kb_123", Some("template".to_string()));
590 assert_eq!(retrieval.knowledge_id, "kb_123");
591 assert_eq!(retrieval.prompt_template, Some("template".to_string()));
592 }
593
594 #[test]
595 fn test_retrieval_new_without_template() {
596 let retrieval = Retrieval::new("kb_456", None);
597 assert_eq!(retrieval.knowledge_id, "kb_456");
598 assert!(retrieval.prompt_template.is_none());
599 }
600
601 #[test]
602 fn test_retrieval_serialization() {
603 let retrieval = Retrieval::new("kb_789", None);
604 let json = serde_json::to_string(&retrieval).unwrap();
605 assert!(json.contains("\"knowledge_id\":\"kb_789\""));
606 assert!(!json.contains("prompt_template"));
608 }
609
610 #[test]
612 fn test_web_search_new() {
613 let web_search = WebSearch::new(SearchEngine::SearchPro);
614 assert_eq!(web_search.search_engine, SearchEngine::SearchPro);
615 assert!(web_search.enable.is_none());
616 }
617
618 #[test]
619 fn test_web_search_with_enable() {
620 let web_search = WebSearch::new(SearchEngine::SearchPro).with_enable(true);
621 assert_eq!(web_search.enable, Some(true));
622 }
623
624 #[test]
625 fn test_web_search_with_search_query() {
626 let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_query("test query");
627 assert_eq!(web_search.search_query, Some("test query".to_string()));
628 }
629
630 #[test]
631 fn test_web_search_with_search_intent() {
632 let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_intent(true);
633 assert_eq!(web_search.search_intent, Some(true));
634 }
635
636 #[test]
637 fn test_web_search_with_count() {
638 let web_search = WebSearch::new(SearchEngine::SearchPro).with_count(10);
639 assert_eq!(web_search.count, Some(10));
640 }
641
642 #[test]
643 fn test_web_search_with_search_domain_filter() {
644 let web_search =
645 WebSearch::new(SearchEngine::SearchPro).with_search_domain_filter("example.com");
646 assert_eq!(
647 web_search.search_domain_filter,
648 Some("example.com".to_string())
649 );
650 }
651
652 #[test]
653 fn test_web_search_with_search_recency_filter() {
654 let filter = SearchRecencyFilter::OneDay;
655 let web_search =
656 WebSearch::new(SearchEngine::SearchPro).with_search_recency_filter(filter.clone());
657 assert_eq!(web_search.search_recency_filter, Some(filter));
658 }
659
660 #[test]
661 fn test_web_search_with_content_size() {
662 let size = ContentSize::Medium;
663 let web_search = WebSearch::new(SearchEngine::SearchPro).with_content_size(size.clone());
664 assert_eq!(web_search.content_size, Some(size));
665 }
666
667 #[test]
668 fn test_web_search_with_result_sequence() {
669 let seq = ResultSequence::After;
670 let web_search = WebSearch::new(SearchEngine::SearchPro).with_result_sequence(seq.clone());
671 assert_eq!(web_search.result_sequence, Some(seq));
672 }
673
674 #[test]
675 fn test_web_search_with_search_result() {
676 let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_result(true);
677 assert_eq!(web_search.search_result, Some(true));
678 }
679
680 #[test]
681 fn test_web_search_with_require_search() {
682 let web_search = WebSearch::new(SearchEngine::SearchPro).with_require_search(true);
683 assert_eq!(web_search.require_search, Some(true));
684 }
685
686 #[test]
687 fn test_web_search_with_search_prompt() {
688 let web_search =
689 WebSearch::new(SearchEngine::SearchPro).with_search_prompt("custom prompt");
690 assert_eq!(web_search.search_prompt, Some("custom prompt".to_string()));
691 }
692
693 #[test]
694 fn test_web_search_serialization() {
695 let web_search = WebSearch::new(SearchEngine::SearchPro)
696 .with_enable(true)
697 .with_count(5);
698 let json = serde_json::to_string(&web_search).unwrap();
699 assert!(json.contains("\"search_engine\""));
700 assert!(json.contains("\"enable\":true"));
701 assert!(json.contains("\"count\":5"));
702 }
703
704 #[test]
706 fn test_mcp_new() {
707 let mcp = MCP::new("server_label");
708 assert_eq!(mcp.server_label, "server_label");
709 assert_eq!(mcp.transport_type, Some(MCPTransportType::StreamableHttp));
710 assert!(mcp.allowed_tools.is_empty());
711 }
712
713 #[test]
714 fn test_mcp_with_server_url() {
715 let mcp = MCP::new("server_label").with_server_url("https://example.com");
716 assert_eq!(mcp.server_url, Some("https://example.com".to_string()));
717 }
718
719 #[test]
720 fn test_mcp_with_transport_type() {
721 let mcp = MCP::new("server_label").with_transport_type(MCPTransportType::Sse);
722 assert_eq!(mcp.transport_type, Some(MCPTransportType::Sse));
723 }
724
725 #[test]
726 fn test_mcp_with_allowed_tools() {
727 let mcp = MCP::new("server_label")
728 .with_allowed_tools(vec!["tool1".to_string(), "tool2".to_string()]);
729 assert_eq!(mcp.allowed_tools.len(), 2);
730 assert!(mcp.allowed_tools.contains(&"tool1".to_string()));
731 }
732
733 #[test]
734 fn test_mcp_add_allowed_tool() {
735 let mcp = MCP::new("server_label")
736 .add_allowed_tool("tool1")
737 .add_allowed_tool("tool2");
738 assert_eq!(mcp.allowed_tools.len(), 2);
739 }
740
741 #[test]
742 fn test_mcp_with_headers() {
743 let mut headers = HashMap::new();
744 headers.insert("Authorization".to_string(), "Bearer token".to_string());
745 let mcp = MCP::new("server_label").with_headers(headers.clone());
746 assert_eq!(mcp.headers, Some(headers));
747 }
748
749 #[test]
750 fn test_mcp_with_header() {
751 let mcp = MCP::new("server_label").with_header("Authorization", "Bearer token");
752 let headers = mcp.headers.unwrap();
753 assert_eq!(
754 headers.get("Authorization"),
755 Some(&"Bearer token".to_string())
756 );
757 }
758
759 #[test]
760 fn test_mcp_serialization() {
761 let mcp = MCP::new("server_label")
762 .with_server_url("https://example.com")
763 .with_transport_type(MCPTransportType::Sse);
764 let json = serde_json::to_string(&mcp).unwrap();
765 assert!(json.contains("\"server_label\":\"server_label\""));
766 assert!(json.contains("\"server_url\":\"https://example.com\""));
767 assert!(json.contains("\"transport_type\":\"sse\""));
768 assert!(!json.contains("allowed_tools"));
770 }
771
772 #[test]
774 fn test_mcp_transport_type_sse_serialization() {
775 let transport = MCPTransportType::Sse;
776 let json = serde_json::to_string(&transport).unwrap();
777 assert!(json.contains("\"sse\""));
778 }
779
780 #[test]
781 fn test_mcp_transport_type_streamable_http_serialization() {
782 let transport = MCPTransportType::StreamableHttp;
783 let json = serde_json::to_string(&transport).unwrap();
784 assert!(json.contains("\"streamable-http\""));
785 }
786
787 #[test]
789 fn test_response_format_text_serialization() {
790 let format = ResponseFormat::Text;
791 let json = serde_json::to_string(&format).unwrap();
792 assert!(json.contains("\"type\":\"text\""));
793 }
794
795 #[test]
796 fn test_response_format_json_object_serialization() {
797 let format = ResponseFormat::JsonObject;
798 let json = serde_json::to_string(&format).unwrap();
799 assert!(json.contains("\"type\":\"json_object\""));
800 }
801
802 #[test]
804 fn test_tools_function_serialization() {
805 let func = Function::new("test_func", "test", serde_json::json!({}));
806 let tools = Tools::Function { function: func };
807 let json = serde_json::to_string(&tools).unwrap();
808 assert!(json.contains("\"type\":\"function\""));
809 assert!(json.contains("\"name\":\"test_func\""));
810 }
811
812 #[test]
813 fn test_tools_retrieval_serialization() {
814 let retrieval = Retrieval::new("kb_123", None);
815 let tools = Tools::Retrieval { retrieval };
816 let json = serde_json::to_string(&tools).unwrap();
817 assert!(json.contains("\"type\":\"retrieval\""));
818 assert!(json.contains("\"knowledge_id\":\"kb_123\""));
819 }
820
821 #[test]
822 fn test_tools_web_search_serialization() {
823 let web_search = WebSearch::new(SearchEngine::SearchPro);
824 let tools = Tools::WebSearch { web_search };
825 let json = serde_json::to_string(&tools).unwrap();
826 assert!(json.contains("\"type\":\"web_search\""));
827 assert!(json.contains("\"search_engine\""));
828 }
829
830 #[test]
831 fn test_tools_mcp_serialization() {
832 let mcp = MCP::new("server_label");
833 let tools = Tools::MCP { mcp };
834 let json = serde_json::to_string(&tools).unwrap();
835 eprintln!("JSON: {}", json);
836 assert!(json.contains("\"type\":\"mcp\""));
837 assert!(json.contains("\"server_label\":\"server_label\""));
838 }
839
840 #[test]
842 fn test_result_sequence_before_serialization() {
843 let seq = ResultSequence::Before;
844 let json = serde_json::to_string(&seq).unwrap();
845 assert!(json.contains("\"before\""));
846 }
847
848 #[test]
849 fn test_result_sequence_after_serialization() {
850 let seq = ResultSequence::After;
851 let json = serde_json::to_string(&seq).unwrap();
852 assert!(json.contains("\"after\""));
853 }
854}