1use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::sync::atomic::{AtomicU64, Ordering};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
9pub enum LogLevel {
10 Off,
12 Error,
14 Warn,
16 #[default]
18 Info,
19 Debug,
21 Trace,
23}
24
25#[derive(Debug, Clone)]
27pub struct LoggingConfig {
28 pub level: LogLevel,
30 pub slow_query_threshold_us: u64,
32 pub log_params: bool,
34 pub log_response: bool,
36 pub max_sql_length: usize,
38 pub prefix: String,
40}
41
42impl Default for LoggingConfig {
43 fn default() -> Self {
44 Self {
45 level: LogLevel::Info,
46 slow_query_threshold_us: 1_000_000, log_params: false,
48 log_response: false,
49 max_sql_length: 500,
50 prefix: "prax".to_string(),
51 }
52 }
53}
54
55pub struct LoggingMiddleware {
68 config: LoggingConfig,
69 query_count: AtomicU64,
70}
71
72impl LoggingMiddleware {
73 pub fn new() -> Self {
75 Self {
76 config: LoggingConfig::default(),
77 query_count: AtomicU64::new(0),
78 }
79 }
80
81 pub fn with_config(config: LoggingConfig) -> Self {
83 Self {
84 config,
85 query_count: AtomicU64::new(0),
86 }
87 }
88
89 pub fn with_level(mut self, level: LogLevel) -> Self {
91 self.config.level = level;
92 self
93 }
94
95 pub fn with_params(mut self, enabled: bool) -> Self {
97 self.config.log_params = enabled;
98 self
99 }
100
101 pub fn with_response(mut self, enabled: bool) -> Self {
103 self.config.log_response = enabled;
104 self
105 }
106
107 pub fn with_slow_threshold(mut self, threshold_us: u64) -> Self {
109 self.config.slow_query_threshold_us = threshold_us;
110 self
111 }
112
113 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
115 self.config.prefix = prefix.into();
116 self
117 }
118
119 pub fn query_count(&self) -> u64 {
121 self.query_count.load(Ordering::Relaxed)
122 }
123
124 fn truncate_sql(&self, sql: &str) -> String {
125 if self.config.max_sql_length == 0 || sql.len() <= self.config.max_sql_length {
126 sql.to_string()
127 } else {
128 format!("{}...", &sql[..self.config.max_sql_length])
129 }
130 }
131
132 fn log_before(&self, ctx: &QueryContext, query_id: u64) {
133 if self.config.level < LogLevel::Debug {
134 return;
135 }
136
137 let sql = self.truncate_sql(ctx.sql());
138 let query_type = format!("{:?}", ctx.query_type());
139
140 if self.config.log_params && self.config.level >= LogLevel::Trace {
141 tracing::debug!(
142 target: "prax::query",
143 query_id = query_id,
144 query_type = %query_type,
145 sql = %sql,
146 params = ?ctx.params(),
147 model = ?ctx.metadata().model,
148 operation = ?ctx.metadata().operation,
149 request_id = ?ctx.metadata().request_id,
150 "[{}] Starting query",
151 self.config.prefix
152 );
153 } else {
154 tracing::debug!(
155 target: "prax::query",
156 query_id = query_id,
157 query_type = %query_type,
158 sql = %sql,
159 "[{}] Starting query",
160 self.config.prefix
161 );
162 }
163 }
164
165 fn log_after(&self, ctx: &QueryContext, response: &QueryResponse, query_id: u64) {
166 let duration_us = response.execution_time_us;
167 let is_slow = duration_us >= self.config.slow_query_threshold_us;
168
169 if is_slow && self.config.level >= LogLevel::Warn {
170 let sql = self.truncate_sql(ctx.sql());
171 tracing::warn!(
172 target: "prax::query",
173 query_id = query_id,
174 duration_us = duration_us,
175 duration_ms = duration_us / 1000,
176 sql = %sql,
177 threshold_us = self.config.slow_query_threshold_us,
178 "[{}] Slow query detected",
179 self.config.prefix
180 );
181 } else if self.config.level >= LogLevel::Info {
182 let sql = self.truncate_sql(ctx.sql());
183
184 if self.config.log_response && self.config.level >= LogLevel::Trace {
185 tracing::info!(
186 target: "prax::query",
187 query_id = query_id,
188 duration_us = duration_us,
189 rows_affected = ?response.rows_affected,
190 from_cache = response.from_cache,
191 sql = %sql,
192 response = ?response.data,
193 "[{}] Query completed",
194 self.config.prefix
195 );
196 } else {
197 tracing::info!(
198 target: "prax::query",
199 query_id = query_id,
200 duration_us = duration_us,
201 rows_affected = ?response.rows_affected,
202 from_cache = response.from_cache,
203 "[{}] Query completed",
204 self.config.prefix
205 );
206 }
207 }
208 }
209
210 fn log_error(&self, ctx: &QueryContext, error: &crate::QueryError, query_id: u64) {
211 if self.config.level >= LogLevel::Error {
212 let sql = self.truncate_sql(ctx.sql());
213 tracing::error!(
214 target: "prax::query",
215 query_id = query_id,
216 sql = %sql,
217 error = %error,
218 "[{}] Query failed",
219 self.config.prefix
220 );
221 }
222 }
223}
224
225impl Default for LoggingMiddleware {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231impl Middleware for LoggingMiddleware {
232 fn handle<'a>(
233 &'a self,
234 ctx: QueryContext,
235 next: Next<'a>,
236 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
237 Box::pin(async move {
238 let query_id = self.query_count.fetch_add(1, Ordering::SeqCst);
239
240 self.log_before(&ctx, query_id);
242
243 let result = next.run(ctx.clone()).await;
245
246 match &result {
248 Ok(response) => self.log_after(&ctx, response, query_id),
249 Err(error) => self.log_error(&ctx, error, query_id),
250 }
251
252 result
253 })
254 }
255
256 fn name(&self) -> &'static str {
257 "LoggingMiddleware"
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_log_level_ordering() {
267 assert!(LogLevel::Error < LogLevel::Warn);
268 assert!(LogLevel::Warn < LogLevel::Info);
269 assert!(LogLevel::Info < LogLevel::Debug);
270 assert!(LogLevel::Debug < LogLevel::Trace);
271 }
272
273 #[test]
274 fn test_logging_middleware_builder() {
275 let middleware = LoggingMiddleware::new()
276 .with_level(LogLevel::Debug)
277 .with_params(true)
278 .with_slow_threshold(500_000);
279
280 assert_eq!(middleware.config.level, LogLevel::Debug);
281 assert!(middleware.config.log_params);
282 assert_eq!(middleware.config.slow_query_threshold_us, 500_000);
283 }
284
285 #[test]
286 fn test_truncate_sql() {
287 let middleware = LoggingMiddleware::new();
288
289 let short = "SELECT * FROM users";
290 assert_eq!(middleware.truncate_sql(short), short);
291
292 let config = LoggingConfig {
293 max_sql_length: 10,
294 ..Default::default()
295 };
296 let middleware = LoggingMiddleware::with_config(config);
297 let long = "SELECT * FROM users WHERE id = 1";
298 assert!(middleware.truncate_sql(long).ends_with("..."));
299 }
300
301 #[test]
302 fn test_query_count() {
303 let middleware = LoggingMiddleware::new();
304 assert_eq!(middleware.query_count(), 0);
305 }
306}