prax_query/middleware/
logging.rs

1//! Logging middleware for query tracing.
2
3use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::sync::atomic::{AtomicU64, Ordering};
6
7/// Log level for query logging.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
9pub enum LogLevel {
10    /// Log nothing.
11    Off,
12    /// Log only errors.
13    Error,
14    /// Log errors and warnings (slow queries).
15    Warn,
16    /// Log all queries.
17    #[default]
18    Info,
19    /// Log queries with parameters.
20    Debug,
21    /// Log everything including internal details.
22    Trace,
23}
24
25/// Configuration for the logging middleware.
26#[derive(Debug, Clone)]
27pub struct LoggingConfig {
28    /// Minimum log level.
29    pub level: LogLevel,
30    /// Threshold for slow query warnings (microseconds).
31    pub slow_query_threshold_us: u64,
32    /// Whether to log query parameters.
33    pub log_params: bool,
34    /// Whether to log response data.
35    pub log_response: bool,
36    /// Maximum length of logged SQL (0 = unlimited).
37    pub max_sql_length: usize,
38    /// Prefix for log messages.
39    pub prefix: String,
40}
41
42impl Default for LoggingConfig {
43    fn default() -> Self {
44        Self {
45            level: LogLevel::Info,
46            slow_query_threshold_us: 1_000_000, // 1 second
47            log_params: false,
48            log_response: false,
49            max_sql_length: 500,
50            prefix: "prax".to_string(),
51        }
52    }
53}
54
55/// Middleware that logs queries.
56///
57/// # Example
58///
59/// ```rust,ignore
60/// use prax_query::middleware::{LoggingMiddleware, LogLevel};
61///
62/// let logging = LoggingMiddleware::new()
63///     .with_level(LogLevel::Debug)
64///     .with_params(true)
65///     .with_slow_threshold(500_000); // 500ms
66/// ```
67pub struct LoggingMiddleware {
68    config: LoggingConfig,
69    query_count: AtomicU64,
70}
71
72impl LoggingMiddleware {
73    /// Create a new logging middleware with default settings.
74    pub fn new() -> Self {
75        Self {
76            config: LoggingConfig::default(),
77            query_count: AtomicU64::new(0),
78        }
79    }
80
81    /// Create with custom configuration.
82    pub fn with_config(config: LoggingConfig) -> Self {
83        Self {
84            config,
85            query_count: AtomicU64::new(0),
86        }
87    }
88
89    /// Set the log level.
90    pub fn with_level(mut self, level: LogLevel) -> Self {
91        self.config.level = level;
92        self
93    }
94
95    /// Enable parameter logging.
96    pub fn with_params(mut self, enabled: bool) -> Self {
97        self.config.log_params = enabled;
98        self
99    }
100
101    /// Enable response logging.
102    pub fn with_response(mut self, enabled: bool) -> Self {
103        self.config.log_response = enabled;
104        self
105    }
106
107    /// Set slow query threshold in microseconds.
108    pub fn with_slow_threshold(mut self, threshold_us: u64) -> Self {
109        self.config.slow_query_threshold_us = threshold_us;
110        self
111    }
112
113    /// Set the log prefix.
114    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
115        self.config.prefix = prefix.into();
116        self
117    }
118
119    /// Get the total query count.
120    pub fn query_count(&self) -> u64 {
121        self.query_count.load(Ordering::Relaxed)
122    }
123
124    fn truncate_sql(&self, sql: &str) -> String {
125        if self.config.max_sql_length == 0 || sql.len() <= self.config.max_sql_length {
126            sql.to_string()
127        } else {
128            format!("{}...", &sql[..self.config.max_sql_length])
129        }
130    }
131
132    fn log_before(&self, ctx: &QueryContext, query_id: u64) {
133        if self.config.level < LogLevel::Debug {
134            return;
135        }
136
137        let sql = self.truncate_sql(ctx.sql());
138        let query_type = format!("{:?}", ctx.query_type());
139
140        if self.config.log_params && self.config.level >= LogLevel::Trace {
141            tracing::debug!(
142                target: "prax::query",
143                query_id = query_id,
144                query_type = %query_type,
145                sql = %sql,
146                params = ?ctx.params(),
147                model = ?ctx.metadata().model,
148                operation = ?ctx.metadata().operation,
149                request_id = ?ctx.metadata().request_id,
150                "[{}] Starting query",
151                self.config.prefix
152            );
153        } else {
154            tracing::debug!(
155                target: "prax::query",
156                query_id = query_id,
157                query_type = %query_type,
158                sql = %sql,
159                "[{}] Starting query",
160                self.config.prefix
161            );
162        }
163    }
164
165    fn log_after(&self, ctx: &QueryContext, response: &QueryResponse, query_id: u64) {
166        let duration_us = response.execution_time_us;
167        let is_slow = duration_us >= self.config.slow_query_threshold_us;
168
169        if is_slow && self.config.level >= LogLevel::Warn {
170            let sql = self.truncate_sql(ctx.sql());
171            tracing::warn!(
172                target: "prax::query",
173                query_id = query_id,
174                duration_us = duration_us,
175                duration_ms = duration_us / 1000,
176                sql = %sql,
177                threshold_us = self.config.slow_query_threshold_us,
178                "[{}] Slow query detected",
179                self.config.prefix
180            );
181        } else if self.config.level >= LogLevel::Info {
182            let sql = self.truncate_sql(ctx.sql());
183
184            if self.config.log_response && self.config.level >= LogLevel::Trace {
185                tracing::info!(
186                    target: "prax::query",
187                    query_id = query_id,
188                    duration_us = duration_us,
189                    rows_affected = ?response.rows_affected,
190                    from_cache = response.from_cache,
191                    sql = %sql,
192                    response = ?response.data,
193                    "[{}] Query completed",
194                    self.config.prefix
195                );
196            } else {
197                tracing::info!(
198                    target: "prax::query",
199                    query_id = query_id,
200                    duration_us = duration_us,
201                    rows_affected = ?response.rows_affected,
202                    from_cache = response.from_cache,
203                    "[{}] Query completed",
204                    self.config.prefix
205                );
206            }
207        }
208    }
209
210    fn log_error(&self, ctx: &QueryContext, error: &crate::QueryError, query_id: u64) {
211        if self.config.level >= LogLevel::Error {
212            let sql = self.truncate_sql(ctx.sql());
213            tracing::error!(
214                target: "prax::query",
215                query_id = query_id,
216                sql = %sql,
217                error = %error,
218                "[{}] Query failed",
219                self.config.prefix
220            );
221        }
222    }
223}
224
225impl Default for LoggingMiddleware {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231impl Middleware for LoggingMiddleware {
232    fn handle<'a>(
233        &'a self,
234        ctx: QueryContext,
235        next: Next<'a>,
236    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
237        Box::pin(async move {
238            let query_id = self.query_count.fetch_add(1, Ordering::SeqCst);
239
240            // Log before
241            self.log_before(&ctx, query_id);
242
243            // Execute query
244            let result = next.run(ctx.clone()).await;
245
246            // Log after
247            match &result {
248                Ok(response) => self.log_after(&ctx, response, query_id),
249                Err(error) => self.log_error(&ctx, error, query_id),
250            }
251
252            result
253        })
254    }
255
256    fn name(&self) -> &'static str {
257        "LoggingMiddleware"
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_log_level_ordering() {
267        assert!(LogLevel::Error < LogLevel::Warn);
268        assert!(LogLevel::Warn < LogLevel::Info);
269        assert!(LogLevel::Info < LogLevel::Debug);
270        assert!(LogLevel::Debug < LogLevel::Trace);
271    }
272
273    #[test]
274    fn test_logging_middleware_builder() {
275        let middleware = LoggingMiddleware::new()
276            .with_level(LogLevel::Debug)
277            .with_params(true)
278            .with_slow_threshold(500_000);
279
280        assert_eq!(middleware.config.level, LogLevel::Debug);
281        assert!(middleware.config.log_params);
282        assert_eq!(middleware.config.slow_query_threshold_us, 500_000);
283    }
284
285    #[test]
286    fn test_truncate_sql() {
287        let middleware = LoggingMiddleware::new();
288
289        let short = "SELECT * FROM users";
290        assert_eq!(middleware.truncate_sql(short), short);
291
292        let config = LoggingConfig {
293            max_sql_length: 10,
294            ..Default::default()
295        };
296        let middleware = LoggingMiddleware::with_config(config);
297        let long = "SELECT * FROM users WHERE id = 1";
298        assert!(middleware.truncate_sql(long).ends_with("..."));
299    }
300
301    #[test]
302    fn test_query_count() {
303        let middleware = LoggingMiddleware::new();
304        assert_eq!(middleware.query_count(), 0);
305    }
306}