1use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::sync::atomic::{AtomicU64, Ordering};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
9pub enum LogLevel {
10 Off,
12 Error,
14 Warn,
16 Info,
18 Debug,
20 Trace,
22}
23
24impl Default for LogLevel {
25 fn default() -> Self {
26 Self::Info
27 }
28}
29
30#[derive(Debug, Clone)]
32pub struct LoggingConfig {
33 pub level: LogLevel,
35 pub slow_query_threshold_us: u64,
37 pub log_params: bool,
39 pub log_response: bool,
41 pub max_sql_length: usize,
43 pub prefix: String,
45}
46
47impl Default for LoggingConfig {
48 fn default() -> Self {
49 Self {
50 level: LogLevel::Info,
51 slow_query_threshold_us: 1_000_000, log_params: false,
53 log_response: false,
54 max_sql_length: 500,
55 prefix: "prax".to_string(),
56 }
57 }
58}
59
60pub struct LoggingMiddleware {
73 config: LoggingConfig,
74 query_count: AtomicU64,
75}
76
77impl LoggingMiddleware {
78 pub fn new() -> Self {
80 Self {
81 config: LoggingConfig::default(),
82 query_count: AtomicU64::new(0),
83 }
84 }
85
86 pub fn with_config(config: LoggingConfig) -> Self {
88 Self {
89 config,
90 query_count: AtomicU64::new(0),
91 }
92 }
93
94 pub fn with_level(mut self, level: LogLevel) -> Self {
96 self.config.level = level;
97 self
98 }
99
100 pub fn with_params(mut self, enabled: bool) -> Self {
102 self.config.log_params = enabled;
103 self
104 }
105
106 pub fn with_response(mut self, enabled: bool) -> Self {
108 self.config.log_response = enabled;
109 self
110 }
111
112 pub fn with_slow_threshold(mut self, threshold_us: u64) -> Self {
114 self.config.slow_query_threshold_us = threshold_us;
115 self
116 }
117
118 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
120 self.config.prefix = prefix.into();
121 self
122 }
123
124 pub fn query_count(&self) -> u64 {
126 self.query_count.load(Ordering::Relaxed)
127 }
128
129 fn truncate_sql(&self, sql: &str) -> String {
130 if self.config.max_sql_length == 0 || sql.len() <= self.config.max_sql_length {
131 sql.to_string()
132 } else {
133 format!("{}...", &sql[..self.config.max_sql_length])
134 }
135 }
136
137 fn log_before(&self, ctx: &QueryContext, query_id: u64) {
138 if self.config.level < LogLevel::Debug {
139 return;
140 }
141
142 let sql = self.truncate_sql(ctx.sql());
143 let query_type = format!("{:?}", ctx.query_type());
144
145 if self.config.log_params && self.config.level >= LogLevel::Trace {
146 tracing::debug!(
147 target: "prax::query",
148 query_id = query_id,
149 query_type = %query_type,
150 sql = %sql,
151 params = ?ctx.params(),
152 model = ?ctx.metadata().model,
153 operation = ?ctx.metadata().operation,
154 request_id = ?ctx.metadata().request_id,
155 "[{}] Starting query",
156 self.config.prefix
157 );
158 } else {
159 tracing::debug!(
160 target: "prax::query",
161 query_id = query_id,
162 query_type = %query_type,
163 sql = %sql,
164 "[{}] Starting query",
165 self.config.prefix
166 );
167 }
168 }
169
170 fn log_after(&self, ctx: &QueryContext, response: &QueryResponse, query_id: u64) {
171 let duration_us = response.execution_time_us;
172 let is_slow = duration_us >= self.config.slow_query_threshold_us;
173
174 if is_slow && self.config.level >= LogLevel::Warn {
175 let sql = self.truncate_sql(ctx.sql());
176 tracing::warn!(
177 target: "prax::query",
178 query_id = query_id,
179 duration_us = duration_us,
180 duration_ms = duration_us / 1000,
181 sql = %sql,
182 threshold_us = self.config.slow_query_threshold_us,
183 "[{}] Slow query detected",
184 self.config.prefix
185 );
186 } else if self.config.level >= LogLevel::Info {
187 let sql = self.truncate_sql(ctx.sql());
188
189 if self.config.log_response && self.config.level >= LogLevel::Trace {
190 tracing::info!(
191 target: "prax::query",
192 query_id = query_id,
193 duration_us = duration_us,
194 rows_affected = ?response.rows_affected,
195 from_cache = response.from_cache,
196 sql = %sql,
197 response = ?response.data,
198 "[{}] Query completed",
199 self.config.prefix
200 );
201 } else {
202 tracing::info!(
203 target: "prax::query",
204 query_id = query_id,
205 duration_us = duration_us,
206 rows_affected = ?response.rows_affected,
207 from_cache = response.from_cache,
208 "[{}] Query completed",
209 self.config.prefix
210 );
211 }
212 }
213 }
214
215 fn log_error(&self, ctx: &QueryContext, error: &crate::QueryError, query_id: u64) {
216 if self.config.level >= LogLevel::Error {
217 let sql = self.truncate_sql(ctx.sql());
218 tracing::error!(
219 target: "prax::query",
220 query_id = query_id,
221 sql = %sql,
222 error = %error,
223 "[{}] Query failed",
224 self.config.prefix
225 );
226 }
227 }
228}
229
230impl Default for LoggingMiddleware {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236impl Middleware for LoggingMiddleware {
237 fn handle<'a>(
238 &'a self,
239 ctx: QueryContext,
240 next: Next<'a>,
241 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
242 Box::pin(async move {
243 let query_id = self.query_count.fetch_add(1, Ordering::SeqCst);
244
245 self.log_before(&ctx, query_id);
247
248 let result = next.run(ctx.clone()).await;
250
251 match &result {
253 Ok(response) => self.log_after(&ctx, response, query_id),
254 Err(error) => self.log_error(&ctx, error, query_id),
255 }
256
257 result
258 })
259 }
260
261 fn name(&self) -> &'static str {
262 "LoggingMiddleware"
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_log_level_ordering() {
272 assert!(LogLevel::Error < LogLevel::Warn);
273 assert!(LogLevel::Warn < LogLevel::Info);
274 assert!(LogLevel::Info < LogLevel::Debug);
275 assert!(LogLevel::Debug < LogLevel::Trace);
276 }
277
278 #[test]
279 fn test_logging_middleware_builder() {
280 let middleware = LoggingMiddleware::new()
281 .with_level(LogLevel::Debug)
282 .with_params(true)
283 .with_slow_threshold(500_000);
284
285 assert_eq!(middleware.config.level, LogLevel::Debug);
286 assert!(middleware.config.log_params);
287 assert_eq!(middleware.config.slow_query_threshold_us, 500_000);
288 }
289
290 #[test]
291 fn test_truncate_sql() {
292 let middleware = LoggingMiddleware::new();
293
294 let short = "SELECT * FROM users";
295 assert_eq!(middleware.truncate_sql(short), short);
296
297 let config = LoggingConfig {
298 max_sql_length: 10,
299 ..Default::default()
300 };
301 let middleware = LoggingMiddleware::with_config(config);
302 let long = "SELECT * FROM users WHERE id = 1";
303 assert!(middleware.truncate_sql(long).ends_with("..."));
304 }
305
306 #[test]
307 fn test_query_count() {
308 let middleware = LoggingMiddleware::new();
309 assert_eq!(middleware.query_count(), 0);
310 }
311}