Skip to main content

mermaid_cli/providers/tool/
web.rs

1//! Web tools: `web_search` and `web_fetch`.
2//!
3//! Both delegate to `web_client::WebSearchClient` — a thin HTTP
4//! client for Ollama Cloud's web API (bearer-token path, via
5//! `OLLAMA_API_KEY`). The wrapper's job is cancellation plumbing +
6//! multi-query fan-out.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::domain::{ToolDefinition, ToolMetadata, ToolOutcome, ToolRunMetadata};
13
14use super::super::ctx::{ExecContext, ProgressEvent};
15use super::ToolExecutor;
16use super::web_client::{WebFetchResult, WebSearchClient};
17
18/// `web_search` — query Ollama Cloud's web-search endpoint. Accepts a
19/// single `{query, max_results}` OR a list of `{queries: [{query,
20/// max_results}]}` for parallel fan-out.
21pub struct WebSearchTool {
22    client: Arc<WebSearchClient>,
23}
24
25impl WebSearchTool {
26    pub fn new(api_key: String) -> Self {
27        Self {
28            client: Arc::new(WebSearchClient::new(api_key)),
29        }
30    }
31}
32
33#[async_trait]
34impl ToolExecutor for WebSearchTool {
35    fn name(&self) -> &'static str {
36        "web_search"
37    }
38
39    fn schema(&self) -> ToolDefinition {
40        ToolDefinition {
41            name: "web_search".to_string(),
42            description:
43                "Search the web via Ollama Cloud's search API. Takes either a single `query` + `max_results`, or an array of `queries` for parallel fan-out."
44                    .to_string(),
45            input_schema: serde_json::json!({
46                "type": "object",
47                "properties": {
48                    "query": { "type": "string" },
49                    "max_results": { "type": "integer", "minimum": 1, "maximum": 10, "default": 5 },
50                    "queries": {
51                        "type": "array",
52                        "items": {
53                            "type": "object",
54                            "properties": {
55                                "query": { "type": "string" },
56                                "max_results": { "type": "integer", "minimum": 1, "maximum": 10 }
57                            },
58                            "required": ["query"]
59                        }
60                    }
61                }
62            }),
63        }
64    }
65
66    async fn execute(&self, args: serde_json::Value, ctx: ExecContext) -> ToolOutcome {
67        let queries = match parse_queries(&args) {
68            Ok(q) => q,
69            Err(e) => return ToolOutcome::error(e, 0.0),
70        };
71        if queries.is_empty() {
72            return ToolOutcome::error("web_search requires at least one query", 0.0);
73        }
74
75        let start = std::time::Instant::now();
76        let mut combined = String::new();
77        let mut result_count = 0usize;
78        let mut sources = Vec::new();
79        for (idx, (query, count)) in queries.iter().enumerate() {
80            let _ = ctx
81                .progress
82                .send(ProgressEvent::Status(format!(
83                    "searching {}/{}: {}",
84                    idx + 1,
85                    queries.len(),
86                    query
87                )))
88                .await;
89
90            let search = self.client.search_query(query, *count);
91            tokio::select! {
92                biased;
93                _ = ctx.token.cancelled() => return ToolOutcome::cancelled(),
94                result = search => {
95                    match result {
96                        Ok(results) => {
97                            result_count += results.len();
98                            sources.extend(results.iter().map(|result| result.url.clone()));
99                            let formatted = self.client.format_results(&results);
100                            if queries.len() > 1 {
101                                combined.push_str(&format!("=== query: {} ===\n{}\n\n", query, formatted));
102                            } else {
103                                combined = formatted;
104                            }
105                        },
106                        Err(e) => {
107                            return ToolOutcome::error(
108                                format!("web_search({}): {}", query, e),
109                                start.elapsed().as_secs_f64(),
110                            );
111                        },
112                    }
113                }
114            }
115        }
116
117        let duration_secs = start.elapsed().as_secs_f64();
118        let requested_count = queries.iter().map(|(_, count)| *count).sum();
119        let query_texts = queries.iter().map(|(query, _)| query.clone()).collect();
120        ToolOutcome::success(
121            combined,
122            format!(
123                "{} {} returned",
124                result_count,
125                if result_count == 1 {
126                    "result"
127                } else {
128                    "results"
129                }
130            ),
131            duration_secs,
132        )
133        .with_metadata(ToolRunMetadata {
134            detail: ToolMetadata::WebSearch {
135                queries: query_texts,
136                requested_count,
137                result_count,
138                sources,
139            },
140            result_count: Some(result_count),
141            ..ToolRunMetadata::default()
142        })
143    }
144}
145
146/// `web_fetch` — retrieve a URL's readable content (Ollama Cloud's
147/// fetch endpoint). Single URL, single response.
148pub struct WebFetchTool {
149    client: Arc<WebSearchClient>,
150}
151
152impl WebFetchTool {
153    pub fn new(api_key: String) -> Self {
154        Self {
155            client: Arc::new(WebSearchClient::new(api_key)),
156        }
157    }
158}
159
160#[async_trait]
161impl ToolExecutor for WebFetchTool {
162    fn name(&self) -> &'static str {
163        "web_fetch"
164    }
165
166    fn schema(&self) -> ToolDefinition {
167        ToolDefinition {
168            name: "web_fetch".to_string(),
169            description: "Retrieve a single URL's main content as text (Ollama Cloud fetch API)."
170                .to_string(),
171            input_schema: serde_json::json!({
172                "type": "object",
173                "properties": { "url": { "type": "string" } },
174                "required": ["url"]
175            }),
176        }
177    }
178
179    async fn execute(&self, args: serde_json::Value, ctx: ExecContext) -> ToolOutcome {
180        let Some(url) = args.get("url").and_then(|v| v.as_str()) else {
181            return ToolOutcome::error("web_fetch requires 'url' (string)", 0.0);
182        };
183        let start = std::time::Instant::now();
184        let fetch = self.client.fetch_url(url);
185
186        tokio::select! {
187            biased;
188            _ = ctx.token.cancelled() => ToolOutcome::cancelled(),
189            result = fetch => match result {
190                Ok(page) => {
191                    let output = format_fetch(url, &page);
192                    let duration_secs = start.elapsed().as_secs_f64();
193                    let line_count = output.lines().count();
194                    let byte_count = output.len();
195                    let title = if page.title.is_empty() {
196                        None
197                    } else {
198                        Some(page.title)
199                    };
200                    ToolOutcome::success(
201                        output,
202                        format!("{} {} fetched", line_count, if line_count == 1 { "line" } else { "lines" }),
203                        duration_secs,
204                    )
205                    .with_metadata(ToolRunMetadata {
206                        detail: ToolMetadata::WebFetch {
207                            url: url.to_string(),
208                            title,
209                            line_count,
210                            byte_count,
211                        },
212                        line_count: Some(line_count),
213                        byte_count: Some(byte_count),
214                        ..ToolRunMetadata::default()
215                    })
216                },
217                Err(e) => ToolOutcome::error(
218                    format!("web_fetch({}): {}", url, e),
219                    start.elapsed().as_secs_f64(),
220                ),
221            },
222        }
223    }
224}
225
226fn format_fetch(url: &str, page: &WebFetchResult) -> String {
227    let title = if page.title.is_empty() {
228        "(no title)"
229    } else {
230        page.title.as_str()
231    };
232    format!("# {}\n\nURL: {}\n\n{}", title, url, page.content)
233}
234
235fn parse_queries(args: &serde_json::Value) -> Result<Vec<(String, usize)>, String> {
236    if let Some(arr) = args.get("queries").and_then(|v| v.as_array()) {
237        let mut out = Vec::with_capacity(arr.len());
238        for v in arr {
239            let Some(obj) = v.as_object() else {
240                return Err(
241                    "web_search: 'queries' must be an array of {query, max_results}".to_string(),
242                );
243            };
244            let Some(query) = obj.get("query").and_then(|x| x.as_str()) else {
245                return Err("web_search: each query entry needs 'query' (string)".to_string());
246            };
247            let count = obj
248                .get("max_results")
249                .or_else(|| obj.get("result_count"))
250                .and_then(|x| x.as_u64())
251                .unwrap_or(5)
252                .clamp(1, 10) as usize;
253            out.push((query.to_string(), count));
254        }
255        return Ok(out);
256    }
257    if let Some(query) = args.get("query").and_then(|v| v.as_str()) {
258        let count = args
259            .get("max_results")
260            .or_else(|| args.get("result_count"))
261            .and_then(|v| v.as_u64())
262            .unwrap_or(5)
263            .clamp(1, 10) as usize;
264        return Ok(vec![(query.to_string(), count)]);
265    }
266    Err("web_search requires 'query' (string) or 'queries' (array)".to_string())
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn parse_queries_single_form() {
275        let args = serde_json::json!({"query": "rust async", "max_results": 3});
276        let q = parse_queries(&args).unwrap();
277        assert_eq!(q.len(), 1);
278        assert_eq!(q[0].0, "rust async");
279        assert_eq!(q[0].1, 3);
280    }
281
282    #[test]
283    fn parse_queries_array_form() {
284        let args = serde_json::json!({"queries": [
285            {"query": "a", "max_results": 2},
286            {"query": "b", "result_count": 5},
287        ]});
288        let q = parse_queries(&args).unwrap();
289        assert_eq!(q.len(), 2);
290        assert_eq!(q[1].1, 5);
291    }
292
293    #[test]
294    fn parse_queries_missing_errors() {
295        let args = serde_json::json!({});
296        assert!(parse_queries(&args).is_err());
297    }
298
299    #[test]
300    fn parse_queries_clamps_count() {
301        let args = serde_json::json!({"query": "q", "max_results": 999});
302        let q = parse_queries(&args).unwrap();
303        assert_eq!(q[0].1, 10);
304        let args = serde_json::json!({"query": "q", "max_results": 0});
305        let q = parse_queries(&args).unwrap();
306        assert_eq!(q[0].1, 1);
307    }
308}