prax_query/middleware/
context.rs

1//! Query context for middleware.
2
3use crate::filter::FilterValue;
4use std::collections::HashMap;
5use std::time::Instant;
6
7/// The type of query being executed.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum QueryType {
10    /// SELECT query.
11    Select,
12    /// INSERT query.
13    Insert,
14    /// UPDATE query.
15    Update,
16    /// DELETE query.
17    Delete,
18    /// COUNT query.
19    Count,
20    /// Raw SQL query.
21    Raw,
22    /// Transaction begin.
23    TransactionBegin,
24    /// Transaction commit.
25    TransactionCommit,
26    /// Transaction rollback.
27    TransactionRollback,
28    /// Unknown query type.
29    Unknown,
30}
31
32impl QueryType {
33    /// Detect query type from SQL string.
34    pub fn from_sql(sql: &str) -> Self {
35        let sql = sql.trim().to_uppercase();
36        if sql.starts_with("SELECT") {
37            // Check if it's a COUNT query
38            if sql.contains("COUNT(") {
39                Self::Count
40            } else {
41                Self::Select
42            }
43        } else if sql.starts_with("INSERT") {
44            Self::Insert
45        } else if sql.starts_with("UPDATE") {
46            Self::Update
47        } else if sql.starts_with("DELETE") {
48            Self::Delete
49        } else if sql.starts_with("BEGIN") || sql.starts_with("START TRANSACTION") {
50            Self::TransactionBegin
51        } else if sql.starts_with("COMMIT") {
52            Self::TransactionCommit
53        } else if sql.starts_with("ROLLBACK") {
54            Self::TransactionRollback
55        } else {
56            Self::Unknown
57        }
58    }
59
60    /// Check if this is a read operation.
61    pub fn is_read(&self) -> bool {
62        matches!(self, Self::Select | Self::Count)
63    }
64
65    /// Check if this is a write operation.
66    pub fn is_write(&self) -> bool {
67        matches!(self, Self::Insert | Self::Update | Self::Delete)
68    }
69
70    /// Check if this is a transaction operation.
71    pub fn is_transaction(&self) -> bool {
72        matches!(
73            self,
74            Self::TransactionBegin | Self::TransactionCommit | Self::TransactionRollback
75        )
76    }
77}
78
79/// The current phase of query execution.
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum QueryPhase {
82    /// Before the query is executed.
83    Before,
84    /// During query execution.
85    During,
86    /// After the query has completed successfully.
87    AfterSuccess,
88    /// After the query has failed.
89    AfterError,
90}
91
92/// Metadata about a query.
93#[derive(Debug, Clone)]
94pub struct QueryMetadata {
95    /// The model being queried (if known).
96    pub model: Option<String>,
97    /// The operation name (e.g., "findMany", "create").
98    pub operation: Option<String>,
99    /// Request ID for tracing.
100    pub request_id: Option<String>,
101    /// User ID for auditing.
102    pub user_id: Option<String>,
103    /// Tenant ID for multi-tenancy.
104    pub tenant_id: Option<String>,
105    /// Schema override for multi-tenancy.
106    pub schema_override: Option<String>,
107    /// Custom tags for filtering.
108    pub tags: HashMap<String, String>,
109    /// Custom attributes.
110    pub attributes: HashMap<String, serde_json::Value>,
111}
112
113impl Default for QueryMetadata {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl QueryMetadata {
120    /// Create new empty metadata.
121    pub fn new() -> Self {
122        Self {
123            model: None,
124            operation: None,
125            request_id: None,
126            user_id: None,
127            tenant_id: None,
128            schema_override: None,
129            tags: HashMap::new(),
130            attributes: HashMap::new(),
131        }
132    }
133
134    /// Set the model name.
135    pub fn with_model(mut self, model: impl Into<String>) -> Self {
136        self.model = Some(model.into());
137        self
138    }
139
140    /// Set the operation name.
141    pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
142        self.operation = Some(operation.into());
143        self
144    }
145
146    /// Set the request ID.
147    pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
148        self.request_id = Some(id.into());
149        self
150    }
151
152    /// Set the user ID.
153    pub fn with_user_id(mut self, id: impl Into<String>) -> Self {
154        self.user_id = Some(id.into());
155        self
156    }
157
158    /// Set the tenant ID.
159    pub fn with_tenant_id(mut self, id: impl Into<String>) -> Self {
160        self.tenant_id = Some(id.into());
161        self
162    }
163
164    /// Add a tag.
165    pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
166        self.tags.insert(key.into(), value.into());
167        self
168    }
169
170    /// Add an attribute.
171    pub fn with_attribute(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
172        self.attributes.insert(key.into(), value);
173        self
174    }
175
176    /// Set the schema override for multi-tenancy.
177    pub fn set_schema_override(&mut self, schema: Option<String>) {
178        self.schema_override = schema;
179    }
180
181    /// Get the schema override.
182    pub fn schema_override(&self) -> Option<&str> {
183        self.schema_override.as_deref()
184    }
185}
186
187/// Context passed through the middleware chain.
188#[derive(Debug, Clone)]
189pub struct QueryContext {
190    /// The SQL query string.
191    sql: String,
192    /// Query parameters.
193    params: Vec<FilterValue>,
194    /// Query type.
195    query_type: QueryType,
196    /// Query metadata.
197    metadata: QueryMetadata,
198    /// When the query started.
199    started_at: Instant,
200    /// Current execution phase.
201    phase: QueryPhase,
202    /// Whether the query should be skipped (e.g., cache hit).
203    skip_execution: bool,
204    /// Cached response (if skipping execution).
205    cached_response: Option<serde_json::Value>,
206}
207
208impl QueryContext {
209    /// Create a new query context.
210    pub fn new(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
211        let sql = sql.into();
212        let query_type = QueryType::from_sql(&sql);
213        Self {
214            sql,
215            params,
216            query_type,
217            metadata: QueryMetadata::new(),
218            started_at: Instant::now(),
219            phase: QueryPhase::Before,
220            skip_execution: false,
221            cached_response: None,
222        }
223    }
224
225    /// Get the SQL string.
226    pub fn sql(&self) -> &str {
227        &self.sql
228    }
229
230    /// Get mutable SQL string (for query modification).
231    pub fn sql_mut(&mut self) -> &mut String {
232        &mut self.sql
233    }
234
235    /// Set a new SQL string.
236    pub fn set_sql(&mut self, sql: impl Into<String>) {
237        self.sql = sql.into();
238        self.query_type = QueryType::from_sql(&self.sql);
239    }
240
241    /// Set a new SQL string (builder pattern).
242    pub fn with_sql(mut self, sql: impl Into<String>) -> Self {
243        self.set_sql(sql);
244        self
245    }
246
247    /// Get the query parameters.
248    pub fn params(&self) -> &[FilterValue] {
249        &self.params
250    }
251
252    /// Get mutable parameters.
253    pub fn params_mut(&mut self) -> &mut Vec<FilterValue> {
254        &mut self.params
255    }
256
257    /// Get the query type.
258    pub fn query_type(&self) -> QueryType {
259        self.query_type
260    }
261
262    /// Get the metadata.
263    pub fn metadata(&self) -> &QueryMetadata {
264        &self.metadata
265    }
266
267    /// Get mutable metadata.
268    pub fn metadata_mut(&mut self) -> &mut QueryMetadata {
269        &mut self.metadata
270    }
271
272    /// Set metadata.
273    pub fn with_metadata(mut self, metadata: QueryMetadata) -> Self {
274        self.metadata = metadata;
275        self
276    }
277
278    /// Get elapsed time since query started.
279    pub fn elapsed(&self) -> std::time::Duration {
280        self.started_at.elapsed()
281    }
282
283    /// Get elapsed time in microseconds.
284    pub fn elapsed_us(&self) -> u64 {
285        self.started_at.elapsed().as_micros() as u64
286    }
287
288    /// Get the current phase.
289    pub fn phase(&self) -> QueryPhase {
290        self.phase
291    }
292
293    /// Set the current phase.
294    pub fn set_phase(&mut self, phase: QueryPhase) {
295        self.phase = phase;
296    }
297
298    /// Check if execution should be skipped.
299    pub fn should_skip(&self) -> bool {
300        self.skip_execution
301    }
302
303    /// Mark query to skip execution (e.g., for cache hit).
304    pub fn skip_with_response(&mut self, response: serde_json::Value) {
305        self.skip_execution = true;
306        self.cached_response = Some(response);
307    }
308
309    /// Get the cached response if skipping.
310    pub fn cached_response(&self) -> Option<&serde_json::Value> {
311        self.cached_response.as_ref()
312    }
313
314    /// Check if this is a read query.
315    pub fn is_read(&self) -> bool {
316        self.query_type.is_read()
317    }
318
319    /// Check if this is a write query.
320    pub fn is_write(&self) -> bool {
321        self.query_type.is_write()
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_query_type_detection() {
331        assert_eq!(
332            QueryType::from_sql("SELECT * FROM users"),
333            QueryType::Select
334        );
335        assert_eq!(
336            QueryType::from_sql("INSERT INTO users VALUES (1)"),
337            QueryType::Insert
338        );
339        assert_eq!(
340            QueryType::from_sql("UPDATE users SET name = 'test'"),
341            QueryType::Update
342        );
343        assert_eq!(
344            QueryType::from_sql("DELETE FROM users WHERE id = 1"),
345            QueryType::Delete
346        );
347        assert_eq!(
348            QueryType::from_sql("SELECT COUNT(*) FROM users"),
349            QueryType::Count
350        );
351        assert_eq!(QueryType::from_sql("BEGIN"), QueryType::TransactionBegin);
352        assert_eq!(QueryType::from_sql("COMMIT"), QueryType::TransactionCommit);
353        assert_eq!(
354            QueryType::from_sql("ROLLBACK"),
355            QueryType::TransactionRollback
356        );
357    }
358
359    #[test]
360    fn test_query_type_categories() {
361        assert!(QueryType::Select.is_read());
362        assert!(QueryType::Count.is_read());
363        assert!(!QueryType::Insert.is_read());
364
365        assert!(QueryType::Insert.is_write());
366        assert!(QueryType::Update.is_write());
367        assert!(QueryType::Delete.is_write());
368        assert!(!QueryType::Select.is_write());
369
370        assert!(QueryType::TransactionBegin.is_transaction());
371        assert!(QueryType::TransactionCommit.is_transaction());
372        assert!(QueryType::TransactionRollback.is_transaction());
373    }
374
375    #[test]
376    fn test_query_context() {
377        let ctx = QueryContext::new("SELECT * FROM users", vec![]);
378        assert_eq!(ctx.sql(), "SELECT * FROM users");
379        assert_eq!(ctx.query_type(), QueryType::Select);
380        assert!(ctx.is_read());
381        assert!(!ctx.is_write());
382    }
383
384    #[test]
385    fn test_query_metadata() {
386        let metadata = QueryMetadata::new()
387            .with_model("User")
388            .with_operation("findMany")
389            .with_request_id("req-123")
390            .with_tag("env", "production");
391
392        assert_eq!(metadata.model, Some("User".to_string()));
393        assert_eq!(metadata.operation, Some("findMany".to_string()));
394        assert_eq!(metadata.tags.get("env"), Some(&"production".to_string()));
395    }
396
397    #[test]
398    fn test_context_skip_execution() {
399        let mut ctx = QueryContext::new("SELECT * FROM users", vec![]);
400        assert!(!ctx.should_skip());
401
402        ctx.skip_with_response(serde_json::json!({"cached": true}));
403        assert!(ctx.should_skip());
404        assert!(ctx.cached_response().is_some());
405    }
406}