Skip to main content

reifydb_sub_server/
interceptor.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4//! Request-level interceptors for pre/post query execution hooks.
5//!
6//! This module provides an async interceptor mechanism that allows consumers
7//! to hook into the request lifecycle — before and after query execution.
8//! Interceptors can reject requests (for auth, rate limiting, credit checks)
9//! or observe results (for logging, billing, usage tracking).
10//!
11//! # Example
12//!
13//! ```ignore
14//! use reifydb::server;
15//!
16//! struct MyInterceptor;
17//!
18//! impl RequestInterceptor for MyInterceptor {
19//!     fn pre_execute(&self, ctx: &mut RequestContext)
20//!         -> Pin<Box<dyn Future<Output = Result<(), ExecuteError>> + Send + '_>>
21//!     {
22//!         Box::pin(async move {
23//!             if ctx.metadata.get("authorization").is_none() {
24//!                 return Err(ExecuteError::Rejected {
25//!                     code: "AUTH_REQUIRED".into(),
26//!                     message: "Missing API key".into(),
27//!                 });
28//!             }
29//!             Ok(())
30//!         })
31//!     }
32//!
33//!     fn post_execute(&self, ctx: &ResponseContext)
34//!         -> Pin<Box<dyn Future<Output = ()> + Send + '_>>
35//!     {
36//!         Box::pin(async move {
37//!             tracing::info!("query executed: {:?}", ctx.total);
38//!         })
39//!     }
40//! }
41//!
42//! let db = server::memory()
43//!     .with_request_interceptor(MyInterceptor)
44//!     .build()?;
45//! ```
46
47use std::{collections::HashMap, future::Future, panic::AssertUnwindSafe, pin::Pin, sync::Arc};
48
49use futures_util::FutureExt;
50use reifydb_core::{actors::server::Operation, metric::ExecutionMetrics};
51use reifydb_type::{
52	params::Params,
53	value::{duration::Duration, identity::IdentityId},
54};
55use tracing::error;
56
57use crate::execute::ExecuteError;
58
59/// The transport protocol used for the request.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
61pub enum Protocol {
62	#[default]
63	Http,
64	WebSocket,
65	Grpc,
66}
67
68/// Protocol-agnostic metadata extracted from the request transport layer.
69///
70/// HTTP headers, gRPC metadata, and WS auth tokens are all normalized into
71/// a string-keyed map. Header names are lowercased for consistent lookup.
72///
73/// Note: this is a single-value map — duplicate keys are overwritten
74/// (last-write-wins). Multi-valued headers (e.g. `Set-Cookie`) only
75/// retain the last value. This is intentional for simplicity; most
76/// interceptor use cases only need single-valued lookups.
77#[derive(Debug, Clone, Default)]
78pub struct RequestMetadata {
79	headers: HashMap<String, String>,
80	protocol: Protocol,
81}
82
83impl RequestMetadata {
84	/// Create empty metadata for the given protocol.
85	pub fn new(protocol: Protocol) -> Self {
86		Self {
87			headers: HashMap::new(),
88			protocol,
89		}
90	}
91
92	/// Insert a header (key is lowercased). Duplicate keys are overwritten.
93	pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) {
94		self.headers.insert(key.into().to_ascii_lowercase(), value.into());
95	}
96
97	/// Get a header value by name (case-insensitive).
98	pub fn get(&self, key: &str) -> Option<&str> {
99		self.headers.get(&key.to_ascii_lowercase()).map(|s| s.as_str())
100	}
101
102	/// Get the protocol.
103	pub fn protocol(&self) -> Protocol {
104		self.protocol
105	}
106
107	/// Get all headers.
108	pub fn headers(&self) -> &HashMap<String, String> {
109		&self.headers
110	}
111}
112
113/// Context available to pre-execute interceptors.
114///
115/// Fields are public and mutable so interceptors can override values
116/// (e.g., resolve API key → set identity, store key_id in metadata for post_execute).
117pub struct RequestContext {
118	/// The resolved identity. Pre-execute interceptors may replace this.
119	pub identity: IdentityId,
120	/// The operation type.
121	pub operation: Operation,
122	/// The RQL string being executed.
123	pub rql: String,
124	/// Query parameters.
125	pub params: Params,
126	/// Protocol-agnostic request metadata (headers, etc.).
127	pub metadata: RequestMetadata,
128}
129
130/// Context available to post-execute interceptors.
131pub struct ResponseContext {
132	/// The identity that executed the request (may have been mutated by pre_execute).
133	pub identity: IdentityId,
134	/// The operation type.
135	pub operation: Operation,
136	/// The RQL string that was executed.
137	pub rql: String,
138	/// Rich metrics for each statement in the request.
139	pub metrics: ExecutionMetrics,
140	/// Query parameters.
141	pub params: Params,
142	/// Protocol-agnostic request metadata.
143	pub metadata: RequestMetadata,
144	/// Execution result: Ok(frame_count) or Err with the error message.
145	pub result: Result<usize, String>,
146	/// Wall-clock execution duration.
147	pub total: Duration,
148	/// Compute-only duration (excludes queue wait and scheduling overhead).
149	pub compute: Duration,
150}
151
152/// Async trait for request-level interceptors.
153///
154/// Interceptors run in the tokio async context (before compute pool dispatch),
155/// so they can perform async I/O (database lookups, network calls, etc.).
156///
157/// Multiple interceptors are chained: `pre_execute` runs in registration order,
158/// `post_execute` runs in reverse order (like middleware stacks).
159pub trait RequestInterceptor: Send + Sync + 'static {
160	/// Called before query execution.
161	///
162	/// Return `Ok(())` to allow the request to proceed.
163	/// Return `Err(ExecuteError)` to reject the request.
164	/// May mutate the context (e.g., set identity from API key lookup).
165	fn pre_execute<'a>(
166		&'a self,
167		ctx: &'a mut RequestContext,
168	) -> Pin<Box<dyn Future<Output = Result<(), ExecuteError>> + Send + 'a>>;
169
170	/// Called after query execution completes (success or failure).
171	///
172	/// This is called even if the execution failed, so interceptors can
173	/// log failures and track usage regardless of outcome.
174	fn post_execute<'a>(&'a self, ctx: &'a ResponseContext) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
175}
176
177/// Ordered chain of request interceptors, cheap to clone (Arc internally).
178#[derive(Clone)]
179pub struct RequestInterceptorChain {
180	interceptors: Arc<Vec<Arc<dyn RequestInterceptor>>>,
181}
182
183impl RequestInterceptorChain {
184	pub fn new(interceptors: Vec<Arc<dyn RequestInterceptor>>) -> Self {
185		Self {
186			interceptors: Arc::new(interceptors),
187		}
188	}
189
190	pub fn empty() -> Self {
191		Self {
192			interceptors: Arc::new(Vec::new()),
193		}
194	}
195
196	pub fn is_empty(&self) -> bool {
197		self.interceptors.is_empty()
198	}
199
200	/// Run all pre_execute interceptors in order.
201	/// Stops and returns Err on first rejection.
202	pub async fn pre_execute(&self, ctx: &mut RequestContext) -> Result<(), ExecuteError> {
203		for interceptor in self.interceptors.iter() {
204			interceptor.pre_execute(ctx).await?;
205		}
206		Ok(())
207	}
208
209	/// Run all post_execute interceptors in reverse order.
210	///
211	/// If an interceptor panics, the panic is caught and logged so that
212	/// remaining interceptors still run.
213	pub async fn post_execute(&self, ctx: &ResponseContext) {
214		for interceptor in self.interceptors.iter().rev() {
215			if let Err(panic) = AssertUnwindSafe(interceptor.post_execute(ctx)).catch_unwind().await {
216				let msg = panic
217					.downcast_ref::<&str>()
218					.copied()
219					.or_else(|| panic.downcast_ref::<String>().map(|s| s.as_str()))
220					.unwrap_or("unknown panic");
221				error!("post_execute interceptor panicked: {}", msg);
222			}
223		}
224	}
225}
226
227impl Default for RequestInterceptorChain {
228	fn default() -> Self {
229		Self::empty()
230	}
231}