Skip to main content

composio_sdk/models/
custom_tools.rs

1//! Custom Tools module
2//!
3//! This module provides functionality for creating and managing custom tools.
4//! Custom tools can be:
5//! - Simple tools without authentication
6//! - Toolkit-based tools with authentication and proxy execution
7//!
8//! # Examples
9//!
10//! ## Simple Custom Tool
11//! ```no_run
12//! use composio::CustomToolsRegistry;
13//! use serde_json::json;
14//!
15//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
16//! # let client = composio::ComposioClient::builder().api_key("key").build()?;
17//! let mut registry = CustomToolsRegistry::new(client.into());
18//!
19//! registry.register_simple(
20//!     "calculate_sum",
21//!     "Calculate the sum of two numbers",
22//!     json!({
23//!         "type": "object",
24//!         "properties": {
25//!             "a": {"type": "number"},
26//!             "b": {"type": "number"}
27//!         },
28//!         "required": ["a", "b"]
29//!     }),
30//!     |request| {
31//!         let a = request["a"].as_f64().unwrap_or(0.0);
32//!         let b = request["b"].as_f64().unwrap_or(0.0);
33//!         Ok(json!({"result": a + b}))
34//!     }
35//! );
36//! # Ok(())
37//! # }
38//! ```
39//!
40//! ## Toolkit-Based Custom Tool
41//! ```no_run
42//! use composio::CustomToolsRegistry;
43//! use serde_json::json;
44//! use std::collections::HashMap;
45//!
46//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
47//! # let client = composio::ComposioClient::builder().api_key("key").build()?;
48//! let mut registry = CustomToolsRegistry::new(client.into());
49//!
50//! registry.register_with_auth(
51//!     "create_custom_issue",
52//!     "Create a custom GitHub issue",
53//!     "github",
54//!     json!({
55//!         "type": "object",
56//!         "properties": {
57//!             "title": {"type": "string"},
58//!             "body": {"type": "string"}
59//!         },
60//!         "required": ["title"]
61//!     }),
62//!     |request, execute_request, _auth_credentials| {
63//!         execute_request.execute(
64//!             "/repos/owner/repo/issues",
65//!             "POST",
66//!             Some(request),
67//!             None,
68//!             None,
69//!         )
70//!     }
71//! );
72//! # Ok(())
73//! # }
74//! ```
75
76use crate::client::ComposioClient;
77use crate::error::ComposioError;
78use crate::models::response::ToolProxyResponse;
79use crate::models::tools::{ProxyParameter, ToolInfo, ToolkitRef};
80use async_trait::async_trait;
81use serde_json::Value as JsonValue;
82use std::collections::HashMap;
83use std::sync::Arc;
84
85/// Function for executing proxy requests to external APIs
86#[async_trait]
87pub trait ExecuteRequestFn: Send + Sync {
88    /// Execute a proxy request
89    ///
90    /// # Arguments
91    /// * `endpoint` - API endpoint (relative or absolute URL)
92    /// * `method` - HTTP method (GET, POST, PUT, DELETE, PATCH)
93    /// * `body` - Request body (optional)
94    /// * `connected_account_id` - Connected account to use for auth (optional)
95    /// * `parameters` - Additional headers/query parameters (optional)
96    async fn execute(
97        &self,
98        endpoint: &str,
99        method: &str,
100        body: Option<JsonValue>,
101        connected_account_id: Option<&str>,
102        parameters: Option<Vec<ProxyParameter>>,
103    ) -> Result<ToolProxyResponse, ComposioError>;
104}
105
106/// Executor for custom tools
107///
108/// This trait abstracts over different types of custom tool executors:
109/// - Simple executors (no authentication)
110/// - Authenticated executors (with proxy and credentials)
111#[async_trait]
112pub trait CustomToolExecutor: Send + Sync {
113    /// Execute the custom tool
114    ///
115    /// # Arguments
116    /// * `request` - Input arguments as JSON
117    /// * `execute_request` - Optional proxy executor for authenticated calls
118    /// * `auth_credentials` - Optional authentication credentials
119    async fn execute(
120        &self,
121        request: JsonValue,
122        execute_request: Option<&dyn ExecuteRequestFn>,
123        auth_credentials: Option<&HashMap<String, JsonValue>>,
124    ) -> Result<JsonValue, ComposioError>;
125}
126
127/// Simple executor that doesn't require authentication
128struct SimpleExecutor<F>
129where
130    F: Fn(JsonValue) -> Result<JsonValue, ComposioError> + Send + Sync,
131{
132    func: F,
133}
134
135#[async_trait]
136impl<F> CustomToolExecutor for SimpleExecutor<F>
137where
138    F: Fn(JsonValue) -> Result<JsonValue, ComposioError> + Send + Sync,
139{
140    async fn execute(
141        &self,
142        request: JsonValue,
143        _execute_request: Option<&dyn ExecuteRequestFn>,
144        _auth_credentials: Option<&HashMap<String, JsonValue>>,
145    ) -> Result<JsonValue, ComposioError> {
146        (self.func)(request)
147    }
148}
149
150/// Authenticated executor that requires proxy and credentials
151struct AuthenticatedExecutor<F>
152where
153    F: Fn(JsonValue, &dyn ExecuteRequestFn, &HashMap<String, JsonValue>) -> Result<JsonValue, ComposioError>
154        + Send
155        + Sync,
156{
157    func: F,
158}
159
160#[async_trait]
161impl<F> CustomToolExecutor for AuthenticatedExecutor<F>
162where
163    F: Fn(JsonValue, &dyn ExecuteRequestFn, &HashMap<String, JsonValue>) -> Result<JsonValue, ComposioError>
164        + Send
165        + Sync,
166{
167    async fn execute(
168        &self,
169        request: JsonValue,
170        execute_request: Option<&dyn ExecuteRequestFn>,
171        auth_credentials: Option<&HashMap<String, JsonValue>>,
172    ) -> Result<JsonValue, ComposioError> {
173        let execute_request = execute_request
174            .ok_or_else(|| ComposioError::InvalidInput("Execute request function required".to_string()))?;
175        let auth_credentials = auth_credentials
176            .ok_or_else(|| ComposioError::InvalidInput("Auth credentials required".to_string()))?;
177        
178        (self.func)(request, execute_request, auth_credentials)
179    }
180}
181
182/// Custom tool definition with execution logic
183pub struct CustomTool {
184    /// Tool slug (unique identifier, uppercase)
185    pub slug: String,
186    
187    /// Tool name (human-readable)
188    pub name: String,
189    
190    /// Tool description
191    pub description: String,
192    
193    /// Toolkit slug (if toolkit-based)
194    pub toolkit: Option<String>,
195    
196    /// Input parameters schema (JSON Schema)
197    pub input_schema: JsonValue,
198    
199    /// Output schema (optional)
200    pub output_schema: Option<JsonValue>,
201    
202    /// Whether authentication is required
203    pub requires_auth: bool,
204    
205    /// Executor function
206    executor: Box<dyn CustomToolExecutor>,
207    
208    /// Client for API operations
209    client: Arc<ComposioClient>,
210}
211
212impl CustomTool {
213    /// Create a simple custom tool without authentication
214    ///
215    /// # Arguments
216    /// * `name` - Tool name (will be converted to uppercase slug)
217    /// * `description` - Tool description
218    /// * `input_schema` - JSON Schema for input parameters
219    /// * `executor` - Function to execute the tool
220    /// * `client` - Composio client
221    pub fn new_simple<F>(
222        name: &str,
223        description: &str,
224        input_schema: JsonValue,
225        executor: F,
226        client: Arc<ComposioClient>,
227    ) -> Self
228    where
229        F: Fn(JsonValue) -> Result<JsonValue, ComposioError> + Send + Sync + 'static,
230    {
231        let slug = name.to_uppercase().replace(' ', "_");
232        
233        Self {
234            slug,
235            name: name.to_string(),
236            description: description.to_string(),
237            toolkit: None,
238            input_schema,
239            output_schema: None,
240            requires_auth: false,
241            executor: Box::new(SimpleExecutor { func: executor }),
242            client,
243        }
244    }
245    
246    /// Create a toolkit-based custom tool with authentication
247    ///
248    /// # Arguments
249    /// * `name` - Tool name (will be prefixed with toolkit and converted to uppercase)
250    /// * `description` - Tool description
251    /// * `toolkit` - Toolkit slug (e.g., "github")
252    /// * `input_schema` - JSON Schema for input parameters
253    /// * `executor` - Function to execute the tool (receives proxy executor and credentials)
254    /// * `client` - Composio client
255    pub fn new_with_auth<F>(
256        name: &str,
257        description: &str,
258        toolkit: &str,
259        input_schema: JsonValue,
260        executor: F,
261        client: Arc<ComposioClient>,
262    ) -> Self
263    where
264        F: Fn(JsonValue, &dyn ExecuteRequestFn, &HashMap<String, JsonValue>) -> Result<JsonValue, ComposioError>
265            + Send
266            + Sync
267            + 'static,
268    {
269        let toolkit_upper = toolkit.to_uppercase();
270        let name_upper = name.to_uppercase().replace(' ', "_");
271        let slug = format!("{}_{}", toolkit_upper, name_upper);
272        let full_name = format!("{}_{}", toolkit.to_lowercase(), name);
273        
274        Self {
275            slug,
276            name: full_name,
277            description: description.to_string(),
278            toolkit: Some(toolkit.to_string()),
279            input_schema,
280            output_schema: None,
281            requires_auth: true,
282            executor: Box::new(AuthenticatedExecutor { func: executor }),
283            client,
284        }
285    }
286    
287    /// Execute the custom tool
288    ///
289    /// # Arguments
290    /// * `arguments` - Input arguments
291    /// * `user_id` - User ID (required for authenticated tools)
292    pub async fn execute(
293        &self,
294        arguments: HashMap<String, JsonValue>,
295        user_id: Option<&str>,
296    ) -> Result<JsonValue, ComposioError> {
297        let request = JsonValue::Object(
298            arguments.into_iter()
299                .map(|(k, v)| (k, v))
300                .collect()
301        );
302        
303        if self.requires_auth {
304            let user_id = user_id.ok_or_else(|| {
305                ComposioError::InvalidInput("user_id required for authenticated tools".to_string())
306            })?;
307            
308            let auth_credentials = self.get_auth_credentials(user_id).await?;
309            
310            // Create proxy executor
311            let proxy_executor = ProxyExecutor {
312                client: self.client.clone(),
313                toolkit: self.toolkit.clone().unwrap(),
314            };
315            
316            self.executor.execute(
317                request,
318                Some(&proxy_executor),
319                Some(&auth_credentials),
320            ).await
321        } else {
322            self.executor.execute(request, None, None).await
323        }
324    }
325    
326    /// Get authentication credentials for a user
327    async fn get_auth_credentials(&self, user_id: &str) -> Result<HashMap<String, JsonValue>, ComposioError> {
328        let toolkit = self.toolkit.as_ref()
329            .ok_or_else(|| ComposioError::InvalidInput("Toolkit required for auth".to_string()))?;
330        
331        // Get connected accounts for this toolkit and user
332        let params = crate::models::connected_accounts::ConnectedAccountListParams {
333            user_ids: Some(vec![user_id.to_string()]),
334            toolkit_slugs: Some(vec![toolkit.clone()]),
335            statuses: Some(vec![crate::models::connected_accounts::ConnectionStatus::Active]),
336            ..Default::default()
337        };
338        
339        let accounts = self.client.list_connected_accounts(params).await?;
340        
341        if accounts.items.is_empty() {
342            return Err(ComposioError::ValidationError(format!(
343                "No active connected accounts found for toolkit {} and user {}",
344                toolkit, user_id
345            )));
346        }
347        
348        // Get most recent account
349        let account = accounts.items.into_iter()
350            .max_by(|a, b| a.created_at.cmp(&b.created_at))
351            .unwrap();
352        
353        // Extract credentials from state
354        if let Some(state) = account.state {
355            Ok(serde_json::from_value(state)?)
356        } else {
357            Err(ComposioError::ValidationError(
358                "Connected account has no state data".to_string()
359            ))
360        }
361    }
362    
363    /// Convert to ToolInfo format (for API compatibility)
364    pub fn to_tool_info(&self) -> ToolInfo {
365        ToolInfo {
366            slug: self.slug.clone(),
367            name: self.name.clone(),
368            description: self.description.clone(),
369            input_parameters: self.input_schema.clone(),
370            output_parameters: self.output_schema.clone().unwrap_or(JsonValue::Object(Default::default())),
371            scopes: vec![],
372            version: "1.0.0".to_string(),
373            available_versions: vec![],
374            toolkit: ToolkitRef {
375                slug: self.toolkit.clone().unwrap_or_else(|| "custom".to_string()).to_uppercase(),
376                name: Some(self.toolkit.clone().unwrap_or_else(|| "custom".to_string())),
377                logo: None,
378            },
379            is_deprecated: false,
380            no_auth: !self.requires_auth,
381            tags: vec![],
382        }
383    }
384}
385
386/// Proxy executor implementation
387struct ProxyExecutor {
388    #[allow(dead_code)]
389    client: Arc<ComposioClient>,
390    #[allow(dead_code)]
391    toolkit: String,
392}
393
394#[async_trait]
395impl ExecuteRequestFn for ProxyExecutor {
396    async fn execute(
397        &self,
398        _endpoint: &str,
399        _method: &str,
400        _body: Option<JsonValue>,
401        _connected_account_id: Option<&str>,
402        _parameters: Option<Vec<ProxyParameter>>,
403    ) -> Result<ToolProxyResponse, ComposioError> {
404        // TODO: Implement actual proxy execution via HTTP
405        // This would call the /api/v3/tools/execute/proxy endpoint
406        Err(ComposioError::InvalidInput(
407            "Proxy execution not yet fully implemented - requires proxy API endpoint".to_string()
408        ))
409    }
410}
411
412/// Registry for managing custom tools
413pub struct CustomToolsRegistry {
414    tools: HashMap<String, Arc<CustomTool>>,
415    client: Arc<ComposioClient>,
416}
417
418impl CustomToolsRegistry {
419    /// Create a new custom tools registry
420    pub fn new(client: Arc<ComposioClient>) -> Self {
421        Self {
422            tools: HashMap::new(),
423            client,
424        }
425    }
426    
427    /// Register a simple custom tool without authentication
428    ///
429    /// # Arguments
430    /// * `name` - Tool name
431    /// * `description` - Tool description
432    /// * `input_schema` - JSON Schema for input parameters
433    /// * `executor` - Function to execute the tool
434    ///
435    /// # Returns
436    /// Arc reference to the registered tool
437    pub fn register_simple<F>(
438        &mut self,
439        name: &str,
440        description: &str,
441        input_schema: JsonValue,
442        executor: F,
443    ) -> Arc<CustomTool>
444    where
445        F: Fn(JsonValue) -> Result<JsonValue, ComposioError> + Send + Sync + 'static,
446    {
447        let tool = Arc::new(CustomTool::new_simple(
448            name,
449            description,
450            input_schema,
451            executor,
452            self.client.clone(),
453        ));
454        
455        self.tools.insert(tool.slug.clone(), tool.clone());
456        tool
457    }
458    
459    /// Register a toolkit-based custom tool with authentication
460    ///
461    /// # Arguments
462    /// * `name` - Tool name
463    /// * `description` - Tool description
464    /// * `toolkit` - Toolkit slug
465    /// * `input_schema` - JSON Schema for input parameters
466    /// * `executor` - Function to execute the tool
467    ///
468    /// # Returns
469    /// Arc reference to the registered tool
470    pub fn register_with_auth<F>(
471        &mut self,
472        name: &str,
473        description: &str,
474        toolkit: &str,
475        input_schema: JsonValue,
476        executor: F,
477    ) -> Arc<CustomTool>
478    where
479        F: Fn(JsonValue, &dyn ExecuteRequestFn, &HashMap<String, JsonValue>) -> Result<JsonValue, ComposioError>
480            + Send
481            + Sync
482            + 'static,
483    {
484        let tool = Arc::new(CustomTool::new_with_auth(
485            name,
486            description,
487            toolkit,
488            input_schema,
489            executor,
490            self.client.clone(),
491        ));
492        
493        self.tools.insert(tool.slug.clone(), tool.clone());
494        tool
495    }
496    
497    /// Get a custom tool by slug
498    pub fn get(&self, slug: &str) -> Option<Arc<CustomTool>> {
499        self.tools.get(slug).cloned()
500    }
501    
502    /// Execute a custom tool
503    ///
504    /// # Arguments
505    /// * `slug` - Tool slug
506    /// * `arguments` - Input arguments
507    /// * `user_id` - User ID (required for authenticated tools)
508    pub async fn execute(
509        &self,
510        slug: &str,
511        arguments: HashMap<String, JsonValue>,
512        user_id: Option<&str>,
513    ) -> Result<JsonValue, ComposioError> {
514        let tool = self.get(slug)
515            .ok_or_else(|| ComposioError::ValidationError(format!("Custom tool {} not found", slug)))?;
516        
517        tool.execute(arguments, user_id).await
518    }
519    
520    /// List all registered custom tools
521    pub fn list(&self) -> Vec<Arc<CustomTool>> {
522        self.tools.values().cloned().collect()
523    }
524    
525    /// Get all custom tools as ToolInfo (for API compatibility)
526    pub fn list_as_tools(&self) -> Vec<ToolInfo> {
527        self.tools.values()
528            .map(|tool| tool.to_tool_info())
529            .collect()
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use serde_json::json;
537    
538    #[test]
539    fn test_custom_tool_simple() {
540        let client = Arc::new(
541            ComposioClient::builder()
542                .api_key("test_key")
543                .build()
544                .unwrap()
545        );
546        
547        let tool = CustomTool::new_simple(
548            "calculate sum",
549            "Calculate the sum of two numbers",
550            json!({
551                "type": "object",
552                "properties": {
553                    "a": {"type": "number"},
554                    "b": {"type": "number"}
555                }
556            }),
557            |request| {
558                let a = request["a"].as_f64().unwrap_or(0.0);
559                let b = request["b"].as_f64().unwrap_or(0.0);
560                Ok(json!({"result": a + b}))
561            },
562            client,
563        );
564        
565        assert_eq!(tool.slug, "CALCULATE_SUM");
566        assert_eq!(tool.name, "calculate sum");
567        assert!(!tool.requires_auth);
568        assert!(tool.toolkit.is_none());
569    }
570    
571    #[test]
572    fn test_custom_tool_with_auth() {
573        let client = Arc::new(
574            ComposioClient::builder()
575                .api_key("test_key")
576                .build()
577                .unwrap()
578        );
579        
580        let tool = CustomTool::new_with_auth(
581            "create issue",
582            "Create a GitHub issue",
583            "github",
584            json!({
585                "type": "object",
586                "properties": {
587                    "title": {"type": "string"}
588                }
589            }),
590            |_request, _execute_request, _auth_credentials| {
591                Ok(json!({"id": 123}))
592            },
593            client,
594        );
595        
596        assert_eq!(tool.slug, "GITHUB_CREATE_ISSUE");
597        assert_eq!(tool.name, "github_create issue");
598        assert!(tool.requires_auth);
599        assert_eq!(tool.toolkit, Some("github".to_string()));
600    }
601    
602    #[test]
603    fn test_registry() {
604        let client = Arc::new(
605            ComposioClient::builder()
606                .api_key("test_key")
607                .build()
608                .unwrap()
609        );
610        
611        let mut registry = CustomToolsRegistry::new(client);
612        
613        registry.register_simple(
614            "test_tool",
615            "A test tool",
616            json!({"type": "object"}),
617            |_request| Ok(json!({"success": true}))
618        );
619        
620        assert!(registry.get("TEST_TOOL").is_some());
621        assert_eq!(registry.list().len(), 1);
622    }
623}