use crate::filter::FilterValue;
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum QueryType {
Select,
Insert,
Update,
Delete,
Count,
Raw,
TransactionBegin,
TransactionCommit,
TransactionRollback,
Unknown,
}
impl QueryType {
pub fn from_sql(sql: &str) -> Self {
let sql = sql.trim().to_uppercase();
if sql.starts_with("SELECT") {
if sql.contains("COUNT(") {
Self::Count
} else {
Self::Select
}
} else if sql.starts_with("INSERT") {
Self::Insert
} else if sql.starts_with("UPDATE") {
Self::Update
} else if sql.starts_with("DELETE") {
Self::Delete
} else if sql.starts_with("BEGIN") || sql.starts_with("START TRANSACTION") {
Self::TransactionBegin
} else if sql.starts_with("COMMIT") {
Self::TransactionCommit
} else if sql.starts_with("ROLLBACK") {
Self::TransactionRollback
} else {
Self::Unknown
}
}
pub fn is_read(&self) -> bool {
matches!(self, Self::Select | Self::Count)
}
pub fn is_write(&self) -> bool {
matches!(self, Self::Insert | Self::Update | Self::Delete)
}
pub fn is_transaction(&self) -> bool {
matches!(
self,
Self::TransactionBegin | Self::TransactionCommit | Self::TransactionRollback
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryPhase {
Before,
During,
AfterSuccess,
AfterError,
}
#[derive(Debug, Clone)]
pub struct QueryMetadata {
pub model: Option<String>,
pub operation: Option<String>,
pub request_id: Option<String>,
pub user_id: Option<String>,
pub tenant_id: Option<String>,
pub schema_override: Option<String>,
pub tags: HashMap<String, String>,
pub attributes: HashMap<String, serde_json::Value>,
}
impl Default for QueryMetadata {
fn default() -> Self {
Self::new()
}
}
impl QueryMetadata {
pub fn new() -> Self {
Self {
model: None,
operation: None,
request_id: None,
user_id: None,
tenant_id: None,
schema_override: None,
tags: HashMap::new(),
attributes: HashMap::new(),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
self.operation = Some(operation.into());
self
}
pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
self.request_id = Some(id.into());
self
}
pub fn with_user_id(mut self, id: impl Into<String>) -> Self {
self.user_id = Some(id.into());
self
}
pub fn with_tenant_id(mut self, id: impl Into<String>) -> Self {
self.tenant_id = Some(id.into());
self
}
pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.tags.insert(key.into(), value.into());
self
}
pub fn with_attribute(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.attributes.insert(key.into(), value);
self
}
pub fn set_schema_override(&mut self, schema: Option<String>) {
self.schema_override = schema;
}
pub fn schema_override(&self) -> Option<&str> {
self.schema_override.as_deref()
}
}
#[derive(Debug, Clone)]
pub struct QueryContext {
sql: String,
params: Vec<FilterValue>,
query_type: QueryType,
metadata: QueryMetadata,
started_at: Instant,
phase: QueryPhase,
skip_execution: bool,
cached_response: Option<serde_json::Value>,
}
impl QueryContext {
pub fn new(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
let sql = sql.into();
let query_type = QueryType::from_sql(&sql);
Self {
sql,
params,
query_type,
metadata: QueryMetadata::new(),
started_at: Instant::now(),
phase: QueryPhase::Before,
skip_execution: false,
cached_response: None,
}
}
pub fn sql(&self) -> &str {
&self.sql
}
pub fn sql_mut(&mut self) -> &mut String {
&mut self.sql
}
pub fn set_sql(&mut self, sql: impl Into<String>) {
self.sql = sql.into();
self.query_type = QueryType::from_sql(&self.sql);
}
pub fn with_sql(mut self, sql: impl Into<String>) -> Self {
self.set_sql(sql);
self
}
pub fn params(&self) -> &[FilterValue] {
&self.params
}
pub fn params_mut(&mut self) -> &mut Vec<FilterValue> {
&mut self.params
}
pub fn query_type(&self) -> QueryType {
self.query_type
}
pub fn metadata(&self) -> &QueryMetadata {
&self.metadata
}
pub fn metadata_mut(&mut self) -> &mut QueryMetadata {
&mut self.metadata
}
pub fn with_metadata(mut self, metadata: QueryMetadata) -> Self {
self.metadata = metadata;
self
}
pub fn elapsed(&self) -> std::time::Duration {
self.started_at.elapsed()
}
pub fn elapsed_us(&self) -> u64 {
self.started_at.elapsed().as_micros() as u64
}
pub fn phase(&self) -> QueryPhase {
self.phase
}
pub fn set_phase(&mut self, phase: QueryPhase) {
self.phase = phase;
}
pub fn should_skip(&self) -> bool {
self.skip_execution
}
pub fn skip_with_response(&mut self, response: serde_json::Value) {
self.skip_execution = true;
self.cached_response = Some(response);
}
pub fn cached_response(&self) -> Option<&serde_json::Value> {
self.cached_response.as_ref()
}
pub fn is_read(&self) -> bool {
self.query_type.is_read()
}
pub fn is_write(&self) -> bool {
self.query_type.is_write()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_type_detection() {
assert_eq!(
QueryType::from_sql("SELECT * FROM users"),
QueryType::Select
);
assert_eq!(
QueryType::from_sql("INSERT INTO users VALUES (1)"),
QueryType::Insert
);
assert_eq!(
QueryType::from_sql("UPDATE users SET name = 'test'"),
QueryType::Update
);
assert_eq!(
QueryType::from_sql("DELETE FROM users WHERE id = 1"),
QueryType::Delete
);
assert_eq!(
QueryType::from_sql("SELECT COUNT(*) FROM users"),
QueryType::Count
);
assert_eq!(QueryType::from_sql("BEGIN"), QueryType::TransactionBegin);
assert_eq!(QueryType::from_sql("COMMIT"), QueryType::TransactionCommit);
assert_eq!(
QueryType::from_sql("ROLLBACK"),
QueryType::TransactionRollback
);
}
#[test]
fn test_query_type_categories() {
assert!(QueryType::Select.is_read());
assert!(QueryType::Count.is_read());
assert!(!QueryType::Insert.is_read());
assert!(QueryType::Insert.is_write());
assert!(QueryType::Update.is_write());
assert!(QueryType::Delete.is_write());
assert!(!QueryType::Select.is_write());
assert!(QueryType::TransactionBegin.is_transaction());
assert!(QueryType::TransactionCommit.is_transaction());
assert!(QueryType::TransactionRollback.is_transaction());
}
#[test]
fn test_query_context() {
let ctx = QueryContext::new("SELECT * FROM users", vec![]);
assert_eq!(ctx.sql(), "SELECT * FROM users");
assert_eq!(ctx.query_type(), QueryType::Select);
assert!(ctx.is_read());
assert!(!ctx.is_write());
}
#[test]
fn test_query_metadata() {
let metadata = QueryMetadata::new()
.with_model("User")
.with_operation("findMany")
.with_request_id("req-123")
.with_tag("env", "production");
assert_eq!(metadata.model, Some("User".to_string()));
assert_eq!(metadata.operation, Some("findMany".to_string()));
assert_eq!(metadata.tags.get("env"), Some(&"production".to_string()));
}
#[test]
fn test_context_skip_execution() {
let mut ctx = QueryContext::new("SELECT * FROM users", vec![]);
assert!(!ctx.should_skip());
ctx.skip_with_response(serde_json::json!({"cached": true}));
assert!(ctx.should_skip());
assert!(ctx.cached_response().is_some());
}
}