Skip to main content

composio_sdk/meta_tools/
multi_executor.rs

1//! Multi-Tool Executor Implementation
2//!
3//! Native Rust implementation of COMPOSIO_MULTI_EXECUTE_TOOL meta tool.
4//! Executes up to 20 tools in parallel using Tokio's async runtime.
5
6use crate::client::ComposioClient;
7use crate::error::ComposioError;
8use crate::models::ToolExecutionResponse;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::task::JoinHandle;
12
13/// Tool call specification for parallel execution
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ToolCall {
16    /// Tool slug to execute
17    pub tool_slug: String,
18    
19    /// Tool arguments
20    pub arguments: serde_json::Value,
21    
22    /// Optional connected account ID
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub connected_account_id: Option<String>,
25}
26
27/// Result of parallel tool execution
28#[derive(Debug)]
29pub struct MultiExecutionResult {
30    /// Individual tool results (in same order as input)
31    pub results: Vec<Result<ToolExecutionResponse, ComposioError>>,
32    
33    /// Number of successful executions
34    pub successful: usize,
35    
36    /// Number of failed executions
37    pub failed: usize,
38    
39    /// Total execution time in milliseconds
40    pub total_time_ms: u128,
41}
42
43/// Multi-tool executor
44pub struct MultiExecutor {
45    client: Arc<ComposioClient>,
46}
47
48impl MultiExecutor {
49    /// Create a new multi-executor instance
50    ///
51    /// # Arguments
52    ///
53    /// * `client` - Composio client instance
54    ///
55    /// # Example
56    ///
57    /// ```no_run
58    /// use composio_sdk::{ComposioClient, meta_tools::MultiExecutor};
59    /// use std::sync::Arc;
60    ///
61    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
62    /// let client = ComposioClient::builder()
63    ///     .api_key("your-api-key")
64    ///     .build()?;
65    ///
66    /// let executor = MultiExecutor::new(Arc::new(client));
67    /// # Ok(())
68    /// # }
69    /// ```
70    pub fn new(client: Arc<ComposioClient>) -> Self {
71        Self { client }
72    }
73
74    /// Execute multiple tools in parallel
75    ///
76    /// # Arguments
77    ///
78    /// * `session_id` - Session ID for execution context
79    /// * `tools` - Vector of tool calls to execute (max 20)
80    ///
81    /// # Returns
82    ///
83    /// Multi-execution result with individual results and statistics
84    ///
85    /// # Example
86    ///
87    /// ```no_run
88    /// # use composio_sdk::{ComposioClient, meta_tools::{MultiExecutor, ToolCall}};
89    /// # use std::sync::Arc;
90    /// # use serde_json::json;
91    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
92    /// # let client = Arc::new(ComposioClient::builder().api_key("key").build()?);
93    /// let executor = MultiExecutor::new(client);
94    ///
95    /// let tools = vec![
96    ///     ToolCall {
97    ///         tool_slug: "GITHUB_GET_REPOS".to_string(),
98    ///         arguments: json!({ "owner": "composio" }),
99    ///         connected_account_id: None,
100    ///     },
101    ///     ToolCall {
102    ///         tool_slug: "GITHUB_GET_ISSUES".to_string(),
103    ///         arguments: json!({ "owner": "composio", "repo": "composio" }),
104    ///         connected_account_id: None,
105    ///     },
106    /// ];
107    ///
108    /// let result = executor.execute_parallel("session_123", tools).await?;
109    /// println!("Successful: {}, Failed: {}", result.successful, result.failed);
110    /// # Ok(())
111    /// # }
112    /// ```
113    pub async fn execute_parallel(
114        &self,
115        session_id: &str,
116        tools: Vec<ToolCall>,
117    ) -> Result<MultiExecutionResult, ComposioError> {
118        // Validate tool count
119        if tools.is_empty() {
120            return Err(ComposioError::ValidationError(
121                "At least one tool must be provided".to_string(),
122            ));
123        }
124
125        if tools.len() > 20 {
126            return Err(ComposioError::ValidationError(
127                "Maximum 20 tools can be executed in parallel".to_string(),
128            ));
129        }
130
131        let start_time = std::time::Instant::now();
132
133        // Spawn parallel execution tasks
134        let mut handles: Vec<JoinHandle<Result<ToolExecutionResponse, ComposioError>>> = Vec::new();
135
136        for tool in tools {
137            let client = self.client.clone();
138            let session_id = session_id.to_string();
139
140            let handle = tokio::spawn(async move {
141                let url = format!(
142                    "{}/tool_router/session/{}/execute",
143                    client.config().base_url,
144                    session_id
145                );
146
147                let response = client
148                    .http_client()
149                    .post(&url)
150                    .json(&serde_json::json!({
151                        "tool_slug": tool.tool_slug,
152                        "arguments": tool.arguments,
153                        "connected_account_id": tool.connected_account_id,
154                    }))
155                    .send()
156                    .await?;
157
158                if !response.status().is_success() {
159                    return Err(ComposioError::from_response(response).await);
160                }
161
162                let result: ToolExecutionResponse = response.json().await?;
163
164                Ok(result)
165            });
166
167            handles.push(handle);
168        }
169
170        // Wait for all tasks to complete
171        let mut results = Vec::new();
172        let mut successful = 0;
173        let mut failed = 0;
174
175        for handle in handles {
176            match handle.await {
177                Ok(result) => {
178                    if result.is_ok() {
179                        successful += 1;
180                    } else {
181                        failed += 1;
182                    }
183                    results.push(result);
184                }
185                Err(e) => {
186                    failed += 1;
187                    results.push(Err(ComposioError::ExecutionError(format!(
188                        "Task panicked: {}",
189                        e
190                    ))));
191                }
192            }
193        }
194
195        let total_time_ms = start_time.elapsed().as_millis();
196
197        Ok(MultiExecutionResult {
198            results,
199            successful,
200            failed,
201            total_time_ms,
202        })
203    }
204
205    /// Execute tools sequentially (fallback for when parallel execution is not desired)
206    ///
207    /// # Arguments
208    ///
209    /// * `session_id` - Session ID
210    /// * `tools` - Vector of tool calls
211    ///
212    /// # Example
213    ///
214    /// ```no_run
215    /// # use composio_sdk::{ComposioClient, meta_tools::{MultiExecutor, ToolCall}};
216    /// # use std::sync::Arc;
217    /// # use serde_json::json;
218    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
219    /// # let client = Arc::new(ComposioClient::builder().api_key("key").build()?);
220    /// let executor = MultiExecutor::new(client);
221    ///
222    /// let tools = vec![
223    ///     ToolCall {
224    ///         tool_slug: "GITHUB_CREATE_ISSUE".to_string(),
225    ///         arguments: json!({ "title": "Bug", "body": "Description" }),
226    ///         connected_account_id: None,
227    ///     },
228    /// ];
229    ///
230    /// let result = executor.execute_sequential("session_123", tools).await?;
231    /// # Ok(())
232    /// # }
233    /// ```
234    pub async fn execute_sequential(
235        &self,
236        session_id: &str,
237        tools: Vec<ToolCall>,
238    ) -> Result<MultiExecutionResult, ComposioError> {
239        let start_time = std::time::Instant::now();
240        let mut results = Vec::new();
241        let mut successful = 0;
242        let mut failed = 0;
243
244        for tool in tools {
245            let url = format!(
246                "{}/tool_router/session/{}/execute",
247                self.client.config().base_url,
248                session_id
249            );
250
251            let result = async {
252                let response = self
253                    .client
254                    .http_client()
255                    .post(&url)
256                    .json(&serde_json::json!({
257                        "tool_slug": tool.tool_slug,
258                        "arguments": tool.arguments,
259                        "connected_account_id": tool.connected_account_id,
260                    }))
261                    .send()
262                    .await?;
263
264                if !response.status().is_success() {
265                    return Err(ComposioError::from_response(response).await);
266                }
267
268                let result: ToolExecutionResponse = response.json().await?;
269
270                Ok(result)
271            }
272            .await;
273
274            if result.is_ok() {
275                successful += 1;
276            } else {
277                failed += 1;
278            }
279
280            results.push(result);
281        }
282
283        let total_time_ms = start_time.elapsed().as_millis();
284
285        Ok(MultiExecutionResult {
286            results,
287            successful,
288            failed,
289            total_time_ms,
290        })
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_tool_call_serialization() {
300        let call = ToolCall {
301            tool_slug: "GITHUB_CREATE_ISSUE".to_string(),
302            arguments: serde_json::json!({
303                "title": "Test Issue",
304                "body": "Test body"
305            }),
306            connected_account_id: Some("ca_123".to_string()),
307        };
308
309        let json = serde_json::to_string(&call).unwrap();
310        assert!(json.contains("GITHUB_CREATE_ISSUE"));
311        assert!(json.contains("Test Issue"));
312        assert!(json.contains("ca_123"));
313
314        let deserialized: ToolCall = serde_json::from_str(&json).unwrap();
315        assert_eq!(deserialized.tool_slug, "GITHUB_CREATE_ISSUE");
316    }
317
318    #[test]
319    fn test_tool_call_without_account_id() {
320        let call = ToolCall {
321            tool_slug: "GMAIL_SEND_EMAIL".to_string(),
322            arguments: serde_json::json!({ "to": "user@example.com" }),
323            connected_account_id: None,
324        };
325
326        let json = serde_json::to_string(&call).unwrap();
327        assert!(!json.contains("connected_account_id"));
328    }
329}