use crate::client::ComposioClient;
use crate::error::ComposioError;
use crate::models::ToolExecutionResponse;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::task::JoinHandle;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub tool_slug: String,
pub arguments: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub connected_account_id: Option<String>,
}
#[derive(Debug)]
pub struct MultiExecutionResult {
pub results: Vec<Result<ToolExecutionResponse, ComposioError>>,
pub successful: usize,
pub failed: usize,
pub total_time_ms: u128,
}
pub struct MultiExecutor {
client: Arc<ComposioClient>,
}
impl MultiExecutor {
pub fn new(client: Arc<ComposioClient>) -> Self {
Self { client }
}
pub async fn execute_parallel(
&self,
session_id: &str,
tools: Vec<ToolCall>,
) -> Result<MultiExecutionResult, ComposioError> {
if tools.is_empty() {
return Err(ComposioError::ValidationError(
"At least one tool must be provided".to_string(),
));
}
if tools.len() > 20 {
return Err(ComposioError::ValidationError(
"Maximum 20 tools can be executed in parallel".to_string(),
));
}
let start_time = std::time::Instant::now();
let mut handles: Vec<JoinHandle<Result<ToolExecutionResponse, ComposioError>>> = Vec::new();
for tool in tools {
let client = self.client.clone();
let session_id = session_id.to_string();
let handle = tokio::spawn(async move {
let url = format!(
"{}/tool_router/session/{}/execute",
client.config().base_url,
session_id
);
let response = client
.http_client()
.post(&url)
.json(&serde_json::json!({
"tool_slug": tool.tool_slug,
"arguments": tool.arguments,
"connected_account_id": tool.connected_account_id,
}))
.send()
.await?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
let result: ToolExecutionResponse = response.json().await?;
Ok(result)
});
handles.push(handle);
}
let mut results = Vec::new();
let mut successful = 0;
let mut failed = 0;
for handle in handles {
match handle.await {
Ok(result) => {
if result.is_ok() {
successful += 1;
} else {
failed += 1;
}
results.push(result);
}
Err(e) => {
failed += 1;
results.push(Err(ComposioError::ExecutionError(format!(
"Task panicked: {}",
e
))));
}
}
}
let total_time_ms = start_time.elapsed().as_millis();
Ok(MultiExecutionResult {
results,
successful,
failed,
total_time_ms,
})
}
pub async fn execute_sequential(
&self,
session_id: &str,
tools: Vec<ToolCall>,
) -> Result<MultiExecutionResult, ComposioError> {
let start_time = std::time::Instant::now();
let mut results = Vec::new();
let mut successful = 0;
let mut failed = 0;
for tool in tools {
let url = format!(
"{}/tool_router/session/{}/execute",
self.client.config().base_url,
session_id
);
let result = async {
let response = self
.client
.http_client()
.post(&url)
.json(&serde_json::json!({
"tool_slug": tool.tool_slug,
"arguments": tool.arguments,
"connected_account_id": tool.connected_account_id,
}))
.send()
.await?;
if !response.status().is_success() {
return Err(ComposioError::from_response(response).await);
}
let result: ToolExecutionResponse = response.json().await?;
Ok(result)
}
.await;
if result.is_ok() {
successful += 1;
} else {
failed += 1;
}
results.push(result);
}
let total_time_ms = start_time.elapsed().as_millis();
Ok(MultiExecutionResult {
results,
successful,
failed,
total_time_ms,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_call_serialization() {
let call = ToolCall {
tool_slug: "GITHUB_CREATE_ISSUE".to_string(),
arguments: serde_json::json!({
"title": "Test Issue",
"body": "Test body"
}),
connected_account_id: Some("ca_123".to_string()),
};
let json = serde_json::to_string(&call).unwrap();
assert!(json.contains("GITHUB_CREATE_ISSUE"));
assert!(json.contains("Test Issue"));
assert!(json.contains("ca_123"));
let deserialized: ToolCall = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.tool_slug, "GITHUB_CREATE_ISSUE");
}
#[test]
fn test_tool_call_without_account_id() {
let call = ToolCall {
tool_slug: "GMAIL_SEND_EMAIL".to_string(),
arguments: serde_json::json!({ "to": "user@example.com" }),
connected_account_id: None,
};
let json = serde_json::to_string(&call).unwrap();
assert!(!json.contains("connected_account_id"));
}
}