1use super::catalog::ModelCatalog;
7use super::health::{CircuitBreaker, HealthChecker};
8use super::routing::Router;
9use super::traits::*;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14#[derive(Debug, Clone)]
20pub struct GatewayConfig {
21 pub max_retries: u32,
23 pub inference_timeout: Duration,
25 pub enable_tracing: bool,
27}
28
29impl Default for GatewayConfig {
30 fn default() -> Self {
31 Self {
32 max_retries: 3,
33 inference_timeout: Duration::from_secs(30),
34 enable_tracing: true,
35 }
36 }
37}
38
39struct StatsTracker {
45 total_requests: AtomicU64,
46 successful_requests: AtomicU64,
47 failed_requests: AtomicU64,
48 total_tokens: AtomicU64,
49 total_latency_ms: AtomicU64,
50 active_streams: AtomicU64,
51}
52
53impl StatsTracker {
54 fn new() -> Self {
55 Self {
56 total_requests: AtomicU64::new(0),
57 successful_requests: AtomicU64::new(0),
58 failed_requests: AtomicU64::new(0),
59 total_tokens: AtomicU64::new(0),
60 total_latency_ms: AtomicU64::new(0),
61 active_streams: AtomicU64::new(0),
62 }
63 }
64
65 fn record_request(&self) {
66 self.total_requests.fetch_add(1, Ordering::SeqCst);
67 }
68
69 fn record_success(&self, latency: Duration, tokens: Option<u32>) {
70 self.successful_requests.fetch_add(1, Ordering::SeqCst);
71 self.total_latency_ms
72 .fetch_add(latency.as_millis() as u64, Ordering::SeqCst);
73 if let Some(t) = tokens {
74 self.total_tokens.fetch_add(t as u64, Ordering::SeqCst);
75 }
76 }
77
78 fn record_failure(&self) {
79 self.failed_requests.fetch_add(1, Ordering::SeqCst);
80 }
81
82 #[allow(dead_code)]
83 fn increment_streams(&self) {
84 self.active_streams.fetch_add(1, Ordering::SeqCst);
85 }
86
87 #[allow(dead_code)]
88 fn decrement_streams(&self) {
89 self.active_streams.fetch_sub(1, Ordering::SeqCst);
90 }
91
92 fn snapshot(&self) -> GatewayStats {
93 let total = self.total_requests.load(Ordering::SeqCst);
94 let successful = self.successful_requests.load(Ordering::SeqCst);
95 let total_latency = self.total_latency_ms.load(Ordering::SeqCst);
96
97 let avg_latency = if successful > 0 {
98 Duration::from_millis(total_latency / successful)
99 } else {
100 Duration::ZERO
101 };
102
103 GatewayStats {
104 total_requests: total,
105 successful_requests: successful,
106 failed_requests: self.failed_requests.load(Ordering::SeqCst),
107 total_tokens: self.total_tokens.load(Ordering::SeqCst),
108 avg_latency,
109 active_streams: self.active_streams.load(Ordering::SeqCst) as u32,
110 }
111 }
112}
113
114impl Default for StatsTracker {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120pub struct FederationGateway {
126 config: GatewayConfig,
127 router: Arc<Router>,
128 health: Arc<HealthChecker>,
129 circuit_breaker: Arc<CircuitBreaker>,
130 middlewares: Vec<Box<dyn GatewayMiddleware>>,
131 stats: StatsTracker,
132}
133
134impl FederationGateway {
135 pub fn new(
136 config: GatewayConfig,
137 router: Arc<Router>,
138 health: Arc<HealthChecker>,
139 circuit_breaker: Arc<CircuitBreaker>,
140 ) -> Self {
141 Self {
142 config,
143 router,
144 health,
145 circuit_breaker,
146 middlewares: Vec::new(),
147 stats: StatsTracker::new(),
148 }
149 }
150
151 #[must_use]
153 pub fn with_middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
154 self.middlewares.push(Box::new(middleware));
155 self
156 }
157
158 async fn execute_with_retries(
160 &self,
161 mut request: InferenceRequest,
162 ) -> FederationResult<InferenceResponse> {
163 for middleware in &self.middlewares {
165 middleware.before_route(&mut request)?;
166 }
167
168 let mut last_error = None;
169 let mut tried_nodes = Vec::new();
170
171 for attempt in 0..=self.config.max_retries {
172 let target = match self.router.route(&request).await {
176 Ok(t) => t,
177 Err(e) => {
178 last_error = Some(e);
179 continue;
180 }
181 };
182
183 if self.circuit_breaker.is_open(&target.node_id) {
185 last_error = Some(FederationError::CircuitOpen(target.node_id.clone()));
186 tried_nodes.push(target.node_id);
187 continue;
188 }
189
190 let start = Instant::now();
192 match self.execute_on_node(&target, &request).await {
193 Ok(mut response) => {
194 let latency = start.elapsed();
195
196 self.health.report_success(&target.node_id, latency);
198 self.circuit_breaker.record_success(&target.node_id);
199 self.stats.record_success(latency, response.tokens);
200
201 for middleware in &self.middlewares {
203 middleware.after_infer(&request, &mut response)?;
204 }
205
206 return Ok(response);
207 }
208 Err(e) => {
209 self.health.report_failure(&target.node_id);
211 self.circuit_breaker.record_failure(&target.node_id);
212
213 for middleware in &self.middlewares {
215 middleware.on_error(&request, &e);
216 }
217
218 last_error = Some(e);
219 tried_nodes.push(target.node_id);
220
221 if attempt < self.config.max_retries {
222 tokio::time::sleep(Duration::from_millis(100 * (attempt as u64 + 1))).await;
224 }
225 }
226 }
227 }
228
229 self.stats.record_failure();
230 Err(last_error
231 .unwrap_or_else(|| FederationError::Internal("All retries exhausted".to_string())))
232 }
233
234 #[allow(clippy::unused_async)] async fn execute_on_node(
237 &self,
238 target: &RouteTarget,
239 _request: &InferenceRequest,
240 ) -> FederationResult<InferenceResponse> {
241 if target.endpoint.is_empty() {
245 Ok(InferenceResponse {
247 output: b"simulated output".to_vec(),
248 served_by: target.node_id.clone(),
249 latency: Duration::from_millis(50),
250 tokens: Some(10),
251 })
252 } else {
253 Ok(InferenceResponse {
256 output: b"simulated output".to_vec(),
257 served_by: target.node_id.clone(),
258 latency: Duration::from_millis(50),
259 tokens: Some(10),
260 })
261 }
262 }
263}
264
265impl GatewayTrait for FederationGateway {
266 fn infer(
267 &self,
268 request: InferenceRequest,
269 ) -> BoxFuture<'_, FederationResult<InferenceResponse>> {
270 Box::pin(async move {
271 self.stats.record_request();
272 self.execute_with_retries(request).await
273 })
274 }
275
276 fn infer_stream(
277 &self,
278 request: InferenceRequest,
279 ) -> BoxFuture<'_, FederationResult<Box<dyn TokenStream>>> {
280 Box::pin(async move {
281 self.stats.record_request();
282 self.stats.increment_streams();
283
284 let target = self.router.route(&request).await?;
286
287 let stream = FederationTokenStream::new(
289 target,
290 request,
291 Arc::clone(&self.health),
292 Arc::clone(&self.circuit_breaker),
293 );
294
295 let stream: Box<dyn TokenStream> = Box::new(stream);
296 Ok(stream)
297 })
298 }
299
300 fn stats(&self) -> GatewayStats {
301 self.stats.snapshot()
302 }
303}
304
305struct FederationTokenStream {
311 target: RouteTarget,
312 _request: InferenceRequest,
313 health: Arc<HealthChecker>,
314 circuit_breaker: Arc<CircuitBreaker>,
315 tokens_generated: u32,
316 finished: bool,
317}
318
319impl FederationTokenStream {
320 fn new(
321 target: RouteTarget,
322 request: InferenceRequest,
323 health: Arc<HealthChecker>,
324 circuit_breaker: Arc<CircuitBreaker>,
325 ) -> Self {
326 Self {
327 target,
328 _request: request,
329 health,
330 circuit_breaker,
331 tokens_generated: 0,
332 finished: false,
333 }
334 }
335}
336
337impl TokenStream for FederationTokenStream {
338 fn next_token(&mut self) -> BoxFuture<'_, Option<FederationResult<Vec<u8>>>> {
339 Box::pin(async move {
340 if self.finished {
341 return None;
342 }
343
344 self.tokens_generated += 1;
346
347 if self.tokens_generated > 10 {
348 self.finished = true;
349 self.health
350 .report_success(&self.target.node_id, Duration::from_millis(50));
351 self.circuit_breaker.record_success(&self.target.node_id);
352 return None;
353 }
354
355 Some(Ok(format!("token_{}", self.tokens_generated).into_bytes()))
356 })
357 }
358
359 fn cancel(&mut self) -> BoxFuture<'_, ()> {
360 Box::pin(async move {
361 self.finished = true;
362 })
363 }
364}
365
366pub struct GatewayBuilder {
372 config: GatewayConfig,
373 catalog: Option<Arc<ModelCatalog>>,
374 health: Option<Arc<HealthChecker>>,
375 circuit_breaker: Option<Arc<CircuitBreaker>>,
376 router: Option<Arc<Router>>,
377 middlewares: Vec<Box<dyn GatewayMiddleware>>,
378}
379
380include!("middleware.rs");
381include!("gateway_03.rs");