Skip to main content

apr_cli/federation/
middleware.rs

1
2impl GatewayBuilder {
3    pub fn new() -> Self {
4        Self {
5            config: GatewayConfig::default(),
6            catalog: None,
7            health: None,
8            circuit_breaker: None,
9            router: None,
10            middlewares: Vec::new(),
11        }
12    }
13
14    #[must_use]
15    pub fn config(mut self, config: GatewayConfig) -> Self {
16        self.config = config;
17        self
18    }
19
20    #[must_use]
21    pub fn catalog(mut self, catalog: Arc<ModelCatalog>) -> Self {
22        self.catalog = Some(catalog);
23        self
24    }
25
26    #[must_use]
27    pub fn health(mut self, health: Arc<HealthChecker>) -> Self {
28        self.health = Some(health);
29        self
30    }
31
32    #[must_use]
33    pub fn circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
34        self.circuit_breaker = Some(cb);
35        self
36    }
37
38    #[must_use]
39    pub fn router(mut self, router: Arc<Router>) -> Self {
40        self.router = Some(router);
41        self
42    }
43
44    #[must_use]
45    pub fn middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
46        self.middlewares.push(Box::new(middleware));
47        self
48    }
49
50    pub fn build(self) -> FederationGateway {
51        let catalog = self
52            .catalog
53            .unwrap_or_else(|| Arc::new(ModelCatalog::new()));
54        let health = self
55            .health
56            .unwrap_or_else(|| Arc::new(HealthChecker::default()));
57        let circuit_breaker = self
58            .circuit_breaker
59            .unwrap_or_else(|| Arc::new(CircuitBreaker::default()));
60
61        let router = self.router.unwrap_or_else(|| {
62            Arc::new(Router::new(
63                super::routing::RouterConfig::default(),
64                Arc::clone(&catalog),
65                Arc::clone(&health),
66                Arc::clone(&circuit_breaker),
67            ))
68        });
69
70        let mut gateway = FederationGateway::new(self.config, router, health, circuit_breaker);
71
72        for middleware in self.middlewares {
73            gateway.middlewares.push(middleware);
74        }
75
76        gateway
77    }
78}
79
80impl Default for GatewayBuilder {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86// ============================================================================
87// Example Middlewares
88// ============================================================================
89
90/// Logging middleware
91pub struct LoggingMiddleware {
92    prefix: String,
93}
94
95impl LoggingMiddleware {
96    pub fn new(prefix: impl Into<String>) -> Self {
97        Self {
98            prefix: prefix.into(),
99        }
100    }
101}
102
103impl GatewayMiddleware for LoggingMiddleware {
104    fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()> {
105        eprintln!(
106            "[{}] Routing request {} for {:?}",
107            self.prefix, request.request_id, request.capability
108        );
109        Ok(())
110    }
111
112    fn after_infer(
113        &self,
114        request: &InferenceRequest,
115        response: &mut InferenceResponse,
116    ) -> FederationResult<()> {
117        eprintln!(
118            "[{}] Request {} served by {:?} in {:?}",
119            self.prefix, request.request_id, response.served_by, response.latency
120        );
121        Ok(())
122    }
123
124    fn on_error(&self, request: &InferenceRequest, error: &FederationError) {
125        eprintln!(
126            "[{}] Request {} failed: {}",
127            self.prefix, request.request_id, error
128        );
129    }
130}
131
132/// Rate limiting middleware
133pub struct RateLimitMiddleware {
134    #[allow(dead_code)]
135    requests_per_second: u32,
136    // In production, would use a token bucket or sliding window
137}
138
139impl RateLimitMiddleware {
140    pub fn new(requests_per_second: u32) -> Self {
141        Self {
142            requests_per_second,
143        }
144    }
145}
146
147impl GatewayMiddleware for RateLimitMiddleware {
148    fn before_route(&self, _request: &mut InferenceRequest) -> FederationResult<()> {
149        // In production, would check rate limit and return error if exceeded
150        // For now, always allow
151        Ok(())
152    }
153
154    fn after_infer(
155        &self,
156        _request: &InferenceRequest,
157        _response: &mut InferenceResponse,
158    ) -> FederationResult<()> {
159        Ok(())
160    }
161
162    fn on_error(&self, _request: &InferenceRequest, _error: &FederationError) {}
163}