Skip to main content

lash_tools/web/
web_search.rs

1use serde_json::{Value, json};
2
3use lash_core::{ToolCall, ToolDefinition, ToolResult, ToolScheduling};
4
5use lash_tool_support::{StaticToolExecute, StaticToolProvider, object_schema};
6
7/// Web search via Tavily API.
8pub struct WebSearch {
9    api_key: String,
10    client: reqwest::Client,
11}
12
13impl WebSearch {
14    pub fn new(api_key: impl Into<String>) -> Self {
15        Self {
16            api_key: api_key.into(),
17            client: reqwest::Client::builder()
18                .timeout(std::time::Duration::from_secs(30))
19                .build()
20                .unwrap_or_default(),
21        }
22    }
23}
24
25/// Build the cached `search_web` tool provider for the given Tavily API key.
26pub fn web_search_provider(api_key: impl Into<String>) -> StaticToolProvider<WebSearch> {
27    StaticToolProvider::new(vec![web_search_tool_definition()], WebSearch::new(api_key))
28}
29
30#[async_trait::async_trait]
31impl StaticToolExecute for WebSearch {
32    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
33        let args = call.args;
34        let query = args
35            .get("query")
36            .and_then(|v| v.as_str())
37            .unwrap_or_default();
38        let limit = args
39            .get("limit")
40            .and_then(|v| v.as_u64())
41            .unwrap_or(5)
42            .clamp(1, 20);
43
44        if self.api_key.trim().is_empty() {
45            return ToolResult::err(json!("Tavily API key is required for web.search"));
46        }
47
48        let body = json!({
49            "query": query,
50            "max_results": limit,
51        });
52
53        let resp = self
54            .client
55            .post("https://api.tavily.com/search")
56            .bearer_auth(&self.api_key)
57            .json(&body)
58            .send()
59            .await;
60        match resp {
61            Ok(r) if r.status().is_success() => match r.json::<serde_json::Value>().await {
62                Ok(data) => ToolResult::ok(json!({
63                    "results": sanitize_results(data.get("results")),
64                })),
65                Err(e) => ToolResult::err_fmt(format_args!("Failed to parse response: {e}")),
66            },
67            Ok(r) => {
68                let status = r.status();
69                let body = r.text().await.unwrap_or_default();
70                ToolResult::err_fmt(format_args!("Tavily API error ({status}): {body}"))
71            }
72            Err(e) => ToolResult::err_fmt(format_args!("Request failed: {e}")),
73        }
74    }
75}
76
77fn sanitize_results(results: Option<&Value>) -> Vec<Value> {
78    results
79        .and_then(Value::as_array)
80        .into_iter()
81        .flatten()
82        .map(|item| {
83            json!({
84                "title": item.get("title").and_then(Value::as_str).unwrap_or_default(),
85                "url": item.get("url").and_then(Value::as_str).unwrap_or_default(),
86                "content": item.get("content").and_then(Value::as_str).unwrap_or_default(),
87            })
88        })
89        .collect()
90}
91
92fn web_search_tool_definition() -> ToolDefinition {
93    ToolDefinition::raw(
94                "tool:search_web",
95                "search_web",
96                "Search the web for candidate sources. Returns ranked `results` with snippet text; use `web.fetch` when you need the page itself. This tool does not expose Tavily's optional generated answer; summarize from result snippets and fetched pages.",
97                object_schema(
98                    serde_json::json!({
99                        "query": { "type": "string" },
100                        "limit": {
101                            "type": "integer",
102                            "minimum": 1,
103                            "maximum": 20,
104                            "default": 5,
105                            "description": "Maximum results to return (default 5)"
106                        }
107                    }),
108                    &["query"],
109                ),
110                serde_json::json!({
111                    "type": "object",
112                    "properties": {
113                        "results": {
114                            "type": "array",
115                            "items": {
116                                "type": "object",
117                                "properties": {
118                                    "title": { "type": "string" },
119                                    "url": { "type": "string" },
120                                    "content": {
121                                        "type": "string",
122                                        "description": "Search-result snippet text."
123                                    }
124                                },
125                                "required": ["title", "url", "content"],
126                                "additionalProperties": false
127                            }
128                        }
129                    },
130                    "required": ["results"],
131                    "additionalProperties": false
132                }),
133            )
134            .with_examples(vec![
135                "await web.search({ query: \"latest Rust release notes\", limit: 5 })?".into(),
136            ])
137            .with_lashlang_binding(lash_tool_support::lashlang_binding(
138                ["web"],
139                "search",
140                &["web_search"],
141            ))
142            .with_scheduling(ToolScheduling::Parallel)
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn search_web_uses_limit_argument_in_model_contract() {
151        let definition = web_search_tool_definition();
152
153        let properties = definition
154            .contract
155            .input_schema
156            .get("properties")
157            .and_then(serde_json::Value::as_object)
158            .expect("object properties");
159        assert!(properties.contains_key("limit"));
160        assert!(!properties.contains_key("max_results"));
161        assert_eq!(properties["limit"]["default"], serde_json::json!(5));
162        assert_eq!(properties["limit"]["maximum"], serde_json::json!(20));
163        assert!(
164            definition
165                .contract
166                .examples
167                .iter()
168                .all(|example| example.contains("limit"))
169        );
170        assert_eq!(
171            definition.contract.output_schema["type"],
172            serde_json::json!("object")
173        );
174        assert_eq!(
175            definition.contract.output_schema["required"],
176            serde_json::json!(["results"])
177        );
178        assert!(
179            !definition.contract.output_schema["properties"]
180                .as_object()
181                .unwrap()
182                .contains_key("answer")
183        );
184        assert_eq!(
185            definition.manifest.activation,
186            lash_core::ToolActivation::Always
187        );
188        assert_eq!(
189            definition.manifest.availability.base,
190            lash_core::ToolAvailability::Showcased
191        );
192    }
193
194    #[test]
195    fn search_web_sanitizes_tavily_results_to_contract() {
196        let results = sanitize_results(Some(&serde_json::json!([
197            {
198                "title": "Title",
199                "url": "https://example.com",
200                "content": "Snippet",
201                "score": 0.9,
202                "raw_content": null,
203                "favicon": "https://example.com/favicon.ico"
204            }
205        ])));
206
207        assert_eq!(
208            results,
209            vec![serde_json::json!({
210                "title": "Title",
211                "url": "https://example.com",
212                "content": "Snippet"
213            })]
214        );
215    }
216}