use crate::client::ComposioClient;
use crate::error::ComposioError;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub slug: String,
pub name: String,
pub description: String,
pub toolkit: String,
pub is_connected: bool,
pub score: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_schema: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub execution_plan: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub known_pitfalls: Option<Vec<String>>,
}
pub struct ToolSearch {
client: Arc<ComposioClient>,
}
impl ToolSearch {
pub fn new(client: Arc<ComposioClient>) -> Self {
Self { client }
}
pub async fn search(
&self,
query: &str,
session_id: &str,
) -> Result<Vec<SearchResult>, ComposioError> {
let url = format!(
"{}/tool_router/session/{}/search",
self.client.config().base_url,
session_id
);
let response = self
.client
.http_client()
.post(&url)
.json(&serde_json::json!({
"query": query,
"include_schema": true,
"include_plan": true,
}))
.send()
.await?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
let data: serde_json::Value = response.json().await?;
let results = data["data"]["tools"]
.as_array()
.ok_or_else(|| {
ComposioError::InvalidInput("Invalid search response format".to_string())
})?
.iter()
.filter_map(|tool| serde_json::from_value(tool.clone()).ok())
.collect();
Ok(results)
}
pub async fn search_filtered(
&self,
query: &str,
session_id: &str,
toolkits: Option<Vec<&str>>,
limit: Option<usize>,
) -> Result<Vec<SearchResult>, ComposioError> {
let url = format!(
"{}/tool_router/session/{}/search",
self.client.config().base_url,
session_id
);
let mut body = serde_json::json!({
"query": query,
"include_schema": true,
"include_plan": true,
});
if let Some(tk) = toolkits {
body["toolkits"] = serde_json::json!(tk);
}
if let Some(lim) = limit {
body["limit"] = serde_json::json!(lim);
}
let response = self
.client
.http_client()
.post(&url)
.json(&body)
.send()
.await?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
let data: serde_json::Value = response.json().await?;
let results = data["data"]["tools"]
.as_array()
.ok_or_else(|| {
ComposioError::InvalidInput("Invalid search response format".to_string())
})?
.iter()
.filter_map(|tool| serde_json::from_value(tool.clone()).ok())
.collect();
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_result_serialization() {
let result = SearchResult {
slug: "GITHUB_CREATE_ISSUE".to_string(),
name: "Create Issue".to_string(),
description: "Create a new issue in a repository".to_string(),
toolkit: "github".to_string(),
is_connected: true,
score: 0.95,
input_schema: Some(serde_json::json!({
"type": "object",
"properties": {
"title": { "type": "string" },
"body": { "type": "string" }
}
})),
execution_plan: Some(vec![
"Ensure GitHub is connected".to_string(),
"Provide repository owner and name".to_string(),
]),
known_pitfalls: Some(vec![
"Title is required".to_string(),
]),
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("GITHUB_CREATE_ISSUE"));
assert!(json.contains("0.95"));
let deserialized: SearchResult = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.slug, "GITHUB_CREATE_ISSUE");
assert_eq!(deserialized.score, 0.95);
}
#[test]
fn test_search_result_without_optional_fields() {
let result = SearchResult {
slug: "GMAIL_SEND_EMAIL".to_string(),
name: "Send Email".to_string(),
description: "Send an email".to_string(),
toolkit: "gmail".to_string(),
is_connected: false,
score: 0.88,
input_schema: None,
execution_plan: None,
known_pitfalls: None,
};
let json = serde_json::to_string(&result).unwrap();
assert!(!json.contains("input_schema"));
assert!(!json.contains("execution_plan"));
assert!(!json.contains("known_pitfalls"));
}
}