prax_query/middleware/
logging.rs

1//! Logging middleware for query tracing.
2
3use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::sync::atomic::{AtomicU64, Ordering};
6
7/// Log level for query logging.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
9pub enum LogLevel {
10    /// Log nothing.
11    Off,
12    /// Log only errors.
13    Error,
14    /// Log errors and warnings (slow queries).
15    Warn,
16    /// Log all queries.
17    Info,
18    /// Log queries with parameters.
19    Debug,
20    /// Log everything including internal details.
21    Trace,
22}
23
24impl Default for LogLevel {
25    fn default() -> Self {
26        Self::Info
27    }
28}
29
30/// Configuration for the logging middleware.
31#[derive(Debug, Clone)]
32pub struct LoggingConfig {
33    /// Minimum log level.
34    pub level: LogLevel,
35    /// Threshold for slow query warnings (microseconds).
36    pub slow_query_threshold_us: u64,
37    /// Whether to log query parameters.
38    pub log_params: bool,
39    /// Whether to log response data.
40    pub log_response: bool,
41    /// Maximum length of logged SQL (0 = unlimited).
42    pub max_sql_length: usize,
43    /// Prefix for log messages.
44    pub prefix: String,
45}
46
47impl Default for LoggingConfig {
48    fn default() -> Self {
49        Self {
50            level: LogLevel::Info,
51            slow_query_threshold_us: 1_000_000, // 1 second
52            log_params: false,
53            log_response: false,
54            max_sql_length: 500,
55            prefix: "prax".to_string(),
56        }
57    }
58}
59
60/// Middleware that logs queries.
61///
62/// # Example
63///
64/// ```rust,ignore
65/// use prax_query::middleware::{LoggingMiddleware, LogLevel};
66///
67/// let logging = LoggingMiddleware::new()
68///     .with_level(LogLevel::Debug)
69///     .with_params(true)
70///     .with_slow_threshold(500_000); // 500ms
71/// ```
72pub struct LoggingMiddleware {
73    config: LoggingConfig,
74    query_count: AtomicU64,
75}
76
77impl LoggingMiddleware {
78    /// Create a new logging middleware with default settings.
79    pub fn new() -> Self {
80        Self {
81            config: LoggingConfig::default(),
82            query_count: AtomicU64::new(0),
83        }
84    }
85
86    /// Create with custom configuration.
87    pub fn with_config(config: LoggingConfig) -> Self {
88        Self {
89            config,
90            query_count: AtomicU64::new(0),
91        }
92    }
93
94    /// Set the log level.
95    pub fn with_level(mut self, level: LogLevel) -> Self {
96        self.config.level = level;
97        self
98    }
99
100    /// Enable parameter logging.
101    pub fn with_params(mut self, enabled: bool) -> Self {
102        self.config.log_params = enabled;
103        self
104    }
105
106    /// Enable response logging.
107    pub fn with_response(mut self, enabled: bool) -> Self {
108        self.config.log_response = enabled;
109        self
110    }
111
112    /// Set slow query threshold in microseconds.
113    pub fn with_slow_threshold(mut self, threshold_us: u64) -> Self {
114        self.config.slow_query_threshold_us = threshold_us;
115        self
116    }
117
118    /// Set the log prefix.
119    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
120        self.config.prefix = prefix.into();
121        self
122    }
123
124    /// Get the total query count.
125    pub fn query_count(&self) -> u64 {
126        self.query_count.load(Ordering::Relaxed)
127    }
128
129    fn truncate_sql(&self, sql: &str) -> String {
130        if self.config.max_sql_length == 0 || sql.len() <= self.config.max_sql_length {
131            sql.to_string()
132        } else {
133            format!("{}...", &sql[..self.config.max_sql_length])
134        }
135    }
136
137    fn log_before(&self, ctx: &QueryContext, query_id: u64) {
138        if self.config.level < LogLevel::Debug {
139            return;
140        }
141
142        let sql = self.truncate_sql(ctx.sql());
143        let query_type = format!("{:?}", ctx.query_type());
144
145        if self.config.log_params && self.config.level >= LogLevel::Trace {
146            tracing::debug!(
147                target: "prax::query",
148                query_id = query_id,
149                query_type = %query_type,
150                sql = %sql,
151                params = ?ctx.params(),
152                model = ?ctx.metadata().model,
153                operation = ?ctx.metadata().operation,
154                request_id = ?ctx.metadata().request_id,
155                "[{}] Starting query",
156                self.config.prefix
157            );
158        } else {
159            tracing::debug!(
160                target: "prax::query",
161                query_id = query_id,
162                query_type = %query_type,
163                sql = %sql,
164                "[{}] Starting query",
165                self.config.prefix
166            );
167        }
168    }
169
170    fn log_after(&self, ctx: &QueryContext, response: &QueryResponse, query_id: u64) {
171        let duration_us = response.execution_time_us;
172        let is_slow = duration_us >= self.config.slow_query_threshold_us;
173
174        if is_slow && self.config.level >= LogLevel::Warn {
175            let sql = self.truncate_sql(ctx.sql());
176            tracing::warn!(
177                target: "prax::query",
178                query_id = query_id,
179                duration_us = duration_us,
180                duration_ms = duration_us / 1000,
181                sql = %sql,
182                threshold_us = self.config.slow_query_threshold_us,
183                "[{}] Slow query detected",
184                self.config.prefix
185            );
186        } else if self.config.level >= LogLevel::Info {
187            let sql = self.truncate_sql(ctx.sql());
188
189            if self.config.log_response && self.config.level >= LogLevel::Trace {
190                tracing::info!(
191                    target: "prax::query",
192                    query_id = query_id,
193                    duration_us = duration_us,
194                    rows_affected = ?response.rows_affected,
195                    from_cache = response.from_cache,
196                    sql = %sql,
197                    response = ?response.data,
198                    "[{}] Query completed",
199                    self.config.prefix
200                );
201            } else {
202                tracing::info!(
203                    target: "prax::query",
204                    query_id = query_id,
205                    duration_us = duration_us,
206                    rows_affected = ?response.rows_affected,
207                    from_cache = response.from_cache,
208                    "[{}] Query completed",
209                    self.config.prefix
210                );
211            }
212        }
213    }
214
215    fn log_error(&self, ctx: &QueryContext, error: &crate::QueryError, query_id: u64) {
216        if self.config.level >= LogLevel::Error {
217            let sql = self.truncate_sql(ctx.sql());
218            tracing::error!(
219                target: "prax::query",
220                query_id = query_id,
221                sql = %sql,
222                error = %error,
223                "[{}] Query failed",
224                self.config.prefix
225            );
226        }
227    }
228}
229
230impl Default for LoggingMiddleware {
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236impl Middleware for LoggingMiddleware {
237    fn handle<'a>(
238        &'a self,
239        ctx: QueryContext,
240        next: Next<'a>,
241    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
242        Box::pin(async move {
243            let query_id = self.query_count.fetch_add(1, Ordering::SeqCst);
244
245            // Log before
246            self.log_before(&ctx, query_id);
247
248            // Execute query
249            let result = next.run(ctx.clone()).await;
250
251            // Log after
252            match &result {
253                Ok(response) => self.log_after(&ctx, response, query_id),
254                Err(error) => self.log_error(&ctx, error, query_id),
255            }
256
257            result
258        })
259    }
260
261    fn name(&self) -> &'static str {
262        "LoggingMiddleware"
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_log_level_ordering() {
272        assert!(LogLevel::Error < LogLevel::Warn);
273        assert!(LogLevel::Warn < LogLevel::Info);
274        assert!(LogLevel::Info < LogLevel::Debug);
275        assert!(LogLevel::Debug < LogLevel::Trace);
276    }
277
278    #[test]
279    fn test_logging_middleware_builder() {
280        let middleware = LoggingMiddleware::new()
281            .with_level(LogLevel::Debug)
282            .with_params(true)
283            .with_slow_threshold(500_000);
284
285        assert_eq!(middleware.config.level, LogLevel::Debug);
286        assert!(middleware.config.log_params);
287        assert_eq!(middleware.config.slow_query_threshold_us, 500_000);
288    }
289
290    #[test]
291    fn test_truncate_sql() {
292        let middleware = LoggingMiddleware::new();
293
294        let short = "SELECT * FROM users";
295        assert_eq!(middleware.truncate_sql(short), short);
296
297        let config = LoggingConfig {
298            max_sql_length: 10,
299            ..Default::default()
300        };
301        let middleware = LoggingMiddleware::with_config(config);
302        let long = "SELECT * FROM users WHERE id = 1";
303        assert!(middleware.truncate_sql(long).ends_with("..."));
304    }
305
306    #[test]
307    fn test_query_count() {
308        let middleware = LoggingMiddleware::new();
309        assert_eq!(middleware.query_count(), 0);
310    }
311}