1use axum::body::Body;
2use axum::extract::{OriginalUri, Request};
3use axum::http::request::Parts;
4use axum::http::Response;
5use axum::response::IntoResponse;
6use http_body_util::BodyExt;
7use serde::de::DeserializeOwned;
8use std::marker::PhantomData;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tower::{Layer, Service};
12use std::future::Future;
13use tracing::debug;
14use std::pin::Pin;
15
16use crate::types::{ApiKeyConfig, ResetOnSuccess, NO_KEY};
17use crate::RedisBarnacleStore;
18use crate::{
19 types::{BarnacleConfig, BarnacleContext, BarnacleKey},
20 BarnacleStore,
21};
22use crate::error::BarnacleError;
23
24pub trait KeyExtractable {
26 fn extract_key(&self, request_parts: &Parts) -> BarnacleKey;
27}
28
29#[derive(Debug, thiserror::Error)]
31pub enum BarnacleLayerBuilderError {
32 #[error("Missing store")]
33 MissingStore,
34 #[error("Missing config")]
35 MissingConfig,
36}
37
38pub struct BarnacleLayerBuilder<T = (), S = RedisBarnacleStore, State = (), E = BarnacleError, V = (), M = ()> {
40 store: Option<S>,
41 config: Option<BarnacleConfig>,
42 state: Option<State>,
43 api_key_validator: Option<V>,
44 api_key_middleware_config: Option<ApiKeyConfig>,
45 request_modifier: Option<M>,
46 _phantom: PhantomData<(T, E)>,
47}
48
49impl<T, S, State, E, V, M> BarnacleLayerBuilder<T, S, State, E, V, M>
50where
51 S: BarnacleStore + 'static,
52 State: Clone +Send + Sync + 'static,
53 V: Clone + Send + Sync + 'static,
54 M: Clone + Send + Sync + 'static,
55{
56 pub fn with_store(mut self, store: S) -> Self {
57 self.store = Some(store);
58 self
59 }
60 pub fn with_config(mut self, config: BarnacleConfig) -> Self {
61 self.config = Some(config);
62 self
63 }
64 pub fn with_state(mut self, state: State) -> Self {
65 self.state = Some(state);
66 self
67 }
68 pub fn with_api_key_validator(mut self, validator: V) -> Self {
69 self.api_key_validator = Some(validator);
70 self
71 }
72 pub fn with_api_key_middleware_config(mut self, config: ApiKeyConfig) -> Self {
73 self.api_key_middleware_config = Some(config);
74 self
75 }
76 pub fn with_request_modifier(mut self, modifier: M) -> Self {
77 self.request_modifier = Some(modifier);
78 self
79 }
80 pub fn build(self) -> Result<BarnacleLayer<T, S, State, E, V, M>, BarnacleLayerBuilderError> {
81 Ok(BarnacleLayer {
82 store: self.store.ok_or(BarnacleLayerBuilderError::MissingStore)?,
83 config: self.config.ok_or(BarnacleLayerBuilderError::MissingConfig)?,
84 state: self.state,
85 api_key_validator: self.api_key_validator,
86 api_key_middleware_config: self.api_key_middleware_config,
87 request_modifier: self.request_modifier,
88 _phantom: PhantomData,
89 })
90 }
91}
92
93pub struct BarnacleLayer<T = (), S = RedisBarnacleStore, State = (), E = BarnacleError, V = (), M = ()> {
95 store: S,
96 config: BarnacleConfig,
97 state: Option<State>,
98 api_key_validator: Option<V>,
99 api_key_middleware_config: Option<ApiKeyConfig>,
100 request_modifier: Option<M>,
101 _phantom: PhantomData<(T, E)>,
102}
103
104impl<T, S, State, E, V, M> Clone for BarnacleLayer<T, S, State, E, V, M>
105where
106 S: Clone + BarnacleStore + 'static,
107 State: Clone + Send + Sync + 'static,
108 V: Clone + Send + Sync + 'static,
109 M: Clone + Send + Sync + 'static,
110{
111 fn clone(&self) -> Self {
112 Self {
113 store: self.store.clone(),
114 config: self.config.clone(),
115 state: self.state.clone(),
116 api_key_validator: self.api_key_validator.clone(),
117 api_key_middleware_config: self.api_key_middleware_config.clone(),
118 request_modifier: self.request_modifier.clone(),
119 _phantom: PhantomData,
120 }
121 }
122}
123
124impl<T, S, State, E, V, M> BarnacleLayer<T, S, State, E, V, M>
125where
126 S: BarnacleStore + 'static,
127 State: Send + Sync + 'static,
128 V: Clone + Send + Sync + 'static,
129 M: Clone + Send + Sync + 'static,
130{
131 pub fn builder() -> BarnacleLayerBuilder<T, S, State, E, V, M> {
132 BarnacleLayerBuilder {
133 store: None,
134 config: None,
135 state: None,
136 api_key_validator: None,
137 api_key_middleware_config: None,
138 request_modifier: None,
139 _phantom: PhantomData,
140 }
141 }
142}
143
144impl<Inner, T, S, State, E, V, M> Layer<Inner> for BarnacleLayer<T, S, State, E, V, M>
145where
146 T: DeserializeOwned + KeyExtractable + Send + 'static,
147 S: Clone + BarnacleStore + 'static,
148 State: Clone + Send + Sync + 'static,
149 E: IntoResponse + Send + Sync + 'static,
150 Inner: Clone,
151 V: Clone + Send + Sync + 'static,
152 M: Clone + Send + Sync + 'static,
153{
154 type Service = BarnacleMiddleware<Inner, T, S, State, E, V, M>;
155 fn layer(&self, inner: Inner) -> Self::Service {
156 BarnacleMiddleware {
157 inner,
158 store: self.store.clone(),
159 config: self.config.clone(),
160 state: self.state.clone(),
161 api_key_validator: self.api_key_validator.clone(),
162 api_key_config: self.api_key_middleware_config.clone(),
163 request_modifier: self.request_modifier.clone(),
164 _phantom: PhantomData,
165 }
166 }
167}
168
169async fn handle_rate_limit_reset<S>(
171 store: &S,
172 config: &BarnacleConfig,
173 context: &BarnacleContext,
174 status_code: u16,
175 is_fallback: bool,
176) where
177 S: BarnacleStore + 'static,
178{
179 if config.reset_on_success == ResetOnSuccess::Not {
180 return;
181 }
182
183 let key_type = if is_fallback { "fallback key" } else { "key" };
184 if !config.is_success_status(status_code) {
185 debug!(
186 "Not resetting rate limit for {} {:?} due to error status: {}",
187 key_type,
188 context.key,
189 status_code
190 );
191 return;
192 }
193
194 let mut contexts = vec![context.clone()];
195
196 if let ResetOnSuccess::Multiple(_, extra_contexts) = &config.reset_on_success {
197 contexts.extend(extra_contexts.iter().cloned());
198 }
199
200 for ctx in contexts.iter_mut() {
201 if ctx.key == BarnacleKey::Custom(NO_KEY.to_string()) {
202 ctx.key = context.key.clone();
203 }
204 match store.reset(ctx).await {
205 Ok(_) => debug!(
206 "Rate limit reset for {} {:?} after successful request (status: {}) path: {}",
207 key_type,
208 ctx.key,
209 status_code,
210 ctx.path
211 ),
212 Err(e) => debug!(
213 "Failed to reset rate limit for {} {:?}: {} path: {}",
214 key_type,
215 ctx.key,
216 e,
217 ctx.path
218 ),
219 }
220 }
221}
222
223fn get_fallback_key_common(
224 extensions: &axum::http::Extensions,
225 headers: &axum::http::HeaderMap,
226 path: &str,
227 method: &axum::http::Method,
228) -> BarnacleKey {
229 if let Some(addr) = extensions.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>() {
231 debug!("IP via ConnectInfo: {}", addr.ip());
232 return BarnacleKey::Ip(addr.ip().to_string());
233 }
234
235 if let Some(forwarded) = headers.get("x-forwarded-for") {
237 if let Ok(forwarded) = forwarded.to_str() {
238 let ip = forwarded.split(',').next().unwrap_or("").trim();
239 if !ip.is_empty() && ip != "unknown" {
240 return BarnacleKey::Ip(ip.to_string());
241 }
242 }
243 }
244
245 if let Some(real_ip) = headers.get("x-real-ip") {
247 if let Ok(real_ip) = real_ip.to_str() {
248 if !real_ip.is_empty() && real_ip != "unknown" {
249 return BarnacleKey::Ip(real_ip.to_string());
250 }
251 }
252 }
253
254 let method_str = method.as_str();
256 let local_key = format!("local:{}:{}", method_str, path);
257 debug!("Local key: {}", local_key);
258 BarnacleKey::Ip(local_key)
259}
260
261
262
263pub struct BarnacleMiddleware<Inner, T, S, State = (), E = BarnacleError, V = (), M = ()> {
265 inner: Inner,
266 store: S,
267 config: BarnacleConfig,
268 state: Option<State>,
269 api_key_validator: Option<V>,
270 api_key_config: Option<ApiKeyConfig>,
271 request_modifier: Option<M>,
272 _phantom: PhantomData<(T, E)>,
273}
274
275impl<Inner, T, S, State, E, V, M> Clone for BarnacleMiddleware<Inner, T, S, State, E, V, M>
276where
277 Inner: Clone,
278 S: Clone + BarnacleStore + 'static,
279 State: Clone + Send + Sync + 'static,
280 V: Clone + Send + Sync + 'static,
281 M: Clone + Send + Sync + 'static,
282{
283 fn clone(&self) -> Self {
284 Self {
285 inner: self.inner.clone(),
286 store: self.store.clone(),
287 config: self.config.clone(),
288 state: self.state.clone(),
289 api_key_validator: self.api_key_validator.clone(),
290 api_key_config: self.api_key_config.clone(),
291 request_modifier: self.request_modifier.clone(),
292 _phantom: PhantomData,
293 }
294 }
295}
296
297pub trait ValidatorCall<T, S, State, E> {
299 fn call(
300 &self,
301 api_key: T,
302 api_key_config: S,
303 parts: Arc<Parts>,
304 state: State,
305 ) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>>;
306}
307
308impl<F, Fut, T, S, State, E> ValidatorCall<T, S, State, E> for F
310where
311 F: Fn(T, S, Arc<Parts>, State) -> Fut + Send + Sync,
312 Fut: Future<Output = Result<(), E>> + Send + 'static,
313 T: Send + 'static,
314 S: Send + 'static,
315 State: Send + 'static,
316 E: Send + 'static,
317{
318 fn call(
319 &self,
320 api_key: T,
321 api_key_config: S,
322 parts: Arc<Parts>,
323 state: State,
324 ) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>> {
325 Box::pin((self)(api_key, api_key_config, parts, state))
326 }
327}
328
329impl<T, S, State, E> ValidatorCall<T, S, State, E> for () {
331 fn call(
332 &self,
333 _api_key: T,
334 _api_key_config: S,
335 _parts: Arc<Parts>,
336 _state: State,
337 ) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>> {
338 Box::pin(async { Ok(()) })
339 }
340}
341
342pub trait RequestModifier<Parts, State, E> {
344 fn modify(
345 &self,
346 parts: Parts,
347 state: State,
348 ) -> Pin<Box<dyn Future<Output = Result<Parts, E>> + Send>>;
349}
350
351impl<Parts, State, E> RequestModifier<Parts, State, E> for ()
353where
354 Parts: Send + 'static,
355{
356 fn modify(
357 &self,
358 parts: Parts,
359 _state: State,
360 ) -> Pin<Box<dyn Future<Output = Result<Parts, E>> + Send>> {
361 Box::pin(async { Ok(parts) })
362 }
363}
364
365impl<F, Fut, Parts, State, E> RequestModifier<Parts, State, E> for F
367where
368 F: Fn(Parts, State) -> Fut + Send + Sync,
369 Fut: Future<Output = Result<Parts, E>> + Send + 'static,
370 Parts: Send + 'static,
371 State: Send + 'static,
372 E: Send + 'static,
373{
374 fn modify(
375 &self,
376 parts: Parts,
377 state: State,
378 ) -> Pin<Box<dyn Future<Output = Result<Parts, E>> + Send>> {
379 Box::pin((self)(parts, state))
380 }
381}
382
383impl KeyExtractable for () {
385 fn extract_key(&self, request_parts: &Parts) -> BarnacleKey {
386 let extensions = &request_parts.extensions;
388 let headers = &request_parts.headers;
389 let path = request_parts.uri.path();
390 let method = &request_parts.method;
391 get_fallback_key_common(extensions, headers, path, method)
392 }
393}
394
395impl<Inner, B, T, S, State, E, V, M> Service<Request<B>> for BarnacleMiddleware<Inner, T, S, State, E, V, M>
396where
397 Inner: Service<Request<axum::body::Body>, Response = Response<Body>> + Clone + Send + 'static,
398 Inner::Future: Send + 'static,
399 B: axum::body::HttpBody + Send + 'static,
400 B::Data: Send,
401 B::Error: std::error::Error + Send + Sync,
402 S: Clone + BarnacleStore + 'static,
403 State: Clone + Send + Sync + 'static,
404 T: KeyExtractable + DeserializeOwned + Send + 'static,
405 E: IntoResponse + Send + Sync + 'static + From<BarnacleError>,
406 V: ValidatorCall<String, ApiKeyConfig, State, E> + Clone + Send + Sync + 'static,
407 M: RequestModifier<Parts, State, E> + Clone + Send + Sync + 'static,
408{
409 type Response = Inner::Response;
410 type Error = Inner::Error;
411 type Future = std::pin::Pin<
412 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
413 >;
414
415 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
416 self.inner.poll_ready(cx)
417 }
418
419 fn call(&mut self, req: Request<B>) -> Self::Future {
420 debug!("[middleware.rs] Unified BarnacleMiddleware::call invoked");
421 let mut inner = self.inner.clone();
422 let store = self.store.clone();
423 let config = self.config.clone();
424 let state = self.state.clone();
425 let validator_state = state.clone(); let api_key_validator = self.api_key_validator.clone();
427 let api_key_config = self.api_key_config.clone();
428 let request_modifier = self.request_modifier.clone();
429 Box::pin(async move {
430 debug!("[middleware.rs] Entered async block in call");
431 let current_path = req
432 .extensions()
433 .get::<OriginalUri>()
434 .map(|original_url| original_url.path().to_owned())
435 .unwrap_or(req.uri().path().to_owned());
436
437 debug!("[middleware.rs] current_path: {}", current_path);
438 let (parts, body) = req.into_parts();
439 debug!("[middleware.rs] Request parts and body split");
440
441 let mut api_key_used: Option<String> = None;
443 let api_key_config = api_key_config.unwrap_or_default();
444 let api_key = parts.headers.get(api_key_config.header_name.as_str()).and_then(|h| h.to_str().ok()).unwrap_or("");
445 debug!("[middleware.rs] About to call validator with key: '{}'", api_key);
446
447 let validation_result = if let Some(validator) = api_key_validator.as_ref() {
448 let is_stateless_validator = std::any::TypeId::of::<V>() == std::any::TypeId::of::<()>();
449 let is_unit_state = std::any::TypeId::of::<State>() == std::any::TypeId::of::<()>();
450 if is_stateless_validator && is_unit_state {
451 validator.call(api_key.to_string(), api_key_config, Arc::new(parts.clone()), unsafe { std::mem::zeroed() }).await
453 } else {
454 match validator_state {
455 Some(validator_state) => {
456 validator.call(api_key.to_string(), api_key_config, Arc::new(parts.clone()), validator_state).await
457 }
458 None => {
459 Err(E::from(BarnacleError::custom("Barnacle: API key validator requires state, but none was provided. Use with_state() or use () for stateless validators.", None)))
461 }
462 }
463 }
464 } else {
465 Ok(())
466 };
467 match validation_result {
468 Ok(_) => {
469 debug!("[middleware.rs] Validator returned Ok for: '{}'", api_key);
470 if !api_key.is_empty() {
471 api_key_used = Some(api_key.to_string());
472 }
473 },
474 Err(e) => {
475 debug!("[middleware.rs] Validator returned Err");
476 return Ok(e.into_response());
477 }
478 }
479
480 let modified_parts = if let Some(modifier) = request_modifier.as_ref() {
482 let modifier_state = state.clone();
484 if let Some(modifier_state) = modifier_state {
485 modifier.modify(parts, modifier_state).await
486 } else {
487 Err(E::from(BarnacleError::custom("Barnacle: Request modifier requires state, but none was provided.", None)))
488 }
489 } else {
490 Ok(parts)
491 };
492 let parts = match modified_parts {
493 Ok(modified_parts) => {
494 debug!("[middleware.rs] Request modifier returned Ok");
495 modified_parts
496 },
497 Err(e) => {
498 debug!("[middleware.rs] Request modifier returned Err");
499 return Ok(e.into_response());
500 }
501 };
502
503 let (rate_limit_context, body_bytes) = match body.collect().await {
505 Ok(collected) => {
506 let bytes = collected.to_bytes();
507 let (key, used_fallback) = if let Some(ref api_key) = api_key_used {
508 (BarnacleKey::ApiKey(api_key.clone()), false)
510 } else {
511 match serde_json::from_slice::<T>(&bytes) {
512 Ok(payload) => (payload.extract_key(&parts), false),
513 Err(_) => (
514 get_fallback_key_common(
515 &parts.extensions,
516 &parts.headers,
517 ¤t_path,
518 &parts.method,
519 ),
520 true,
521 ),
522 }
523 };
524 let context = BarnacleContext {
525 key,
526 path: current_path.clone(),
527 method: parts.method.as_str().to_string(),
528 };
529 if used_fallback {
530 debug!("[middleware.rs] (unified) Using fallback key for rate limiting");
531 } else if api_key_used.is_some() {
532 debug!("[middleware.rs] (unified) Using API key for rate limiting");
533 } else {
534 debug!("[middleware.rs] (unified) Extracted key from payload for rate limiting");
535 }
536 (context, Some(bytes))
537 }
538 Err(_) => {
539 debug!("[middleware.rs] (unified) Failed to collect body, using fallback key");
540 let fallback_key = get_fallback_key_common(
541 &parts.extensions,
542 &parts.headers,
543 ¤t_path,
544 &parts.method,
545 );
546 let context = BarnacleContext {
547 key: fallback_key,
548 path: current_path.clone(),
549 method: parts.method.as_str().to_string(),
550 };
551 (context, None)
552 }
553 };
554 debug!("[middleware.rs] (unified) About to increment rate limit for context: {:?}", rate_limit_context);
555 tracing::debug!("[middleware.rs] Rate limit increment: api_key={:?}, path={}, method={}", rate_limit_context.key, rate_limit_context.path, rate_limit_context.method);
556 let result = match store.increment(&rate_limit_context, &config).await {
557 Ok(result) => result,
558 Err(e) => {
559 debug!("[middleware.rs] (unified) Rate limit store error: {}", e);
560 return Ok(E::from(e).into_response());
561 }
562 };
563 debug!("[middleware.rs] (unified) Rate limit check passed for key: {:?}, remaining: {}, retry_after: {:?}", rate_limit_context.key, result.remaining, result.retry_after);
564 let reconstructed_body = match body_bytes {
565 Some(bytes) => axum::body::Body::from(bytes),
566 None => axum::body::Body::empty(),
567 };
568 let new_req = Request::from_parts(parts, reconstructed_body);
569 debug!("[middleware.rs] (unified) Calling inner service");
570 let response = inner.call(new_req).await?;
571 let mut response_with_headers = response;
573 {
574 let headers = response_with_headers.headers_mut();
575 if let Ok(remaining_header) = result.remaining.to_string().parse() {
576 headers.insert("X-RateLimit-Remaining", remaining_header);
577 debug!("[middleware.rs] (unified) Added X-RateLimit-Remaining: {}", result.remaining);
578 }
579 if let Ok(limit_header) = config.max_requests.to_string().parse() {
580 headers.insert("X-RateLimit-Limit", limit_header);
581 debug!("[middleware.rs] (unified) Added X-RateLimit-Limit: {}", config.max_requests);
582 }
583 if let Some(retry_after) = result.retry_after {
584 if let Ok(reset_header) = retry_after.as_secs().to_string().parse() {
585 headers.insert("X-RateLimit-Reset", reset_header);
586 debug!("[middleware.rs] (unified) Added X-RateLimit-Reset: {}", retry_after.as_secs());
587 }
588 }
589 }
590 handle_rate_limit_reset(
591 &store,
592 &config,
593 &rate_limit_context,
594 response_with_headers.status().as_u16(),
595 false,
596 )
597 .await;
598 debug!("[middleware.rs] (unified) Returning final response");
599 Ok(response_with_headers)
600 })
601 }
602}