Skip to main content

lash_tools/web/
web_search.rs

1use serde_json::{Value, json};
2
3use lash_core::{ToolCall, ToolDefinition, ToolResult, ToolScheduling};
4
5use lash_tool_support::{
6    StaticToolExecute, StaticToolProvider, ToolDefinitionLashlangExt, object_schema,
7};
8
9/// Web search via Tavily API.
10pub struct WebSearch {
11    api_key: String,
12    client: reqwest::Client,
13}
14
15impl WebSearch {
16    pub fn new(api_key: impl Into<String>) -> Self {
17        Self {
18            api_key: api_key.into(),
19            client: reqwest::Client::builder()
20                .timeout(std::time::Duration::from_secs(30))
21                .build()
22                .unwrap_or_default(),
23        }
24    }
25}
26
27/// Build the cached `search_web` tool provider for the given Tavily API key.
28pub fn web_search_provider(api_key: impl Into<String>) -> StaticToolProvider<WebSearch> {
29    StaticToolProvider::new(vec![web_search_tool_definition()], WebSearch::new(api_key))
30}
31
32#[async_trait::async_trait]
33impl StaticToolExecute for WebSearch {
34    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
35        let args = call.args;
36        let query = args
37            .get("query")
38            .and_then(|v| v.as_str())
39            .unwrap_or_default();
40        let limit = args
41            .get("limit")
42            .and_then(|v| v.as_u64())
43            .unwrap_or(5)
44            .clamp(1, 20);
45
46        if self.api_key.trim().is_empty() {
47            return ToolResult::err(json!("Tavily API key is required for web.search"));
48        }
49
50        let body = json!({
51            "query": query,
52            "max_results": limit,
53        });
54
55        let resp = self
56            .client
57            .post("https://api.tavily.com/search")
58            .bearer_auth(&self.api_key)
59            .json(&body)
60            .send()
61            .await;
62        match resp {
63            Ok(r) if r.status().is_success() => match r.json::<serde_json::Value>().await {
64                Ok(data) => ToolResult::ok(json!({
65                    "results": sanitize_results(data.get("results")),
66                })),
67                Err(e) => ToolResult::err_fmt(format_args!("Failed to parse response: {e}")),
68            },
69            Ok(r) => {
70                let status = r.status();
71                let body = r.text().await.unwrap_or_default();
72                ToolResult::err_fmt(format_args!("Tavily API error ({status}): {body}"))
73            }
74            Err(e) => ToolResult::err_fmt(format_args!("Request failed: {e}")),
75        }
76    }
77}
78
79fn sanitize_results(results: Option<&Value>) -> Vec<Value> {
80    results
81        .and_then(Value::as_array)
82        .into_iter()
83        .flatten()
84        .map(|item| {
85            json!({
86                "title": item.get("title").and_then(Value::as_str).unwrap_or_default(),
87                "url": item.get("url").and_then(Value::as_str).unwrap_or_default(),
88                "content": item.get("content").and_then(Value::as_str).unwrap_or_default(),
89            })
90        })
91        .collect()
92}
93
94fn web_search_tool_definition() -> ToolDefinition {
95    ToolDefinition::raw(
96                "tool:search_web",
97                "search_web",
98                "Search the web for candidate sources. Returns ranked `results` with snippet text; use `web.fetch` when you need the page itself. This tool does not expose Tavily's optional generated answer; summarize from result snippets and fetched pages.",
99                object_schema(
100                    serde_json::json!({
101                        "query": { "type": "string" },
102                        "limit": {
103                            "type": "integer",
104                            "minimum": 1,
105                            "maximum": 20,
106                            "default": 5,
107                            "description": "Maximum results to return (default 5)"
108                        }
109                    }),
110                    &["query"],
111                ),
112                serde_json::json!({
113                    "type": "object",
114                    "properties": {
115                        "results": {
116                            "type": "array",
117                            "items": {
118                                "type": "object",
119                                "properties": {
120                                    "title": { "type": "string" },
121                                    "url": { "type": "string" },
122                                    "content": {
123                                        "type": "string",
124                                        "description": "Search-result snippet text."
125                                    }
126                                },
127                                "required": ["title", "url", "content"],
128                                "additionalProperties": false
129                            }
130                        }
131                    },
132                    "required": ["results"],
133                    "additionalProperties": false
134                }),
135            )
136            .with_examples(vec![
137                "await web.search({ query: \"latest Rust release notes\", limit: 5 })?".into(),
138            ])
139            .with_lashlang_binding(lash_tool_support::lashlang_binding(
140                ["web"],
141                "search",
142                &["web_search"],
143            ))
144            .with_scheduling(ToolScheduling::Parallel)
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn search_web_uses_limit_argument_in_model_contract() {
153        let definition = web_search_tool_definition();
154
155        let properties = definition
156            .contract
157            .input_schema
158            .get("properties")
159            .and_then(serde_json::Value::as_object)
160            .expect("object properties");
161        assert!(properties.contains_key("limit"));
162        assert!(!properties.contains_key("max_results"));
163        assert_eq!(properties["limit"]["default"], serde_json::json!(5));
164        assert_eq!(properties["limit"]["maximum"], serde_json::json!(20));
165        assert!(
166            definition
167                .contract
168                .examples
169                .iter()
170                .all(|example| example.contains("limit"))
171        );
172        assert_eq!(
173            definition.contract.output_schema["type"],
174            serde_json::json!("object")
175        );
176        assert_eq!(
177            definition.contract.output_schema["required"],
178            serde_json::json!(["results"])
179        );
180        assert!(
181            !definition.contract.output_schema["properties"]
182                .as_object()
183                .unwrap()
184                .contains_key("answer")
185        );
186        assert_eq!(
187            definition.manifest.activation,
188            lash_core::ToolActivation::Always
189        );
190        assert_eq!(
191            definition.manifest.availability.base,
192            lash_core::ToolAvailability::Showcased
193        );
194    }
195
196    #[test]
197    fn search_web_sanitizes_tavily_results_to_contract() {
198        let results = sanitize_results(Some(&serde_json::json!([
199            {
200                "title": "Title",
201                "url": "https://example.com",
202                "content": "Snippet",
203                "score": 0.9,
204                "raw_content": null,
205                "favicon": "https://example.com/favicon.ico"
206            }
207        ])));
208
209        assert_eq!(
210            results,
211            vec![serde_json::json!({
212                "title": "Title",
213                "url": "https://example.com",
214                "content": "Snippet"
215            })]
216        );
217    }
218}