sentinel_proxy/agents/
manager.rs

1//! Agent manager for coordinating external processing agents.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use base64::{engine::general_purpose::STANDARD, Engine as _};
8use sentinel_agent_protocol::{
9    EventType, RequestBodyChunkEvent, RequestHeadersEvent, ResponseHeadersEvent,
10};
11use sentinel_common::{
12    errors::{SentinelError, SentinelResult},
13    types::CircuitBreakerConfig,
14    CircuitBreaker,
15};
16use sentinel_config::{AgentConfig, FailureMode};
17use tokio::sync::{RwLock, Semaphore};
18use tracing::{debug, error, info, warn};
19
20use super::agent::Agent;
21use super::context::AgentCallContext;
22use super::decision::AgentDecision;
23use super::metrics::AgentMetrics;
24use super::pool::AgentConnectionPool;
25
26/// Agent manager handling all external agents.
27pub struct AgentManager {
28    /// Configured agents
29    agents: Arc<RwLock<HashMap<String, Arc<Agent>>>>,
30    /// Connection pools for agents
31    connection_pools: Arc<RwLock<HashMap<String, Arc<AgentConnectionPool>>>>,
32    /// Circuit breakers per agent
33    circuit_breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
34    /// Global agent metrics
35    metrics: Arc<AgentMetrics>,
36    /// Maximum concurrent agent calls
37    #[allow(dead_code)]
38    max_concurrent_calls: usize,
39    /// Global semaphore for agent calls
40    call_semaphore: Arc<Semaphore>,
41}
42
43impl AgentManager {
44    /// Create new agent manager.
45    pub async fn new(
46        agents: Vec<AgentConfig>,
47        max_concurrent_calls: usize,
48    ) -> SentinelResult<Self> {
49        let mut agent_map = HashMap::new();
50        let mut pools = HashMap::new();
51        let mut breakers = HashMap::new();
52
53        for config in agents {
54            let pool = Arc::new(AgentConnectionPool::new(
55                10, // max connections
56                2,  // min idle
57                5,  // max idle
58                Duration::from_secs(60),
59            ));
60
61            let circuit_breaker = Arc::new(CircuitBreaker::new(
62                config
63                    .circuit_breaker
64                    .clone()
65                    .unwrap_or_else(CircuitBreakerConfig::default),
66            ));
67
68            let agent = Arc::new(Agent::new(
69                config.clone(),
70                Arc::clone(&pool),
71                Arc::clone(&circuit_breaker),
72            ));
73
74            agent_map.insert(config.id.clone(), agent);
75            pools.insert(config.id.clone(), pool);
76            breakers.insert(config.id.clone(), circuit_breaker);
77        }
78
79        Ok(Self {
80            agents: Arc::new(RwLock::new(agent_map)),
81            connection_pools: Arc::new(RwLock::new(pools)),
82            circuit_breakers: Arc::new(RwLock::new(breakers)),
83            metrics: Arc::new(AgentMetrics::default()),
84            max_concurrent_calls,
85            call_semaphore: Arc::new(Semaphore::new(max_concurrent_calls)),
86        })
87    }
88
89    /// Process request headers through agents.
90    pub async fn process_request_headers(
91        &self,
92        ctx: &AgentCallContext,
93        headers: &HashMap<String, Vec<String>>,
94        route_agents: &[String],
95    ) -> SentinelResult<AgentDecision> {
96        let event = RequestHeadersEvent {
97            metadata: ctx.metadata.clone(),
98            method: headers
99                .get(":method")
100                .and_then(|v| v.first())
101                .unwrap_or(&"GET".to_string())
102                .clone(),
103            uri: headers
104                .get(":path")
105                .and_then(|v| v.first())
106                .unwrap_or(&"/".to_string())
107                .clone(),
108            headers: headers.clone(),
109        };
110
111        self.process_event(EventType::RequestHeaders, &event, route_agents, ctx)
112            .await
113    }
114
115    /// Process request body chunk through agents.
116    pub async fn process_request_body(
117        &self,
118        ctx: &AgentCallContext,
119        data: &[u8],
120        is_last: bool,
121        route_agents: &[String],
122    ) -> SentinelResult<AgentDecision> {
123        // Check body size limits
124        let max_size = 1024 * 1024; // 1MB default
125        if data.len() > max_size {
126            warn!(
127                correlation_id = %ctx.correlation_id,
128                size = data.len(),
129                "Request body exceeds agent inspection limit"
130            );
131            return Ok(AgentDecision::default_allow());
132        }
133
134        let event = RequestBodyChunkEvent {
135            correlation_id: ctx.correlation_id.to_string(),
136            data: STANDARD.encode(data),
137            is_last,
138            total_size: ctx.request_body.as_ref().map(|b| b.len()),
139        };
140
141        self.process_event(EventType::RequestBodyChunk, &event, route_agents, ctx)
142            .await
143    }
144
145    /// Process response headers through agents.
146    pub async fn process_response_headers(
147        &self,
148        ctx: &AgentCallContext,
149        status: u16,
150        headers: &HashMap<String, Vec<String>>,
151        route_agents: &[String],
152    ) -> SentinelResult<AgentDecision> {
153        let event = ResponseHeadersEvent {
154            correlation_id: ctx.correlation_id.to_string(),
155            status,
156            headers: headers.clone(),
157        };
158
159        self.process_event(EventType::ResponseHeaders, &event, route_agents, ctx)
160            .await
161    }
162
163    /// Process an event through relevant agents.
164    async fn process_event<T: serde::Serialize>(
165        &self,
166        event_type: EventType,
167        event: &T,
168        route_agents: &[String],
169        ctx: &AgentCallContext,
170    ) -> SentinelResult<AgentDecision> {
171        // Get relevant agents for this route and event type
172        let agents = self.agents.read().await;
173        let relevant_agents: Vec<_> = route_agents
174            .iter()
175            .filter_map(|id| agents.get(id))
176            .filter(|agent| agent.handles_event(event_type))
177            .collect();
178
179        if relevant_agents.is_empty() {
180            return Ok(AgentDecision::default_allow());
181        }
182
183        debug!(
184            correlation_id = %ctx.correlation_id,
185            event_type = ?event_type,
186            agent_count = relevant_agents.len(),
187            "Processing event through agents"
188        );
189
190        // Process through each agent sequentially
191        let mut combined_decision = AgentDecision::default_allow();
192
193        for agent in relevant_agents {
194            // Acquire semaphore permit
195            let _permit = self.call_semaphore.acquire().await.map_err(|_| {
196                SentinelError::Internal {
197                    message: "Failed to acquire agent call permit".to_string(),
198                    correlation_id: Some(ctx.correlation_id.to_string()),
199                    source: None,
200                }
201            })?;
202
203            // Check circuit breaker
204            if !agent.circuit_breaker().is_closed().await {
205                warn!(
206                    agent_id = %agent.id(),
207                    correlation_id = %ctx.correlation_id,
208                    "Circuit breaker open, skipping agent"
209                );
210
211                // Handle based on failure mode
212                if agent.failure_mode() == FailureMode::Closed {
213                    return Ok(AgentDecision::block(503, "Service unavailable"));
214                }
215                continue;
216            }
217
218            // Call agent with timeout
219            let start = Instant::now();
220            let timeout = Duration::from_millis(agent.timeout_ms());
221
222            match tokio::time::timeout(timeout, agent.call_event(event_type, event)).await {
223                Ok(Ok(response)) => {
224                    let duration = start.elapsed();
225                    agent.record_success(duration).await;
226
227                    // Merge response into combined decision
228                    combined_decision.merge(response.into());
229
230                    // If decision is to block/redirect/challenge, stop processing
231                    if !combined_decision.is_allow() {
232                        break;
233                    }
234                }
235                Ok(Err(e)) => {
236                    agent.record_failure().await;
237                    error!(
238                        agent_id = %agent.id(),
239                        correlation_id = %ctx.correlation_id,
240                        error = %e,
241                        "Agent call failed"
242                    );
243
244                    if agent.failure_mode() == FailureMode::Closed {
245                        return Err(e);
246                    }
247                }
248                Err(_) => {
249                    agent.record_timeout().await;
250                    warn!(
251                        agent_id = %agent.id(),
252                        correlation_id = %ctx.correlation_id,
253                        timeout_ms = agent.timeout_ms(),
254                        "Agent call timed out"
255                    );
256
257                    if agent.failure_mode() == FailureMode::Closed {
258                        return Ok(AgentDecision::block(504, "Gateway timeout"));
259                    }
260                }
261            }
262        }
263
264        Ok(combined_decision)
265    }
266
267    /// Initialize agent connections.
268    pub async fn initialize(&self) -> SentinelResult<()> {
269        let agents = self.agents.read().await;
270
271        for (id, agent) in agents.iter() {
272            info!("Initializing agent: {}", id);
273            if let Err(e) = agent.initialize().await {
274                error!("Failed to initialize agent {}: {}", id, e);
275                // Continue with other agents
276            }
277        }
278
279        Ok(())
280    }
281
282    /// Shutdown all agents.
283    pub async fn shutdown(&self) {
284        info!("Shutting down agent manager");
285
286        let agents = self.agents.read().await;
287        for (id, agent) in agents.iter() {
288            debug!("Shutting down agent: {}", id);
289            agent.shutdown().await;
290        }
291    }
292
293    /// Get agent metrics.
294    pub fn metrics(&self) -> &AgentMetrics {
295        &self.metrics
296    }
297}