use anyhow::Result;
use std::sync::Arc;
use tokio::sync::RwLock;
#[allow(dead_code)]
pub struct AgentRouter {
handlers: Arc<RwLock<Vec<Box<dyn Handler>>>>,
circuit_breaker: CircuitBreaker,
}
#[async_trait::async_trait]
pub trait Handler: Send + Sync {
async fn handle(&self, request: &AgentRequest) -> Result<AgentResponse>;
fn protocol(&self) -> Protocol;
}
#[derive(Debug, Clone)]
pub struct AgentRequest {
pub id: String,
pub protocol: Protocol,
pub payload: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct AgentResponse {
pub request_id: String,
pub success: bool,
pub result: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct RouteDecision {
pub request: AgentRequest,
pub handler_index: usize,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
pub failure_threshold: u32,
pub reset_timeout: u64,
state: Arc<RwLock<CircuitState>>,
}
#[derive(Debug, Clone)]
enum CircuitState {
Closed,
#[allow(dead_code)]
Open,
#[allow(dead_code)]
HalfOpen,
}
impl Default for AgentRouter {
fn default() -> Self {
Self::new()
}
}
impl AgentRouter {
#[must_use]
pub fn new() -> Self {
Self {
handlers: Arc::new(RwLock::new(Vec::new())),
circuit_breaker: CircuitBreaker::new(),
}
}
pub async fn route(&self, request: AgentRequest) -> Result<AgentResponse> {
let handlers = self.handlers.read().await;
for handler in handlers.iter() {
if handler.protocol() == request.protocol {
return handler.handle(&request).await;
}
}
Err(anyhow::anyhow!(
"No handler for protocol: {:?}",
request.protocol
))
}
pub async fn register_handler(&self, handler: Box<dyn Handler>) {
let mut handlers = self.handlers.write().await;
handlers.push(handler);
}
#[must_use]
pub fn balance_load(&self, requests: Vec<AgentRequest>) -> Vec<RouteDecision> {
requests
.into_iter()
.enumerate()
.map(|(i, request)| {
RouteDecision {
request,
handler_index: i % 3, }
})
.collect()
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new()
}
}
impl CircuitBreaker {
#[must_use]
pub fn new() -> Self {
Self {
failure_threshold: 5,
reset_timeout: 60,
state: Arc::new(RwLock::new(CircuitState::Closed)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Protocol {
AgentsMd,
Mcp,
Http,
WebSocket,
}
#[cfg(test)]
mod tests {
use super::*;
struct TestHandler;
#[async_trait::async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: &AgentRequest) -> Result<AgentResponse> {
Ok(AgentResponse {
request_id: "test".to_string(),
success: true,
result: serde_json::json!({}),
})
}
fn protocol(&self) -> Protocol {
Protocol::AgentsMd
}
}
#[tokio::test]
async fn test_router_creation() {
let router = AgentRouter::new();
let handlers = router.handlers.read().await;
assert_eq!(handlers.len(), 0);
}
#[tokio::test]
async fn test_handler_registration() {
let router = AgentRouter::new();
let handler = Box::new(TestHandler);
router.register_handler(handler).await;
let handlers = router.handlers.read().await;
assert_eq!(handlers.len(), 1);
}
#[tokio::test]
async fn test_request_routing() {
let router = AgentRouter::new();
router.register_handler(Box::new(TestHandler)).await;
let request = AgentRequest {
id: "test".to_string(),
protocol: Protocol::AgentsMd,
payload: serde_json::json!({}),
};
let response = router.route(request).await.unwrap();
assert!(response.success);
}
#[test]
fn test_load_balancing() {
let router = AgentRouter::new();
let requests = vec![
AgentRequest {
id: "1".to_string(),
protocol: Protocol::AgentsMd,
payload: serde_json::json!({}),
},
AgentRequest {
id: "2".to_string(),
protocol: Protocol::Mcp,
payload: serde_json::json!({}),
},
];
let decisions = router.balance_load(requests);
assert_eq!(decisions.len(), 2);
assert_eq!(decisions[0].handler_index, 0);
assert_eq!(decisions[1].handler_index, 1);
}
}