Skip to main content

forge_runtime/observability/
db.rs

1use opentelemetry::global;
2use opentelemetry::metrics::{Gauge, Histogram};
3use sqlx::PgPool;
4use std::sync::OnceLock;
5use std::time::{Duration, Instant};
6use tracing::{Instrument, info_span};
7
8const DB_SYSTEM: &str = "db.system";
9const DB_OPERATION_NAME: &str = "db.operation.name";
10const DB_SYSTEM_POSTGRESQL: &str = "postgresql";
11
12// Catch performance regressions before they hit production
13const SLOW_QUERY_THRESHOLD: Duration = Duration::from_millis(500);
14
15static DB_METRICS: OnceLock<DbMetrics> = OnceLock::new();
16
17struct DbMetrics {
18    query_duration: Histogram<f64>,
19    pool_connections_active: Gauge<u64>,
20    pool_connections_idle: Gauge<u64>,
21    pool_connections_max: Gauge<u64>,
22}
23
24fn get_metrics() -> &'static DbMetrics {
25    DB_METRICS.get_or_init(|| {
26        let meter = global::meter("forge.db");
27
28        DbMetrics {
29            query_duration: meter
30                .f64_histogram("db.client.operation.duration")
31                .with_description("Duration of database operations")
32                .with_unit("s")
33                .build(),
34            pool_connections_active: meter
35                .u64_gauge("db.client.connection.count")
36                .with_description("Number of active database connections")
37                .build(),
38            pool_connections_idle: meter
39                .u64_gauge("db.client.connection.idle_count")
40                .with_description("Number of idle database connections")
41                .build(),
42            pool_connections_max: meter
43                .u64_gauge("db.client.connection.max")
44                .with_description("Maximum number of database connections")
45                .build(),
46        }
47    })
48}
49
50/// Record pool connection metrics from a PgPool.
51pub fn record_pool_metrics(pool: &PgPool) {
52    let metrics = get_metrics();
53    let pool_size = pool.size();
54    let idle_count = pool.num_idle();
55    let max_connections = pool.options().get_max_connections();
56
57    metrics.pool_connections_active.record(
58        (pool_size - idle_count as u32) as u64,
59        &[opentelemetry::KeyValue::new(
60            DB_SYSTEM,
61            DB_SYSTEM_POSTGRESQL,
62        )],
63    );
64    metrics.pool_connections_idle.record(
65        idle_count as u64,
66        &[opentelemetry::KeyValue::new(
67            DB_SYSTEM,
68            DB_SYSTEM_POSTGRESQL,
69        )],
70    );
71    metrics.pool_connections_max.record(
72        max_connections as u64,
73        &[opentelemetry::KeyValue::new(
74            DB_SYSTEM,
75            DB_SYSTEM_POSTGRESQL,
76        )],
77    );
78}
79
80/// Record a query execution with its duration.
81pub fn record_query_duration(operation: &str, duration: Duration) {
82    let metrics = get_metrics();
83    metrics.query_duration.record(
84        duration.as_secs_f64(),
85        &[
86            opentelemetry::KeyValue::new(DB_SYSTEM, DB_SYSTEM_POSTGRESQL),
87            opentelemetry::KeyValue::new(DB_OPERATION_NAME, operation.to_string()),
88        ],
89    );
90}
91
92/// Extract the table name from a simple SQL query.
93/// Returns None for complex queries or when table cannot be determined.
94pub fn extract_table_name(sql: &str) -> Option<&str> {
95    let sql = sql.trim();
96    let upper = sql.to_uppercase();
97
98    // Handle common patterns
99    if upper.starts_with("SELECT") {
100        // SELECT ... FROM table_name ...
101        if let Some(from_pos) = upper.find(" FROM ") {
102            let after_from = &sql[from_pos + 6..];
103            return extract_first_identifier(after_from.trim_start());
104        }
105    } else if upper.starts_with("INSERT INTO ") {
106        let after_into = &sql[12..];
107        return extract_first_identifier(after_into.trim_start());
108    } else if upper.starts_with("UPDATE ") {
109        let after_update = &sql[7..];
110        return extract_first_identifier(after_update.trim_start());
111    } else if upper.starts_with("DELETE FROM ") {
112        let after_from = &sql[12..];
113        return extract_first_identifier(after_from.trim_start());
114    } else if upper.starts_with("CREATE TABLE ") {
115        let after_table = if upper.starts_with("CREATE TABLE IF NOT EXISTS ") {
116            &sql[27..]
117        } else {
118            &sql[13..]
119        };
120        return extract_first_identifier(after_table.trim_start());
121    }
122
123    None
124}
125
126fn extract_first_identifier(s: &str) -> Option<&str> {
127    let end = s
128        .find(|c: char| c.is_whitespace() || c == '(' || c == ',' || c == ';')
129        .unwrap_or(s.len());
130
131    if end > 0 { Some(&s[..end]) } else { None }
132}
133
134/// Execute a database operation with tracing instrumentation.
135/// This creates a span and records the duration as a metric.
136pub async fn instrumented_query<F, T, E>(operation: &str, table: Option<&str>, f: F) -> Result<T, E>
137where
138    F: std::future::Future<Output = Result<T, E>>,
139{
140    let span = if let Some(tbl) = table {
141        info_span!(
142            "db.query",
143            db.system = DB_SYSTEM_POSTGRESQL,
144            db.operation.name = operation,
145            db.collection.name = tbl,
146        )
147    } else {
148        info_span!(
149            "db.query",
150            db.system = DB_SYSTEM_POSTGRESQL,
151            db.operation.name = operation,
152        )
153    };
154
155    let start = Instant::now();
156    let result = f.instrument(span).await;
157    let elapsed = start.elapsed();
158    record_query_duration(operation, elapsed);
159
160    if elapsed > SLOW_QUERY_THRESHOLD {
161        tracing::warn!(
162            db.operation.name = operation,
163            db.collection.name = table,
164            duration_ms = elapsed.as_millis() as u64,
165            "Slow query detected"
166        );
167    }
168
169    result
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_extract_table_select() {
178        assert_eq!(
179            extract_table_name("SELECT * FROM users WHERE id = 1"),
180            Some("users")
181        );
182        assert_eq!(
183            extract_table_name("SELECT id, name FROM accounts"),
184            Some("accounts")
185        );
186        assert_eq!(extract_table_name("select * from Orders"), Some("Orders"));
187    }
188
189    #[test]
190    fn test_extract_table_insert() {
191        assert_eq!(
192            extract_table_name("INSERT INTO users (id, name) VALUES (1, 'test')"),
193            Some("users")
194        );
195    }
196
197    #[test]
198    fn test_extract_table_update() {
199        assert_eq!(
200            extract_table_name("UPDATE users SET name = 'test' WHERE id = 1"),
201            Some("users")
202        );
203    }
204
205    #[test]
206    fn test_extract_table_delete() {
207        assert_eq!(
208            extract_table_name("DELETE FROM users WHERE id = 1"),
209            Some("users")
210        );
211    }
212
213    #[test]
214    fn test_extract_table_create() {
215        assert_eq!(
216            extract_table_name("CREATE TABLE users (id UUID PRIMARY KEY)"),
217            Some("users")
218        );
219        assert_eq!(
220            extract_table_name("CREATE TABLE IF NOT EXISTS accounts (id INT)"),
221            Some("accounts")
222        );
223    }
224
225    #[test]
226    fn test_extract_table_complex_query() {
227        // Complex queries should still find the first table
228        assert_eq!(
229            extract_table_name("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id"),
230            Some("users")
231        );
232    }
233}