reifydb_sub_server/interceptor.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4//! Request-level interceptors for pre/post query execution hooks.
5//!
6//! This module provides an async interceptor mechanism that allows consumers
7//! to hook into the request lifecycle — before and after query execution.
8//! Interceptors can reject requests (for auth, rate limiting, credit checks)
9//! or observe results (for logging, billing, usage tracking).
10//!
11//! # Example
12//!
13//! ```ignore
14//! use reifydb::server;
15//!
16//! struct MyInterceptor;
17//!
18//! impl RequestInterceptor for MyInterceptor {
19//! fn pre_execute(&self, ctx: &mut RequestContext)
20//! -> Pin<Box<dyn Future<Output = Result<(), ExecuteError>> + Send + '_>>
21//! {
22//! Box::pin(async move {
23//! if ctx.metadata.get("authorization").is_none() {
24//! return Err(ExecuteError::Rejected {
25//! code: "AUTH_REQUIRED".into(),
26//! message: "Missing API key".into(),
27//! });
28//! }
29//! Ok(())
30//! })
31//! }
32//!
33//! fn post_execute(&self, ctx: &ResponseContext)
34//! -> Pin<Box<dyn Future<Output = ()> + Send + '_>>
35//! {
36//! Box::pin(async move {
37//! tracing::info!(duration_ms = ctx.duration.as_millis(), "query executed");
38//! })
39//! }
40//! }
41//!
42//! let db = server::memory()
43//! .with_request_interceptor(MyInterceptor)
44//! .build()?;
45//! ```
46
47use std::{collections::HashMap, future::Future, panic::AssertUnwindSafe, pin::Pin, sync::Arc, time::Duration};
48
49use futures_util::FutureExt;
50use reifydb_type::{params::Params, value::identity::IdentityId};
51use tracing::error;
52
53use crate::execute::ExecuteError;
54
55/// The type of database operation being executed.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub enum Operation {
58 Query,
59 Command,
60 Admin,
61 Subscribe,
62}
63
64/// The transport protocol used for the request.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
66pub enum Protocol {
67 #[default]
68 Http,
69 WebSocket,
70 Grpc,
71}
72
73/// Protocol-agnostic metadata extracted from the request transport layer.
74///
75/// HTTP headers, gRPC metadata, and WS auth tokens are all normalized into
76/// a string-keyed map. Header names are lowercased for consistent lookup.
77///
78/// Note: this is a single-value map — duplicate keys are overwritten
79/// (last-write-wins). Multi-valued headers (e.g. `Set-Cookie`) only
80/// retain the last value. This is intentional for simplicity; most
81/// interceptor use cases only need single-valued lookups.
82#[derive(Debug, Clone, Default)]
83pub struct RequestMetadata {
84 headers: HashMap<String, String>,
85 protocol: Protocol,
86}
87
88impl RequestMetadata {
89 /// Create empty metadata for the given protocol.
90 pub fn new(protocol: Protocol) -> Self {
91 Self {
92 headers: HashMap::new(),
93 protocol,
94 }
95 }
96
97 /// Insert a header (key is lowercased). Duplicate keys are overwritten.
98 pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) {
99 self.headers.insert(key.into().to_ascii_lowercase(), value.into());
100 }
101
102 /// Get a header value by name (case-insensitive).
103 pub fn get(&self, key: &str) -> Option<&str> {
104 self.headers.get(&key.to_ascii_lowercase()).map(|s| s.as_str())
105 }
106
107 /// Get the protocol.
108 pub fn protocol(&self) -> Protocol {
109 self.protocol
110 }
111
112 /// Get all headers.
113 pub fn headers(&self) -> &HashMap<String, String> {
114 &self.headers
115 }
116}
117
118/// Context available to pre-execute interceptors.
119///
120/// Fields are public and mutable so interceptors can override values
121/// (e.g., resolve API key → set identity, store key_id in metadata for post_execute).
122pub struct RequestContext {
123 /// The resolved identity. Pre-execute interceptors may replace this.
124 pub identity: IdentityId,
125 /// The operation type.
126 pub operation: Operation,
127 /// The RQL statements being executed.
128 pub statements: Vec<String>,
129 /// Query parameters.
130 pub params: Params,
131 /// Protocol-agnostic request metadata (headers, etc.).
132 pub metadata: RequestMetadata,
133}
134
135/// Context available to post-execute interceptors.
136pub struct ResponseContext {
137 /// The identity that executed the request (may have been mutated by pre_execute).
138 pub identity: IdentityId,
139 /// The operation type.
140 pub operation: Operation,
141 /// The RQL statements that were executed.
142 pub statements: Vec<String>,
143 /// Query parameters.
144 pub params: Params,
145 /// Protocol-agnostic request metadata.
146 pub metadata: RequestMetadata,
147 /// Execution result: Ok(frame_count) or Err with the error message.
148 pub result: Result<usize, String>,
149 /// Wall-clock execution duration.
150 pub duration: Duration,
151 /// Compute-only duration (excludes queue wait and scheduling overhead).
152 pub compute_duration: Duration,
153}
154
155/// Async trait for request-level interceptors.
156///
157/// Interceptors run in the tokio async context (before compute pool dispatch),
158/// so they can perform async I/O (database lookups, network calls, etc.).
159///
160/// Multiple interceptors are chained: `pre_execute` runs in registration order,
161/// `post_execute` runs in reverse order (like middleware stacks).
162pub trait RequestInterceptor: Send + Sync + 'static {
163 /// Called before query execution.
164 ///
165 /// Return `Ok(())` to allow the request to proceed.
166 /// Return `Err(ExecuteError)` to reject the request.
167 /// May mutate the context (e.g., set identity from API key lookup).
168 fn pre_execute<'a>(
169 &'a self,
170 ctx: &'a mut RequestContext,
171 ) -> Pin<Box<dyn Future<Output = Result<(), ExecuteError>> + Send + 'a>>;
172
173 /// Called after query execution completes (success or failure).
174 ///
175 /// This is called even if the execution failed, so interceptors can
176 /// log failures and track usage regardless of outcome.
177 fn post_execute<'a>(&'a self, ctx: &'a ResponseContext) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
178}
179
180/// Ordered chain of request interceptors, cheap to clone (Arc internally).
181#[derive(Clone)]
182pub struct RequestInterceptorChain {
183 interceptors: Arc<Vec<Arc<dyn RequestInterceptor>>>,
184}
185
186impl RequestInterceptorChain {
187 pub fn new(interceptors: Vec<Arc<dyn RequestInterceptor>>) -> Self {
188 Self {
189 interceptors: Arc::new(interceptors),
190 }
191 }
192
193 pub fn empty() -> Self {
194 Self {
195 interceptors: Arc::new(Vec::new()),
196 }
197 }
198
199 pub fn is_empty(&self) -> bool {
200 self.interceptors.is_empty()
201 }
202
203 /// Run all pre_execute interceptors in order.
204 /// Stops and returns Err on first rejection.
205 pub async fn pre_execute(&self, ctx: &mut RequestContext) -> Result<(), ExecuteError> {
206 for interceptor in self.interceptors.iter() {
207 interceptor.pre_execute(ctx).await?;
208 }
209 Ok(())
210 }
211
212 /// Run all post_execute interceptors in reverse order.
213 ///
214 /// If an interceptor panics, the panic is caught and logged so that
215 /// remaining interceptors still run.
216 pub async fn post_execute(&self, ctx: &ResponseContext) {
217 for interceptor in self.interceptors.iter().rev() {
218 if let Err(panic) = AssertUnwindSafe(interceptor.post_execute(ctx)).catch_unwind().await {
219 let msg = panic
220 .downcast_ref::<&str>()
221 .copied()
222 .or_else(|| panic.downcast_ref::<String>().map(|s| s.as_str()))
223 .unwrap_or("unknown panic");
224 error!("post_execute interceptor panicked: {}", msg);
225 }
226 }
227 }
228}
229
230impl Default for RequestInterceptorChain {
231 fn default() -> Self {
232 Self::empty()
233 }
234}