1use std::borrow::Cow;
2
3use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
4use serde::{Deserialize, Deserializer, Serialize};
5use url::Url;
6
7const SEARCH_DEFAULT_LIMIT: u8 = 5;
8const SEARCH_MIN_LIMIT: u8 = 1;
9const SEARCH_MAX_LIMIT: u8 = 10;
10const SUMMARIZE_MAX_TEXT_BYTES: usize = 50_000;
11
12#[derive(Debug, Clone, Deserialize)]
13#[serde(deny_unknown_fields)]
14pub struct SearchToolInput {
15 #[serde(deserialize_with = "deserialize_trimmed_query")]
16 pub query: String,
17
18 #[serde(
19 default = "default_search_limit",
20 deserialize_with = "deserialize_search_limit"
21 )]
22 pub limit: u8,
23}
24
25impl SearchToolInput {
26 pub fn limit_as_usize(&self) -> usize {
27 self.limit as usize
28 }
29}
30
31impl JsonSchema for SearchToolInput {
32 fn schema_name() -> Cow<'static, str> {
33 "SearchToolInput".into()
34 }
35
36 fn schema_id() -> Cow<'static, str> {
37 concat!(module_path!(), "::SearchToolInput").into()
38 }
39
40 fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
41 json_schema!({
42 "type": "object",
43 "additionalProperties": false,
44 "required": ["query"],
45 "properties": {
46 "query": {
47 "type": "string",
48 "minLength": 1,
49 "pattern": ".*\\S.*"
50 },
51 "limit": {
52 "type": "integer",
53 "minimum": SEARCH_MIN_LIMIT,
54 "maximum": SEARCH_MAX_LIMIT,
55 "default": SEARCH_DEFAULT_LIMIT
56 }
57 }
58 })
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct SummarizeToolInput {
64 pub url: Option<String>,
65 pub text: Option<String>,
66}
67
68impl JsonSchema for SummarizeToolInput {
69 fn schema_name() -> Cow<'static, str> {
70 "SummarizeToolInput".into()
71 }
72
73 fn schema_id() -> Cow<'static, str> {
74 concat!(module_path!(), "::SummarizeToolInput").into()
75 }
76
77 fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
78 json_schema!({
79 "type": "object",
80 "additionalProperties": false,
81 "properties": {
82 "url": {
83 "type": "string"
84 },
85 "text": {
86 "type": "string"
87 }
88 }
89 })
90 }
91}
92
93impl<'de> Deserialize<'de> for SummarizeToolInput {
94 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
95 where
96 D: Deserializer<'de>,
97 {
98 #[derive(Deserialize)]
99 #[serde(deny_unknown_fields)]
100 struct RawSummarizeToolInput {
101 url: Option<String>,
102 text: Option<String>,
103 }
104
105 let raw = RawSummarizeToolInput::deserialize(deserializer)?;
106
107 let normalized_url = raw
108 .url
109 .and_then(|url| if url.is_empty() { None } else { Some(url) });
110 let normalized_text = raw
111 .text
112 .and_then(|text| if text.is_empty() { None } else { Some(text) });
113
114 let has_url = normalized_url.is_some();
115 let has_text = normalized_text.is_some();
116 if has_url == has_text {
117 return Err(serde::de::Error::custom(
118 "exactly one of `url` or `text` must be provided",
119 ));
120 }
121
122 if let Some(raw_url) = normalized_url {
123 if raw_url != raw_url.trim() {
124 return Err(serde::de::Error::custom(
125 "`url` cannot have leading or trailing whitespace",
126 ));
127 }
128
129 let parsed = Url::parse(&raw_url).map_err(|source| {
130 serde::de::Error::custom(format!(
131 "`url` must be an absolute HTTP(S) URL ({source})"
132 ))
133 })?;
134
135 if !matches!(parsed.scheme(), "http" | "https") {
136 return Err(serde::de::Error::custom("`url` must use `http` or `https`"));
137 }
138
139 return Ok(Self {
140 url: Some(parsed.to_string()),
141 text: None,
142 });
143 }
144
145 let text = normalized_text.expect("xor check ensures text exists");
146 if text.trim().is_empty() {
147 return Err(serde::de::Error::custom("`text` cannot be blank"));
148 }
149
150 let byte_len = text.len();
151 if byte_len > SUMMARIZE_MAX_TEXT_BYTES {
152 return Err(serde::de::Error::custom(format!(
153 "`text` exceeds {SUMMARIZE_MAX_TEXT_BYTES} UTF-8 bytes"
154 )));
155 }
156
157 Ok(Self {
158 url: None,
159 text: Some(text),
160 })
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
165#[serde(deny_unknown_fields)]
166pub struct SearchResultCard {
167 pub title: String,
168 pub url: String,
169
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub snippet: Option<String>,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
175#[serde(deny_unknown_fields)]
176pub struct SearchToolOutput {
177 pub results: Vec<SearchResultCard>,
178 pub total_returned: usize,
179}
180
181impl JsonSchema for SearchToolOutput {
182 fn schema_name() -> Cow<'static, str> {
183 "SearchToolOutput".into()
184 }
185
186 fn schema_id() -> Cow<'static, str> {
187 concat!(module_path!(), "::SearchToolOutput").into()
188 }
189
190 fn json_schema(generator: &mut SchemaGenerator) -> Schema {
191 let results_schema = generator
192 .subschema_for::<Vec<SearchResultCard>>()
193 .to_value();
194
195 json_schema!({
196 "type": "object",
197 "additionalProperties": false,
198 "required": ["results", "total_returned"],
199 "properties": {
200 "results": results_schema,
201 "total_returned": {
202 "type": "integer",
203 "minimum": 0
204 }
205 }
206 })
207 }
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
211#[serde(deny_unknown_fields)]
212pub struct SummarizeToolOutput {
213 pub markdown: String,
214
215 #[serde(skip_serializing_if = "Option::is_none")]
216 pub text: Option<String>,
217
218 #[serde(skip_serializing_if = "Option::is_none")]
219 pub source_url: Option<String>,
220}
221
222fn default_search_limit() -> u8 {
223 SEARCH_DEFAULT_LIMIT
224}
225
226fn deserialize_trimmed_query<'de, D>(deserializer: D) -> Result<String, D::Error>
227where
228 D: Deserializer<'de>,
229{
230 let raw_query = String::deserialize(deserializer)?;
231 let trimmed = raw_query.trim();
232 if trimmed.is_empty() {
233 return Err(serde::de::Error::custom("`query` cannot be blank"));
234 }
235
236 Ok(trimmed.to_string())
237}
238
239fn deserialize_search_limit<'de, D>(deserializer: D) -> Result<u8, D::Error>
240where
241 D: Deserializer<'de>,
242{
243 let limit = u8::deserialize(deserializer)?;
244 if !(SEARCH_MIN_LIMIT..=SEARCH_MAX_LIMIT).contains(&limit) {
245 return Err(serde::de::Error::custom(format!(
246 "`limit` must be between {SEARCH_MIN_LIMIT} and {SEARCH_MAX_LIMIT}"
247 )));
248 }
249
250 Ok(limit)
251}