1use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::sync::atomic::{AtomicU64, Ordering};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
9#[derive(Default)]
10pub enum LogLevel {
11 Off,
13 Error,
15 Warn,
17 #[default]
19 Info,
20 Debug,
22 Trace,
24}
25
26
27#[derive(Debug, Clone)]
29pub struct LoggingConfig {
30 pub level: LogLevel,
32 pub slow_query_threshold_us: u64,
34 pub log_params: bool,
36 pub log_response: bool,
38 pub max_sql_length: usize,
40 pub prefix: String,
42}
43
44impl Default for LoggingConfig {
45 fn default() -> Self {
46 Self {
47 level: LogLevel::Info,
48 slow_query_threshold_us: 1_000_000, log_params: false,
50 log_response: false,
51 max_sql_length: 500,
52 prefix: "prax".to_string(),
53 }
54 }
55}
56
57pub struct LoggingMiddleware {
70 config: LoggingConfig,
71 query_count: AtomicU64,
72}
73
74impl LoggingMiddleware {
75 pub fn new() -> Self {
77 Self {
78 config: LoggingConfig::default(),
79 query_count: AtomicU64::new(0),
80 }
81 }
82
83 pub fn with_config(config: LoggingConfig) -> Self {
85 Self {
86 config,
87 query_count: AtomicU64::new(0),
88 }
89 }
90
91 pub fn with_level(mut self, level: LogLevel) -> Self {
93 self.config.level = level;
94 self
95 }
96
97 pub fn with_params(mut self, enabled: bool) -> Self {
99 self.config.log_params = enabled;
100 self
101 }
102
103 pub fn with_response(mut self, enabled: bool) -> Self {
105 self.config.log_response = enabled;
106 self
107 }
108
109 pub fn with_slow_threshold(mut self, threshold_us: u64) -> Self {
111 self.config.slow_query_threshold_us = threshold_us;
112 self
113 }
114
115 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
117 self.config.prefix = prefix.into();
118 self
119 }
120
121 pub fn query_count(&self) -> u64 {
123 self.query_count.load(Ordering::Relaxed)
124 }
125
126 fn truncate_sql(&self, sql: &str) -> String {
127 if self.config.max_sql_length == 0 || sql.len() <= self.config.max_sql_length {
128 sql.to_string()
129 } else {
130 format!("{}...", &sql[..self.config.max_sql_length])
131 }
132 }
133
134 fn log_before(&self, ctx: &QueryContext, query_id: u64) {
135 if self.config.level < LogLevel::Debug {
136 return;
137 }
138
139 let sql = self.truncate_sql(ctx.sql());
140 let query_type = format!("{:?}", ctx.query_type());
141
142 if self.config.log_params && self.config.level >= LogLevel::Trace {
143 tracing::debug!(
144 target: "prax::query",
145 query_id = query_id,
146 query_type = %query_type,
147 sql = %sql,
148 params = ?ctx.params(),
149 model = ?ctx.metadata().model,
150 operation = ?ctx.metadata().operation,
151 request_id = ?ctx.metadata().request_id,
152 "[{}] Starting query",
153 self.config.prefix
154 );
155 } else {
156 tracing::debug!(
157 target: "prax::query",
158 query_id = query_id,
159 query_type = %query_type,
160 sql = %sql,
161 "[{}] Starting query",
162 self.config.prefix
163 );
164 }
165 }
166
167 fn log_after(&self, ctx: &QueryContext, response: &QueryResponse, query_id: u64) {
168 let duration_us = response.execution_time_us;
169 let is_slow = duration_us >= self.config.slow_query_threshold_us;
170
171 if is_slow && self.config.level >= LogLevel::Warn {
172 let sql = self.truncate_sql(ctx.sql());
173 tracing::warn!(
174 target: "prax::query",
175 query_id = query_id,
176 duration_us = duration_us,
177 duration_ms = duration_us / 1000,
178 sql = %sql,
179 threshold_us = self.config.slow_query_threshold_us,
180 "[{}] Slow query detected",
181 self.config.prefix
182 );
183 } else if self.config.level >= LogLevel::Info {
184 let sql = self.truncate_sql(ctx.sql());
185
186 if self.config.log_response && self.config.level >= LogLevel::Trace {
187 tracing::info!(
188 target: "prax::query",
189 query_id = query_id,
190 duration_us = duration_us,
191 rows_affected = ?response.rows_affected,
192 from_cache = response.from_cache,
193 sql = %sql,
194 response = ?response.data,
195 "[{}] Query completed",
196 self.config.prefix
197 );
198 } else {
199 tracing::info!(
200 target: "prax::query",
201 query_id = query_id,
202 duration_us = duration_us,
203 rows_affected = ?response.rows_affected,
204 from_cache = response.from_cache,
205 "[{}] Query completed",
206 self.config.prefix
207 );
208 }
209 }
210 }
211
212 fn log_error(&self, ctx: &QueryContext, error: &crate::QueryError, query_id: u64) {
213 if self.config.level >= LogLevel::Error {
214 let sql = self.truncate_sql(ctx.sql());
215 tracing::error!(
216 target: "prax::query",
217 query_id = query_id,
218 sql = %sql,
219 error = %error,
220 "[{}] Query failed",
221 self.config.prefix
222 );
223 }
224 }
225}
226
227impl Default for LoggingMiddleware {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233impl Middleware for LoggingMiddleware {
234 fn handle<'a>(
235 &'a self,
236 ctx: QueryContext,
237 next: Next<'a>,
238 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
239 Box::pin(async move {
240 let query_id = self.query_count.fetch_add(1, Ordering::SeqCst);
241
242 self.log_before(&ctx, query_id);
244
245 let result = next.run(ctx.clone()).await;
247
248 match &result {
250 Ok(response) => self.log_after(&ctx, response, query_id),
251 Err(error) => self.log_error(&ctx, error, query_id),
252 }
253
254 result
255 })
256 }
257
258 fn name(&self) -> &'static str {
259 "LoggingMiddleware"
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_log_level_ordering() {
269 assert!(LogLevel::Error < LogLevel::Warn);
270 assert!(LogLevel::Warn < LogLevel::Info);
271 assert!(LogLevel::Info < LogLevel::Debug);
272 assert!(LogLevel::Debug < LogLevel::Trace);
273 }
274
275 #[test]
276 fn test_logging_middleware_builder() {
277 let middleware = LoggingMiddleware::new()
278 .with_level(LogLevel::Debug)
279 .with_params(true)
280 .with_slow_threshold(500_000);
281
282 assert_eq!(middleware.config.level, LogLevel::Debug);
283 assert!(middleware.config.log_params);
284 assert_eq!(middleware.config.slow_query_threshold_us, 500_000);
285 }
286
287 #[test]
288 fn test_truncate_sql() {
289 let middleware = LoggingMiddleware::new();
290
291 let short = "SELECT * FROM users";
292 assert_eq!(middleware.truncate_sql(short), short);
293
294 let config = LoggingConfig {
295 max_sql_length: 10,
296 ..Default::default()
297 };
298 let middleware = LoggingMiddleware::with_config(config);
299 let long = "SELECT * FROM users WHERE id = 1";
300 assert!(middleware.truncate_sql(long).ends_with("..."));
301 }
302
303 #[test]
304 fn test_query_count() {
305 let middleware = LoggingMiddleware::new();
306 assert_eq!(middleware.query_count(), 0);
307 }
308}