mermaid_cli/providers/tool/
web.rs1use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::domain::{ToolDefinition, ToolMetadata, ToolOutcome, ToolRunMetadata};
13
14use super::super::ctx::{ExecContext, ProgressEvent};
15use super::ToolExecutor;
16use super::web_client::{WebFetchResult, WebSearchClient};
17
18pub struct WebSearchTool {
22 client: Arc<WebSearchClient>,
23}
24
25impl WebSearchTool {
26 pub fn new(api_key: String) -> Self {
27 Self {
28 client: Arc::new(WebSearchClient::new(api_key)),
29 }
30 }
31}
32
33#[async_trait]
34impl ToolExecutor for WebSearchTool {
35 fn name(&self) -> &'static str {
36 "web_search"
37 }
38
39 fn schema(&self) -> ToolDefinition {
40 ToolDefinition {
41 name: "web_search".to_string(),
42 description:
43 "Search the web via Ollama Cloud's search API. Takes either a single `query` + `max_results`, or an array of `queries` for parallel fan-out."
44 .to_string(),
45 input_schema: serde_json::json!({
46 "type": "object",
47 "properties": {
48 "query": { "type": "string" },
49 "max_results": { "type": "integer", "minimum": 1, "maximum": 10, "default": 5 },
50 "queries": {
51 "type": "array",
52 "items": {
53 "type": "object",
54 "properties": {
55 "query": { "type": "string" },
56 "max_results": { "type": "integer", "minimum": 1, "maximum": 10 }
57 },
58 "required": ["query"]
59 }
60 }
61 }
62 }),
63 }
64 }
65
66 async fn execute(&self, args: serde_json::Value, ctx: ExecContext) -> ToolOutcome {
67 let queries = match parse_queries(&args) {
68 Ok(q) => q,
69 Err(e) => return ToolOutcome::error(e, 0.0),
70 };
71 if queries.is_empty() {
72 return ToolOutcome::error("web_search requires at least one query", 0.0);
73 }
74
75 let start = std::time::Instant::now();
76 let mut combined = String::new();
77 let mut result_count = 0usize;
78 let mut sources = Vec::new();
79 for (idx, (query, count)) in queries.iter().enumerate() {
80 let _ = ctx
81 .progress
82 .send(ProgressEvent::Status(format!(
83 "searching {}/{}: {}",
84 idx + 1,
85 queries.len(),
86 query
87 )))
88 .await;
89
90 let search = self.client.search_query(query, *count);
91 tokio::select! {
92 biased;
93 _ = ctx.token.cancelled() => return ToolOutcome::cancelled(),
94 result = search => {
95 match result {
96 Ok(results) => {
97 result_count += results.len();
98 sources.extend(results.iter().map(|result| result.url.clone()));
99 let formatted = self.client.format_results(&results);
100 if queries.len() > 1 {
101 combined.push_str(&format!("=== query: {} ===\n{}\n\n", query, formatted));
102 } else {
103 combined = formatted;
104 }
105 },
106 Err(e) => {
107 return ToolOutcome::error(
108 format!("web_search({}): {}", query, e),
109 start.elapsed().as_secs_f64(),
110 );
111 },
112 }
113 }
114 }
115 }
116
117 let duration_secs = start.elapsed().as_secs_f64();
118 let requested_count = queries.iter().map(|(_, count)| *count).sum();
119 let query_texts = queries.iter().map(|(query, _)| query.clone()).collect();
120 ToolOutcome::success(
121 combined,
122 format!(
123 "{} {} returned",
124 result_count,
125 if result_count == 1 {
126 "result"
127 } else {
128 "results"
129 }
130 ),
131 duration_secs,
132 )
133 .with_metadata(ToolRunMetadata {
134 detail: ToolMetadata::WebSearch {
135 queries: query_texts,
136 requested_count,
137 result_count,
138 sources,
139 },
140 result_count: Some(result_count),
141 ..ToolRunMetadata::default()
142 })
143 }
144}
145
146pub struct WebFetchTool {
149 client: Arc<WebSearchClient>,
150}
151
152impl WebFetchTool {
153 pub fn new(api_key: String) -> Self {
154 Self {
155 client: Arc::new(WebSearchClient::new(api_key)),
156 }
157 }
158}
159
160#[async_trait]
161impl ToolExecutor for WebFetchTool {
162 fn name(&self) -> &'static str {
163 "web_fetch"
164 }
165
166 fn schema(&self) -> ToolDefinition {
167 ToolDefinition {
168 name: "web_fetch".to_string(),
169 description: "Retrieve a single URL's main content as text (Ollama Cloud fetch API)."
170 .to_string(),
171 input_schema: serde_json::json!({
172 "type": "object",
173 "properties": { "url": { "type": "string" } },
174 "required": ["url"]
175 }),
176 }
177 }
178
179 async fn execute(&self, args: serde_json::Value, ctx: ExecContext) -> ToolOutcome {
180 let Some(url) = args.get("url").and_then(|v| v.as_str()) else {
181 return ToolOutcome::error("web_fetch requires 'url' (string)", 0.0);
182 };
183 let start = std::time::Instant::now();
184 let fetch = self.client.fetch_url(url);
185
186 tokio::select! {
187 biased;
188 _ = ctx.token.cancelled() => ToolOutcome::cancelled(),
189 result = fetch => match result {
190 Ok(page) => {
191 let output = format_fetch(url, &page);
192 let duration_secs = start.elapsed().as_secs_f64();
193 let line_count = output.lines().count();
194 let byte_count = output.len();
195 let title = if page.title.is_empty() {
196 None
197 } else {
198 Some(page.title)
199 };
200 ToolOutcome::success(
201 output,
202 format!("{} {} fetched", line_count, if line_count == 1 { "line" } else { "lines" }),
203 duration_secs,
204 )
205 .with_metadata(ToolRunMetadata {
206 detail: ToolMetadata::WebFetch {
207 url: url.to_string(),
208 title,
209 line_count,
210 byte_count,
211 },
212 line_count: Some(line_count),
213 byte_count: Some(byte_count),
214 ..ToolRunMetadata::default()
215 })
216 },
217 Err(e) => ToolOutcome::error(
218 format!("web_fetch({}): {}", url, e),
219 start.elapsed().as_secs_f64(),
220 ),
221 },
222 }
223 }
224}
225
226fn format_fetch(url: &str, page: &WebFetchResult) -> String {
227 let title = if page.title.is_empty() {
228 "(no title)"
229 } else {
230 page.title.as_str()
231 };
232 format!("# {}\n\nURL: {}\n\n{}", title, url, page.content)
233}
234
235fn parse_queries(args: &serde_json::Value) -> Result<Vec<(String, usize)>, String> {
236 if let Some(arr) = args.get("queries").and_then(|v| v.as_array()) {
237 let mut out = Vec::with_capacity(arr.len());
238 for v in arr {
239 let Some(obj) = v.as_object() else {
240 return Err(
241 "web_search: 'queries' must be an array of {query, max_results}".to_string(),
242 );
243 };
244 let Some(query) = obj.get("query").and_then(|x| x.as_str()) else {
245 return Err("web_search: each query entry needs 'query' (string)".to_string());
246 };
247 let count = obj
248 .get("max_results")
249 .or_else(|| obj.get("result_count"))
250 .and_then(|x| x.as_u64())
251 .unwrap_or(5)
252 .clamp(1, 10) as usize;
253 out.push((query.to_string(), count));
254 }
255 return Ok(out);
256 }
257 if let Some(query) = args.get("query").and_then(|v| v.as_str()) {
258 let count = args
259 .get("max_results")
260 .or_else(|| args.get("result_count"))
261 .and_then(|v| v.as_u64())
262 .unwrap_or(5)
263 .clamp(1, 10) as usize;
264 return Ok(vec![(query.to_string(), count)]);
265 }
266 Err("web_search requires 'query' (string) or 'queries' (array)".to_string())
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn parse_queries_single_form() {
275 let args = serde_json::json!({"query": "rust async", "max_results": 3});
276 let q = parse_queries(&args).unwrap();
277 assert_eq!(q.len(), 1);
278 assert_eq!(q[0].0, "rust async");
279 assert_eq!(q[0].1, 3);
280 }
281
282 #[test]
283 fn parse_queries_array_form() {
284 let args = serde_json::json!({"queries": [
285 {"query": "a", "max_results": 2},
286 {"query": "b", "result_count": 5},
287 ]});
288 let q = parse_queries(&args).unwrap();
289 assert_eq!(q.len(), 2);
290 assert_eq!(q[1].1, 5);
291 }
292
293 #[test]
294 fn parse_queries_missing_errors() {
295 let args = serde_json::json!({});
296 assert!(parse_queries(&args).is_err());
297 }
298
299 #[test]
300 fn parse_queries_clamps_count() {
301 let args = serde_json::json!({"query": "q", "max_results": 999});
302 let q = parse_queries(&args).unwrap();
303 assert_eq!(q[0].1, 10);
304 let args = serde_json::json!({"query": "q", "max_results": 0});
305 let q = parse_queries(&args).unwrap();
306 assert_eq!(q[0].1, 1);
307 }
308}