Skip to main content

kagi_mcp/
schema.rs

1use std::borrow::Cow;
2
3use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
4use serde::{Deserialize, Deserializer, Serialize};
5use url::Url;
6
7const SEARCH_DEFAULT_LIMIT: u8 = 5;
8const SEARCH_MIN_LIMIT: u8 = 1;
9const SEARCH_MAX_LIMIT: u8 = 10;
10const SUMMARIZE_MAX_TEXT_BYTES: usize = 50_000;
11
12#[derive(Debug, Clone, Deserialize)]
13#[serde(deny_unknown_fields)]
14pub struct SearchToolInput {
15    #[serde(deserialize_with = "deserialize_trimmed_query")]
16    pub query: String,
17
18    #[serde(
19        default = "default_search_limit",
20        deserialize_with = "deserialize_search_limit"
21    )]
22    pub limit: u8,
23}
24
25impl SearchToolInput {
26    pub fn limit_as_usize(&self) -> usize {
27        self.limit as usize
28    }
29}
30
31impl JsonSchema for SearchToolInput {
32    fn schema_name() -> Cow<'static, str> {
33        "SearchToolInput".into()
34    }
35
36    fn schema_id() -> Cow<'static, str> {
37        concat!(module_path!(), "::SearchToolInput").into()
38    }
39
40    fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
41        json_schema!({
42            "type": "object",
43            "additionalProperties": false,
44            "required": ["query"],
45            "properties": {
46                "query": {
47                    "type": "string",
48                    "minLength": 1,
49                    "pattern": ".*\\S.*"
50                },
51                "limit": {
52                    "type": "integer",
53                    "minimum": SEARCH_MIN_LIMIT,
54                    "maximum": SEARCH_MAX_LIMIT,
55                    "default": SEARCH_DEFAULT_LIMIT
56                }
57            }
58        })
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct SummarizeToolInput {
64    pub url: Option<String>,
65    pub text: Option<String>,
66}
67
68impl JsonSchema for SummarizeToolInput {
69    fn schema_name() -> Cow<'static, str> {
70        "SummarizeToolInput".into()
71    }
72
73    fn schema_id() -> Cow<'static, str> {
74        concat!(module_path!(), "::SummarizeToolInput").into()
75    }
76
77    fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
78        json_schema!({
79            "type": "object",
80            "additionalProperties": false,
81            "properties": {
82                "url": {
83                    "type": "string"
84                },
85                "text": {
86                    "type": "string"
87                }
88            }
89        })
90    }
91}
92
93impl<'de> Deserialize<'de> for SummarizeToolInput {
94    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
95    where
96        D: Deserializer<'de>,
97    {
98        #[derive(Deserialize)]
99        #[serde(deny_unknown_fields)]
100        struct RawSummarizeToolInput {
101            url: Option<String>,
102            text: Option<String>,
103        }
104
105        let raw = RawSummarizeToolInput::deserialize(deserializer)?;
106
107        let normalized_url = raw
108            .url
109            .and_then(|url| if url.is_empty() { None } else { Some(url) });
110        let normalized_text = raw
111            .text
112            .and_then(|text| if text.is_empty() { None } else { Some(text) });
113
114        let has_url = normalized_url.is_some();
115        let has_text = normalized_text.is_some();
116        if has_url == has_text {
117            return Err(serde::de::Error::custom(
118                "exactly one of `url` or `text` must be provided",
119            ));
120        }
121
122        if let Some(raw_url) = normalized_url {
123            if raw_url != raw_url.trim() {
124                return Err(serde::de::Error::custom(
125                    "`url` cannot have leading or trailing whitespace",
126                ));
127            }
128
129            let parsed = Url::parse(&raw_url).map_err(|source| {
130                serde::de::Error::custom(format!(
131                    "`url` must be an absolute HTTP(S) URL ({source})"
132                ))
133            })?;
134
135            if !matches!(parsed.scheme(), "http" | "https") {
136                return Err(serde::de::Error::custom("`url` must use `http` or `https`"));
137            }
138
139            return Ok(Self {
140                url: Some(parsed.to_string()),
141                text: None,
142            });
143        }
144
145        let text = normalized_text.expect("xor check ensures text exists");
146        if text.trim().is_empty() {
147            return Err(serde::de::Error::custom("`text` cannot be blank"));
148        }
149
150        let byte_len = text.len();
151        if byte_len > SUMMARIZE_MAX_TEXT_BYTES {
152            return Err(serde::de::Error::custom(format!(
153                "`text` exceeds {SUMMARIZE_MAX_TEXT_BYTES} UTF-8 bytes"
154            )));
155        }
156
157        Ok(Self {
158            url: None,
159            text: Some(text),
160        })
161    }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
165#[serde(deny_unknown_fields)]
166pub struct SearchResultCard {
167    pub title: String,
168    pub url: String,
169
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub snippet: Option<String>,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
175#[serde(deny_unknown_fields)]
176pub struct SearchToolOutput {
177    pub results: Vec<SearchResultCard>,
178    pub total_returned: usize,
179}
180
181impl JsonSchema for SearchToolOutput {
182    fn schema_name() -> Cow<'static, str> {
183        "SearchToolOutput".into()
184    }
185
186    fn schema_id() -> Cow<'static, str> {
187        concat!(module_path!(), "::SearchToolOutput").into()
188    }
189
190    fn json_schema(generator: &mut SchemaGenerator) -> Schema {
191        let results_schema = generator
192            .subschema_for::<Vec<SearchResultCard>>()
193            .to_value();
194
195        json_schema!({
196            "type": "object",
197            "additionalProperties": false,
198            "required": ["results", "total_returned"],
199            "properties": {
200                "results": results_schema,
201                "total_returned": {
202                    "type": "integer",
203                    "minimum": 0
204                }
205            }
206        })
207    }
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
211#[serde(deny_unknown_fields)]
212pub struct SummarizeToolOutput {
213    pub markdown: String,
214
215    #[serde(skip_serializing_if = "Option::is_none")]
216    pub text: Option<String>,
217
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub source_url: Option<String>,
220}
221
222fn default_search_limit() -> u8 {
223    SEARCH_DEFAULT_LIMIT
224}
225
226fn deserialize_trimmed_query<'de, D>(deserializer: D) -> Result<String, D::Error>
227where
228    D: Deserializer<'de>,
229{
230    let raw_query = String::deserialize(deserializer)?;
231    let trimmed = raw_query.trim();
232    if trimmed.is_empty() {
233        return Err(serde::de::Error::custom("`query` cannot be blank"));
234    }
235
236    Ok(trimmed.to_string())
237}
238
239fn deserialize_search_limit<'de, D>(deserializer: D) -> Result<u8, D::Error>
240where
241    D: Deserializer<'de>,
242{
243    let limit = u8::deserialize(deserializer)?;
244    if !(SEARCH_MIN_LIMIT..=SEARCH_MAX_LIMIT).contains(&limit) {
245        return Err(serde::de::Error::custom(format!(
246            "`limit` must be between {SEARCH_MIN_LIMIT} and {SEARCH_MAX_LIMIT}"
247        )));
248    }
249
250    Ok(limit)
251}