barnacle_rs/
middleware.rs

1use axum::body::Body;
2use axum::extract::{OriginalUri, Request};
3use axum::http::request::Parts;
4use axum::http::Response;
5use axum::response::IntoResponse;
6use http_body_util::BodyExt;
7use serde::de::DeserializeOwned;
8use std::marker::PhantomData;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tower::{Layer, Service};
12use std::future::Future;
13use tracing::debug;
14use std::pin::Pin;
15
16use crate::types::{ApiKeyConfig, ResetOnSuccess, NO_KEY};
17use crate::RedisBarnacleStore;
18use crate::{
19    types::{BarnacleConfig, BarnacleContext, BarnacleKey},
20    BarnacleStore,
21};
22use crate::error::BarnacleError;
23
24/// Trait to extract the key from any payload type
25pub trait KeyExtractable {
26    fn extract_key(&self, request_parts: &Parts) -> BarnacleKey;
27}
28
29/// Error type for BarnacleLayerBuilder
30#[derive(Debug, thiserror::Error)]
31pub enum BarnacleLayerBuilderError {
32    #[error("Missing store")]
33    MissingStore,
34    #[error("Missing config")]
35    MissingConfig,
36}
37
38/// Builder for BarnacleLayer
39pub struct BarnacleLayerBuilder<T = (), S = RedisBarnacleStore, State = (), E = BarnacleError, V = (), M = ()> {
40    store: Option<S>,
41    config: Option<BarnacleConfig>,
42    state: Option<State>,
43    api_key_validator: Option<V>,
44    api_key_middleware_config: Option<ApiKeyConfig>,
45    request_modifier: Option<M>,
46    _phantom: PhantomData<(T, E)>,
47}
48
49impl<T, S, State, E, V, M> BarnacleLayerBuilder<T, S, State, E, V, M>
50where
51    S: BarnacleStore + 'static,
52    State: Clone +Send + Sync + 'static,
53    V: Clone + Send + Sync + 'static,
54    M: Clone + Send + Sync + 'static,
55{
56    pub fn with_store(mut self, store: S) -> Self {
57        self.store = Some(store);
58        self
59    }
60    pub fn with_config(mut self, config: BarnacleConfig) -> Self {
61        self.config = Some(config);
62        self
63    }
64    pub fn with_state(mut self, state: State) -> Self {
65        self.state = Some(state);
66        self
67    }
68    pub fn with_api_key_validator(mut self, validator: V) -> Self {
69        self.api_key_validator = Some(validator);
70        self
71    }
72    pub fn with_api_key_middleware_config(mut self, config: ApiKeyConfig) -> Self {
73        self.api_key_middleware_config = Some(config);
74        self
75    }
76    pub fn with_request_modifier(mut self, modifier: M) -> Self {
77        self.request_modifier = Some(modifier);
78        self
79    }
80    pub fn build(self) -> Result<BarnacleLayer<T, S, State, E, V, M>, BarnacleLayerBuilderError> {
81        Ok(BarnacleLayer {
82            store: self.store.ok_or(BarnacleLayerBuilderError::MissingStore)?,
83            config: self.config.ok_or(BarnacleLayerBuilderError::MissingConfig)?,
84            state: self.state,
85            api_key_validator: self.api_key_validator,
86            api_key_middleware_config: self.api_key_middleware_config,
87            request_modifier: self.request_modifier,
88            _phantom: PhantomData,
89        })
90    }
91}
92
93/// Generic rate limiting and API key layer
94pub struct BarnacleLayer<T = (), S = RedisBarnacleStore, State = (), E = BarnacleError, V = (), M = ()> {
95    store: S,
96    config: BarnacleConfig,
97    state: Option<State>,
98    api_key_validator: Option<V>,
99    api_key_middleware_config: Option<ApiKeyConfig>,
100    request_modifier: Option<M>,
101    _phantom: PhantomData<(T, E)>,
102}
103
104impl<T, S, State, E, V, M> Clone for BarnacleLayer<T, S, State, E, V, M>
105where
106    S: Clone + BarnacleStore + 'static,
107    State: Clone + Send + Sync + 'static,
108    V: Clone + Send + Sync + 'static,
109    M: Clone + Send + Sync + 'static,
110{
111    fn clone(&self) -> Self {
112        Self {
113            store: self.store.clone(),
114            config: self.config.clone(),
115            state: self.state.clone(),
116            api_key_validator: self.api_key_validator.clone(),
117            api_key_middleware_config: self.api_key_middleware_config.clone(),
118            request_modifier: self.request_modifier.clone(),
119            _phantom: PhantomData,
120        }
121    }
122}
123
124impl<T, S, State, E, V, M> BarnacleLayer<T, S, State, E, V, M>
125where
126    S: BarnacleStore + 'static,
127    State: Send + Sync + 'static,
128    V: Clone + Send + Sync + 'static,
129    M: Clone + Send + Sync + 'static,
130{
131    pub fn builder() -> BarnacleLayerBuilder<T, S, State, E, V, M> {
132        BarnacleLayerBuilder {
133            store: None,
134            config: None,
135            state: None,
136            api_key_validator: None,
137            api_key_middleware_config: None,
138            request_modifier: None,
139            _phantom: PhantomData,
140        }
141    }
142}
143
144impl<Inner, T, S, State, E, V, M> Layer<Inner> for BarnacleLayer<T, S, State, E, V, M>
145where
146    T: DeserializeOwned + KeyExtractable + Send + 'static,
147    S: Clone + BarnacleStore + 'static,
148    State: Clone + Send + Sync + 'static,
149    E: IntoResponse + Send + Sync + 'static,
150    Inner: Clone,
151    V: Clone + Send + Sync + 'static,
152    M: Clone + Send + Sync + 'static,
153{
154    type Service = BarnacleMiddleware<Inner, T, S, State, E, V, M>;
155    fn layer(&self, inner: Inner) -> Self::Service {
156        BarnacleMiddleware {
157            inner,
158            store: self.store.clone(),
159            config: self.config.clone(),
160            state: self.state.clone(),
161            api_key_validator: self.api_key_validator.clone(),
162            api_key_config: self.api_key_middleware_config.clone(),
163            request_modifier: self.request_modifier.clone(),
164            _phantom: PhantomData,
165        }
166    }
167}
168
169/// Helper function to handle rate limit reset logic
170async fn handle_rate_limit_reset<S>(
171    store: &S,
172    config: &BarnacleConfig,
173    context: &BarnacleContext,
174    status_code: u16,
175    is_fallback: bool,
176) where
177    S: BarnacleStore + 'static,
178{
179    if config.reset_on_success == ResetOnSuccess::Not {
180        return;
181    }
182
183    let key_type = if is_fallback { "fallback key" } else { "key" };
184    if !config.is_success_status(status_code) {
185        debug!(
186            "Not resetting rate limit for {} {:?} due to error status: {}",
187            key_type,
188            context.key,
189            status_code
190        );
191        return;
192    }
193
194    let mut contexts = vec![context.clone()];
195
196    if let ResetOnSuccess::Multiple(_, extra_contexts) = &config.reset_on_success {
197        contexts.extend(extra_contexts.iter().cloned());
198    }
199
200    for ctx in contexts.iter_mut() {
201        if ctx.key == BarnacleKey::Custom(NO_KEY.to_string()) {
202            ctx.key = context.key.clone();
203        }
204        match store.reset(ctx).await {
205            Ok(_) => debug!(
206                "Rate limit reset for {} {:?} after successful request (status: {}) path: {}",
207                key_type,
208                ctx.key,
209                status_code,
210                ctx.path
211            ),
212            Err(e) => debug!(
213                "Failed to reset rate limit for {} {:?}: {} path: {}",
214                key_type,
215                ctx.key,
216                e,
217                ctx.path
218            ),
219        }
220    }
221}
222
223fn get_fallback_key_common(
224    extensions: &axum::http::Extensions,
225    headers: &axum::http::HeaderMap,
226    path: &str,
227    method: &axum::http::Method,
228) -> BarnacleKey {
229    // 1. Try ConnectInfo<SocketAddr> (only available in full Request)
230    if let Some(addr) = extensions.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>() {
231        debug!("IP via ConnectInfo: {}", addr.ip());
232        return BarnacleKey::Ip(addr.ip().to_string());
233    }
234
235    // 2. Try X-Forwarded-For header
236    if let Some(forwarded) = headers.get("x-forwarded-for") {
237        if let Ok(forwarded) = forwarded.to_str() {
238            let ip = forwarded.split(',').next().unwrap_or("").trim();
239            if !ip.is_empty() && ip != "unknown" {
240                return BarnacleKey::Ip(ip.to_string());
241            }
242        }
243    }
244
245    // 3. Try X-Real-IP header
246    if let Some(real_ip) = headers.get("x-real-ip") {
247        if let Ok(real_ip) = real_ip.to_str() {
248            if !real_ip.is_empty() && real_ip != "unknown" {
249                return BarnacleKey::Ip(real_ip.to_string());
250            }
251        }
252    }
253
254    // 4. For local requests, use a unique identifier based on route + method
255    let method_str = method.as_str();
256    let local_key = format!("local:{}:{}", method_str, path);
257    debug!("Local key: {}", local_key);
258    BarnacleKey::Ip(local_key)
259}
260
261
262
263/// The actual middleware that handles payload-based key extraction
264pub struct BarnacleMiddleware<Inner, T, S, State = (), E = BarnacleError, V = (), M = ()> {
265    inner: Inner,
266    store: S,
267    config: BarnacleConfig,
268    state: Option<State>,
269    api_key_validator: Option<V>,
270    api_key_config: Option<ApiKeyConfig>,
271    request_modifier: Option<M>,
272    _phantom: PhantomData<(T, E)>,
273}
274
275impl<Inner, T, S, State, E, V, M> Clone for BarnacleMiddleware<Inner, T, S, State, E, V, M>
276where
277    Inner: Clone,
278    S: Clone + BarnacleStore + 'static,
279    State: Clone + Send + Sync + 'static,
280    V: Clone + Send + Sync + 'static,
281    M: Clone + Send + Sync + 'static,
282{
283    fn clone(&self) -> Self {
284        Self {
285            inner: self.inner.clone(),
286            store: self.store.clone(),
287            config: self.config.clone(),
288            state: self.state.clone(),
289            api_key_validator: self.api_key_validator.clone(),
290            api_key_config: self.api_key_config.clone(),
291            request_modifier: self.request_modifier.clone(),
292            _phantom: PhantomData,
293        }
294    }
295}
296
297// --- ValidatorCall trait for owned types ---
298pub trait ValidatorCall<T, S, State, E> {
299    fn call(
300        &self,
301        api_key: T,
302        api_key_config: S,
303        parts: Arc<Parts>,
304        state: State,
305    ) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>>;
306}
307
308// Implementation for closures
309impl<F, Fut, T, S, State, E> ValidatorCall<T, S, State, E> for F
310where
311    F: Fn(T, S, Arc<Parts>, State) -> Fut + Send + Sync,
312    Fut: Future<Output = Result<(), E>> + Send + 'static,
313    T: Send + 'static,
314    S: Send + 'static,
315    State: Send + 'static,
316    E: Send + 'static,
317{
318    fn call(
319        &self,
320        api_key: T,
321        api_key_config: S,
322        parts: Arc<Parts>,
323        state: State,
324    ) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>> {
325        Box::pin((self)(api_key, api_key_config, parts, state))
326    }
327}
328
329// Implementation for ()
330impl<T, S, State, E> ValidatorCall<T, S, State, E> for () {
331    fn call(
332        &self,
333        _api_key: T,
334        _api_key_config: S,
335        _parts: Arc<Parts>,
336        _state: State,
337    ) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>> {
338        Box::pin(async { Ok(()) })
339    }
340}
341
342// --- RequestModifier trait for owned types ---
343pub trait RequestModifier<Parts, State, E> {
344    fn modify(
345        &self,
346        parts: Parts,
347        state: State,
348    ) -> Pin<Box<dyn Future<Output = Result<Parts, E>> + Send>>;
349}
350
351// Blanket impl to require Send for Parts
352impl<Parts, State, E> RequestModifier<Parts, State, E> for ()
353where
354    Parts: Send + 'static,
355{
356    fn modify(
357        &self,
358        parts: Parts,
359        _state: State,
360    ) -> Pin<Box<dyn Future<Output = Result<Parts, E>> + Send>> {
361        Box::pin(async { Ok(parts) })
362    }
363}
364
365// Implementation for closures
366impl<F, Fut, Parts, State, E> RequestModifier<Parts, State, E> for F
367where
368    F: Fn(Parts, State) -> Fut + Send + Sync,
369    Fut: Future<Output = Result<Parts, E>> + Send + 'static,
370    Parts: Send + 'static,
371    State: Send + 'static,
372    E: Send + 'static,
373{
374    fn modify(
375        &self,
376        parts: Parts,
377        state: State,
378    ) -> Pin<Box<dyn Future<Output = Result<Parts, E>> + Send>> {
379        Box::pin((self)(parts, state))
380    }
381}
382
383// Provide a KeyExtractable impl for ()
384impl KeyExtractable for () {
385    fn extract_key(&self, request_parts: &Parts) -> BarnacleKey {
386        // Use fallback key logic
387        let extensions = &request_parts.extensions;
388        let headers = &request_parts.headers;
389        let path = request_parts.uri.path();
390        let method = &request_parts.method;
391        get_fallback_key_common(extensions, headers, path, method)
392    }
393}
394
395impl<Inner, B, T, S, State, E, V, M> Service<Request<B>> for BarnacleMiddleware<Inner, T, S, State, E, V, M>
396where
397    Inner: Service<Request<axum::body::Body>, Response = Response<Body>> + Clone + Send + 'static,
398    Inner::Future: Send + 'static,
399    B: axum::body::HttpBody + Send + 'static,
400    B::Data: Send,
401    B::Error: std::error::Error + Send + Sync,
402    S: Clone + BarnacleStore + 'static,
403    State: Clone + Send + Sync + 'static,
404    T: KeyExtractable + DeserializeOwned + Send + 'static,
405    E: IntoResponse + Send + Sync + 'static + From<BarnacleError>,
406    V: ValidatorCall<String, ApiKeyConfig, State, E> + Clone + Send + Sync + 'static,
407    M: RequestModifier<Parts, State, E> + Clone + Send + Sync + 'static,
408{
409    type Response = Inner::Response;
410    type Error = Inner::Error;
411    type Future = std::pin::Pin<
412        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
413    >;
414
415    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
416        self.inner.poll_ready(cx)
417    }
418
419    fn call(&mut self, req: Request<B>) -> Self::Future {
420        debug!("[middleware.rs] Unified BarnacleMiddleware::call invoked");
421        let mut inner = self.inner.clone();
422        let store = self.store.clone();
423        let config = self.config.clone();
424        let state = self.state.clone();
425        let validator_state = state.clone(); // Separate clone for validator to avoid move issues
426        let api_key_validator = self.api_key_validator.clone();
427        let api_key_config = self.api_key_config.clone();
428        let request_modifier = self.request_modifier.clone();
429        Box::pin(async move {
430            debug!("[middleware.rs] Entered async block in call");
431            let current_path = req
432                .extensions()
433                .get::<OriginalUri>()
434                .map(|original_url| original_url.path().to_owned())
435                .unwrap_or(req.uri().path().to_owned());
436            
437            debug!("[middleware.rs] current_path: {}", current_path);
438            let (parts, body) = req.into_parts();
439            debug!("[middleware.rs] Request parts and body split");
440
441            // API key validation (if configured)
442            let mut api_key_used: Option<String> = None;
443            let api_key_config = api_key_config.unwrap_or_default();
444            let api_key = parts.headers.get(api_key_config.header_name.as_str()).and_then(|h| h.to_str().ok()).unwrap_or("");
445            debug!("[middleware.rs] About to call validator with key: '{}'", api_key);
446
447            let validation_result = if let Some(validator) = api_key_validator.as_ref() {
448                let is_stateless_validator = std::any::TypeId::of::<V>() == std::any::TypeId::of::<()>();
449                let is_unit_state = std::any::TypeId::of::<State>() == std::any::TypeId::of::<()>();
450                if is_stateless_validator && is_unit_state {
451                    // Both validator and state are (), safe to call with zeroed State
452                    validator.call(api_key.to_string(), api_key_config, Arc::new(parts.clone()), unsafe { std::mem::zeroed() }).await
453                } else {
454                    match validator_state {
455                        Some(validator_state) => {
456                            validator.call(api_key.to_string(), api_key_config, Arc::new(parts.clone()), validator_state).await
457                        }
458                        None => {
459                            // Return a more appropriate error for missing validator state
460                            Err(E::from(BarnacleError::custom("Barnacle: API key validator requires state, but none was provided. Use with_state() or use () for stateless validators.", None)))
461                        }
462                    }
463                }
464            } else {
465                Ok(())
466            };
467            match validation_result {
468                Ok(_) => {
469                    debug!("[middleware.rs] Validator returned Ok for: '{}'", api_key);
470                    if !api_key.is_empty() {
471                        api_key_used = Some(api_key.to_string());
472                    }
473                },
474                Err(e) => {
475                    debug!("[middleware.rs] Validator returned Err");
476                    return Ok(e.into_response());
477                }
478            }
479
480            // Apply request modifier after validation (if configured)
481            let modified_parts = if let Some(modifier) = request_modifier.as_ref() {
482                // Clone state for modifier to avoid move issues
483                let modifier_state = state.clone();
484                if let Some(modifier_state) = modifier_state {
485                    modifier.modify(parts, modifier_state).await
486                } else {
487                    Err(E::from(BarnacleError::custom("Barnacle: Request modifier requires state, but none was provided.", None)))
488                }
489            } else {
490                Ok(parts)
491            };
492            let parts = match modified_parts {
493                Ok(modified_parts) => {
494                    debug!("[middleware.rs] Request modifier returned Ok");
495                    modified_parts
496                },
497                Err(e) => {
498                    debug!("[middleware.rs] Request modifier returned Err");
499                    return Ok(e.into_response());
500                }
501            };
502
503            // Unified logic: always try to extract key from body (for T=(), uses fallback)
504            let (rate_limit_context, body_bytes) = match body.collect().await {
505                Ok(collected) => {
506                    let bytes = collected.to_bytes();
507                    let (key, used_fallback) = if let Some(ref api_key) = api_key_used {
508                        // Use API key as the rate limiting key
509                        (BarnacleKey::ApiKey(api_key.clone()), false)
510                    } else {
511                        match serde_json::from_slice::<T>(&bytes) {
512                            Ok(payload) => (payload.extract_key(&parts), false),
513                            Err(_) => (
514                                get_fallback_key_common(
515                                    &parts.extensions,
516                                    &parts.headers,
517                                    &current_path,
518                                    &parts.method,
519                                ),
520                                true,
521                            ),
522                        }
523                    };
524                    let context = BarnacleContext {
525                        key,
526                        path: current_path.clone(),
527                        method: parts.method.as_str().to_string(),
528                    };
529                    if used_fallback {
530                        debug!("[middleware.rs] (unified) Using fallback key for rate limiting");
531                    } else if api_key_used.is_some() {
532                        debug!("[middleware.rs] (unified) Using API key for rate limiting");
533                    } else {
534                        debug!("[middleware.rs] (unified) Extracted key from payload for rate limiting");
535                    }
536                    (context, Some(bytes))
537                }
538                Err(_) => {
539                    debug!("[middleware.rs] (unified) Failed to collect body, using fallback key");
540                    let fallback_key = get_fallback_key_common(
541                        &parts.extensions,
542                        &parts.headers,
543                        &current_path,
544                        &parts.method,
545                    );
546                    let context = BarnacleContext {
547                        key: fallback_key,
548                        path: current_path.clone(),
549                        method: parts.method.as_str().to_string(),
550                    };
551                    (context, None)
552                }
553            };
554            debug!("[middleware.rs] (unified) About to increment rate limit for context: {:?}", rate_limit_context);
555            tracing::debug!("[middleware.rs] Rate limit increment: api_key={:?}, path={}, method={}", rate_limit_context.key, rate_limit_context.path, rate_limit_context.method);
556            let result = match store.increment(&rate_limit_context, &config).await {
557                Ok(result) => result,
558                Err(e) => {
559                    debug!("[middleware.rs] (unified) Rate limit store error: {}", e);
560                    return Ok(E::from(e).into_response());
561                }
562            };
563            debug!("[middleware.rs] (unified) Rate limit check passed for key: {:?}, remaining: {}, retry_after: {:?}", rate_limit_context.key, result.remaining, result.retry_after);
564            let reconstructed_body = match body_bytes {
565                Some(bytes) => axum::body::Body::from(bytes),
566                None => axum::body::Body::empty(),
567            };
568            let new_req = Request::from_parts(parts, reconstructed_body);
569            debug!("[middleware.rs] (unified) Calling inner service");
570            let response = inner.call(new_req).await?;
571            // Add rate limit headers to successful response
572            let mut response_with_headers = response;
573            {
574                let headers = response_with_headers.headers_mut();
575                if let Ok(remaining_header) = result.remaining.to_string().parse() {
576                    headers.insert("X-RateLimit-Remaining", remaining_header);
577                    debug!("[middleware.rs] (unified) Added X-RateLimit-Remaining: {}", result.remaining);
578                }
579                if let Ok(limit_header) = config.max_requests.to_string().parse() {
580                    headers.insert("X-RateLimit-Limit", limit_header);
581                    debug!("[middleware.rs] (unified) Added X-RateLimit-Limit: {}", config.max_requests);
582                }
583                if let Some(retry_after) = result.retry_after {
584                    if let Ok(reset_header) = retry_after.as_secs().to_string().parse() {
585                        headers.insert("X-RateLimit-Reset", reset_header);
586                        debug!("[middleware.rs] (unified) Added X-RateLimit-Reset: {}", retry_after.as_secs());
587                    }
588                }
589            }
590            handle_rate_limit_reset(
591                &store,
592                &config,
593                &rate_limit_context,
594                response_with_headers.status().as_u16(),
595                false,
596            )
597            .await;
598            debug!("[middleware.rs] (unified) Returning final response");
599            Ok(response_with_headers)
600        })
601    }
602}