1use crate::filter::FilterValue;
4use std::collections::HashMap;
5use std::time::Instant;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum QueryType {
10 Select,
12 Insert,
14 Update,
16 Delete,
18 Count,
20 Raw,
22 TransactionBegin,
24 TransactionCommit,
26 TransactionRollback,
28 Unknown,
30}
31
32impl QueryType {
33 pub fn from_sql(sql: &str) -> Self {
35 let sql = sql.trim().to_uppercase();
36 if sql.starts_with("SELECT") {
37 if sql.contains("COUNT(") {
39 Self::Count
40 } else {
41 Self::Select
42 }
43 } else if sql.starts_with("INSERT") {
44 Self::Insert
45 } else if sql.starts_with("UPDATE") {
46 Self::Update
47 } else if sql.starts_with("DELETE") {
48 Self::Delete
49 } else if sql.starts_with("BEGIN") || sql.starts_with("START TRANSACTION") {
50 Self::TransactionBegin
51 } else if sql.starts_with("COMMIT") {
52 Self::TransactionCommit
53 } else if sql.starts_with("ROLLBACK") {
54 Self::TransactionRollback
55 } else {
56 Self::Unknown
57 }
58 }
59
60 pub fn is_read(&self) -> bool {
62 matches!(self, Self::Select | Self::Count)
63 }
64
65 pub fn is_write(&self) -> bool {
67 matches!(self, Self::Insert | Self::Update | Self::Delete)
68 }
69
70 pub fn is_transaction(&self) -> bool {
72 matches!(
73 self,
74 Self::TransactionBegin | Self::TransactionCommit | Self::TransactionRollback
75 )
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum QueryPhase {
82 Before,
84 During,
86 AfterSuccess,
88 AfterError,
90}
91
92#[derive(Debug, Clone)]
94pub struct QueryMetadata {
95 pub model: Option<String>,
97 pub operation: Option<String>,
99 pub request_id: Option<String>,
101 pub user_id: Option<String>,
103 pub tenant_id: Option<String>,
105 pub schema_override: Option<String>,
107 pub tags: HashMap<String, String>,
109 pub attributes: HashMap<String, serde_json::Value>,
111}
112
113impl Default for QueryMetadata {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl QueryMetadata {
120 pub fn new() -> Self {
122 Self {
123 model: None,
124 operation: None,
125 request_id: None,
126 user_id: None,
127 tenant_id: None,
128 schema_override: None,
129 tags: HashMap::new(),
130 attributes: HashMap::new(),
131 }
132 }
133
134 pub fn with_model(mut self, model: impl Into<String>) -> Self {
136 self.model = Some(model.into());
137 self
138 }
139
140 pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
142 self.operation = Some(operation.into());
143 self
144 }
145
146 pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
148 self.request_id = Some(id.into());
149 self
150 }
151
152 pub fn with_user_id(mut self, id: impl Into<String>) -> Self {
154 self.user_id = Some(id.into());
155 self
156 }
157
158 pub fn with_tenant_id(mut self, id: impl Into<String>) -> Self {
160 self.tenant_id = Some(id.into());
161 self
162 }
163
164 pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
166 self.tags.insert(key.into(), value.into());
167 self
168 }
169
170 pub fn with_attribute(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
172 self.attributes.insert(key.into(), value);
173 self
174 }
175
176 pub fn set_schema_override(&mut self, schema: Option<String>) {
178 self.schema_override = schema;
179 }
180
181 pub fn schema_override(&self) -> Option<&str> {
183 self.schema_override.as_deref()
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct QueryContext {
190 sql: String,
192 params: Vec<FilterValue>,
194 query_type: QueryType,
196 metadata: QueryMetadata,
198 started_at: Instant,
200 phase: QueryPhase,
202 skip_execution: bool,
204 cached_response: Option<serde_json::Value>,
206}
207
208impl QueryContext {
209 pub fn new(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
211 let sql = sql.into();
212 let query_type = QueryType::from_sql(&sql);
213 Self {
214 sql,
215 params,
216 query_type,
217 metadata: QueryMetadata::new(),
218 started_at: Instant::now(),
219 phase: QueryPhase::Before,
220 skip_execution: false,
221 cached_response: None,
222 }
223 }
224
225 pub fn sql(&self) -> &str {
227 &self.sql
228 }
229
230 pub fn sql_mut(&mut self) -> &mut String {
232 &mut self.sql
233 }
234
235 pub fn set_sql(&mut self, sql: impl Into<String>) {
237 self.sql = sql.into();
238 self.query_type = QueryType::from_sql(&self.sql);
239 }
240
241 pub fn with_sql(mut self, sql: impl Into<String>) -> Self {
243 self.set_sql(sql);
244 self
245 }
246
247 pub fn params(&self) -> &[FilterValue] {
249 &self.params
250 }
251
252 pub fn params_mut(&mut self) -> &mut Vec<FilterValue> {
254 &mut self.params
255 }
256
257 pub fn query_type(&self) -> QueryType {
259 self.query_type
260 }
261
262 pub fn metadata(&self) -> &QueryMetadata {
264 &self.metadata
265 }
266
267 pub fn metadata_mut(&mut self) -> &mut QueryMetadata {
269 &mut self.metadata
270 }
271
272 pub fn with_metadata(mut self, metadata: QueryMetadata) -> Self {
274 self.metadata = metadata;
275 self
276 }
277
278 pub fn elapsed(&self) -> std::time::Duration {
280 self.started_at.elapsed()
281 }
282
283 pub fn elapsed_us(&self) -> u64 {
285 self.started_at.elapsed().as_micros() as u64
286 }
287
288 pub fn phase(&self) -> QueryPhase {
290 self.phase
291 }
292
293 pub fn set_phase(&mut self, phase: QueryPhase) {
295 self.phase = phase;
296 }
297
298 pub fn should_skip(&self) -> bool {
300 self.skip_execution
301 }
302
303 pub fn skip_with_response(&mut self, response: serde_json::Value) {
305 self.skip_execution = true;
306 self.cached_response = Some(response);
307 }
308
309 pub fn cached_response(&self) -> Option<&serde_json::Value> {
311 self.cached_response.as_ref()
312 }
313
314 pub fn is_read(&self) -> bool {
316 self.query_type.is_read()
317 }
318
319 pub fn is_write(&self) -> bool {
321 self.query_type.is_write()
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_query_type_detection() {
331 assert_eq!(
332 QueryType::from_sql("SELECT * FROM users"),
333 QueryType::Select
334 );
335 assert_eq!(
336 QueryType::from_sql("INSERT INTO users VALUES (1)"),
337 QueryType::Insert
338 );
339 assert_eq!(
340 QueryType::from_sql("UPDATE users SET name = 'test'"),
341 QueryType::Update
342 );
343 assert_eq!(
344 QueryType::from_sql("DELETE FROM users WHERE id = 1"),
345 QueryType::Delete
346 );
347 assert_eq!(
348 QueryType::from_sql("SELECT COUNT(*) FROM users"),
349 QueryType::Count
350 );
351 assert_eq!(QueryType::from_sql("BEGIN"), QueryType::TransactionBegin);
352 assert_eq!(QueryType::from_sql("COMMIT"), QueryType::TransactionCommit);
353 assert_eq!(
354 QueryType::from_sql("ROLLBACK"),
355 QueryType::TransactionRollback
356 );
357 }
358
359 #[test]
360 fn test_query_type_categories() {
361 assert!(QueryType::Select.is_read());
362 assert!(QueryType::Count.is_read());
363 assert!(!QueryType::Insert.is_read());
364
365 assert!(QueryType::Insert.is_write());
366 assert!(QueryType::Update.is_write());
367 assert!(QueryType::Delete.is_write());
368 assert!(!QueryType::Select.is_write());
369
370 assert!(QueryType::TransactionBegin.is_transaction());
371 assert!(QueryType::TransactionCommit.is_transaction());
372 assert!(QueryType::TransactionRollback.is_transaction());
373 }
374
375 #[test]
376 fn test_query_context() {
377 let ctx = QueryContext::new("SELECT * FROM users", vec![]);
378 assert_eq!(ctx.sql(), "SELECT * FROM users");
379 assert_eq!(ctx.query_type(), QueryType::Select);
380 assert!(ctx.is_read());
381 assert!(!ctx.is_write());
382 }
383
384 #[test]
385 fn test_query_metadata() {
386 let metadata = QueryMetadata::new()
387 .with_model("User")
388 .with_operation("findMany")
389 .with_request_id("req-123")
390 .with_tag("env", "production");
391
392 assert_eq!(metadata.model, Some("User".to_string()));
393 assert_eq!(metadata.operation, Some("findMany".to_string()));
394 assert_eq!(metadata.tags.get("env"), Some(&"production".to_string()));
395 }
396
397 #[test]
398 fn test_context_skip_execution() {
399 let mut ctx = QueryContext::new("SELECT * FROM users", vec![]);
400 assert!(!ctx.should_skip());
401
402 ctx.skip_with_response(serde_json::json!({"cached": true}));
403 assert!(ctx.should_skip());
404 assert!(ctx.cached_response().is_some());
405 }
406}