Skip to main content

apr_cli/federation/
gateway.rs

1//! Federation Gateway - Main entry point for distributed inference
2//!
3//! The gateway orchestrates the full inference lifecycle:
4//! routing, execution, retries, and response handling.
5
6use super::catalog::ModelCatalog;
7use super::health::{CircuitBreaker, HealthChecker};
8use super::routing::Router;
9use super::traits::*;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14// ============================================================================
15// Gateway Configuration
16// ============================================================================
17
18/// Configuration for the federation gateway
19#[derive(Debug, Clone)]
20pub struct GatewayConfig {
21    /// Maximum retries per request
22    pub max_retries: u32,
23    /// Timeout for individual inference calls
24    pub inference_timeout: Duration,
25    /// Enable request tracing
26    pub enable_tracing: bool,
27}
28
29impl Default for GatewayConfig {
30    fn default() -> Self {
31        Self {
32            max_retries: 3,
33            inference_timeout: Duration::from_secs(30),
34            enable_tracing: true,
35        }
36    }
37}
38
39// ============================================================================
40// Gateway Statistics
41// ============================================================================
42
43/// Thread-safe statistics tracker
44struct StatsTracker {
45    total_requests: AtomicU64,
46    successful_requests: AtomicU64,
47    failed_requests: AtomicU64,
48    total_tokens: AtomicU64,
49    total_latency_ms: AtomicU64,
50    active_streams: AtomicU64,
51}
52
53impl StatsTracker {
54    fn new() -> Self {
55        Self {
56            total_requests: AtomicU64::new(0),
57            successful_requests: AtomicU64::new(0),
58            failed_requests: AtomicU64::new(0),
59            total_tokens: AtomicU64::new(0),
60            total_latency_ms: AtomicU64::new(0),
61            active_streams: AtomicU64::new(0),
62        }
63    }
64
65    fn record_request(&self) {
66        self.total_requests.fetch_add(1, Ordering::SeqCst);
67    }
68
69    fn record_success(&self, latency: Duration, tokens: Option<u32>) {
70        self.successful_requests.fetch_add(1, Ordering::SeqCst);
71        self.total_latency_ms
72            .fetch_add(latency.as_millis() as u64, Ordering::SeqCst);
73        if let Some(t) = tokens {
74            self.total_tokens.fetch_add(t as u64, Ordering::SeqCst);
75        }
76    }
77
78    fn record_failure(&self) {
79        self.failed_requests.fetch_add(1, Ordering::SeqCst);
80    }
81
82    #[allow(dead_code)]
83    fn increment_streams(&self) {
84        self.active_streams.fetch_add(1, Ordering::SeqCst);
85    }
86
87    #[allow(dead_code)]
88    fn decrement_streams(&self) {
89        self.active_streams.fetch_sub(1, Ordering::SeqCst);
90    }
91
92    fn snapshot(&self) -> GatewayStats {
93        let total = self.total_requests.load(Ordering::SeqCst);
94        let successful = self.successful_requests.load(Ordering::SeqCst);
95        let total_latency = self.total_latency_ms.load(Ordering::SeqCst);
96
97        let avg_latency = if successful > 0 {
98            Duration::from_millis(total_latency / successful)
99        } else {
100            Duration::ZERO
101        };
102
103        GatewayStats {
104            total_requests: total,
105            successful_requests: successful,
106            failed_requests: self.failed_requests.load(Ordering::SeqCst),
107            total_tokens: self.total_tokens.load(Ordering::SeqCst),
108            avg_latency,
109            active_streams: self.active_streams.load(Ordering::SeqCst) as u32,
110        }
111    }
112}
113
114impl Default for StatsTracker {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120// ============================================================================
121// Federation Gateway
122// ============================================================================
123
124/// The main federation gateway
125pub struct FederationGateway {
126    config: GatewayConfig,
127    router: Arc<Router>,
128    health: Arc<HealthChecker>,
129    circuit_breaker: Arc<CircuitBreaker>,
130    middlewares: Vec<Box<dyn GatewayMiddleware>>,
131    stats: StatsTracker,
132}
133
134impl FederationGateway {
135    pub fn new(
136        config: GatewayConfig,
137        router: Arc<Router>,
138        health: Arc<HealthChecker>,
139        circuit_breaker: Arc<CircuitBreaker>,
140    ) -> Self {
141        Self {
142            config,
143            router,
144            health,
145            circuit_breaker,
146            middlewares: Vec::new(),
147            stats: StatsTracker::new(),
148        }
149    }
150
151    /// Add middleware to the gateway
152    #[must_use]
153    pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
154        self.middlewares.push(Box::new(middleware));
155        self
156    }
157
158    /// Execute inference with retries
159    async fn execute_with_retries(
160        &self,
161        mut request: InferenceRequest,
162    ) -> FederationResult<InferenceResponse> {
163        // Apply before_route middlewares
164        for middleware in &self.middlewares {
165            middleware.before_route(&mut request)?;
166        }
167
168        let mut last_error = None;
169        let mut tried_nodes = Vec::new();
170
171        for attempt in 0..=self.config.max_retries {
172            // Route request (excluding already-tried nodes)
173            // In production, we'd modify the request to exclude tried_nodes
174            // For now, use the original request
175            let target = match self.router.route(&request).await {
176                Ok(t) => t,
177                Err(e) => {
178                    last_error = Some(e);
179                    continue;
180                }
181            };
182
183            // Check circuit breaker
184            if self.circuit_breaker.is_open(&target.node_id) {
185                last_error = Some(FederationError::CircuitOpen(target.node_id.clone()));
186                tried_nodes.push(target.node_id);
187                continue;
188            }
189
190            // Execute inference
191            let start = Instant::now();
192            match self.execute_on_node(&target, &request).await {
193                Ok(mut response) => {
194                    let latency = start.elapsed();
195
196                    // Record success
197                    self.health.report_success(&target.node_id, latency);
198                    self.circuit_breaker.record_success(&target.node_id);
199                    self.stats.record_success(latency, response.tokens);
200
201                    // Apply after_infer middlewares
202                    for middleware in &self.middlewares {
203                        middleware.after_infer(&request, &mut response)?;
204                    }
205
206                    return Ok(response);
207                }
208                Err(e) => {
209                    // Record failure
210                    self.health.report_failure(&target.node_id);
211                    self.circuit_breaker.record_failure(&target.node_id);
212
213                    // Notify middlewares
214                    for middleware in &self.middlewares {
215                        middleware.on_error(&request, &e);
216                    }
217
218                    last_error = Some(e);
219                    tried_nodes.push(target.node_id);
220
221                    if attempt < self.config.max_retries {
222                        // Brief backoff before retry
223                        tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
224                    }
225                }
226            }
227        }
228
229        self.stats.record_failure();
230        Err(last_error
231            .unwrap_or_else(|| FederationError::Internal("All retries exhausted".to_string())))
232    }
233
234    /// Execute inference on a specific node
235    #[allow(clippy::unused_async)] // Will be async when HTTP calls implemented
236    async fn execute_on_node(
237        &self,
238        target: &RouteTarget,
239        _request: &InferenceRequest,
240    ) -> FederationResult<InferenceResponse> {
241        // In production, this would make an HTTP/gRPC call to the target node
242        // For now, we simulate the response
243
244        if target.endpoint.is_empty() {
245            // Simulated response for testing
246            Ok(InferenceResponse {
247                output: b"simulated output".to_vec(),
248                served_by: target.node_id.clone(),
249                latency: Duration::from_millis(50),
250                tokens: Some(10),
251            })
252        } else {
253            // Would make actual HTTP call here
254            // For now, return simulated response
255            Ok(InferenceResponse {
256                output: b"simulated output".to_vec(),
257                served_by: target.node_id.clone(),
258                latency: Duration::from_millis(50),
259                tokens: Some(10),
260            })
261        }
262    }
263}
264
265impl GatewayTrait for FederationGateway {
266    fn infer(
267        &self,
268        request: InferenceRequest,
269    ) -> BoxFuture<'_, FederationResult<InferenceResponse>> {
270        Box::pin(async move {
271            self.stats.record_request();
272            self.execute_with_retries(request).await
273        })
274    }
275
276    fn infer_stream(
277        &self,
278        request: InferenceRequest,
279    ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>> {
280        Box::pin(async move {
281            self.stats.record_request();
282            self.stats.increment_streams();
283
284            // Route request
285            let target = self.router.route(&request).await?;
286
287            // Create streaming connection
288            let stream = FederationTokenStream::new(
289                target,
290                request,
291                Arc::clone(&self.health),
292                Arc::clone(&self.circuit_breaker),
293            );
294
295            let stream: Box<dyn TokenStream> = Box::new(stream);
296            Ok(stream)
297        })
298    }
299
300    fn stats(&self) -> GatewayStats {
301        self.stats.snapshot()
302    }
303}
304
305// ============================================================================
306// Token Stream Implementation
307// ============================================================================
308
309/// Streaming token response
310struct FederationTokenStream {
311    target: RouteTarget,
312    _request: InferenceRequest,
313    health: Arc<HealthChecker>,
314    circuit_breaker: Arc<CircuitBreaker>,
315    tokens_generated: u32,
316    finished: bool,
317}
318
319impl FederationTokenStream {
320    fn new(
321        target: RouteTarget,
322        request: InferenceRequest,
323        health: Arc<HealthChecker>,
324        circuit_breaker: Arc<CircuitBreaker>,
325    ) -> Self {
326        Self {
327            target,
328            _request: request,
329            health,
330            circuit_breaker,
331            tokens_generated: 0,
332            finished: false,
333        }
334    }
335}
336
337impl TokenStream for FederationTokenStream {
338    fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>> {
339        Box::pin(async move {
340            if self.finished {
341                return None;
342            }
343
344            // Simulate token generation (in production, would read from connection)
345            self.tokens_generated += 1;
346
347            if self.tokens_generated > 10 {
348                self.finished = true;
349                self.health
350                    .report_success(&self.target.node_id, Duration::from_millis(50));
351                self.circuit_breaker.record_success(&self.target.node_id);
352                return None;
353            }
354
355            Some(Ok(format!("token_{}", self.tokens_generated).into_bytes()))
356        })
357    }
358
359    fn cancel(&mut self) -> BoxFuture<'_, ()> {
360        Box::pin(async move {
361            self.finished = true;
362        })
363    }
364}
365
366// ============================================================================
367// Gateway Builder
368// ============================================================================
369
370/// Builder for creating federation gateways
371pub struct GatewayBuilder {
372    config: GatewayConfig,
373    catalog: Option<Arc<ModelCatalog>>,
374    health: Option<Arc<HealthChecker>>,
375    circuit_breaker: Option<Arc<CircuitBreaker>>,
376    router: Option<Arc<Router>>,
377    middlewares: Vec<Box<dyn GatewayMiddleware>>,
378}
379
380include!("middleware.rs");
381include!("gateway_03.rs");