use std::collections::HashMap;
use serde::Serialize;
use validator::*;
use super::model_validate::validate_json_schema_value;
use crate::tool::web_search::request::{ContentSize, SearchEngine, SearchRecencyFilter};
#[derive(Debug, Clone, Serialize)]
pub struct ThinkingType {
#[serde(rename = "type")]
pub mode: ThinkingMode,
#[serde(skip_serializing_if = "Option::is_none")]
pub clear_thinking: Option<bool>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum ThinkingMode {
Enabled,
Disabled,
}
impl ThinkingType {
pub fn enabled() -> Self {
Self {
mode: ThinkingMode::Enabled,
clear_thinking: None,
}
}
pub fn disabled() -> Self {
Self {
mode: ThinkingMode::Disabled,
clear_thinking: None,
}
}
pub fn with_clear_thinking(mut self, clear: bool) -> Self {
self.clear_thinking = Some(clear);
self
}
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum Tools {
Function { function: Function },
Retrieval { retrieval: Retrieval },
WebSearch { web_search: WebSearch },
#[serde(rename = "mcp")]
MCP { mcp: MCP },
}
#[derive(Debug, Clone, Serialize, Validate)]
pub struct Function {
#[validate(length(min = 1, max = 64))]
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_json_schema_value"))]
pub parameters: Option<serde_json::Value>,
}
impl Function {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
parameters: serde_json::Value,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: Some(parameters),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct Retrieval {
knowledge_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
prompt_template: Option<String>,
}
impl Retrieval {
pub fn new(knowledge_id: impl Into<String>, prompt_template: Option<String>) -> Self {
Self {
knowledge_id: knowledge_id.into(),
prompt_template,
}
}
}
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ResultSequence {
Before,
After,
}
#[derive(Debug, Clone, Serialize, Validate)]
pub struct WebSearch {
pub search_engine: SearchEngine,
#[serde(skip_serializing_if = "Option::is_none")]
pub enable: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_query: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_intent: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 1, max = 50))]
pub count: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_domain_filter: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_recency_filter: Option<SearchRecencyFilter>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content_size: Option<ContentSize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result_sequence: Option<ResultSequence>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_result: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub require_search: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub search_prompt: Option<String>,
}
impl WebSearch {
pub fn new(search_engine: SearchEngine) -> Self {
Self {
search_engine,
enable: None,
search_query: None,
search_intent: None,
count: None,
search_domain_filter: None,
search_recency_filter: None,
content_size: None,
result_sequence: None,
search_result: None,
require_search: None,
search_prompt: None,
}
}
pub fn with_enable(mut self, enable: bool) -> Self {
self.enable = Some(enable);
self
}
pub fn with_search_query(mut self, query: impl Into<String>) -> Self {
self.search_query = Some(query.into());
self
}
pub fn with_search_intent(mut self, search_intent: bool) -> Self {
self.search_intent = Some(search_intent);
self
}
pub fn with_count(mut self, count: u32) -> Self {
self.count = Some(count);
self
}
pub fn with_search_domain_filter(mut self, domain: impl Into<String>) -> Self {
self.search_domain_filter = Some(domain.into());
self
}
pub fn with_search_recency_filter(mut self, filter: SearchRecencyFilter) -> Self {
self.search_recency_filter = Some(filter);
self
}
pub fn with_content_size(mut self, size: ContentSize) -> Self {
self.content_size = Some(size);
self
}
pub fn with_result_sequence(mut self, seq: ResultSequence) -> Self {
self.result_sequence = Some(seq);
self
}
pub fn with_search_result(mut self, enable: bool) -> Self {
self.search_result = Some(enable);
self
}
pub fn with_require_search(mut self, require: bool) -> Self {
self.require_search = Some(require);
self
}
pub fn with_search_prompt(mut self, prompt: impl Into<String>) -> Self {
self.search_prompt = Some(prompt.into());
self
}
}
#[derive(Debug, Clone, Serialize, Validate)]
pub struct MCP {
#[validate(length(min = 1))]
pub server_label: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(url)]
pub server_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub transport_type: Option<MCPTransportType>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub allowed_tools: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
}
impl MCP {
pub fn new(server_label: impl Into<String>) -> Self {
Self {
server_label: server_label.into(),
server_url: None,
transport_type: Some(MCPTransportType::StreamableHttp),
allowed_tools: Vec::new(),
headers: None,
}
}
pub fn with_server_url(mut self, url: impl Into<String>) -> Self {
self.server_url = Some(url.into());
self
}
pub fn with_transport_type(mut self, transport: MCPTransportType) -> Self {
self.transport_type = Some(transport);
self
}
pub fn with_allowed_tools(mut self, tools: impl Into<Vec<String>>) -> Self {
self.allowed_tools = tools.into();
self
}
pub fn add_allowed_tool(mut self, tool: impl Into<String>) -> Self {
self.allowed_tools.push(tool.into());
self
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers = Some(headers);
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
let mut map = self.headers.unwrap_or_default();
map.insert(key.into(), value.into());
self.headers = Some(map);
self
}
}
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(rename_all = "kebab-case")]
pub enum MCPTransportType {
Sse,
StreamableHttp,
}
#[derive(Debug, Clone, Copy, Serialize)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum ResponseFormat {
Text,
JsonObject,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_thinking_type_enabled_serialization() {
let thinking = ThinkingType::enabled();
let json = serde_json::to_string(&thinking).unwrap();
assert!(json.contains("\"type\":\"enabled\""));
assert!(!json.contains("clear_thinking"));
}
#[test]
fn test_thinking_type_disabled_serialization() {
let thinking = ThinkingType::disabled();
let json = serde_json::to_string(&thinking).unwrap();
assert!(json.contains("\"type\":\"disabled\""));
assert!(!json.contains("clear_thinking"));
}
#[test]
fn test_thinking_type_with_clear_thinking_serialization() {
let thinking = ThinkingType::enabled().with_clear_thinking(false);
let json = serde_json::to_string(&thinking).unwrap();
assert!(json.contains("\"type\":\"enabled\""));
assert!(json.contains("\"clear_thinking\":false"));
}
#[test]
fn test_thinking_type_disabled_with_clear_thinking() {
let thinking = ThinkingType::disabled().with_clear_thinking(true);
let json = serde_json::to_string(&thinking).unwrap();
assert!(json.contains("\"type\":\"disabled\""));
assert!(json.contains("\"clear_thinking\":true"));
}
#[test]
fn test_function_new() {
let params = serde_json::json!({
"type": "object",
"properties": {
"name": {"type": "string"}
}
});
let func = Function::new("test_func", "A test function", params);
assert_eq!(func.name, "test_func");
assert_eq!(func.description, "A test function");
assert!(func.parameters.is_some());
}
#[test]
fn test_function_serialization() {
let params = serde_json::json!({
"type": "object",
"properties": {
"value": {"type": "number"}
}
});
let func = Function::new("test_func", "A test function", params);
let json = serde_json::to_string(&func).unwrap();
assert!(json.contains("\"name\":\"test_func\""));
assert!(json.contains("\"description\":\"A test function\""));
assert!(json.contains("\"properties\""));
}
#[test]
fn test_function_validation() {
let params = serde_json::json!({
"type": "object",
"properties": {}
});
let func = Function::new("valid_name", "Description", params.clone());
assert!(func.validate().is_ok());
let invalid_name = Function::new("", "Description", params.clone());
assert!(invalid_name.validate().is_err());
let long_name = Function::new("a".repeat(65), "Description", params);
assert!(long_name.validate().is_err());
}
#[test]
fn test_retrieval_new() {
let retrieval = Retrieval::new("kb_123", Some("template".to_string()));
assert_eq!(retrieval.knowledge_id, "kb_123");
assert_eq!(retrieval.prompt_template, Some("template".to_string()));
}
#[test]
fn test_retrieval_new_without_template() {
let retrieval = Retrieval::new("kb_456", None);
assert_eq!(retrieval.knowledge_id, "kb_456");
assert!(retrieval.prompt_template.is_none());
}
#[test]
fn test_retrieval_serialization() {
let retrieval = Retrieval::new("kb_789", None);
let json = serde_json::to_string(&retrieval).unwrap();
assert!(json.contains("\"knowledge_id\":\"kb_789\""));
assert!(!json.contains("prompt_template"));
}
#[test]
fn test_web_search_new() {
let web_search = WebSearch::new(SearchEngine::SearchPro);
assert_eq!(web_search.search_engine, SearchEngine::SearchPro);
assert!(web_search.enable.is_none());
}
#[test]
fn test_web_search_with_enable() {
let web_search = WebSearch::new(SearchEngine::SearchPro).with_enable(true);
assert_eq!(web_search.enable, Some(true));
}
#[test]
fn test_web_search_with_search_query() {
let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_query("test query");
assert_eq!(web_search.search_query, Some("test query".to_string()));
}
#[test]
fn test_web_search_with_search_intent() {
let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_intent(true);
assert_eq!(web_search.search_intent, Some(true));
}
#[test]
fn test_web_search_with_count() {
let web_search = WebSearch::new(SearchEngine::SearchPro).with_count(10);
assert_eq!(web_search.count, Some(10));
}
#[test]
fn test_web_search_with_search_domain_filter() {
let web_search =
WebSearch::new(SearchEngine::SearchPro).with_search_domain_filter("example.com");
assert_eq!(
web_search.search_domain_filter,
Some("example.com".to_string())
);
}
#[test]
fn test_web_search_with_search_recency_filter() {
let filter = SearchRecencyFilter::OneDay;
let web_search =
WebSearch::new(SearchEngine::SearchPro).with_search_recency_filter(filter.clone());
assert_eq!(web_search.search_recency_filter, Some(filter));
}
#[test]
fn test_web_search_with_content_size() {
let size = ContentSize::Medium;
let web_search = WebSearch::new(SearchEngine::SearchPro).with_content_size(size.clone());
assert_eq!(web_search.content_size, Some(size));
}
#[test]
fn test_web_search_with_result_sequence() {
let seq = ResultSequence::After;
let web_search = WebSearch::new(SearchEngine::SearchPro).with_result_sequence(seq.clone());
assert_eq!(web_search.result_sequence, Some(seq));
}
#[test]
fn test_web_search_with_search_result() {
let web_search = WebSearch::new(SearchEngine::SearchPro).with_search_result(true);
assert_eq!(web_search.search_result, Some(true));
}
#[test]
fn test_web_search_with_require_search() {
let web_search = WebSearch::new(SearchEngine::SearchPro).with_require_search(true);
assert_eq!(web_search.require_search, Some(true));
}
#[test]
fn test_web_search_with_search_prompt() {
let web_search =
WebSearch::new(SearchEngine::SearchPro).with_search_prompt("custom prompt");
assert_eq!(web_search.search_prompt, Some("custom prompt".to_string()));
}
#[test]
fn test_web_search_serialization() {
let web_search = WebSearch::new(SearchEngine::SearchPro)
.with_enable(true)
.with_count(5);
let json = serde_json::to_string(&web_search).unwrap();
assert!(json.contains("\"search_engine\""));
assert!(json.contains("\"enable\":true"));
assert!(json.contains("\"count\":5"));
}
#[test]
fn test_mcp_new() {
let mcp = MCP::new("server_label");
assert_eq!(mcp.server_label, "server_label");
assert_eq!(mcp.transport_type, Some(MCPTransportType::StreamableHttp));
assert!(mcp.allowed_tools.is_empty());
}
#[test]
fn test_mcp_with_server_url() {
let mcp = MCP::new("server_label").with_server_url("https://example.com");
assert_eq!(mcp.server_url, Some("https://example.com".to_string()));
}
#[test]
fn test_mcp_with_transport_type() {
let mcp = MCP::new("server_label").with_transport_type(MCPTransportType::Sse);
assert_eq!(mcp.transport_type, Some(MCPTransportType::Sse));
}
#[test]
fn test_mcp_with_allowed_tools() {
let mcp = MCP::new("server_label")
.with_allowed_tools(vec!["tool1".to_string(), "tool2".to_string()]);
assert_eq!(mcp.allowed_tools.len(), 2);
assert!(mcp.allowed_tools.contains(&"tool1".to_string()));
}
#[test]
fn test_mcp_add_allowed_tool() {
let mcp = MCP::new("server_label")
.add_allowed_tool("tool1")
.add_allowed_tool("tool2");
assert_eq!(mcp.allowed_tools.len(), 2);
}
#[test]
fn test_mcp_with_headers() {
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), "Bearer token".to_string());
let mcp = MCP::new("server_label").with_headers(headers.clone());
assert_eq!(mcp.headers, Some(headers));
}
#[test]
fn test_mcp_with_header() {
let mcp = MCP::new("server_label").with_header("Authorization", "Bearer token");
let headers = mcp.headers.unwrap();
assert_eq!(
headers.get("Authorization"),
Some(&"Bearer token".to_string())
);
}
#[test]
fn test_mcp_serialization() {
let mcp = MCP::new("server_label")
.with_server_url("https://example.com")
.with_transport_type(MCPTransportType::Sse);
let json = serde_json::to_string(&mcp).unwrap();
assert!(json.contains("\"server_label\":\"server_label\""));
assert!(json.contains("\"server_url\":\"https://example.com\""));
assert!(json.contains("\"transport_type\":\"sse\""));
assert!(!json.contains("allowed_tools"));
}
#[test]
fn test_mcp_transport_type_sse_serialization() {
let transport = MCPTransportType::Sse;
let json = serde_json::to_string(&transport).unwrap();
assert!(json.contains("\"sse\""));
}
#[test]
fn test_mcp_transport_type_streamable_http_serialization() {
let transport = MCPTransportType::StreamableHttp;
let json = serde_json::to_string(&transport).unwrap();
assert!(json.contains("\"streamable-http\""));
}
#[test]
fn test_response_format_text_serialization() {
let format = ResponseFormat::Text;
let json = serde_json::to_string(&format).unwrap();
assert!(json.contains("\"type\":\"text\""));
}
#[test]
fn test_response_format_json_object_serialization() {
let format = ResponseFormat::JsonObject;
let json = serde_json::to_string(&format).unwrap();
assert!(json.contains("\"type\":\"json_object\""));
}
#[test]
fn test_tools_function_serialization() {
let func = Function::new("test_func", "test", serde_json::json!({}));
let tools = Tools::Function { function: func };
let json = serde_json::to_string(&tools).unwrap();
assert!(json.contains("\"type\":\"function\""));
assert!(json.contains("\"name\":\"test_func\""));
}
#[test]
fn test_tools_retrieval_serialization() {
let retrieval = Retrieval::new("kb_123", None);
let tools = Tools::Retrieval { retrieval };
let json = serde_json::to_string(&tools).unwrap();
assert!(json.contains("\"type\":\"retrieval\""));
assert!(json.contains("\"knowledge_id\":\"kb_123\""));
}
#[test]
fn test_tools_web_search_serialization() {
let web_search = WebSearch::new(SearchEngine::SearchPro);
let tools = Tools::WebSearch { web_search };
let json = serde_json::to_string(&tools).unwrap();
assert!(json.contains("\"type\":\"web_search\""));
assert!(json.contains("\"search_engine\""));
}
#[test]
fn test_tools_mcp_serialization() {
let mcp = MCP::new("server_label");
let tools = Tools::MCP { mcp };
let json = serde_json::to_string(&tools).unwrap();
eprintln!("JSON: {}", json);
assert!(json.contains("\"type\":\"mcp\""));
assert!(json.contains("\"server_label\":\"server_label\""));
}
#[test]
fn test_result_sequence_before_serialization() {
let seq = ResultSequence::Before;
let json = serde_json::to_string(&seq).unwrap();
assert!(json.contains("\"before\""));
}
#[test]
fn test_result_sequence_after_serialization() {
let seq = ResultSequence::After;
let json = serde_json::to_string(&seq).unwrap();
assert!(json.contains("\"after\""));
}
}