Skip to main content

lash_tools/web/
web_search.rs

1use serde_json::{Value, json};
2
3use lash_core::{ToolCall, ToolDefinition, ToolResult, ToolScheduling};
4
5use lash_tool_support::{
6    StaticToolExecute, StaticToolProvider, ToolDefinitionLashlangExt, object_schema,
7};
8
9/// Web search via Tavily API.
10pub struct WebSearch {
11    api_key: String,
12    client: reqwest::Client,
13}
14
15impl WebSearch {
16    pub fn new(api_key: impl Into<String>) -> Self {
17        Self {
18            api_key: api_key.into(),
19            client: reqwest::Client::builder()
20                .timeout(std::time::Duration::from_secs(30))
21                .build()
22                .unwrap_or_default(),
23        }
24    }
25}
26
27/// Build the cached `search_web` tool provider for the given Tavily API key.
28pub fn web_search_provider(api_key: impl Into<String>) -> StaticToolProvider<WebSearch> {
29    StaticToolProvider::new(vec![web_search_tool_definition()], WebSearch::new(api_key))
30}
31
32#[async_trait::async_trait]
33impl StaticToolExecute for WebSearch {
34    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
35        let args = call.args;
36        let query = args
37            .get("query")
38            .and_then(|v| v.as_str())
39            .unwrap_or_default();
40        let limit = args
41            .get("limit")
42            .and_then(|v| v.as_u64())
43            .unwrap_or(5)
44            .clamp(1, 20);
45
46        if self.api_key.trim().is_empty() {
47            return ToolResult::err(json!("Tavily API key is required for web.search"));
48        }
49
50        let body = json!({
51            "query": query,
52            "max_results": limit,
53        });
54
55        let resp = self
56            .client
57            .post("https://api.tavily.com/search")
58            .bearer_auth(&self.api_key)
59            .json(&body)
60            .send()
61            .await;
62        match resp {
63            Ok(r) if r.status().is_success() => match r.json::<serde_json::Value>().await {
64                Ok(data) => ToolResult::ok(json!({
65                    "results": sanitize_results(data.get("results")),
66                })),
67                Err(e) => ToolResult::err_fmt(format_args!("Failed to parse response: {e}")),
68            },
69            Ok(r) => {
70                let status = r.status();
71                let body = r.text().await.unwrap_or_default();
72                ToolResult::err_fmt(format_args!("Tavily API error ({status}): {body}"))
73            }
74            Err(e) => ToolResult::err_fmt(format_args!("Request failed: {e}")),
75        }
76    }
77}
78
79fn sanitize_results(results: Option<&Value>) -> Vec<Value> {
80    results
81        .and_then(Value::as_array)
82        .into_iter()
83        .flatten()
84        .map(|item| {
85            json!({
86                "title": item.get("title").and_then(Value::as_str).unwrap_or_default(),
87                "url": item.get("url").and_then(Value::as_str).unwrap_or_default(),
88                "content": item.get("content").and_then(Value::as_str).unwrap_or_default(),
89            })
90        })
91        .collect()
92}
93
94fn web_search_tool_definition() -> ToolDefinition {
95    ToolDefinition::raw(
96                "tool:search_web",
97                "search_web",
98                "Search the web for candidate sources. Returns ranked `results` with snippet text; use `web.fetch` when you need the page itself. This tool does not expose Tavily's optional generated answer; summarize from result snippets and fetched pages.",
99                object_schema(
100                    serde_json::json!({
101                        "query": { "type": "string" },
102                        "limit": {
103                            "type": "integer",
104                            "minimum": 1,
105                            "maximum": 20,
106                            "default": 5,
107                            "description": "Maximum results to return (default 5)"
108                        }
109                    }),
110                    &["query"],
111                ),
112                serde_json::json!({
113                    "type": "object",
114                    "properties": {
115                        "results": {
116                            "type": "array",
117                            "items": {
118                                "type": "object",
119                                "properties": {
120                                    "title": { "type": "string" },
121                                    "url": { "type": "string" },
122                                    "content": {
123                                        "type": "string",
124                                        "description": "Search-result snippet text."
125                                    }
126                                },
127                                "required": ["title", "url", "content"],
128                                "additionalProperties": false
129                            }
130                        }
131                    },
132                    "required": ["results"],
133                    "additionalProperties": false
134                }),
135            )
136            .with_examples(vec![
137                "await web.search({ query: \"latest Rust release notes\", limit: 5 })?".into(),
138            ])
139            .with_lashlang_binding(lash_tool_support::lashlang_binding(
140                ["web"],
141                "search",
142                &["web_search"],
143            ))
144            .with_scheduling(ToolScheduling::Parallel)
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn search_web_uses_limit_argument_in_model_contract() {
153        let definition = web_search_tool_definition();
154
155        let properties = definition
156            .contract
157            .input_schema
158            .canonical
159            .get("properties")
160            .and_then(serde_json::Value::as_object)
161            .expect("object properties");
162        assert!(properties.contains_key("limit"));
163        assert!(!properties.contains_key("max_results"));
164        assert_eq!(properties["limit"]["default"], serde_json::json!(5));
165        assert_eq!(properties["limit"]["maximum"], serde_json::json!(20));
166        assert!(
167            definition
168                .contract
169                .examples
170                .iter()
171                .all(|example| example.contains("limit"))
172        );
173        assert_eq!(
174            definition.contract.output_schema.canonical["type"],
175            serde_json::json!("object")
176        );
177        assert_eq!(
178            definition.contract.output_schema.canonical["required"],
179            serde_json::json!(["results"])
180        );
181        assert!(
182            !definition.contract.output_schema.canonical["properties"]
183                .as_object()
184                .unwrap()
185                .contains_key("answer")
186        );
187        assert_eq!(
188            definition.manifest.activation,
189            lash_core::ToolActivation::Always
190        );
191    }
192
193    #[test]
194    fn search_web_sanitizes_tavily_results_to_contract() {
195        let results = sanitize_results(Some(&serde_json::json!([
196            {
197                "title": "Title",
198                "url": "https://example.com",
199                "content": "Snippet",
200                "score": 0.9,
201                "raw_content": null,
202                "favicon": "https://example.com/favicon.ico"
203            }
204        ])));
205
206        assert_eq!(
207            results,
208            vec![serde_json::json!({
209                "title": "Title",
210                "url": "https://example.com",
211                "content": "Snippet"
212            })]
213        );
214    }
215}