Skip to main content

mssql_client/
instrumentation.rs

1//! OpenTelemetry instrumentation for database operations.
2//!
3//! This module provides first-class OpenTelemetry tracing support when the
4//! `otel` feature is enabled. It follows the OpenTelemetry semantic conventions
5//! for database operations.
6//!
7//! ## Features
8//!
9//! When the `otel` feature is enabled, the following instrumentation is available:
10//!
11//! - **Connection spans**: Track connection establishment time and success/failure
12//! - **Query spans**: Track SQL execution with sanitized statement attributes
13//! - **Transaction spans**: Track transaction boundaries (begin, commit, rollback)
14//! - **Error events**: Record errors with appropriate attributes
15//!
16//! ## Usage
17//!
18//! Build the driver with the `otel` feature (`cargo add mssql-client --features
19//! otel`). Spans and metrics are then emitted automatically for connections,
20//! queries, executes, transactions, and pool operations; with the feature off
21//! the instrumentation compiles to no-ops at zero cost.
22//!
23//! This crate emits OpenTelemetry telemetry but does not configure an exporter —
24//! install a tracer/meter provider in your application using the
25//! `opentelemetry`, `opentelemetry_sdk`, and an exporter crate (e.g.
26//! `opentelemetry-otlp` to an OTLP collector such as Jaeger), then drive the
27//! `tracing` <-> OpenTelemetry bridge with `tracing-opentelemetry`. See those
28//! crates' docs for provider setup.
29//!
30//! ## Semantic Conventions
31//!
32//! Follows OpenTelemetry database semantic conventions:
33//! - `db.system`: "mssql"
34//! - `db.name`: Database name
35//! - `db.statement`: SQL statement (sanitized if configured)
36//! - `db.operation`: Query operation type (SELECT, INSERT, etc.)
37//! - `server.address`: Server hostname
38//! - `server.port`: Server port
39//!
40//! `db.statement` is sanitized by default ([`SanitizationConfig`]) so literal
41//! parameter values are replaced with placeholders before being recorded; opt
42//! out with [`SanitizationConfig::no_sanitization`] only when capturing raw SQL
43//! is acceptable.
44//!
45//! ## Troubleshooting
46//!
47//! - **No spans appear** — confirm the `otel` feature is enabled and a tracer
48//!   provider is installed before the first operation.
49//! - **`db.rows_affected` missing** — it is only recorded on mutating
50//!   statements, not on `SELECT`.
51//! - **High attribute cardinality** — keep sanitization on and avoid adding
52//!   per-row custom attributes.
53
54#[cfg(feature = "otel")]
55use opentelemetry::{
56    KeyValue, global,
57    trace::{Span, SpanKind, Status, Tracer},
58};
59
60/// Database system identifier for MSSQL.
61pub const DB_SYSTEM: &str = "mssql";
62
63/// Span names for database operations.
64pub mod span_names {
65    /// Span name for connection establishment.
66    pub const CONNECT: &str = "mssql.connect";
67    /// Span name for query execution.
68    pub const QUERY: &str = "mssql.query";
69    /// Span name for command execution.
70    pub const EXECUTE: &str = "mssql.execute";
71    /// Span name for beginning a transaction.
72    pub const BEGIN_TRANSACTION: &str = "mssql.begin_transaction";
73    /// Span name for committing a transaction.
74    pub const COMMIT: &str = "mssql.commit";
75    /// Span name for rolling back a transaction.
76    pub const ROLLBACK: &str = "mssql.rollback";
77    /// Span name for savepoint operations.
78    pub const SAVEPOINT: &str = "mssql.savepoint";
79    /// Span name for bulk insert operations.
80    pub const BULK_INSERT: &str = "mssql.bulk_insert";
81}
82
83/// Attribute keys following OpenTelemetry semantic conventions.
84pub mod attributes {
85    /// Database system type.
86    pub const DB_SYSTEM: &str = "db.system";
87    /// Database name.
88    pub const DB_NAME: &str = "db.name";
89    /// SQL statement (may be sanitized).
90    pub const DB_STATEMENT: &str = "db.statement";
91    /// Database operation type.
92    pub const DB_OPERATION: &str = "db.operation";
93    /// Server hostname.
94    pub const SERVER_ADDRESS: &str = "server.address";
95    /// Server port.
96    pub const SERVER_PORT: &str = "server.port";
97    /// Number of rows affected.
98    pub const DB_ROWS_AFFECTED: &str = "db.rows_affected";
99    /// Transaction isolation level.
100    pub const DB_ISOLATION_LEVEL: &str = "db.mssql.isolation_level";
101    /// Connection ID.
102    pub const DB_CONNECTION_ID: &str = "db.connection_id";
103    /// Error type.
104    pub const ERROR_TYPE: &str = "error.type";
105}
106
107/// Configuration for SQL statement sanitization.
108#[derive(Debug, Clone)]
109pub struct SanitizationConfig {
110    /// Whether to sanitize SQL statements.
111    pub enabled: bool,
112    /// Maximum length of statement to record.
113    pub max_length: usize,
114    /// Placeholder to use for sanitized values.
115    pub placeholder: String,
116}
117
118impl Default for SanitizationConfig {
119    fn default() -> Self {
120        Self {
121            enabled: true,
122            max_length: 2048,
123            placeholder: "?".to_string(),
124        }
125    }
126}
127
128impl SanitizationConfig {
129    /// Create a configuration that doesn't sanitize statements.
130    #[must_use]
131    pub fn no_sanitization() -> Self {
132        Self {
133            enabled: false,
134            max_length: usize::MAX,
135            placeholder: String::new(),
136        }
137    }
138
139    /// Sanitize a SQL statement according to the configuration.
140    #[must_use]
141    pub fn sanitize(&self, sql: &str) -> String {
142        if !self.enabled {
143            return truncate_string(sql, self.max_length);
144        }
145
146        // Simple sanitization: replace string literals and numbers
147        let sanitized = sanitize_sql(sql, &self.placeholder);
148        truncate_string(&sanitized, self.max_length)
149    }
150}
151
152/// Sanitize SQL by replacing literal values with placeholders.
153fn sanitize_sql(sql: &str, placeholder: &str) -> String {
154    let mut result = String::with_capacity(sql.len());
155    let mut chars = sql.chars().peekable();
156    let mut in_string = false;
157    let mut string_char = ' ';
158
159    while let Some(c) = chars.next() {
160        if in_string {
161            if c == string_char {
162                // Check for escaped quote
163                if chars.peek() == Some(&string_char) {
164                    chars.next();
165                    continue;
166                }
167                in_string = false;
168                result.push_str(placeholder);
169            }
170            continue;
171        }
172
173        if c == '\'' || c == '"' {
174            in_string = true;
175            string_char = c;
176            continue;
177        }
178
179        // Replace numeric literals (simplified)
180        if c.is_ascii_digit() && !result.ends_with(|ch: char| ch.is_alphanumeric() || ch == '_') {
181            // Skip the number
182            while chars
183                .peek()
184                .is_some_and(|ch| ch.is_ascii_digit() || *ch == '.')
185            {
186                chars.next();
187            }
188            result.push_str(placeholder);
189            continue;
190        }
191
192        result.push(c);
193    }
194
195    // If we ended in a string, close it
196    if in_string {
197        result.push_str(placeholder);
198    }
199
200    result
201}
202
203/// Truncate a string to a maximum length.
204fn truncate_string(s: &str, max_len: usize) -> String {
205    if s.len() <= max_len {
206        s.to_string()
207    } else {
208        format!("{}...", &s[..max_len.saturating_sub(3)])
209    }
210}
211
212/// Extract the operation type from a SQL statement.
213#[must_use]
214pub fn extract_operation(sql: &str) -> &'static str {
215    let sql_upper = sql.trim().to_uppercase();
216
217    if sql_upper.starts_with("SELECT") {
218        "SELECT"
219    } else if sql_upper.starts_with("INSERT") {
220        "INSERT"
221    } else if sql_upper.starts_with("UPDATE") {
222        "UPDATE"
223    } else if sql_upper.starts_with("DELETE") {
224        "DELETE"
225    } else if sql_upper.starts_with("EXEC") || sql_upper.starts_with("EXECUTE") {
226        "EXECUTE"
227    } else if sql_upper.starts_with("BEGIN TRAN") {
228        "BEGIN"
229    } else if sql_upper.starts_with("COMMIT") {
230        "COMMIT"
231    } else if sql_upper.starts_with("ROLLBACK") {
232        "ROLLBACK"
233    } else if sql_upper.starts_with("CREATE") {
234        "CREATE"
235    } else if sql_upper.starts_with("ALTER") {
236        "ALTER"
237    } else if sql_upper.starts_with("DROP") {
238        "DROP"
239    } else {
240        "OTHER"
241    }
242}
243
244/// Instrumentation context for database operations.
245#[cfg(feature = "otel")]
246#[derive(Clone)]
247pub struct InstrumentationContext {
248    /// Server address.
249    pub server_address: String,
250    /// Server port.
251    pub server_port: u16,
252    /// Database name.
253    pub database: Option<String>,
254    /// Sanitization configuration.
255    pub sanitization: SanitizationConfig,
256    /// Operation-level metrics (duration histogram, operation/error counters)
257    /// bound to this connection's server attributes.
258    metrics: std::sync::Arc<DatabaseMetrics>,
259}
260
261#[cfg(feature = "otel")]
262impl std::fmt::Debug for InstrumentationContext {
263    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264        f.debug_struct("InstrumentationContext")
265            .field("server_address", &self.server_address)
266            .field("server_port", &self.server_port)
267            .field("database", &self.database)
268            .field("sanitization", &self.sanitization)
269            .finish_non_exhaustive()
270    }
271}
272
273#[cfg(feature = "otel")]
274impl InstrumentationContext {
275    /// Create a new instrumentation context.
276    #[must_use]
277    pub fn new(server_address: String, server_port: u16) -> Self {
278        let metrics = std::sync::Arc::new(DatabaseMetrics::new(None, &server_address, server_port));
279        Self {
280            server_address,
281            server_port,
282            database: None,
283            sanitization: SanitizationConfig::default(),
284            metrics,
285        }
286    }
287
288    /// Operation-level metrics for this connection.
289    #[must_use]
290    pub fn metrics(&self) -> &DatabaseMetrics {
291        &self.metrics
292    }
293
294    /// Set the database name.
295    #[must_use]
296    pub fn with_database(mut self, database: impl Into<String>) -> Self {
297        self.database = Some(database.into());
298        self
299    }
300
301    /// Set the sanitization configuration.
302    #[must_use]
303    pub fn with_sanitization(mut self, config: SanitizationConfig) -> Self {
304        self.sanitization = config;
305        self
306    }
307
308    /// Get base attributes for spans.
309    pub fn base_attributes(&self) -> Vec<KeyValue> {
310        let mut attrs = vec![
311            KeyValue::new(attributes::DB_SYSTEM, DB_SYSTEM),
312            KeyValue::new(attributes::SERVER_ADDRESS, self.server_address.clone()),
313            KeyValue::new(attributes::SERVER_PORT, i64::from(self.server_port)),
314        ];
315
316        if let Some(ref db) = self.database {
317            attrs.push(KeyValue::new(attributes::DB_NAME, db.clone()));
318        }
319
320        attrs
321    }
322
323    /// Create a connection span.
324    pub fn connection_span(&self) -> impl Span {
325        let tracer = global::tracer("mssql-client");
326        let mut attrs = self.base_attributes();
327        attrs.push(KeyValue::new(
328            "db.connection_string.host",
329            self.server_address.clone(),
330        ));
331
332        tracer
333            .span_builder(span_names::CONNECT)
334            .with_kind(SpanKind::Client)
335            .with_attributes(attrs)
336            .start(&tracer)
337    }
338
339    /// Create a query span.
340    pub fn query_span(&self, sql: &str) -> impl Span {
341        let tracer = global::tracer("mssql-client");
342        let mut attrs = self.base_attributes();
343
344        let operation = extract_operation(sql);
345        attrs.push(KeyValue::new(attributes::DB_OPERATION, operation));
346        attrs.push(KeyValue::new(
347            attributes::DB_STATEMENT,
348            self.sanitization.sanitize(sql),
349        ));
350
351        tracer
352            .span_builder(span_names::QUERY)
353            .with_kind(SpanKind::Client)
354            .with_attributes(attrs)
355            .start(&tracer)
356    }
357
358    /// Create a transaction span.
359    pub fn transaction_span(&self, operation: &str) -> impl Span {
360        let tracer = global::tracer("mssql-client");
361        let mut attrs = self.base_attributes();
362        attrs.push(KeyValue::new(
363            attributes::DB_OPERATION,
364            operation.to_string(),
365        ));
366
367        let span_name = match operation {
368            "BEGIN" => span_names::BEGIN_TRANSACTION,
369            "COMMIT" => span_names::COMMIT,
370            "ROLLBACK" => span_names::ROLLBACK,
371            _ => span_names::SAVEPOINT,
372        };
373
374        tracer
375            .span_builder(span_name)
376            .with_kind(SpanKind::Client)
377            .with_attributes(attrs)
378            .start(&tracer)
379    }
380
381    /// Record an error on the current span.
382    pub fn record_error(span: &mut impl Span, error: &crate::error::Error) {
383        span.set_status(Status::error(error.to_string()));
384        span.record_error(error);
385    }
386
387    /// Record success with optional row count.
388    pub fn record_success(span: &mut impl Span, rows_affected: Option<u64>) {
389        span.set_status(Status::Ok);
390        if let Some(rows) = rows_affected {
391            span.set_attribute(KeyValue::new(attributes::DB_ROWS_AFFECTED, rows as i64));
392        }
393    }
394}
395
396/// No-op instrumentation context when otel feature is disabled.
397#[cfg(not(feature = "otel"))]
398#[derive(Debug, Clone, Default)]
399pub struct InstrumentationContext;
400
401#[cfg(not(feature = "otel"))]
402impl InstrumentationContext {
403    /// Create a new instrumentation context (no-op).
404    #[must_use]
405    pub fn new(_server_address: String, _server_port: u16) -> Self {
406        Self
407    }
408
409    /// Set the database name (no-op).
410    #[must_use]
411    pub fn with_database(self, _database: impl Into<String>) -> Self {
412        self
413    }
414
415    /// Set the sanitization configuration (no-op).
416    #[must_use]
417    pub fn with_sanitization(self, _config: SanitizationConfig) -> Self {
418        self
419    }
420}
421
422// =============================================================================
423// OpenTelemetry Metrics Support
424// =============================================================================
425
426/// Metric names following OpenTelemetry semantic conventions.
427pub mod metric_names {
428    /// Gauge: Number of connections currently in use.
429    pub const DB_CLIENT_CONNECTIONS_USAGE: &str = "db.client.connections.usage";
430    /// Gauge: Number of idle connections in the pool.
431    pub const DB_CLIENT_CONNECTIONS_IDLE: &str = "db.client.connections.idle";
432    /// Gauge: Maximum connections allowed in the pool.
433    pub const DB_CLIENT_CONNECTIONS_MAX: &str = "db.client.connections.max";
434    /// Counter: Total number of connections created.
435    pub const DB_CLIENT_CONNECTIONS_CREATE_TOTAL: &str = "db.client.connections.create.total";
436    /// Counter: Total number of connections closed.
437    pub const DB_CLIENT_CONNECTIONS_CLOSE_TOTAL: &str = "db.client.connections.close.total";
438    /// Histogram: Duration of database operations (queries, executes).
439    pub const DB_CLIENT_OPERATION_DURATION: &str = "db.client.operation.duration";
440    /// Counter: Total number of operations performed.
441    pub const DB_CLIENT_OPERATIONS_TOTAL: &str = "db.client.operations.total";
442    /// Counter: Total number of operation errors.
443    pub const DB_CLIENT_ERRORS_TOTAL: &str = "db.client.errors.total";
444    /// Histogram: Time spent waiting for a connection from the pool.
445    pub const DB_CLIENT_CONNECTIONS_WAIT_TIME: &str = "db.client.connections.wait_time";
446}
447
448/// Database metrics collector using OpenTelemetry.
449#[cfg(feature = "otel")]
450pub struct DatabaseMetrics {
451    /// Connection usage gauge.
452    connections_usage: opentelemetry::metrics::Gauge<u64>,
453    /// Idle connections gauge.
454    connections_idle: opentelemetry::metrics::Gauge<u64>,
455    /// Max connections gauge.
456    connections_max: opentelemetry::metrics::Gauge<u64>,
457    /// Connections created counter.
458    connections_create_total: opentelemetry::metrics::Counter<u64>,
459    /// Connections closed counter.
460    connections_close_total: opentelemetry::metrics::Counter<u64>,
461    /// Operation duration histogram.
462    operation_duration: opentelemetry::metrics::Histogram<f64>,
463    /// Total operations counter.
464    operations_total: opentelemetry::metrics::Counter<u64>,
465    /// Error counter.
466    errors_total: opentelemetry::metrics::Counter<u64>,
467    /// Connection wait time histogram.
468    connections_wait_time: opentelemetry::metrics::Histogram<f64>,
469    /// Base attributes for all metrics.
470    base_attributes: Vec<opentelemetry::KeyValue>,
471}
472
473#[cfg(feature = "otel")]
474impl DatabaseMetrics {
475    /// Create a new metrics collector.
476    ///
477    /// # Arguments
478    ///
479    /// * `pool_name` - Optional name to identify this pool in metrics
480    /// * `server_address` - Server hostname
481    /// * `server_port` - Server port
482    pub fn new(pool_name: Option<&str>, server_address: &str, server_port: u16) -> Self {
483        use opentelemetry::{KeyValue, global};
484
485        let meter = global::meter("mssql-client");
486
487        let connections_usage = meter
488            .u64_gauge(metric_names::DB_CLIENT_CONNECTIONS_USAGE)
489            .with_description("Number of connections currently in use")
490            .with_unit("connections")
491            .build();
492
493        let connections_idle = meter
494            .u64_gauge(metric_names::DB_CLIENT_CONNECTIONS_IDLE)
495            .with_description("Number of idle connections available")
496            .with_unit("connections")
497            .build();
498
499        let connections_max = meter
500            .u64_gauge(metric_names::DB_CLIENT_CONNECTIONS_MAX)
501            .with_description("Maximum number of connections allowed")
502            .with_unit("connections")
503            .build();
504
505        let connections_create_total = meter
506            .u64_counter(metric_names::DB_CLIENT_CONNECTIONS_CREATE_TOTAL)
507            .with_description("Total number of connections created")
508            .with_unit("connections")
509            .build();
510
511        let connections_close_total = meter
512            .u64_counter(metric_names::DB_CLIENT_CONNECTIONS_CLOSE_TOTAL)
513            .with_description("Total number of connections closed")
514            .with_unit("connections")
515            .build();
516
517        let operation_duration = meter
518            .f64_histogram(metric_names::DB_CLIENT_OPERATION_DURATION)
519            .with_description("Duration of database operations")
520            .with_unit("s")
521            .build();
522
523        let operations_total = meter
524            .u64_counter(metric_names::DB_CLIENT_OPERATIONS_TOTAL)
525            .with_description("Total number of database operations")
526            .with_unit("operations")
527            .build();
528
529        let errors_total = meter
530            .u64_counter(metric_names::DB_CLIENT_ERRORS_TOTAL)
531            .with_description("Total number of operation errors")
532            .with_unit("errors")
533            .build();
534
535        let connections_wait_time = meter
536            .f64_histogram(metric_names::DB_CLIENT_CONNECTIONS_WAIT_TIME)
537            .with_description("Time spent waiting for a connection")
538            .with_unit("s")
539            .build();
540
541        let mut base_attributes = vec![
542            KeyValue::new(attributes::DB_SYSTEM, DB_SYSTEM),
543            KeyValue::new(attributes::SERVER_ADDRESS, server_address.to_string()),
544            KeyValue::new(attributes::SERVER_PORT, i64::from(server_port)),
545        ];
546
547        if let Some(name) = pool_name {
548            base_attributes.push(KeyValue::new("db.client.pool.name", name.to_string()));
549        }
550
551        Self {
552            connections_usage,
553            connections_idle,
554            connections_max,
555            connections_create_total,
556            connections_close_total,
557            operation_duration,
558            operations_total,
559            errors_total,
560            connections_wait_time,
561            base_attributes,
562        }
563    }
564
565    /// Record pool connection status.
566    pub fn record_pool_status(&self, in_use: u64, idle: u64, max: u64) {
567        self.connections_usage.record(in_use, &self.base_attributes);
568        self.connections_idle.record(idle, &self.base_attributes);
569        self.connections_max.record(max, &self.base_attributes);
570    }
571
572    /// Record a connection being created.
573    pub fn record_connection_created(&self) {
574        self.connections_create_total.add(1, &self.base_attributes);
575    }
576
577    /// Record a connection being closed.
578    pub fn record_connection_closed(&self) {
579        self.connections_close_total.add(1, &self.base_attributes);
580    }
581
582    /// Record an operation duration.
583    pub fn record_operation(&self, operation: &str, duration_seconds: f64, success: bool) {
584        use opentelemetry::KeyValue;
585
586        let mut attrs = self.base_attributes.clone();
587        attrs.push(KeyValue::new(
588            attributes::DB_OPERATION,
589            operation.to_string(),
590        ));
591        attrs.push(KeyValue::new("db.operation.success", success));
592
593        self.operations_total.add(1, &attrs);
594        self.operation_duration.record(duration_seconds, &attrs);
595
596        if !success {
597            self.errors_total.add(1, &attrs);
598        }
599    }
600
601    /// Record time spent waiting for a connection from the pool.
602    pub fn record_connection_wait(&self, duration_seconds: f64) {
603        self.connections_wait_time
604            .record(duration_seconds, &self.base_attributes);
605    }
606}
607
608/// No-op metrics collector when otel feature is disabled.
609#[cfg(not(feature = "otel"))]
610#[derive(Debug, Clone, Default)]
611pub struct DatabaseMetrics;
612
613#[cfg(not(feature = "otel"))]
614impl DatabaseMetrics {
615    /// Create a new no-op metrics collector.
616    #[must_use]
617    pub fn new(_pool_name: Option<&str>, _server_address: &str, _server_port: u16) -> Self {
618        Self
619    }
620
621    /// Record pool status (no-op).
622    pub fn record_pool_status(&self, _in_use: u64, _idle: u64, _max: u64) {}
623
624    /// Record connection created (no-op).
625    pub fn record_connection_created(&self) {}
626
627    /// Record connection closed (no-op).
628    pub fn record_connection_closed(&self) {}
629
630    /// Record operation (no-op).
631    pub fn record_operation(&self, _operation: &str, _duration_seconds: f64, _success: bool) {}
632
633    /// Record connection wait time (no-op).
634    pub fn record_connection_wait(&self, _duration_seconds: f64) {}
635}
636
637/// Helper for timing operations.
638#[derive(Debug, Clone)]
639pub struct OperationTimer {
640    start: std::time::Instant,
641    operation: &'static str,
642}
643
644impl OperationTimer {
645    /// Start timing an operation.
646    #[must_use]
647    pub fn start(operation: &'static str) -> Self {
648        Self {
649            start: std::time::Instant::now(),
650            operation,
651        }
652    }
653
654    /// Get the elapsed time in seconds.
655    #[must_use]
656    pub fn elapsed_seconds(&self) -> f64 {
657        self.start.elapsed().as_secs_f64()
658    }
659
660    /// Get the operation name.
661    #[must_use]
662    pub fn operation(&self) -> &'static str {
663        self.operation
664    }
665
666    /// Finish timing and record the metric.
667    #[cfg(feature = "otel")]
668    pub fn finish(self, metrics: &DatabaseMetrics, success: bool) {
669        metrics.record_operation(self.operation, self.elapsed_seconds(), success);
670    }
671
672    /// Finish timing (no-op when otel is disabled).
673    #[cfg(not(feature = "otel"))]
674    pub fn finish(self, _metrics: &DatabaseMetrics, _success: bool) {}
675}
676
677#[cfg(test)]
678#[allow(clippy::unwrap_used)]
679mod tests {
680    use super::*;
681
682    #[test]
683    fn test_extract_operation() {
684        assert_eq!(extract_operation("SELECT * FROM users"), "SELECT");
685        assert_eq!(extract_operation("  select id from users"), "SELECT");
686        assert_eq!(extract_operation("INSERT INTO users VALUES (1)"), "INSERT");
687        assert_eq!(extract_operation("UPDATE users SET name = 'foo'"), "UPDATE");
688        assert_eq!(extract_operation("DELETE FROM users"), "DELETE");
689        assert_eq!(extract_operation("EXEC sp_help"), "EXECUTE");
690        assert_eq!(extract_operation("BEGIN TRANSACTION"), "BEGIN");
691        assert_eq!(extract_operation("COMMIT"), "COMMIT");
692        assert_eq!(extract_operation("ROLLBACK"), "ROLLBACK");
693        assert_eq!(extract_operation("CREATE TABLE foo"), "CREATE");
694        assert_eq!(extract_operation("unknown stuff"), "OTHER");
695    }
696
697    #[test]
698    fn test_sanitize_sql() {
699        let placeholder = "?";
700
701        // String literals
702        assert_eq!(
703            sanitize_sql("SELECT * FROM users WHERE name = 'Alice'", placeholder),
704            "SELECT * FROM users WHERE name = ?"
705        );
706
707        // Multiple strings
708        assert_eq!(
709            sanitize_sql("INSERT INTO t VALUES ('a', 'b')", placeholder),
710            "INSERT INTO t VALUES (?, ?)"
711        );
712
713        // Escaped quotes
714        assert_eq!(
715            sanitize_sql("SELECT * WHERE name = 'O''Brien'", placeholder),
716            "SELECT * WHERE name = ?"
717        );
718
719        // Numbers
720        assert_eq!(
721            sanitize_sql("SELECT * WHERE id = 123", placeholder),
722            "SELECT * WHERE id = ?"
723        );
724
725        // Mixed
726        assert_eq!(
727            sanitize_sql("SELECT * WHERE id = 42 AND name = 'test'", placeholder),
728            "SELECT * WHERE id = ? AND name = ?"
729        );
730    }
731
732    #[test]
733    fn test_truncate_string() {
734        assert_eq!(truncate_string("hello", 10), "hello");
735        assert_eq!(truncate_string("hello world", 8), "hello...");
736        assert_eq!(truncate_string("hi", 2), "hi");
737    }
738
739    #[test]
740    fn test_sanitization_config_default() {
741        let config = SanitizationConfig::default();
742        assert!(config.enabled);
743        assert_eq!(config.max_length, 2048);
744        assert_eq!(config.placeholder, "?");
745    }
746
747    #[test]
748    fn test_sanitization_config_no_sanitization() {
749        let config = SanitizationConfig::no_sanitization();
750        assert!(!config.enabled);
751
752        let sql = "SELECT * FROM users WHERE name = 'Alice'";
753        assert_eq!(config.sanitize(sql), sql);
754    }
755}