Skip to main content

reifydb_sub_server/
interceptor.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4//! Request-level interceptors for pre/post query execution hooks.
5//!
6//! This module provides an async interceptor mechanism that allows consumers
7//! to hook into the request lifecycle — before and after query execution.
8//! Interceptors can reject requests (for auth, rate limiting, credit checks)
9//! or observe results (for logging, billing, usage tracking).
10//!
11//! # Example
12//!
13//! ```ignore
14//! use reifydb::server;
15//!
16//! struct MyInterceptor;
17//!
18//! impl RequestInterceptor for MyInterceptor {
19//!     fn pre_execute(&self, ctx: &mut RequestContext)
20//!         -> Pin<Box<dyn Future<Output = Result<(), ExecuteError>> + Send + '_>>
21//!     {
22//!         Box::pin(async move {
23//!             if ctx.metadata.get("authorization").is_none() {
24//!                 return Err(ExecuteError::Rejected {
25//!                     code: "AUTH_REQUIRED".into(),
26//!                     message: "Missing API key".into(),
27//!                 });
28//!             }
29//!             Ok(())
30//!         })
31//!     }
32//!
33//!     fn post_execute(&self, ctx: &ResponseContext)
34//!         -> Pin<Box<dyn Future<Output = ()> + Send + '_>>
35//!     {
36//!         Box::pin(async move {
37//!             tracing::info!(duration_ms = ctx.duration.as_millis(), "query executed");
38//!         })
39//!     }
40//! }
41//!
42//! let db = server::memory()
43//!     .with_request_interceptor(MyInterceptor)
44//!     .build()?;
45//! ```
46
47use std::{collections::HashMap, future::Future, panic::AssertUnwindSafe, pin::Pin, sync::Arc, time::Duration};
48
49use futures_util::FutureExt;
50use reifydb_type::{params::Params, value::identity::IdentityId};
51use tracing::error;
52
53use crate::execute::ExecuteError;
54
55/// The type of database operation being executed.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub enum Operation {
58	Query,
59	Command,
60	Admin,
61	Subscribe,
62}
63
64/// The transport protocol used for the request.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
66pub enum Protocol {
67	#[default]
68	Http,
69	WebSocket,
70	Grpc,
71}
72
73/// Protocol-agnostic metadata extracted from the request transport layer.
74///
75/// HTTP headers, gRPC metadata, and WS auth tokens are all normalized into
76/// a string-keyed map. Header names are lowercased for consistent lookup.
77///
78/// Note: this is a single-value map — duplicate keys are overwritten
79/// (last-write-wins). Multi-valued headers (e.g. `Set-Cookie`) only
80/// retain the last value. This is intentional for simplicity; most
81/// interceptor use cases only need single-valued lookups.
82#[derive(Debug, Clone, Default)]
83pub struct RequestMetadata {
84	headers: HashMap<String, String>,
85	protocol: Protocol,
86}
87
88impl RequestMetadata {
89	/// Create empty metadata for the given protocol.
90	pub fn new(protocol: Protocol) -> Self {
91		Self {
92			headers: HashMap::new(),
93			protocol,
94		}
95	}
96
97	/// Insert a header (key is lowercased). Duplicate keys are overwritten.
98	pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) {
99		self.headers.insert(key.into().to_ascii_lowercase(), value.into());
100	}
101
102	/// Get a header value by name (case-insensitive).
103	pub fn get(&self, key: &str) -> Option<&str> {
104		self.headers.get(&key.to_ascii_lowercase()).map(|s| s.as_str())
105	}
106
107	/// Get the protocol.
108	pub fn protocol(&self) -> Protocol {
109		self.protocol
110	}
111
112	/// Get all headers.
113	pub fn headers(&self) -> &HashMap<String, String> {
114		&self.headers
115	}
116}
117
118/// Context available to pre-execute interceptors.
119///
120/// Fields are public and mutable so interceptors can override values
121/// (e.g., resolve API key → set identity, store key_id in metadata for post_execute).
122pub struct RequestContext {
123	/// The resolved identity. Pre-execute interceptors may replace this.
124	pub identity: IdentityId,
125	/// The operation type.
126	pub operation: Operation,
127	/// The RQL statements being executed.
128	pub statements: Vec<String>,
129	/// Query parameters.
130	pub params: Params,
131	/// Protocol-agnostic request metadata (headers, etc.).
132	pub metadata: RequestMetadata,
133}
134
135/// Context available to post-execute interceptors.
136pub struct ResponseContext {
137	/// The identity that executed the request (may have been mutated by pre_execute).
138	pub identity: IdentityId,
139	/// The operation type.
140	pub operation: Operation,
141	/// The RQL statements that were executed.
142	pub statements: Vec<String>,
143	/// Query parameters.
144	pub params: Params,
145	/// Protocol-agnostic request metadata.
146	pub metadata: RequestMetadata,
147	/// Execution result: Ok(frame_count) or Err with the error message.
148	pub result: Result<usize, String>,
149	/// Wall-clock execution duration.
150	pub duration: Duration,
151	/// Compute-only duration (excludes queue wait and scheduling overhead).
152	pub compute_duration: Duration,
153}
154
155/// Async trait for request-level interceptors.
156///
157/// Interceptors run in the tokio async context (before compute pool dispatch),
158/// so they can perform async I/O (database lookups, network calls, etc.).
159///
160/// Multiple interceptors are chained: `pre_execute` runs in registration order,
161/// `post_execute` runs in reverse order (like middleware stacks).
162pub trait RequestInterceptor: Send + Sync + 'static {
163	/// Called before query execution.
164	///
165	/// Return `Ok(())` to allow the request to proceed.
166	/// Return `Err(ExecuteError)` to reject the request.
167	/// May mutate the context (e.g., set identity from API key lookup).
168	fn pre_execute<'a>(
169		&'a self,
170		ctx: &'a mut RequestContext,
171	) -> Pin<Box<dyn Future<Output = Result<(), ExecuteError>> + Send + 'a>>;
172
173	/// Called after query execution completes (success or failure).
174	///
175	/// This is called even if the execution failed, so interceptors can
176	/// log failures and track usage regardless of outcome.
177	fn post_execute<'a>(&'a self, ctx: &'a ResponseContext) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
178}
179
180/// Ordered chain of request interceptors, cheap to clone (Arc internally).
181#[derive(Clone)]
182pub struct RequestInterceptorChain {
183	interceptors: Arc<Vec<Arc<dyn RequestInterceptor>>>,
184}
185
186impl RequestInterceptorChain {
187	pub fn new(interceptors: Vec<Arc<dyn RequestInterceptor>>) -> Self {
188		Self {
189			interceptors: Arc::new(interceptors),
190		}
191	}
192
193	pub fn empty() -> Self {
194		Self {
195			interceptors: Arc::new(Vec::new()),
196		}
197	}
198
199	pub fn is_empty(&self) -> bool {
200		self.interceptors.is_empty()
201	}
202
203	/// Run all pre_execute interceptors in order.
204	/// Stops and returns Err on first rejection.
205	pub async fn pre_execute(&self, ctx: &mut RequestContext) -> Result<(), ExecuteError> {
206		for interceptor in self.interceptors.iter() {
207			interceptor.pre_execute(ctx).await?;
208		}
209		Ok(())
210	}
211
212	/// Run all post_execute interceptors in reverse order.
213	///
214	/// If an interceptor panics, the panic is caught and logged so that
215	/// remaining interceptors still run.
216	pub async fn post_execute(&self, ctx: &ResponseContext) {
217		for interceptor in self.interceptors.iter().rev() {
218			if let Err(panic) = AssertUnwindSafe(interceptor.post_execute(ctx)).catch_unwind().await {
219				let msg = panic
220					.downcast_ref::<&str>()
221					.copied()
222					.or_else(|| panic.downcast_ref::<String>().map(|s| s.as_str()))
223					.unwrap_or("unknown panic");
224				error!("post_execute interceptor panicked: {}", msg);
225			}
226		}
227	}
228}
229
230impl Default for RequestInterceptorChain {
231	fn default() -> Self {
232		Self::empty()
233	}
234}