use futures_util::future::BoxFuture;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_json_schema(&self) -> &str;
fn call(
&self,
args: Value,
) -> impl std::future::Future<Output = Result<Value, anyhow::Error>> + Send;
}
pub trait DynTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_json_schema(&self) -> &str;
fn call(&self, args: Value) -> BoxFuture<'_, Result<Value, anyhow::Error>>;
}
impl<T: Tool + ?Sized> DynTool for T {
fn name(&self) -> &str {
self.name()
}
fn description(&self) -> &str {
self.description()
}
fn parameters_json_schema(&self) -> &str {
self.parameters_json_schema()
}
fn call(&self, args: Value) -> BoxFuture<'_, Result<Value, anyhow::Error>> {
Box::pin(async move { self.call(args).await })
}
}
#[derive(Clone, Default)]
pub struct ToolRunner {
pub tools: Arc<tokio::sync::RwLock<HashMap<String, Arc<dyn DynTool>>>>,
}
impl std::fmt::Debug for ToolRunner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRunner")
.field("tools_count", &self.tools.try_read().map_or(0, |t| t.len()))
.finish()
}
}
impl ToolRunner {
pub fn new() -> Self {
Self {
tools: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
pub async fn register(&self, tool: Arc<dyn DynTool>) {
self.tools
.write()
.await
.insert(tool.name().to_string(), tool);
}
pub async fn process_tool_calls(
&self,
calls: Vec<crate::types::ToolCall>,
) -> Vec<crate::types::ToolResult> {
let mut results = Vec::new();
for call in calls {
let tools = self.tools.read().await;
if let Some(tool) = tools.get(&call.name) {
match tool.call(call.args).await {
Ok(val) => {
results.push(crate::types::ToolResult {
id: Some(call.id),
name: call.name.clone(),
result: Some(val),
error: None,
});
}
Err(e) => {
results.push(crate::types::ToolResult {
id: Some(call.id),
name: call.name.clone(),
result: None,
error: Some(e.to_string()),
});
}
}
} else if call.name == "google_search" || call.name == "web_search" {
let query = call
.args
.get("query")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
tracing::debug!(
"Built-in search fallback for '{}': query={}",
call.name,
query
);
#[cfg(not(target_arch = "wasm32"))]
{
match builtin_web_search(&query).await {
Ok(val) => {
results.push(crate::types::ToolResult {
id: Some(call.id),
name: call.name.clone(),
result: Some(val),
error: None,
});
}
Err(e) => {
results.push(crate::types::ToolResult {
id: Some(call.id),
name: call.name.clone(),
result: None,
error: Some(format!("Search fallback error: {e}")),
});
}
}
}
#[cfg(target_arch = "wasm32")]
{
results.push(crate::types::ToolResult {
id: Some(call.id),
name: call.name.clone(),
result: Some(serde_json::json!({
"results": [],
"note": "Search not available in WASM environment"
})),
error: None,
});
}
} else {
results.push(crate::types::ToolResult {
id: Some(call.id),
name: call.name.clone(),
result: None,
error: Some(format!("Tool {} not found", call.name)),
});
}
}
results
}
}
#[cfg(not(target_arch = "wasm32"))]
async fn builtin_web_search(query: &str) -> Result<Value, anyhow::Error> {
use tokio::process::Command;
let script = r#"
import sys, json, urllib.request, urllib.parse, html, re
query = sys.argv[1] if len(sys.argv) > 1 else ""
if not query:
print(json.dumps({"results": [], "note": "Empty query"}))
sys.exit(0)
url = "https://html.duckduckgo.com/html/?q=" + urllib.parse.quote_plus(query)
headers = {"User-Agent": "Mozilla/5.0 (compatible; AntigravitySDK/1.0)"}
req = urllib.request.Request(url, headers=headers)
try:
resp = urllib.request.urlopen(req, timeout=10)
body = resp.read().decode("utf-8", errors="replace")
except Exception as e:
print(json.dumps({"results": [], "error": str(e)}))
sys.exit(0)
results = []
# Parse result blocks: each result link has class "result__a"
for m in re.finditer(r'<a[^>]+class="result__a"[^>]*href="([^"]*)"[^>]*>(.*?)</a>', body, re.DOTALL):
link = html.unescape(m.group(1))
title = re.sub(r'<[^>]+>', '', html.unescape(m.group(2))).strip()
# DuckDuckGo wraps links through a redirect; extract the actual URL
if "uddg=" in link:
actual = urllib.parse.unquote(link.split("uddg=")[-1].split("&")[0])
link = actual
results.append({"title": title, "url": link})
if len(results) >= 10:
break
# Try to get snippets
snippets = re.findall(r'<a[^>]+class="result__snippet"[^>]*>(.*?)</a>', body, re.DOTALL)
for i, snip in enumerate(snippets):
if i < len(results):
results[i]["snippet"] = re.sub(r'<[^>]+>', '', html.unescape(snip)).strip()
print(json.dumps({"results": results}))
"#;
let output = Command::new("python3")
.arg("-c")
.arg(script)
.arg(query)
.output()
.await
.map_err(|e| anyhow::anyhow!("Failed to run python3 for search: {e}"))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(anyhow::anyhow!("Search script failed: {stderr}"));
}
let stdout = String::from_utf8_lossy(&output.stdout);
let parsed: Value = serde_json::from_str(stdout.trim())
.map_err(|e| anyhow::anyhow!("Failed to parse search results: {e}"))?;
Ok(parsed)
}