#[cfg(feature = "composio")]
use std::sync::Arc;
#[cfg(feature = "composio")]
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
#[cfg(feature = "composio")]
use async_trait::async_trait;
#[cfg(feature = "composio")]
use crate::integrations::composio::transport::SseTransport;
#[cfg(feature = "composio")]
use crate::integrations::composio::ComposioError;
#[cfg(feature = "composio")]
use crate::reasoning::circuit_breaker::CircuitBreakerRegistry;
#[cfg(feature = "composio")]
use crate::reasoning::executor::ActionExecutor;
#[cfg(feature = "composio")]
use crate::reasoning::inference::ToolDefinition;
#[cfg(feature = "composio")]
use crate::reasoning::loop_types::{LoopConfig, Observation, ProposedAction};
#[cfg(feature = "composio")]
struct McpRateLimiter {
max_per_minute: u32,
calls: AtomicU32,
window_start: AtomicU64,
}
#[cfg(feature = "composio")]
impl McpRateLimiter {
fn new(max_per_minute: Option<u32>) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Self {
max_per_minute: max_per_minute.unwrap_or(0),
calls: AtomicU32::new(0),
window_start: AtomicU64::new(now),
}
}
fn check(&self) -> bool {
if self.max_per_minute == 0 {
return true; }
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let window = self.window_start.load(Ordering::Relaxed);
if now - window >= 60 {
self.window_start.store(now, Ordering::Relaxed);
self.calls.store(1, Ordering::Relaxed);
return true;
}
let current = self.calls.fetch_add(1, Ordering::Relaxed);
current < self.max_per_minute
}
}
#[cfg(feature = "composio")]
pub struct ComposioToolExecutor {
transport: Arc<SseTransport>,
tool_definitions: Vec<ToolDefinition>,
rate_limiter: McpRateLimiter,
}
#[cfg(feature = "composio")]
impl ComposioToolExecutor {
pub async fn discover_with_rate_limit(
transport: Arc<SseTransport>,
max_calls_per_minute: Option<u32>,
) -> Result<Self, ComposioError> {
let mut executor = Self::discover(transport).await?;
executor.rate_limiter = McpRateLimiter::new(max_calls_per_minute);
Ok(executor)
}
pub async fn discover(transport: Arc<SseTransport>) -> Result<Self, ComposioError> {
let result = transport
.request("tools/list", serde_json::json!({}))
.await?;
let tools_value = result.get("tools").cloned().unwrap_or(result.clone());
let raw_tools: Vec<serde_json::Value> =
serde_json::from_value(tools_value).map_err(|e| ComposioError::TransportError {
reason: format!("failed to parse tools/list response: {}", e),
})?;
let tool_definitions = raw_tools
.into_iter()
.map(|t| {
let name = t["name"].as_str().unwrap_or("unknown").to_string();
let description = t["description"].as_str().unwrap_or("").to_string();
let parameters = t
.get("inputSchema")
.or_else(|| t.get("parameters"))
.cloned()
.unwrap_or(serde_json::json!({
"type": "object",
"properties": {},
"required": []
}));
ToolDefinition {
name,
description,
parameters,
}
})
.collect();
Ok(Self {
transport,
tool_definitions,
rate_limiter: McpRateLimiter::new(None),
})
}
pub fn tool_definitions(&self) -> &[ToolDefinition] {
&self.tool_definitions
}
async fn call_tool(&self, name: &str, arguments: &str) -> Result<String, String> {
if !self.rate_limiter.check() {
tracing::warn!(tool = name, "MCP rate limit exceeded");
return Err(format!(
"Rate limit exceeded: max {} calls/min for MCP server",
self.rate_limiter.max_per_minute
));
}
let args: serde_json::Value =
serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
let params = serde_json::json!({
"name": name,
"arguments": args,
});
let result = self
.transport
.request("tools/call", params)
.await
.map_err(|e| e.to_string())?;
if let Some(content) = result.get("content") {
if let Some(arr) = content.as_array() {
let texts: Vec<&str> = arr
.iter()
.filter_map(|c| c.get("text").and_then(|t| t.as_str()))
.collect();
if !texts.is_empty() {
return Ok(texts.join("\n"));
}
}
}
Ok(serde_json::to_string_pretty(&result).unwrap_or_default())
}
}
#[cfg(feature = "composio")]
#[async_trait]
impl ActionExecutor for ComposioToolExecutor {
async fn execute_actions(
&self,
actions: &[ProposedAction],
_config: &LoopConfig,
_circuit_breakers: &CircuitBreakerRegistry,
) -> Vec<Observation> {
let mut observations = Vec::new();
for action in actions {
if let ProposedAction::ToolCall {
call_id,
name,
arguments,
} = action
{
match self.call_tool(name, arguments).await {
Ok(result) => {
observations.push(
Observation::tool_result(name.clone(), result)
.with_call_id(call_id.clone()),
);
}
Err(err) => {
observations.push(
Observation::tool_error(name.clone(), err)
.with_call_id(call_id.clone()),
);
}
}
}
}
observations
}
fn tool_definitions(&self) -> Vec<ToolDefinition> {
self.tool_definitions.clone()
}
}
#[cfg(test)]
#[cfg(feature = "composio")]
mod tests {
use super::*;
#[test]
fn test_tool_definition_parsing() {
let raw = serde_json::json!([
{
"name": "TWITTER_CREATE_TWEET",
"description": "Post a tweet to Twitter/X",
"inputSchema": {
"type": "object",
"properties": {
"text": { "type": "string", "description": "Tweet text" }
},
"required": ["text"]
}
}
]);
let tools: Vec<serde_json::Value> = serde_json::from_value(raw).unwrap();
let defs: Vec<ToolDefinition> = tools
.into_iter()
.map(|t| {
let name = t["name"].as_str().unwrap_or("unknown").to_string();
let description = t["description"].as_str().unwrap_or("").to_string();
let parameters = t
.get("inputSchema")
.or_else(|| t.get("parameters"))
.cloned()
.unwrap_or(serde_json::json!({}));
ToolDefinition {
name,
description,
parameters,
}
})
.collect();
assert_eq!(defs.len(), 1);
assert_eq!(defs[0].name, "TWITTER_CREATE_TWEET");
assert!(defs[0].parameters["properties"]["text"].is_object());
}
#[test]
fn test_rate_limiter_unlimited() {
let limiter = McpRateLimiter::new(None);
for _ in 0..1000 {
assert!(limiter.check());
}
}
#[test]
fn test_rate_limiter_enforced() {
let limiter = McpRateLimiter::new(Some(5));
for _ in 0..5 {
assert!(limiter.check());
}
assert!(!limiter.check());
}
#[test]
fn test_mcp_content_extraction() {
let result = serde_json::json!({
"content": [
{ "type": "text", "text": "Tweet posted successfully" }
]
});
if let Some(content) = result.get("content") {
if let Some(arr) = content.as_array() {
let texts: Vec<&str> = arr
.iter()
.filter_map(|c| c.get("text").and_then(|t| t.as_str()))
.collect();
assert_eq!(texts, vec!["Tweet posted successfully"]);
}
}
}
}