Skip to main content

mssql_client/
instrumentation.rs

1//! OpenTelemetry instrumentation for database operations.
2//!
3//! This module provides first-class OpenTelemetry tracing support when the
4//! `otel` feature is enabled. It follows the OpenTelemetry semantic conventions
5//! for database operations.
6//!
7//! ## Features
8//!
9//! When the `otel` feature is enabled, the following instrumentation is available:
10//!
11//! - **Connection spans**: Track connection establishment time and success/failure
12//! - **Query spans**: Track SQL execution with sanitized statement attributes
13//! - **Transaction spans**: Track transaction boundaries (begin, commit, rollback)
14//! - **Error events**: Record errors with appropriate attributes
15//!
16//! ## Usage
17//!
18//! Build the driver with the `otel` feature (`cargo add mssql-client --features
19//! otel`). Spans and metrics are then emitted automatically for connections,
20//! queries, executes, transactions, and pool operations; with the feature off
21//! the instrumentation compiles to no-ops at zero cost.
22//!
23//! This crate emits OpenTelemetry telemetry but does not configure an exporter —
24//! install a tracer/meter provider in your application using the
25//! `opentelemetry`, `opentelemetry_sdk`, and an exporter crate (e.g.
26//! `opentelemetry-otlp` to an OTLP collector such as Jaeger), then drive the
27//! `tracing` <-> OpenTelemetry bridge with `tracing-opentelemetry`. See those
28//! crates' docs for provider setup.
29//!
30//! ## Semantic Conventions
31//!
32//! Follows OpenTelemetry database semantic conventions:
33//! - `db.system`: "mssql"
34//! - `db.name`: Database name
35//! - `db.statement`: SQL statement (sanitized if configured)
36//! - `db.operation`: Query operation type (SELECT, INSERT, etc.)
37//! - `server.address`: Server hostname
38//! - `server.port`: Server port
39//!
40//! `db.statement` is sanitized by default ([`SanitizationConfig`]) so literal
41//! parameter values are replaced with placeholders before being recorded; opt
42//! out with [`SanitizationConfig::no_sanitization`] only when capturing raw SQL
43//! is acceptable.
44//!
45//! ## Troubleshooting
46//!
47//! - **No spans appear** — confirm the `otel` feature is enabled and a tracer
48//!   provider is installed before the first operation.
49//! - **`db.rows_affected` missing** — it is only recorded on mutating
50//!   statements, not on `SELECT`.
51//! - **High attribute cardinality** — keep sanitization on and avoid adding
52//!   per-row custom attributes.
53
54#[cfg(feature = "otel")]
55use opentelemetry::{
56    KeyValue, global,
57    trace::{Span, SpanKind, Status, Tracer},
58};
59
60/// Database system identifier for MSSQL.
61pub const DB_SYSTEM: &str = "mssql";
62
63/// Span names for database operations.
64pub mod span_names {
65    /// Span name for connection establishment.
66    pub const CONNECT: &str = "mssql.connect";
67    /// Span name for query execution.
68    pub const QUERY: &str = "mssql.query";
69    /// Span name for command execution.
70    pub const EXECUTE: &str = "mssql.execute";
71    /// Span name for beginning a transaction.
72    pub const BEGIN_TRANSACTION: &str = "mssql.begin_transaction";
73    /// Span name for committing a transaction.
74    pub const COMMIT: &str = "mssql.commit";
75    /// Span name for rolling back a transaction.
76    pub const ROLLBACK: &str = "mssql.rollback";
77    /// Span name for savepoint operations.
78    pub const SAVEPOINT: &str = "mssql.savepoint";
79    /// Span name for bulk insert operations.
80    pub const BULK_INSERT: &str = "mssql.bulk_insert";
81}
82
83/// Attribute keys following OpenTelemetry semantic conventions.
84pub mod attributes {
85    /// Database system type.
86    pub const DB_SYSTEM: &str = "db.system";
87    /// Database name.
88    pub const DB_NAME: &str = "db.name";
89    /// SQL statement (may be sanitized).
90    pub const DB_STATEMENT: &str = "db.statement";
91    /// Database operation type.
92    pub const DB_OPERATION: &str = "db.operation";
93    /// Server hostname.
94    pub const SERVER_ADDRESS: &str = "server.address";
95    /// Server port.
96    pub const SERVER_PORT: &str = "server.port";
97    /// Number of rows affected.
98    pub const DB_ROWS_AFFECTED: &str = "db.rows_affected";
99    /// Transaction isolation level.
100    pub const DB_ISOLATION_LEVEL: &str = "db.mssql.isolation_level";
101    /// Connection ID.
102    pub const DB_CONNECTION_ID: &str = "db.connection_id";
103    /// Error type.
104    pub const ERROR_TYPE: &str = "error.type";
105}
106
107/// Configuration for SQL statement sanitization.
108#[derive(Debug, Clone)]
109pub struct SanitizationConfig {
110    /// Whether to sanitize SQL statements.
111    pub enabled: bool,
112    /// Maximum length of statement to record.
113    pub max_length: usize,
114    /// Placeholder to use for sanitized values.
115    pub placeholder: String,
116}
117
118impl Default for SanitizationConfig {
119    fn default() -> Self {
120        Self {
121            enabled: true,
122            max_length: 2048,
123            placeholder: "?".to_string(),
124        }
125    }
126}
127
128impl SanitizationConfig {
129    /// Create a configuration that doesn't sanitize statements.
130    #[must_use]
131    pub fn no_sanitization() -> Self {
132        Self {
133            enabled: false,
134            max_length: usize::MAX,
135            placeholder: String::new(),
136        }
137    }
138
139    /// Sanitize a SQL statement according to the configuration.
140    #[must_use]
141    pub fn sanitize(&self, sql: &str) -> String {
142        if !self.enabled {
143            return truncate_string(sql, self.max_length);
144        }
145
146        // Simple sanitization: replace string literals and numbers
147        let sanitized = sanitize_sql(sql, &self.placeholder);
148        truncate_string(&sanitized, self.max_length)
149    }
150}
151
152/// Sanitize SQL by replacing literal values with placeholders.
153fn sanitize_sql(sql: &str, placeholder: &str) -> String {
154    let mut result = String::with_capacity(sql.len());
155    let mut chars = sql.chars().peekable();
156    let mut in_string = false;
157    let mut string_char = ' ';
158
159    while let Some(c) = chars.next() {
160        if in_string {
161            if c == string_char {
162                // Check for escaped quote
163                if chars.peek() == Some(&string_char) {
164                    chars.next();
165                    continue;
166                }
167                in_string = false;
168                result.push_str(placeholder);
169            }
170            continue;
171        }
172
173        if c == '\'' || c == '"' {
174            in_string = true;
175            string_char = c;
176            continue;
177        }
178
179        // Replace numeric literals (simplified)
180        if c.is_ascii_digit() && !result.ends_with(|ch: char| ch.is_alphanumeric() || ch == '_') {
181            // Skip the number
182            while chars
183                .peek()
184                .is_some_and(|ch| ch.is_ascii_digit() || *ch == '.')
185            {
186                chars.next();
187            }
188            result.push_str(placeholder);
189            continue;
190        }
191
192        result.push(c);
193    }
194
195    // If we ended in a string, close it
196    if in_string {
197        result.push_str(placeholder);
198    }
199
200    result
201}
202
203/// Truncate a string to a maximum length.
204fn truncate_string(s: &str, max_len: usize) -> String {
205    if s.len() <= max_len {
206        s.to_string()
207    } else {
208        format!("{}...", &s[..max_len.saturating_sub(3)])
209    }
210}
211
212/// Extract the operation type from a SQL statement.
213#[must_use]
214pub fn extract_operation(sql: &str) -> &'static str {
215    let sql_upper = sql.trim().to_uppercase();
216
217    if sql_upper.starts_with("SELECT") {
218        "SELECT"
219    } else if sql_upper.starts_with("INSERT") {
220        "INSERT"
221    } else if sql_upper.starts_with("UPDATE") {
222        "UPDATE"
223    } else if sql_upper.starts_with("DELETE") {
224        "DELETE"
225    } else if sql_upper.starts_with("EXEC") || sql_upper.starts_with("EXECUTE") {
226        "EXECUTE"
227    } else if sql_upper.starts_with("BEGIN TRAN") {
228        "BEGIN"
229    } else if sql_upper.starts_with("COMMIT") {
230        "COMMIT"
231    } else if sql_upper.starts_with("ROLLBACK") {
232        "ROLLBACK"
233    } else if sql_upper.starts_with("CREATE") {
234        "CREATE"
235    } else if sql_upper.starts_with("ALTER") {
236        "ALTER"
237    } else if sql_upper.starts_with("DROP") {
238        "DROP"
239    } else {
240        "OTHER"
241    }
242}
243
244/// Instrumentation context for database operations.
245#[cfg(feature = "otel")]
246#[derive(Clone)]
247pub struct InstrumentationContext {
248    /// Server address.
249    pub server_address: String,
250    /// Server port.
251    pub server_port: u16,
252    /// Database name.
253    pub database: Option<String>,
254    /// Sanitization configuration.
255    pub sanitization: SanitizationConfig,
256    /// Operation-level metrics (duration histogram, operation/error counters)
257    /// bound to this connection's server attributes.
258    metrics: std::sync::Arc<DatabaseMetrics>,
259}
260
261#[cfg(feature = "otel")]
262impl std::fmt::Debug for InstrumentationContext {
263    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264        f.debug_struct("InstrumentationContext")
265            .field("server_address", &self.server_address)
266            .field("server_port", &self.server_port)
267            .field("database", &self.database)
268            .field("sanitization", &self.sanitization)
269            .finish_non_exhaustive()
270    }
271}
272
273#[cfg(feature = "otel")]
274impl InstrumentationContext {
275    /// Create a new instrumentation context.
276    #[must_use]
277    pub fn new(server_address: String, server_port: u16) -> Self {
278        let metrics = std::sync::Arc::new(DatabaseMetrics::new(None, &server_address, server_port));
279        Self {
280            server_address,
281            server_port,
282            database: None,
283            sanitization: SanitizationConfig::default(),
284            metrics,
285        }
286    }
287
288    /// Operation-level metrics for this connection.
289    #[must_use]
290    pub fn metrics(&self) -> &DatabaseMetrics {
291        &self.metrics
292    }
293
294    /// Set the database name.
295    #[must_use]
296    pub fn with_database(mut self, database: impl Into<String>) -> Self {
297        self.database = Some(database.into());
298        self
299    }
300
301    /// Set the sanitization configuration.
302    #[must_use]
303    pub fn with_sanitization(mut self, config: SanitizationConfig) -> Self {
304        self.sanitization = config;
305        self
306    }
307
308    /// Get base attributes for spans.
309    pub fn base_attributes(&self) -> Vec<KeyValue> {
310        let mut attrs = vec![
311            KeyValue::new(attributes::DB_SYSTEM, DB_SYSTEM),
312            KeyValue::new(attributes::SERVER_ADDRESS, self.server_address.clone()),
313            KeyValue::new(attributes::SERVER_PORT, i64::from(self.server_port)),
314        ];
315
316        if let Some(ref db) = self.database {
317            attrs.push(KeyValue::new(attributes::DB_NAME, db.clone()));
318        }
319
320        attrs
321    }
322
323    /// Create a connection span.
324    pub fn connection_span(&self) -> impl Span {
325        let tracer = global::tracer("mssql-client");
326        let mut attrs = self.base_attributes();
327        attrs.push(KeyValue::new(
328            "db.connection_string.host",
329            self.server_address.clone(),
330        ));
331
332        tracer
333            .span_builder(span_names::CONNECT)
334            .with_kind(SpanKind::Client)
335            .with_attributes(attrs)
336            .start(&tracer)
337    }
338
339    /// Create a query span.
340    pub fn query_span(&self, sql: &str) -> impl Span {
341        let tracer = global::tracer("mssql-client");
342        let mut attrs = self.base_attributes();
343
344        let operation = extract_operation(sql);
345        attrs.push(KeyValue::new(attributes::DB_OPERATION, operation));
346        attrs.push(KeyValue::new(
347            attributes::DB_STATEMENT,
348            self.sanitization.sanitize(sql),
349        ));
350
351        tracer
352            .span_builder(span_names::QUERY)
353            .with_kind(SpanKind::Client)
354            .with_attributes(attrs)
355            .start(&tracer)
356    }
357
358    /// Create a stored-procedure call span.
359    ///
360    /// A procedure call has no SQL text, so the procedure name is recorded as
361    /// `db.statement` (it is a validated identifier, not a value, so it is not
362    /// sanitized) and the operation is `EXECUTE`.
363    pub fn procedure_span(&self, proc_name: &str) -> impl Span {
364        let tracer = global::tracer("mssql-client");
365        let mut attrs = self.base_attributes();
366        attrs.push(KeyValue::new(attributes::DB_OPERATION, "EXECUTE"));
367        attrs.push(KeyValue::new(
368            attributes::DB_STATEMENT,
369            proc_name.to_string(),
370        ));
371
372        tracer
373            .span_builder(span_names::EXECUTE)
374            .with_kind(SpanKind::Client)
375            .with_attributes(attrs)
376            .start(&tracer)
377    }
378
379    /// Create a transaction span.
380    pub fn transaction_span(&self, operation: &str) -> impl Span {
381        let tracer = global::tracer("mssql-client");
382        let mut attrs = self.base_attributes();
383        attrs.push(KeyValue::new(
384            attributes::DB_OPERATION,
385            operation.to_string(),
386        ));
387
388        let span_name = match operation {
389            "BEGIN" => span_names::BEGIN_TRANSACTION,
390            "COMMIT" => span_names::COMMIT,
391            "ROLLBACK" => span_names::ROLLBACK,
392            _ => span_names::SAVEPOINT,
393        };
394
395        tracer
396            .span_builder(span_name)
397            .with_kind(SpanKind::Client)
398            .with_attributes(attrs)
399            .start(&tracer)
400    }
401
402    /// Record an error on the current span.
403    pub fn record_error(span: &mut impl Span, error: &crate::error::Error) {
404        span.set_status(Status::error(error.to_string()));
405        span.record_error(error);
406    }
407
408    /// Record success with optional row count.
409    pub fn record_success(span: &mut impl Span, rows_affected: Option<u64>) {
410        span.set_status(Status::Ok);
411        if let Some(rows) = rows_affected {
412            span.set_attribute(KeyValue::new(attributes::DB_ROWS_AFFECTED, rows as i64));
413        }
414    }
415}
416
417/// No-op instrumentation context when otel feature is disabled.
418#[cfg(not(feature = "otel"))]
419#[derive(Debug, Clone, Default)]
420pub struct InstrumentationContext;
421
422#[cfg(not(feature = "otel"))]
423impl InstrumentationContext {
424    /// Create a new instrumentation context (no-op).
425    #[must_use]
426    pub fn new(_server_address: String, _server_port: u16) -> Self {
427        Self
428    }
429
430    /// Set the database name (no-op).
431    #[must_use]
432    pub fn with_database(self, _database: impl Into<String>) -> Self {
433        self
434    }
435
436    /// Set the sanitization configuration (no-op).
437    #[must_use]
438    pub fn with_sanitization(self, _config: SanitizationConfig) -> Self {
439        self
440    }
441}
442
443// =============================================================================
444// OpenTelemetry Metrics Support
445// =============================================================================
446
447/// Metric names following OpenTelemetry semantic conventions.
448pub mod metric_names {
449    /// Gauge: Number of connections currently in use.
450    pub const DB_CLIENT_CONNECTIONS_USAGE: &str = "db.client.connections.usage";
451    /// Gauge: Number of idle connections in the pool.
452    pub const DB_CLIENT_CONNECTIONS_IDLE: &str = "db.client.connections.idle";
453    /// Gauge: Maximum connections allowed in the pool.
454    pub const DB_CLIENT_CONNECTIONS_MAX: &str = "db.client.connections.max";
455    /// Counter: Total number of connections created.
456    pub const DB_CLIENT_CONNECTIONS_CREATE_TOTAL: &str = "db.client.connections.create.total";
457    /// Counter: Total number of connections closed.
458    pub const DB_CLIENT_CONNECTIONS_CLOSE_TOTAL: &str = "db.client.connections.close.total";
459    /// Histogram: Duration of database operations (queries, executes).
460    pub const DB_CLIENT_OPERATION_DURATION: &str = "db.client.operation.duration";
461    /// Counter: Total number of operations performed.
462    pub const DB_CLIENT_OPERATIONS_TOTAL: &str = "db.client.operations.total";
463    /// Counter: Total number of operation errors.
464    pub const DB_CLIENT_ERRORS_TOTAL: &str = "db.client.errors.total";
465    /// Histogram: Time spent waiting for a connection from the pool.
466    pub const DB_CLIENT_CONNECTIONS_WAIT_TIME: &str = "db.client.connections.wait_time";
467}
468
469/// Database metrics collector using OpenTelemetry.
470#[cfg(feature = "otel")]
471pub struct DatabaseMetrics {
472    /// Connection usage gauge.
473    connections_usage: opentelemetry::metrics::Gauge<u64>,
474    /// Idle connections gauge.
475    connections_idle: opentelemetry::metrics::Gauge<u64>,
476    /// Max connections gauge.
477    connections_max: opentelemetry::metrics::Gauge<u64>,
478    /// Connections created counter.
479    connections_create_total: opentelemetry::metrics::Counter<u64>,
480    /// Connections closed counter.
481    connections_close_total: opentelemetry::metrics::Counter<u64>,
482    /// Operation duration histogram.
483    operation_duration: opentelemetry::metrics::Histogram<f64>,
484    /// Total operations counter.
485    operations_total: opentelemetry::metrics::Counter<u64>,
486    /// Error counter.
487    errors_total: opentelemetry::metrics::Counter<u64>,
488    /// Connection wait time histogram.
489    connections_wait_time: opentelemetry::metrics::Histogram<f64>,
490    /// Base attributes for all metrics.
491    base_attributes: Vec<opentelemetry::KeyValue>,
492}
493
494#[cfg(feature = "otel")]
495impl DatabaseMetrics {
496    /// Create a new metrics collector.
497    ///
498    /// # Arguments
499    ///
500    /// * `pool_name` - Optional name to identify this pool in metrics
501    /// * `server_address` - Server hostname
502    /// * `server_port` - Server port
503    pub fn new(pool_name: Option<&str>, server_address: &str, server_port: u16) -> Self {
504        use opentelemetry::{KeyValue, global};
505
506        let meter = global::meter("mssql-client");
507
508        let connections_usage = meter
509            .u64_gauge(metric_names::DB_CLIENT_CONNECTIONS_USAGE)
510            .with_description("Number of connections currently in use")
511            .with_unit("connections")
512            .build();
513
514        let connections_idle = meter
515            .u64_gauge(metric_names::DB_CLIENT_CONNECTIONS_IDLE)
516            .with_description("Number of idle connections available")
517            .with_unit("connections")
518            .build();
519
520        let connections_max = meter
521            .u64_gauge(metric_names::DB_CLIENT_CONNECTIONS_MAX)
522            .with_description("Maximum number of connections allowed")
523            .with_unit("connections")
524            .build();
525
526        let connections_create_total = meter
527            .u64_counter(metric_names::DB_CLIENT_CONNECTIONS_CREATE_TOTAL)
528            .with_description("Total number of connections created")
529            .with_unit("connections")
530            .build();
531
532        let connections_close_total = meter
533            .u64_counter(metric_names::DB_CLIENT_CONNECTIONS_CLOSE_TOTAL)
534            .with_description("Total number of connections closed")
535            .with_unit("connections")
536            .build();
537
538        let operation_duration = meter
539            .f64_histogram(metric_names::DB_CLIENT_OPERATION_DURATION)
540            .with_description("Duration of database operations")
541            .with_unit("s")
542            .build();
543
544        let operations_total = meter
545            .u64_counter(metric_names::DB_CLIENT_OPERATIONS_TOTAL)
546            .with_description("Total number of database operations")
547            .with_unit("operations")
548            .build();
549
550        let errors_total = meter
551            .u64_counter(metric_names::DB_CLIENT_ERRORS_TOTAL)
552            .with_description("Total number of operation errors")
553            .with_unit("errors")
554            .build();
555
556        let connections_wait_time = meter
557            .f64_histogram(metric_names::DB_CLIENT_CONNECTIONS_WAIT_TIME)
558            .with_description("Time spent waiting for a connection")
559            .with_unit("s")
560            .build();
561
562        let mut base_attributes = vec![
563            KeyValue::new(attributes::DB_SYSTEM, DB_SYSTEM),
564            KeyValue::new(attributes::SERVER_ADDRESS, server_address.to_string()),
565            KeyValue::new(attributes::SERVER_PORT, i64::from(server_port)),
566        ];
567
568        if let Some(name) = pool_name {
569            base_attributes.push(KeyValue::new("db.client.pool.name", name.to_string()));
570        }
571
572        Self {
573            connections_usage,
574            connections_idle,
575            connections_max,
576            connections_create_total,
577            connections_close_total,
578            operation_duration,
579            operations_total,
580            errors_total,
581            connections_wait_time,
582            base_attributes,
583        }
584    }
585
586    /// Record pool connection status.
587    pub fn record_pool_status(&self, in_use: u64, idle: u64, max: u64) {
588        self.connections_usage.record(in_use, &self.base_attributes);
589        self.connections_idle.record(idle, &self.base_attributes);
590        self.connections_max.record(max, &self.base_attributes);
591    }
592
593    /// Record a connection being created.
594    pub fn record_connection_created(&self) {
595        self.connections_create_total.add(1, &self.base_attributes);
596    }
597
598    /// Record a connection being closed.
599    pub fn record_connection_closed(&self) {
600        self.connections_close_total.add(1, &self.base_attributes);
601    }
602
603    /// Record an operation duration.
604    pub fn record_operation(&self, operation: &str, duration_seconds: f64, success: bool) {
605        use opentelemetry::KeyValue;
606
607        let mut attrs = self.base_attributes.clone();
608        attrs.push(KeyValue::new(
609            attributes::DB_OPERATION,
610            operation.to_string(),
611        ));
612        attrs.push(KeyValue::new("db.operation.success", success));
613
614        self.operations_total.add(1, &attrs);
615        self.operation_duration.record(duration_seconds, &attrs);
616
617        if !success {
618            self.errors_total.add(1, &attrs);
619        }
620    }
621
622    /// Record time spent waiting for a connection from the pool.
623    pub fn record_connection_wait(&self, duration_seconds: f64) {
624        self.connections_wait_time
625            .record(duration_seconds, &self.base_attributes);
626    }
627}
628
629/// No-op metrics collector when otel feature is disabled.
630#[cfg(not(feature = "otel"))]
631#[derive(Debug, Clone, Default)]
632pub struct DatabaseMetrics;
633
634#[cfg(not(feature = "otel"))]
635impl DatabaseMetrics {
636    /// Create a new no-op metrics collector.
637    #[must_use]
638    pub fn new(_pool_name: Option<&str>, _server_address: &str, _server_port: u16) -> Self {
639        Self
640    }
641
642    /// Record pool status (no-op).
643    pub fn record_pool_status(&self, _in_use: u64, _idle: u64, _max: u64) {}
644
645    /// Record connection created (no-op).
646    pub fn record_connection_created(&self) {}
647
648    /// Record connection closed (no-op).
649    pub fn record_connection_closed(&self) {}
650
651    /// Record operation (no-op).
652    pub fn record_operation(&self, _operation: &str, _duration_seconds: f64, _success: bool) {}
653
654    /// Record connection wait time (no-op).
655    pub fn record_connection_wait(&self, _duration_seconds: f64) {}
656}
657
658/// Helper for timing operations.
659#[derive(Debug, Clone)]
660pub struct OperationTimer {
661    start: std::time::Instant,
662    operation: &'static str,
663}
664
665impl OperationTimer {
666    /// Start timing an operation.
667    #[must_use]
668    pub fn start(operation: &'static str) -> Self {
669        Self {
670            start: std::time::Instant::now(),
671            operation,
672        }
673    }
674
675    /// Get the elapsed time in seconds.
676    #[must_use]
677    pub fn elapsed_seconds(&self) -> f64 {
678        self.start.elapsed().as_secs_f64()
679    }
680
681    /// Get the operation name.
682    #[must_use]
683    pub fn operation(&self) -> &'static str {
684        self.operation
685    }
686
687    /// Finish timing and record the metric.
688    #[cfg(feature = "otel")]
689    pub fn finish(self, metrics: &DatabaseMetrics, success: bool) {
690        metrics.record_operation(self.operation, self.elapsed_seconds(), success);
691    }
692
693    /// Finish timing (no-op when otel is disabled).
694    #[cfg(not(feature = "otel"))]
695    pub fn finish(self, _metrics: &DatabaseMetrics, _success: bool) {}
696}
697
698#[cfg(test)]
699#[allow(clippy::unwrap_used)]
700mod tests {
701    use super::*;
702
703    #[test]
704    fn test_extract_operation() {
705        assert_eq!(extract_operation("SELECT * FROM users"), "SELECT");
706        assert_eq!(extract_operation("  select id from users"), "SELECT");
707        assert_eq!(extract_operation("INSERT INTO users VALUES (1)"), "INSERT");
708        assert_eq!(extract_operation("UPDATE users SET name = 'foo'"), "UPDATE");
709        assert_eq!(extract_operation("DELETE FROM users"), "DELETE");
710        assert_eq!(extract_operation("EXEC sp_help"), "EXECUTE");
711        assert_eq!(extract_operation("BEGIN TRANSACTION"), "BEGIN");
712        assert_eq!(extract_operation("COMMIT"), "COMMIT");
713        assert_eq!(extract_operation("ROLLBACK"), "ROLLBACK");
714        assert_eq!(extract_operation("CREATE TABLE foo"), "CREATE");
715        assert_eq!(extract_operation("unknown stuff"), "OTHER");
716    }
717
718    #[test]
719    fn test_sanitize_sql() {
720        let placeholder = "?";
721
722        // String literals
723        assert_eq!(
724            sanitize_sql("SELECT * FROM users WHERE name = 'Alice'", placeholder),
725            "SELECT * FROM users WHERE name = ?"
726        );
727
728        // Multiple strings
729        assert_eq!(
730            sanitize_sql("INSERT INTO t VALUES ('a', 'b')", placeholder),
731            "INSERT INTO t VALUES (?, ?)"
732        );
733
734        // Escaped quotes
735        assert_eq!(
736            sanitize_sql("SELECT * WHERE name = 'O''Brien'", placeholder),
737            "SELECT * WHERE name = ?"
738        );
739
740        // Numbers
741        assert_eq!(
742            sanitize_sql("SELECT * WHERE id = 123", placeholder),
743            "SELECT * WHERE id = ?"
744        );
745
746        // Mixed
747        assert_eq!(
748            sanitize_sql("SELECT * WHERE id = 42 AND name = 'test'", placeholder),
749            "SELECT * WHERE id = ? AND name = ?"
750        );
751    }
752
753    #[test]
754    fn test_truncate_string() {
755        assert_eq!(truncate_string("hello", 10), "hello");
756        assert_eq!(truncate_string("hello world", 8), "hello...");
757        assert_eq!(truncate_string("hi", 2), "hi");
758    }
759
760    #[test]
761    fn test_sanitization_config_default() {
762        let config = SanitizationConfig::default();
763        assert!(config.enabled);
764        assert_eq!(config.max_length, 2048);
765        assert_eq!(config.placeholder, "?");
766    }
767
768    #[test]
769    fn test_sanitization_config_no_sanitization() {
770        let config = SanitizationConfig::no_sanitization();
771        assert!(!config.enabled);
772
773        let sql = "SELECT * FROM users WHERE name = 'Alice'";
774        assert_eq!(config.sanitize(sql), sql);
775    }
776}