1use parking_lot::RwLock;
10use std::sync::Arc;
11use std::time::Instant;
12use tracing::info_span;
13
14#[derive(Debug, Clone)]
16pub struct DbTraceConfig {
17 pub service_name: String,
19 pub include_statements: bool,
21 pub sanitize_params: bool,
23 pub max_statement_length: usize,
25 pub slow_query_threshold_ms: Option<u64>,
27}
28
29impl Default for DbTraceConfig {
30 fn default() -> Self {
31 Self {
32 service_name: "kaccy-db".to_string(),
33 include_statements: true,
34 sanitize_params: true,
35 max_statement_length: 1000,
36 slow_query_threshold_ms: Some(100),
37 }
38 }
39}
40
41pub struct DbTracer {
43 config: DbTraceConfig,
44}
45
46impl DbTracer {
47 pub fn new(config: DbTraceConfig) -> Self {
49 Self { config }
50 }
51
52 pub fn trace_query<F, R>(&self, operation: &str, table: &str, sql: &str, f: F) -> R
54 where
55 F: FnOnce() -> R,
56 {
57 let start = Instant::now();
58
59 let span = info_span!(
60 "db.query",
61 db.system = "postgresql",
62 db.operation = operation,
63 db.sql.table = table,
64 );
65
66 let _enter = span.enter();
67
68 if self.config.include_statements {
69 let sanitized = self.sanitize_sql(sql);
70 tracing::info!(db.statement = %sanitized, "Executing database query");
71 }
72
73 let result = f();
74
75 let duration_ms = start.elapsed().as_millis() as u64;
76
77 if let Some(threshold) = self.config.slow_query_threshold_ms {
78 if duration_ms > threshold {
79 tracing::warn!(duration_ms, operation, table, "Slow query detected");
80 }
81 }
82
83 tracing::info!(duration_ms, "Query completed");
84
85 result
86 }
87
88 pub fn sanitize_sql(&self, sql: &str) -> String {
90 if !self.config.sanitize_params {
91 return self.truncate_sql(sql);
92 }
93
94 let sanitized = sql
96 .replace(['\'', '"'], "?")
97 .lines()
98 .map(|line| line.trim())
99 .collect::<Vec<_>>()
100 .join(" ");
101
102 self.truncate_sql(&sanitized)
103 }
104
105 fn truncate_sql(&self, sql: &str) -> String {
107 if sql.len() <= self.config.max_statement_length {
108 sql.to_string()
109 } else {
110 format!("{}...", &sql[..self.config.max_statement_length])
111 }
112 }
113}
114
115pub trait TraceableOperation {
117 fn with_trace<F, R>(&self, operation: &str, table: &str, sql: &str, f: F) -> R
119 where
120 F: FnOnce() -> R;
121}
122
123impl TraceableOperation for DbTracer {
124 fn with_trace<F, R>(&self, operation: &str, table: &str, sql: &str, f: F) -> R
125 where
126 F: FnOnce() -> R,
127 {
128 self.trace_query(operation, table, sql, f)
129 }
130}
131
132#[derive(Debug, Clone)]
134pub struct SpanContext {
135 pub trace_id: String,
137 pub span_id: String,
139 pub parent_span_id: Option<String>,
141 pub trace_flags: u8,
143}
144
145impl SpanContext {
146 pub fn new(trace_id: String, span_id: String) -> Self {
148 Self {
149 trace_id,
150 span_id,
151 parent_span_id: None,
152 trace_flags: 1, }
154 }
155
156 pub fn with_parent(mut self, parent_span_id: String) -> Self {
158 self.parent_span_id = Some(parent_span_id);
159 self
160 }
161
162 pub fn to_traceparent(&self) -> String {
164 let parent = self.parent_span_id.as_deref().unwrap_or(&self.span_id);
165 format!("00-{}-{}-{:02x}", self.trace_id, parent, self.trace_flags)
166 }
167
168 pub fn from_traceparent(traceparent: &str) -> Option<Self> {
170 let parts: Vec<&str> = traceparent.split('-').collect();
171 if parts.len() != 4 || parts[0] != "00" {
172 return None;
173 }
174
175 Some(Self {
176 trace_id: parts[1].to_string(),
177 span_id: parts[2].to_string(),
178 parent_span_id: None,
179 trace_flags: u8::from_str_radix(parts[3], 16).ok()?,
180 })
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct QueryMetadata {
187 pub operation: String,
189 pub table: String,
191 pub rows_affected: Option<u64>,
193 pub duration_ms: u64,
195 pub success: bool,
197 pub error: Option<String>,
199}
200
201impl QueryMetadata {
202 pub fn success(operation: String, table: String, duration_ms: u64) -> Self {
204 Self {
205 operation,
206 table,
207 rows_affected: None,
208 duration_ms,
209 success: true,
210 error: None,
211 }
212 }
213
214 pub fn failure(operation: String, table: String, duration_ms: u64, error: String) -> Self {
216 Self {
217 operation,
218 table,
219 rows_affected: None,
220 duration_ms,
221 success: false,
222 error: Some(error),
223 }
224 }
225
226 pub fn with_rows_affected(mut self, rows: u64) -> Self {
228 self.rows_affected = Some(rows);
229 self
230 }
231}
232
233static TRACER: RwLock<Option<Arc<DbTracer>>> = RwLock::new(None);
235
236pub fn init_tracer(config: DbTraceConfig) {
238 *TRACER.write() = Some(Arc::new(DbTracer::new(config)));
239}
240
241pub fn get_tracer() -> Option<Arc<DbTracer>> {
243 TRACER.read().as_ref().map(Arc::clone)
244}
245
246pub fn trace_db_operation<F, R>(operation: &str, table: &str, sql: &str, f: F) -> R
248where
249 F: FnOnce() -> R,
250{
251 if let Some(tracer) = get_tracer() {
252 tracer.trace_query(operation, table, sql, f)
253 } else {
254 f()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_db_trace_config_default() {
264 let config = DbTraceConfig::default();
265 assert_eq!(config.service_name, "kaccy-db");
266 assert!(config.include_statements);
267 assert!(config.sanitize_params);
268 assert_eq!(config.max_statement_length, 1000);
269 assert_eq!(config.slow_query_threshold_ms, Some(100));
270 }
271
272 #[test]
273 fn test_sanitize_sql() {
274 let config = DbTraceConfig::default();
275 let tracer = DbTracer::new(config);
276
277 let sql = "SELECT * FROM users WHERE email = 'test@example.com'";
278 let sanitized = tracer.sanitize_sql(sql);
279
280 assert!(sanitized.len() <= 1000 + 3); }
284
285 #[test]
286 fn test_truncate_long_sql() {
287 let config = DbTraceConfig {
288 max_statement_length: 50,
289 ..Default::default()
290 };
291 let tracer = DbTracer::new(config);
292
293 let long_sql = "SELECT * FROM users WHERE id = 1 AND name = 'test' AND email = 'test@example.com' AND age > 18";
294 let truncated = tracer.sanitize_sql(long_sql);
295
296 assert!(truncated.len() <= 53); assert!(truncated.ends_with("..."));
298 }
299
300 #[test]
301 fn test_span_context_to_traceparent() {
302 let context = SpanContext::new(
303 "0af7651916cd43dd8448eb211c80319c".to_string(),
304 "b7ad6b7169203331".to_string(),
305 );
306
307 let traceparent = context.to_traceparent();
308 assert_eq!(
309 traceparent,
310 "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"
311 );
312 }
313
314 #[test]
315 fn test_span_context_from_traceparent() {
316 let traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
317 let context = SpanContext::from_traceparent(traceparent).unwrap();
318
319 assert_eq!(context.trace_id, "0af7651916cd43dd8448eb211c80319c");
320 assert_eq!(context.span_id, "b7ad6b7169203331");
321 assert_eq!(context.trace_flags, 1);
322 }
323
324 #[test]
325 fn test_span_context_with_parent() {
326 let context = SpanContext::new("trace123".to_string(), "span456".to_string())
327 .with_parent("parent789".to_string());
328
329 assert_eq!(context.parent_span_id, Some("parent789".to_string()));
330 }
331
332 #[test]
333 fn test_query_metadata_success() {
334 let metadata = QueryMetadata::success("SELECT".to_string(), "users".to_string(), 50);
335
336 assert!(metadata.success);
337 assert_eq!(metadata.operation, "SELECT");
338 assert_eq!(metadata.table, "users");
339 assert_eq!(metadata.duration_ms, 50);
340 assert!(metadata.error.is_none());
341 }
342
343 #[test]
344 fn test_query_metadata_failure() {
345 let metadata = QueryMetadata::failure(
346 "INSERT".to_string(),
347 "users".to_string(),
348 100,
349 "Constraint violation".to_string(),
350 );
351
352 assert!(!metadata.success);
353 assert_eq!(metadata.error, Some("Constraint violation".to_string()));
354 }
355
356 #[test]
357 fn test_query_metadata_with_rows_affected() {
358 let metadata = QueryMetadata::success("UPDATE".to_string(), "users".to_string(), 75)
359 .with_rows_affected(5);
360
361 assert_eq!(metadata.rows_affected, Some(5));
362 }
363
364 #[test]
365 fn test_trace_db_operation_without_init() {
366 let result = trace_db_operation("SELECT", "users", "SELECT * FROM users", || 42);
368 assert_eq!(result, 42);
369 }
370
371 #[test]
372 fn test_tracer_with_trace() {
373 let config = DbTraceConfig::default();
374 let tracer = DbTracer::new(config);
375
376 let result = tracer.with_trace("SELECT", "users", "SELECT * FROM users", || "success");
377 assert_eq!(result, "success");
378 }
379}