apr_cli/federation/
middleware.rs1
2impl GatewayBuilder {
3 pub fn new() -> Self {
4 Self {
5 config: GatewayConfig::default(),
6 catalog: None,
7 health: None,
8 circuit_breaker: None,
9 router: None,
10 middlewares: Vec::new(),
11 }
12 }
13
14 #[must_use]
15 pub fn config(mut self, config: GatewayConfig) -> Self {
16 self.config = config;
17 self
18 }
19
20 #[must_use]
21 pub fn catalog(mut self, catalog: Arc<ModelCatalog>) -> Self {
22 self.catalog = Some(catalog);
23 self
24 }
25
26 #[must_use]
27 pub fn health(mut self, health: Arc<HealthChecker>) -> Self {
28 self.health = Some(health);
29 self
30 }
31
32 #[must_use]
33 pub fn circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
34 self.circuit_breaker = Some(cb);
35 self
36 }
37
38 #[must_use]
39 pub fn router(mut self, router: Arc<Router>) -> Self {
40 self.router = Some(router);
41 self
42 }
43
44 #[must_use]
45 pub fn middleware(mut self, middleware: impl GatewayMiddleware + 'static) -> Self {
46 self.middlewares.push(Box::new(middleware));
47 self
48 }
49
50 pub fn build(self) -> FederationGateway {
51 let catalog = self
52 .catalog
53 .unwrap_or_else(|| Arc::new(ModelCatalog::new()));
54 let health = self
55 .health
56 .unwrap_or_else(|| Arc::new(HealthChecker::default()));
57 let circuit_breaker = self
58 .circuit_breaker
59 .unwrap_or_else(|| Arc::new(CircuitBreaker::default()));
60
61 let router = self.router.unwrap_or_else(|| {
62 Arc::new(Router::new(
63 super::routing::RouterConfig::default(),
64 Arc::clone(&catalog),
65 Arc::clone(&health),
66 Arc::clone(&circuit_breaker),
67 ))
68 });
69
70 let mut gateway = FederationGateway::new(self.config, router, health, circuit_breaker);
71
72 for middleware in self.middlewares {
73 gateway.middlewares.push(middleware);
74 }
75
76 gateway
77 }
78}
79
80impl Default for GatewayBuilder {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86pub struct LoggingMiddleware {
92 prefix: String,
93}
94
95impl LoggingMiddleware {
96 pub fn new(prefix: impl Into<String>) -> Self {
97 Self {
98 prefix: prefix.into(),
99 }
100 }
101}
102
103impl GatewayMiddleware for LoggingMiddleware {
104 fn before_route(&self, request: &mut InferenceRequest) -> FederationResult<()> {
105 eprintln!(
106 "[{}] Routing request {} for {:?}",
107 self.prefix, request.request_id, request.capability
108 );
109 Ok(())
110 }
111
112 fn after_infer(
113 &self,
114 request: &InferenceRequest,
115 response: &mut InferenceResponse,
116 ) -> FederationResult<()> {
117 eprintln!(
118 "[{}] Request {} served by {:?} in {:?}",
119 self.prefix, request.request_id, response.served_by, response.latency
120 );
121 Ok(())
122 }
123
124 fn on_error(&self, request: &InferenceRequest, error: &FederationError) {
125 eprintln!(
126 "[{}] Request {} failed: {}",
127 self.prefix, request.request_id, error
128 );
129 }
130}
131
132pub struct RateLimitMiddleware {
134 #[allow(dead_code)]
135 requests_per_second: u32,
136 }
138
139impl RateLimitMiddleware {
140 pub fn new(requests_per_second: u32) -> Self {
141 Self {
142 requests_per_second,
143 }
144 }
145}
146
147impl GatewayMiddleware for RateLimitMiddleware {
148 fn before_route(&self, _request: &mut InferenceRequest) -> FederationResult<()> {
149 Ok(())
152 }
153
154 fn after_infer(
155 &self,
156 _request: &InferenceRequest,
157 _response: &mut InferenceResponse,
158 ) -> FederationResult<()> {
159 Ok(())
160 }
161
162 fn on_error(&self, _request: &InferenceRequest, _error: &FederationError) {}
163}