Skip to main content

forge_runtime/observability/
db.rs

1use opentelemetry::global;
2use opentelemetry::metrics::{Gauge, Histogram};
3use sqlx::PgPool;
4use std::sync::OnceLock;
5use std::time::{Duration, Instant};
6use tracing::{Instrument, info_span};
7
8const DB_SYSTEM: &str = "db.system";
9const DB_OPERATION_NAME: &str = "db.operation.name";
10const DB_SYSTEM_POSTGRESQL: &str = "postgresql";
11
12static DB_METRICS: OnceLock<DbMetrics> = OnceLock::new();
13
14struct DbMetrics {
15    query_duration: Histogram<f64>,
16    pool_connections_active: Gauge<u64>,
17    pool_connections_idle: Gauge<u64>,
18    pool_connections_max: Gauge<u64>,
19}
20
21fn get_metrics() -> &'static DbMetrics {
22    DB_METRICS.get_or_init(|| {
23        let meter = global::meter("forge.db");
24
25        DbMetrics {
26            query_duration: meter
27                .f64_histogram("db.client.operation.duration")
28                .with_description("Duration of database operations")
29                .with_unit("s")
30                .build(),
31            pool_connections_active: meter
32                .u64_gauge("db.client.connection.count")
33                .with_description("Number of active database connections")
34                .build(),
35            pool_connections_idle: meter
36                .u64_gauge("db.client.connection.idle_count")
37                .with_description("Number of idle database connections")
38                .build(),
39            pool_connections_max: meter
40                .u64_gauge("db.client.connection.max")
41                .with_description("Maximum number of database connections")
42                .build(),
43        }
44    })
45}
46
47/// Record pool connection metrics from a PgPool.
48pub fn record_pool_metrics(pool: &PgPool) {
49    let metrics = get_metrics();
50    let pool_size = pool.size();
51    let idle_count = pool.num_idle();
52    let max_connections = pool.options().get_max_connections();
53
54    metrics.pool_connections_active.record(
55        (pool_size - idle_count as u32) as u64,
56        &[opentelemetry::KeyValue::new(
57            DB_SYSTEM,
58            DB_SYSTEM_POSTGRESQL,
59        )],
60    );
61    metrics.pool_connections_idle.record(
62        idle_count as u64,
63        &[opentelemetry::KeyValue::new(
64            DB_SYSTEM,
65            DB_SYSTEM_POSTGRESQL,
66        )],
67    );
68    metrics.pool_connections_max.record(
69        max_connections as u64,
70        &[opentelemetry::KeyValue::new(
71            DB_SYSTEM,
72            DB_SYSTEM_POSTGRESQL,
73        )],
74    );
75}
76
77/// Record a query execution with its duration.
78pub fn record_query_duration(operation: &str, duration: Duration) {
79    let metrics = get_metrics();
80    metrics.query_duration.record(
81        duration.as_secs_f64(),
82        &[
83            opentelemetry::KeyValue::new(DB_SYSTEM, DB_SYSTEM_POSTGRESQL),
84            opentelemetry::KeyValue::new(DB_OPERATION_NAME, operation.to_string()),
85        ],
86    );
87}
88
89/// Extract the table name from a simple SQL query.
90/// Returns None for complex queries or when table cannot be determined.
91pub fn extract_table_name(sql: &str) -> Option<&str> {
92    let sql = sql.trim();
93    let upper = sql.to_uppercase();
94
95    // Handle common patterns
96    if upper.starts_with("SELECT") {
97        // SELECT ... FROM table_name ...
98        if let Some(from_pos) = upper.find(" FROM ") {
99            let after_from = &sql[from_pos + 6..];
100            return extract_first_identifier(after_from.trim_start());
101        }
102    } else if upper.starts_with("INSERT INTO ") {
103        let after_into = &sql[12..];
104        return extract_first_identifier(after_into.trim_start());
105    } else if upper.starts_with("UPDATE ") {
106        let after_update = &sql[7..];
107        return extract_first_identifier(after_update.trim_start());
108    } else if upper.starts_with("DELETE FROM ") {
109        let after_from = &sql[12..];
110        return extract_first_identifier(after_from.trim_start());
111    } else if upper.starts_with("CREATE TABLE ") {
112        let after_table = if upper.starts_with("CREATE TABLE IF NOT EXISTS ") {
113            &sql[27..]
114        } else {
115            &sql[13..]
116        };
117        return extract_first_identifier(after_table.trim_start());
118    }
119
120    None
121}
122
123fn extract_first_identifier(s: &str) -> Option<&str> {
124    let end = s
125        .find(|c: char| c.is_whitespace() || c == '(' || c == ',' || c == ';')
126        .unwrap_or(s.len());
127
128    if end > 0 { Some(&s[..end]) } else { None }
129}
130
131/// Execute a database operation with tracing instrumentation.
132/// This creates a span and records the duration as a metric.
133pub async fn instrumented_query<F, T, E>(operation: &str, table: Option<&str>, f: F) -> Result<T, E>
134where
135    F: std::future::Future<Output = Result<T, E>>,
136{
137    let span = if let Some(tbl) = table {
138        info_span!(
139            "db.query",
140            db.system = DB_SYSTEM_POSTGRESQL,
141            db.operation.name = operation,
142            db.collection.name = tbl,
143        )
144    } else {
145        info_span!(
146            "db.query",
147            db.system = DB_SYSTEM_POSTGRESQL,
148            db.operation.name = operation,
149        )
150    };
151
152    let start = Instant::now();
153    let result = f.instrument(span).await;
154    record_query_duration(operation, start.elapsed());
155
156    result
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_extract_table_select() {
165        assert_eq!(
166            extract_table_name("SELECT * FROM users WHERE id = 1"),
167            Some("users")
168        );
169        assert_eq!(
170            extract_table_name("SELECT id, name FROM accounts"),
171            Some("accounts")
172        );
173        assert_eq!(extract_table_name("select * from Orders"), Some("Orders"));
174    }
175
176    #[test]
177    fn test_extract_table_insert() {
178        assert_eq!(
179            extract_table_name("INSERT INTO users (id, name) VALUES (1, 'test')"),
180            Some("users")
181        );
182    }
183
184    #[test]
185    fn test_extract_table_update() {
186        assert_eq!(
187            extract_table_name("UPDATE users SET name = 'test' WHERE id = 1"),
188            Some("users")
189        );
190    }
191
192    #[test]
193    fn test_extract_table_delete() {
194        assert_eq!(
195            extract_table_name("DELETE FROM users WHERE id = 1"),
196            Some("users")
197        );
198    }
199
200    #[test]
201    fn test_extract_table_create() {
202        assert_eq!(
203            extract_table_name("CREATE TABLE users (id UUID PRIMARY KEY)"),
204            Some("users")
205        );
206        assert_eq!(
207            extract_table_name("CREATE TABLE IF NOT EXISTS accounts (id INT)"),
208            Some("accounts")
209        );
210    }
211
212    #[test]
213    fn test_extract_table_complex_query() {
214        // Complex queries should still find the first table
215        assert_eq!(
216            extract_table_name("SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id"),
217            Some("users")
218        );
219    }
220}