use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use crate::matrixrpc::{
ErrorCode, JsonRpcError, JsonRpcId, JsonRpcResponse,
ServiceId, ServiceStatus, RegistryService,
};
use crate::matrixrpc::transport::{StdioTransport, TransportConfig as TransportSettings};
use crate::matrixrpc::router::{ToolRouter, ToolRouteResult, ToolRouterError};
#[derive(Debug, thiserror::Error)]
pub enum ToolExecutorError {
#[error("Transport error: {0}")]
TransportError(String),
#[error("Tool '{tool}' execution timed out after {timeout_ms}ms")]
Timeout { tool: String, timeout_ms: u64 },
#[error("Tool '{tool}' execution failed after {attempts} attempts")]
RetryExhausted { tool: String, attempts: u32, last_error: String },
#[error("Service '{0}' is not connected")]
ServiceNotConnected(ServiceId),
#[error("Invalid response from service: {0}")]
InvalidResponse(String),
#[error("Tool '{tool}' execution failed: {message}")]
ExecutionFailed { tool: String, message: String, data: Option<serde_json::Value> },
#[error("Routing error: {0}")]
RoutingError(#[from] ToolRouterError),
#[error("Internal error: {0}")]
Internal(String),
}
#[derive(Debug, Clone)]
pub struct RetryStrategy {
pub max_attempts: u32,
pub base_interval_ms: u64,
pub backoff: BackoffStrategy,
}
impl Default for RetryStrategy {
fn default() -> Self {
Self {
max_attempts: 3,
base_interval_ms: 1000,
backoff: BackoffStrategy::Exponential,
}
}
}
impl RetryStrategy {
pub fn new(max_attempts: u32, base_interval_ms: u64) -> Self {
Self {
max_attempts,
base_interval_ms,
backoff: BackoffStrategy::Exponential,
}
}
pub fn none() -> Self {
Self {
max_attempts: 1,
base_interval_ms: 0,
backoff: BackoffStrategy::Fixed,
}
}
pub fn fixed(max_attempts: u32, interval_ms: u64) -> Self {
Self {
max_attempts,
base_interval_ms: interval_ms,
backoff: BackoffStrategy::Fixed,
}
}
pub fn linear(max_attempts: u32, base_interval_ms: u64) -> Self {
Self {
max_attempts,
base_interval_ms,
backoff: BackoffStrategy::Linear,
}
}
pub fn get_delay_ms(&self, attempt: u32) -> u64 {
if attempt >= self.max_attempts {
return 0;
}
match self.backoff {
BackoffStrategy::Fixed => self.base_interval_ms,
BackoffStrategy::Linear => self.base_interval_ms * (attempt + 1) as u64,
BackoffStrategy::Exponential => {
self.base_interval_ms * 2u64.pow(attempt)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackoffStrategy {
Fixed,
Linear,
Exponential,
}
#[derive(Debug, Clone)]
pub struct ExecutionConfig {
pub timeout_ms: u64,
pub retry: RetryStrategy,
pub transport: TransportSettings,
}
impl Default for ExecutionConfig {
fn default() -> Self {
Self {
timeout_ms: 30_000,
retry: RetryStrategy::default(),
transport: TransportSettings::default(),
}
}
}
impl ExecutionConfig {
pub fn new(timeout_ms: u64) -> Self {
Self {
timeout_ms,
..Default::default()
}
}
pub fn retry(mut self, retry: RetryStrategy) -> Self {
self.retry = retry;
self
}
pub fn transport(mut self, transport: TransportSettings) -> Self {
self.transport = transport;
self
}
pub fn no_retry(timeout_ms: u64) -> Self {
Self {
timeout_ms,
retry: RetryStrategy::none(),
transport: TransportSettings::default(),
}
}
}
#[derive(Debug)]
#[allow(dead_code)]
struct TransportConnection {
service_id: ServiceId,
connected: RwLock<bool>,
}
impl TransportConnection {
fn new(service_id: ServiceId) -> Self {
Self {
service_id,
connected: RwLock::new(false),
}
}
async fn is_connected(&self) -> bool {
*self.connected.read().await
}
}
#[derive(Debug)]
pub struct ToolExecutor {
router: Arc<ToolRouter>,
registry: Arc<RegistryService>,
config: ExecutionConfig,
connections: RwLock<HashMap<ServiceId, Arc<TransportConnection>>>,
}
use std::collections::HashMap;
impl ToolExecutor {
pub fn new(router: Arc<ToolRouter>, registry: Arc<RegistryService>) -> Self {
Self {
router,
registry,
config: ExecutionConfig::default(),
connections: RwLock::new(HashMap::new()),
}
}
pub fn with_config(
router: Arc<ToolRouter>,
registry: Arc<RegistryService>,
config: ExecutionConfig,
) -> Self {
Self {
router,
registry,
config,
connections: RwLock::new(HashMap::new()),
}
}
pub async fn execute(
&self,
tool_name: &str,
params: serde_json::Value,
) -> Result<serde_json::Value, ToolExecutorError> {
let request_id = JsonRpcId::generate();
let route_result = self.router
.route(tool_name, params.clone(), request_id.clone())
.await?;
let tool_timeout = self.router.get_timeout(tool_name).await;
self.execute_with_retry(route_result, tool_timeout).await
}
pub async fn execute_with_id(
&self,
tool_name: &str,
params: serde_json::Value,
request_id: JsonRpcId,
) -> Result<serde_json::Value, ToolExecutorError> {
let route_result = self.router
.route(tool_name, params.clone(), request_id.clone())
.await?;
let tool_timeout = self.router.get_timeout(tool_name).await;
self.execute_with_retry(route_result, tool_timeout).await
}
async fn execute_with_retry(
&self,
route_result: ToolRouteResult,
timeout_ms: u64,
) -> Result<serde_json::Value, ToolExecutorError> {
let mut attempts = 0;
let mut last_error: Option<String> = None;
while attempts < self.config.retry.max_attempts {
attempts += 1;
if attempts > 1 {
let delay_ms = self.config.retry.get_delay_ms(attempts - 1);
if delay_ms > 0 {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
let service = self.registry.get(&route_result.service_id).await;
match service {
Some(s) if s.status == ServiceStatus::Running => {
}
Some(s) => {
last_error = Some(format!("Service status: {:?}", s.status));
continue; }
None => {
return Err(ToolExecutorError::ServiceNotConnected(route_result.service_id.clone()));
}
}
let result = self.execute_single(&route_result, timeout_ms).await;
match result {
Ok(value) => return Ok(value),
Err(ToolExecutorError::Timeout { .. }) => {
last_error = Some("Timeout".to_string());
}
Err(ToolExecutorError::TransportError(_)) => {
last_error = Some(result.unwrap_err().to_string());
}
Err(e) => {
return Err(e);
}
}
}
Err(ToolExecutorError::RetryExhausted {
tool: route_result.tool_name.clone(),
attempts,
last_error: last_error.unwrap_or_else(|| "Unknown error".to_string()),
})
}
async fn execute_single(
&self,
route_result: &ToolRouteResult,
__timeout_ms: u64,
) -> Result<serde_json::Value, ToolExecutorError> {
let __request = self.router.create_tool_request(route_result.clone());
let _connection = self.get_connection(&route_result.service_id).await?;
Err(ToolExecutorError::ServiceNotConnected(route_result.service_id.clone()))
}
async fn get_connection(
&self,
service_id: &ServiceId,
) -> Result<Arc<TransportConnection>, ToolExecutorError> {
let connections = self.connections.read().await;
if let Some(conn) = connections.get(service_id) {
if conn.is_connected().await {
return Ok(conn.clone());
}
}
drop(connections);
Err(ToolExecutorError::ServiceNotConnected(service_id.clone()))
}
#[allow(dead_code)]
fn process_response(
&self,
response: JsonRpcResponse,
tool_name: &str,
) -> Result<serde_json::Value, ToolExecutorError> {
if response.is_success() {
response.result.clone().ok_or_else(|| {
ToolExecutorError::InvalidResponse("No result in success response".to_string())
})
} else if response.is_error() {
let error = response.error.clone().unwrap_or_else(|| {
JsonRpcError::internal_error("Unknown error")
});
if error.code == ErrorCode::TIMEOUT_ERROR || error.code == ErrorCode::TRANSPORT_ERROR {
Err(ToolExecutorError::Timeout {
tool: tool_name.to_string(),
timeout_ms: self.config.timeout_ms,
})
} else {
Err(ToolExecutorError::ExecutionFailed {
tool: tool_name.to_string(),
message: error.message,
data: error.data,
})
}
} else {
Err(ToolExecutorError::InvalidResponse(
"Response has neither result nor error".to_string()
))
}
}
pub async fn register_connection(
&self,
service_id: ServiceId,
_transport: StdioTransport,
) {
let connection = Arc::new(TransportConnection::new(service_id.clone()));
*connection.connected.write().await = true;
let mut connections = self.connections.write().await;
connections.insert(service_id, connection);
}
pub async fn remove_connection(&self, service_id: &ServiceId) {
let mut connections = self.connections.write().await;
if let Some(conn) = connections.get(service_id) {
*conn.connected.write().await = false;
}
connections.remove(service_id);
}
pub async fn is_connected(&self, service_id: &ServiceId) -> bool {
let connections = self.connections.read().await;
match connections.get(service_id) {
Some(c) => c.is_connected().await,
None => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_execution_config_defaults() {
let config = ExecutionConfig::default();
assert_eq!(config.timeout_ms, 30_000);
assert_eq!(config.retry.max_attempts, 3);
}
#[test]
fn test_retry_strategy_delays() {
let fixed = RetryStrategy::fixed(3, 1000);
assert_eq!(fixed.get_delay_ms(0), 1000);
assert_eq!(fixed.get_delay_ms(1), 1000);
let linear = RetryStrategy::linear(3, 1000);
assert_eq!(linear.get_delay_ms(0), 1000);
assert_eq!(linear.get_delay_ms(1), 2000);
let exponential = RetryStrategy::new(3, 1000);
assert_eq!(exponential.get_delay_ms(0), 1000);
assert_eq!(exponential.get_delay_ms(1), 2000);
assert_eq!(exponential.get_delay_ms(2), 4000);
}
#[test]
fn test_retry_strategy_none() {
let none = RetryStrategy::none();
assert_eq!(none.max_attempts, 1);
assert_eq!(none.get_delay_ms(0), 0);
}
#[test]
fn test_execution_config_no_retry() {
let config = ExecutionConfig::no_retry(5000);
assert_eq!(config.timeout_ms, 5000);
assert_eq!(config.retry.max_attempts, 1);
}
}