#[cfg(feature = "otel")]
use opentelemetry::{
KeyValue, global,
trace::{Span, SpanKind, Status, Tracer},
};
pub const DB_SYSTEM: &str = "mssql";
pub mod span_names {
pub const CONNECT: &str = "mssql.connect";
pub const QUERY: &str = "mssql.query";
pub const EXECUTE: &str = "mssql.execute";
pub const BEGIN_TRANSACTION: &str = "mssql.begin_transaction";
pub const COMMIT: &str = "mssql.commit";
pub const ROLLBACK: &str = "mssql.rollback";
pub const SAVEPOINT: &str = "mssql.savepoint";
pub const BULK_INSERT: &str = "mssql.bulk_insert";
}
pub mod attributes {
pub const DB_SYSTEM: &str = "db.system";
pub const DB_NAME: &str = "db.name";
pub const DB_STATEMENT: &str = "db.statement";
pub const DB_OPERATION: &str = "db.operation";
pub const SERVER_ADDRESS: &str = "server.address";
pub const SERVER_PORT: &str = "server.port";
pub const DB_ROWS_AFFECTED: &str = "db.rows_affected";
pub const DB_ISOLATION_LEVEL: &str = "db.mssql.isolation_level";
pub const DB_CONNECTION_ID: &str = "db.connection_id";
pub const ERROR_TYPE: &str = "error.type";
}
#[derive(Debug, Clone)]
pub struct SanitizationConfig {
pub enabled: bool,
pub max_length: usize,
pub placeholder: String,
}
impl Default for SanitizationConfig {
fn default() -> Self {
Self {
enabled: true,
max_length: 2048,
placeholder: "?".to_string(),
}
}
}
impl SanitizationConfig {
#[must_use]
pub fn no_sanitization() -> Self {
Self {
enabled: false,
max_length: usize::MAX,
placeholder: String::new(),
}
}
#[must_use]
pub fn sanitize(&self, sql: &str) -> String {
if !self.enabled {
return truncate_string(sql, self.max_length);
}
let sanitized = sanitize_sql(sql, &self.placeholder);
truncate_string(&sanitized, self.max_length)
}
}
fn sanitize_sql(sql: &str, placeholder: &str) -> String {
let mut result = String::with_capacity(sql.len());
let mut chars = sql.chars().peekable();
let mut in_string = false;
let mut string_char = ' ';
while let Some(c) = chars.next() {
if in_string {
if c == string_char {
if chars.peek() == Some(&string_char) {
chars.next();
continue;
}
in_string = false;
result.push_str(placeholder);
}
continue;
}
if c == '\'' || c == '"' {
in_string = true;
string_char = c;
continue;
}
if c.is_ascii_digit() && !result.ends_with(|ch: char| ch.is_alphanumeric() || ch == '_') {
while chars
.peek()
.is_some_and(|ch| ch.is_ascii_digit() || *ch == '.')
{
chars.next();
}
result.push_str(placeholder);
continue;
}
result.push(c);
}
if in_string {
result.push_str(placeholder);
}
result
}
fn truncate_string(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}...", &s[..max_len.saturating_sub(3)])
}
}
#[must_use]
pub fn extract_operation(sql: &str) -> &'static str {
let sql_upper = sql.trim().to_uppercase();
if sql_upper.starts_with("SELECT") {
"SELECT"
} else if sql_upper.starts_with("INSERT") {
"INSERT"
} else if sql_upper.starts_with("UPDATE") {
"UPDATE"
} else if sql_upper.starts_with("DELETE") {
"DELETE"
} else if sql_upper.starts_with("EXEC") || sql_upper.starts_with("EXECUTE") {
"EXECUTE"
} else if sql_upper.starts_with("BEGIN TRAN") {
"BEGIN"
} else if sql_upper.starts_with("COMMIT") {
"COMMIT"
} else if sql_upper.starts_with("ROLLBACK") {
"ROLLBACK"
} else if sql_upper.starts_with("CREATE") {
"CREATE"
} else if sql_upper.starts_with("ALTER") {
"ALTER"
} else if sql_upper.starts_with("DROP") {
"DROP"
} else {
"OTHER"
}
}
#[cfg(feature = "otel")]
#[derive(Debug, Clone)]
pub struct InstrumentationContext {
pub server_address: String,
pub server_port: u16,
pub database: Option<String>,
pub sanitization: SanitizationConfig,
}
#[cfg(feature = "otel")]
impl InstrumentationContext {
#[must_use]
pub fn new(server_address: String, server_port: u16) -> Self {
Self {
server_address,
server_port,
database: None,
sanitization: SanitizationConfig::default(),
}
}
#[must_use]
pub fn with_database(mut self, database: impl Into<String>) -> Self {
self.database = Some(database.into());
self
}
#[must_use]
pub fn with_sanitization(mut self, config: SanitizationConfig) -> Self {
self.sanitization = config;
self
}
pub fn base_attributes(&self) -> Vec<KeyValue> {
let mut attrs = vec![
KeyValue::new(attributes::DB_SYSTEM, DB_SYSTEM),
KeyValue::new(attributes::SERVER_ADDRESS, self.server_address.clone()),
KeyValue::new(attributes::SERVER_PORT, i64::from(self.server_port)),
];
if let Some(ref db) = self.database {
attrs.push(KeyValue::new(attributes::DB_NAME, db.clone()));
}
attrs
}
pub fn connection_span(&self) -> impl Span {
let tracer = global::tracer("mssql-client");
let mut attrs = self.base_attributes();
attrs.push(KeyValue::new(
"db.connection_string.host",
self.server_address.clone(),
));
tracer
.span_builder(span_names::CONNECT)
.with_kind(SpanKind::Client)
.with_attributes(attrs)
.start(&tracer)
}
pub fn query_span(&self, sql: &str) -> impl Span {
let tracer = global::tracer("mssql-client");
let mut attrs = self.base_attributes();
let operation = extract_operation(sql);
attrs.push(KeyValue::new(attributes::DB_OPERATION, operation));
attrs.push(KeyValue::new(
attributes::DB_STATEMENT,
self.sanitization.sanitize(sql),
));
tracer
.span_builder(span_names::QUERY)
.with_kind(SpanKind::Client)
.with_attributes(attrs)
.start(&tracer)
}
pub fn transaction_span(&self, operation: &str) -> impl Span {
let tracer = global::tracer("mssql-client");
let mut attrs = self.base_attributes();
attrs.push(KeyValue::new(
attributes::DB_OPERATION,
operation.to_string(),
));
let span_name = match operation {
"BEGIN" => span_names::BEGIN_TRANSACTION,
"COMMIT" => span_names::COMMIT,
"ROLLBACK" => span_names::ROLLBACK,
_ => span_names::SAVEPOINT,
};
tracer
.span_builder(span_name)
.with_kind(SpanKind::Client)
.with_attributes(attrs)
.start(&tracer)
}
pub fn record_error(span: &mut impl Span, error: &crate::error::Error) {
span.set_status(Status::error(error.to_string()));
span.record_error(error);
}
pub fn record_success(span: &mut impl Span, rows_affected: Option<u64>) {
span.set_status(Status::Ok);
if let Some(rows) = rows_affected {
span.set_attribute(KeyValue::new(attributes::DB_ROWS_AFFECTED, rows as i64));
}
}
}
#[cfg(not(feature = "otel"))]
#[derive(Debug, Clone, Default)]
pub struct InstrumentationContext;
#[cfg(not(feature = "otel"))]
impl InstrumentationContext {
#[must_use]
pub fn new(_server_address: String, _server_port: u16) -> Self {
Self
}
#[must_use]
pub fn with_database(self, _database: impl Into<String>) -> Self {
self
}
#[must_use]
pub fn with_sanitization(self, _config: SanitizationConfig) -> Self {
self
}
}
pub mod metric_names {
pub const DB_CLIENT_CONNECTIONS_USAGE: &str = "db.client.connections.usage";
pub const DB_CLIENT_CONNECTIONS_IDLE: &str = "db.client.connections.idle";
pub const DB_CLIENT_CONNECTIONS_MAX: &str = "db.client.connections.max";
pub const DB_CLIENT_CONNECTIONS_CREATE_TOTAL: &str = "db.client.connections.create.total";
pub const DB_CLIENT_CONNECTIONS_CLOSE_TOTAL: &str = "db.client.connections.close.total";
pub const DB_CLIENT_OPERATION_DURATION: &str = "db.client.operation.duration";
pub const DB_CLIENT_OPERATIONS_TOTAL: &str = "db.client.operations.total";
pub const DB_CLIENT_ERRORS_TOTAL: &str = "db.client.errors.total";
pub const DB_CLIENT_CONNECTIONS_WAIT_TIME: &str = "db.client.connections.wait_time";
}
#[cfg(feature = "otel")]
pub struct DatabaseMetrics {
connections_usage: opentelemetry::metrics::Gauge<u64>,
connections_idle: opentelemetry::metrics::Gauge<u64>,
connections_max: opentelemetry::metrics::Gauge<u64>,
connections_create_total: opentelemetry::metrics::Counter<u64>,
connections_close_total: opentelemetry::metrics::Counter<u64>,
operation_duration: opentelemetry::metrics::Histogram<f64>,
operations_total: opentelemetry::metrics::Counter<u64>,
errors_total: opentelemetry::metrics::Counter<u64>,
connections_wait_time: opentelemetry::metrics::Histogram<f64>,
base_attributes: Vec<opentelemetry::KeyValue>,
}
#[cfg(feature = "otel")]
impl DatabaseMetrics {
pub fn new(pool_name: Option<&str>, server_address: &str, server_port: u16) -> Self {
use opentelemetry::{KeyValue, global};
let meter = global::meter("mssql-client");
let connections_usage = meter
.u64_gauge(metric_names::DB_CLIENT_CONNECTIONS_USAGE)
.with_description("Number of connections currently in use")
.with_unit("connections")
.build();
let connections_idle = meter
.u64_gauge(metric_names::DB_CLIENT_CONNECTIONS_IDLE)
.with_description("Number of idle connections available")
.with_unit("connections")
.build();
let connections_max = meter
.u64_gauge(metric_names::DB_CLIENT_CONNECTIONS_MAX)
.with_description("Maximum number of connections allowed")
.with_unit("connections")
.build();
let connections_create_total = meter
.u64_counter(metric_names::DB_CLIENT_CONNECTIONS_CREATE_TOTAL)
.with_description("Total number of connections created")
.with_unit("connections")
.build();
let connections_close_total = meter
.u64_counter(metric_names::DB_CLIENT_CONNECTIONS_CLOSE_TOTAL)
.with_description("Total number of connections closed")
.with_unit("connections")
.build();
let operation_duration = meter
.f64_histogram(metric_names::DB_CLIENT_OPERATION_DURATION)
.with_description("Duration of database operations")
.with_unit("s")
.build();
let operations_total = meter
.u64_counter(metric_names::DB_CLIENT_OPERATIONS_TOTAL)
.with_description("Total number of database operations")
.with_unit("operations")
.build();
let errors_total = meter
.u64_counter(metric_names::DB_CLIENT_ERRORS_TOTAL)
.with_description("Total number of operation errors")
.with_unit("errors")
.build();
let connections_wait_time = meter
.f64_histogram(metric_names::DB_CLIENT_CONNECTIONS_WAIT_TIME)
.with_description("Time spent waiting for a connection")
.with_unit("s")
.build();
let mut base_attributes = vec![
KeyValue::new(attributes::DB_SYSTEM, DB_SYSTEM),
KeyValue::new(attributes::SERVER_ADDRESS, server_address.to_string()),
KeyValue::new(attributes::SERVER_PORT, i64::from(server_port)),
];
if let Some(name) = pool_name {
base_attributes.push(KeyValue::new("db.client.pool.name", name.to_string()));
}
Self {
connections_usage,
connections_idle,
connections_max,
connections_create_total,
connections_close_total,
operation_duration,
operations_total,
errors_total,
connections_wait_time,
base_attributes,
}
}
pub fn record_pool_status(&self, in_use: u64, idle: u64, max: u64) {
self.connections_usage.record(in_use, &self.base_attributes);
self.connections_idle.record(idle, &self.base_attributes);
self.connections_max.record(max, &self.base_attributes);
}
pub fn record_connection_created(&self) {
self.connections_create_total.add(1, &self.base_attributes);
}
pub fn record_connection_closed(&self) {
self.connections_close_total.add(1, &self.base_attributes);
}
pub fn record_operation(&self, operation: &str, duration_seconds: f64, success: bool) {
use opentelemetry::KeyValue;
let mut attrs = self.base_attributes.clone();
attrs.push(KeyValue::new(
attributes::DB_OPERATION,
operation.to_string(),
));
attrs.push(KeyValue::new("db.operation.success", success));
self.operations_total.add(1, &attrs);
self.operation_duration.record(duration_seconds, &attrs);
if !success {
self.errors_total.add(1, &attrs);
}
}
pub fn record_connection_wait(&self, duration_seconds: f64) {
self.connections_wait_time
.record(duration_seconds, &self.base_attributes);
}
}
#[cfg(not(feature = "otel"))]
#[derive(Debug, Clone, Default)]
pub struct DatabaseMetrics;
#[cfg(not(feature = "otel"))]
impl DatabaseMetrics {
#[must_use]
pub fn new(_pool_name: Option<&str>, _server_address: &str, _server_port: u16) -> Self {
Self
}
pub fn record_pool_status(&self, _in_use: u64, _idle: u64, _max: u64) {}
pub fn record_connection_created(&self) {}
pub fn record_connection_closed(&self) {}
pub fn record_operation(&self, _operation: &str, _duration_seconds: f64, _success: bool) {}
pub fn record_connection_wait(&self, _duration_seconds: f64) {}
}
#[derive(Debug, Clone)]
pub struct OperationTimer {
start: std::time::Instant,
operation: &'static str,
}
impl OperationTimer {
#[must_use]
pub fn start(operation: &'static str) -> Self {
Self {
start: std::time::Instant::now(),
operation,
}
}
#[must_use]
pub fn elapsed_seconds(&self) -> f64 {
self.start.elapsed().as_secs_f64()
}
#[must_use]
pub fn operation(&self) -> &'static str {
self.operation
}
#[cfg(feature = "otel")]
pub fn finish(self, metrics: &DatabaseMetrics, success: bool) {
metrics.record_operation(self.operation, self.elapsed_seconds(), success);
}
#[cfg(not(feature = "otel"))]
pub fn finish(self, _metrics: &DatabaseMetrics, _success: bool) {}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_extract_operation() {
assert_eq!(extract_operation("SELECT * FROM users"), "SELECT");
assert_eq!(extract_operation(" select id from users"), "SELECT");
assert_eq!(extract_operation("INSERT INTO users VALUES (1)"), "INSERT");
assert_eq!(extract_operation("UPDATE users SET name = 'foo'"), "UPDATE");
assert_eq!(extract_operation("DELETE FROM users"), "DELETE");
assert_eq!(extract_operation("EXEC sp_help"), "EXECUTE");
assert_eq!(extract_operation("BEGIN TRANSACTION"), "BEGIN");
assert_eq!(extract_operation("COMMIT"), "COMMIT");
assert_eq!(extract_operation("ROLLBACK"), "ROLLBACK");
assert_eq!(extract_operation("CREATE TABLE foo"), "CREATE");
assert_eq!(extract_operation("unknown stuff"), "OTHER");
}
#[test]
fn test_sanitize_sql() {
let placeholder = "?";
assert_eq!(
sanitize_sql("SELECT * FROM users WHERE name = 'Alice'", placeholder),
"SELECT * FROM users WHERE name = ?"
);
assert_eq!(
sanitize_sql("INSERT INTO t VALUES ('a', 'b')", placeholder),
"INSERT INTO t VALUES (?, ?)"
);
assert_eq!(
sanitize_sql("SELECT * WHERE name = 'O''Brien'", placeholder),
"SELECT * WHERE name = ?"
);
assert_eq!(
sanitize_sql("SELECT * WHERE id = 123", placeholder),
"SELECT * WHERE id = ?"
);
assert_eq!(
sanitize_sql("SELECT * WHERE id = 42 AND name = 'test'", placeholder),
"SELECT * WHERE id = ? AND name = ?"
);
}
#[test]
fn test_truncate_string() {
assert_eq!(truncate_string("hello", 10), "hello");
assert_eq!(truncate_string("hello world", 8), "hello...");
assert_eq!(truncate_string("hi", 2), "hi");
}
#[test]
fn test_sanitization_config_default() {
let config = SanitizationConfig::default();
assert!(config.enabled);
assert_eq!(config.max_length, 2048);
assert_eq!(config.placeholder, "?");
}
#[test]
fn test_sanitization_config_no_sanitization() {
let config = SanitizationConfig::no_sanitization();
assert!(!config.enabled);
let sql = "SELECT * FROM users WHERE name = 'Alice'";
assert_eq!(config.sanitize(sql), sql);
}
}