prax_query/middleware/
logging.rs

1//! Logging middleware for query tracing.
2
3use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::sync::atomic::{AtomicU64, Ordering};
6
7/// Log level for query logging.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
9#[derive(Default)]
10pub enum LogLevel {
11    /// Log nothing.
12    Off,
13    /// Log only errors.
14    Error,
15    /// Log errors and warnings (slow queries).
16    Warn,
17    /// Log all queries.
18    #[default]
19    Info,
20    /// Log queries with parameters.
21    Debug,
22    /// Log everything including internal details.
23    Trace,
24}
25
26
27/// Configuration for the logging middleware.
28#[derive(Debug, Clone)]
29pub struct LoggingConfig {
30    /// Minimum log level.
31    pub level: LogLevel,
32    /// Threshold for slow query warnings (microseconds).
33    pub slow_query_threshold_us: u64,
34    /// Whether to log query parameters.
35    pub log_params: bool,
36    /// Whether to log response data.
37    pub log_response: bool,
38    /// Maximum length of logged SQL (0 = unlimited).
39    pub max_sql_length: usize,
40    /// Prefix for log messages.
41    pub prefix: String,
42}
43
44impl Default for LoggingConfig {
45    fn default() -> Self {
46        Self {
47            level: LogLevel::Info,
48            slow_query_threshold_us: 1_000_000, // 1 second
49            log_params: false,
50            log_response: false,
51            max_sql_length: 500,
52            prefix: "prax".to_string(),
53        }
54    }
55}
56
57/// Middleware that logs queries.
58///
59/// # Example
60///
61/// ```rust,ignore
62/// use prax_query::middleware::{LoggingMiddleware, LogLevel};
63///
64/// let logging = LoggingMiddleware::new()
65///     .with_level(LogLevel::Debug)
66///     .with_params(true)
67///     .with_slow_threshold(500_000); // 500ms
68/// ```
69pub struct LoggingMiddleware {
70    config: LoggingConfig,
71    query_count: AtomicU64,
72}
73
74impl LoggingMiddleware {
75    /// Create a new logging middleware with default settings.
76    pub fn new() -> Self {
77        Self {
78            config: LoggingConfig::default(),
79            query_count: AtomicU64::new(0),
80        }
81    }
82
83    /// Create with custom configuration.
84    pub fn with_config(config: LoggingConfig) -> Self {
85        Self {
86            config,
87            query_count: AtomicU64::new(0),
88        }
89    }
90
91    /// Set the log level.
92    pub fn with_level(mut self, level: LogLevel) -> Self {
93        self.config.level = level;
94        self
95    }
96
97    /// Enable parameter logging.
98    pub fn with_params(mut self, enabled: bool) -> Self {
99        self.config.log_params = enabled;
100        self
101    }
102
103    /// Enable response logging.
104    pub fn with_response(mut self, enabled: bool) -> Self {
105        self.config.log_response = enabled;
106        self
107    }
108
109    /// Set slow query threshold in microseconds.
110    pub fn with_slow_threshold(mut self, threshold_us: u64) -> Self {
111        self.config.slow_query_threshold_us = threshold_us;
112        self
113    }
114
115    /// Set the log prefix.
116    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
117        self.config.prefix = prefix.into();
118        self
119    }
120
121    /// Get the total query count.
122    pub fn query_count(&self) -> u64 {
123        self.query_count.load(Ordering::Relaxed)
124    }
125
126    fn truncate_sql(&self, sql: &str) -> String {
127        if self.config.max_sql_length == 0 || sql.len() <= self.config.max_sql_length {
128            sql.to_string()
129        } else {
130            format!("{}...", &sql[..self.config.max_sql_length])
131        }
132    }
133
134    fn log_before(&self, ctx: &QueryContext, query_id: u64) {
135        if self.config.level < LogLevel::Debug {
136            return;
137        }
138
139        let sql = self.truncate_sql(ctx.sql());
140        let query_type = format!("{:?}", ctx.query_type());
141
142        if self.config.log_params && self.config.level >= LogLevel::Trace {
143            tracing::debug!(
144                target: "prax::query",
145                query_id = query_id,
146                query_type = %query_type,
147                sql = %sql,
148                params = ?ctx.params(),
149                model = ?ctx.metadata().model,
150                operation = ?ctx.metadata().operation,
151                request_id = ?ctx.metadata().request_id,
152                "[{}] Starting query",
153                self.config.prefix
154            );
155        } else {
156            tracing::debug!(
157                target: "prax::query",
158                query_id = query_id,
159                query_type = %query_type,
160                sql = %sql,
161                "[{}] Starting query",
162                self.config.prefix
163            );
164        }
165    }
166
167    fn log_after(&self, ctx: &QueryContext, response: &QueryResponse, query_id: u64) {
168        let duration_us = response.execution_time_us;
169        let is_slow = duration_us >= self.config.slow_query_threshold_us;
170
171        if is_slow && self.config.level >= LogLevel::Warn {
172            let sql = self.truncate_sql(ctx.sql());
173            tracing::warn!(
174                target: "prax::query",
175                query_id = query_id,
176                duration_us = duration_us,
177                duration_ms = duration_us / 1000,
178                sql = %sql,
179                threshold_us = self.config.slow_query_threshold_us,
180                "[{}] Slow query detected",
181                self.config.prefix
182            );
183        } else if self.config.level >= LogLevel::Info {
184            let sql = self.truncate_sql(ctx.sql());
185
186            if self.config.log_response && self.config.level >= LogLevel::Trace {
187                tracing::info!(
188                    target: "prax::query",
189                    query_id = query_id,
190                    duration_us = duration_us,
191                    rows_affected = ?response.rows_affected,
192                    from_cache = response.from_cache,
193                    sql = %sql,
194                    response = ?response.data,
195                    "[{}] Query completed",
196                    self.config.prefix
197                );
198            } else {
199                tracing::info!(
200                    target: "prax::query",
201                    query_id = query_id,
202                    duration_us = duration_us,
203                    rows_affected = ?response.rows_affected,
204                    from_cache = response.from_cache,
205                    "[{}] Query completed",
206                    self.config.prefix
207                );
208            }
209        }
210    }
211
212    fn log_error(&self, ctx: &QueryContext, error: &crate::QueryError, query_id: u64) {
213        if self.config.level >= LogLevel::Error {
214            let sql = self.truncate_sql(ctx.sql());
215            tracing::error!(
216                target: "prax::query",
217                query_id = query_id,
218                sql = %sql,
219                error = %error,
220                "[{}] Query failed",
221                self.config.prefix
222            );
223        }
224    }
225}
226
227impl Default for LoggingMiddleware {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233impl Middleware for LoggingMiddleware {
234    fn handle<'a>(
235        &'a self,
236        ctx: QueryContext,
237        next: Next<'a>,
238    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
239        Box::pin(async move {
240            let query_id = self.query_count.fetch_add(1, Ordering::SeqCst);
241
242            // Log before
243            self.log_before(&ctx, query_id);
244
245            // Execute query
246            let result = next.run(ctx.clone()).await;
247
248            // Log after
249            match &result {
250                Ok(response) => self.log_after(&ctx, response, query_id),
251                Err(error) => self.log_error(&ctx, error, query_id),
252            }
253
254            result
255        })
256    }
257
258    fn name(&self) -> &'static str {
259        "LoggingMiddleware"
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_log_level_ordering() {
269        assert!(LogLevel::Error < LogLevel::Warn);
270        assert!(LogLevel::Warn < LogLevel::Info);
271        assert!(LogLevel::Info < LogLevel::Debug);
272        assert!(LogLevel::Debug < LogLevel::Trace);
273    }
274
275    #[test]
276    fn test_logging_middleware_builder() {
277        let middleware = LoggingMiddleware::new()
278            .with_level(LogLevel::Debug)
279            .with_params(true)
280            .with_slow_threshold(500_000);
281
282        assert_eq!(middleware.config.level, LogLevel::Debug);
283        assert!(middleware.config.log_params);
284        assert_eq!(middleware.config.slow_query_threshold_us, 500_000);
285    }
286
287    #[test]
288    fn test_truncate_sql() {
289        let middleware = LoggingMiddleware::new();
290
291        let short = "SELECT * FROM users";
292        assert_eq!(middleware.truncate_sql(short), short);
293
294        let config = LoggingConfig {
295            max_sql_length: 10,
296            ..Default::default()
297        };
298        let middleware = LoggingMiddleware::with_config(config);
299        let long = "SELECT * FROM users WHERE id = 1";
300        assert!(middleware.truncate_sql(long).ends_with("..."));
301    }
302
303    #[test]
304    fn test_query_count() {
305        let middleware = LoggingMiddleware::new();
306        assert_eq!(middleware.query_count(), 0);
307    }
308}