Skip to main content

earl_protocol_sql/
sandbox.rs

1use std::sync::Once;
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use serde_json::{Map, Value};
6use sqlx::{Column, Row, TypeInfo};
7
8static INSTALL_DRIVERS: Once = Once::new();
9
10/// Execute a SQL query against the given connection URL and return the result rows as JSON objects.
11pub async fn execute_query(
12    connection_url: &str,
13    query: &str,
14    params: &[Value],
15    read_only: bool,
16    max_rows: usize,
17    timeout: Duration,
18) -> Result<Vec<Map<String, Value>>> {
19    INSTALL_DRIVERS.call_once(|| {
20        sqlx::any::install_default_drivers();
21    });
22
23    // For SQLite, enforce read-only mode via the connection URL before connecting.
24    let connection_url: std::borrow::Cow<'_, str> = if read_only
25        && connection_url.to_ascii_lowercase().starts_with("sqlite")
26        && !connection_url.to_ascii_lowercase().contains("mode=ro")
27    {
28        let separator = if connection_url.contains('?') {
29            "&"
30        } else {
31            "?"
32        };
33        std::borrow::Cow::Owned(format!("{connection_url}{separator}mode=ro"))
34    } else {
35        std::borrow::Cow::Borrowed(connection_url)
36    };
37
38    let pool = sqlx::any::AnyPoolOptions::new()
39        .max_connections(1)
40        .acquire_timeout(timeout)
41        .connect(&connection_url)
42        .await
43        .context("failed connecting to SQL database")?;
44
45    // For PostgreSQL and MySQL, enforce read-only mode via explicit read-only
46    // transactions instead of SET commands (which can be overridden by users).
47    // SQLite is already handled via mode=ro connection parameter above.
48    let url_lower = connection_url.to_ascii_lowercase();
49    let use_read_only_transaction =
50        read_only && (url_lower.starts_with("postgres") || url_lower.starts_with("mysql"));
51
52    if use_read_only_transaction {
53        let begin_stmt = if url_lower.starts_with("postgres") {
54            "BEGIN READ ONLY"
55        } else {
56            "START TRANSACTION READ ONLY"
57        };
58        sqlx::query(begin_stmt)
59            .execute(&pool)
60            .await
61            .context("failed starting read-only transaction")?;
62    }
63
64    // codeql[rust/cleartext-storage-database] - False positive: the connection URL (which may
65    // contain credentials) is used to *connect* to the database, not to store data in it.
66    let mut sqlx_query = sqlx::query(query);
67    for param in params {
68        sqlx_query = bind_json_param(sqlx_query, param);
69    }
70
71    let query_result = tokio::time::timeout(timeout, sqlx_query.fetch_all(&pool))
72        .await
73        .map_err(|_| anyhow::anyhow!("SQL query timed out after {timeout:?}"));
74
75    // Always rollback the read-only transaction, even on error.
76    if use_read_only_transaction {
77        let _ = sqlx::query("ROLLBACK").execute(&pool).await;
78    }
79
80    let rows = query_result?.context("SQL query execution failed")?;
81
82    let mut results = Vec::with_capacity(rows.len().min(max_rows));
83    for row in rows.iter().take(max_rows) {
84        results.push(row_to_json(row)?);
85    }
86
87    pool.close().await;
88
89    Ok(results)
90}
91
92fn bind_json_param<'q>(
93    query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
94    value: &'q Value,
95) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
96    match value {
97        Value::Null => query.bind(None::<String>),
98        Value::Bool(b) => query.bind(*b),
99        Value::Number(n) => {
100            if let Some(i) = n.as_i64() {
101                query.bind(i)
102            } else if let Some(f) = n.as_f64() {
103                query.bind(f)
104            } else {
105                query.bind(n.to_string())
106            }
107        }
108        Value::String(s) => query.bind(s.as_str()),
109        _ => query.bind(serde_json::to_string(value).unwrap_or_default()),
110    }
111}
112
113fn row_to_json(row: &sqlx::any::AnyRow) -> Result<Map<String, Value>> {
114    let mut map = Map::new();
115    for col in row.columns() {
116        let name = col.name().to_string();
117        let type_name = col.type_info().name();
118        let value = match type_name {
119            "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" => row
120                .try_get::<i64, _>(col.ordinal())
121                .map(|v| Value::Number(v.into()))
122                .unwrap_or(Value::Null),
123            "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "NUMERIC" => row
124                .try_get::<f64, _>(col.ordinal())
125                .ok()
126                .and_then(|v| serde_json::Number::from_f64(v).map(Value::Number))
127                .unwrap_or(Value::Null),
128            "BOOLEAN" | "BOOL" => row
129                .try_get::<bool, _>(col.ordinal())
130                .map(Value::Bool)
131                .unwrap_or(Value::Null),
132            _ => {
133                // For unknown types (e.g. SQLite "NULL" for literal expressions),
134                // try decoding as each type in order: i64 -> f64 -> bool -> String.
135                let ordinal = col.ordinal();
136                if let Ok(v) = row.try_get::<i64, _>(ordinal) {
137                    Value::Number(v.into())
138                } else if let Ok(v) = row.try_get::<f64, _>(ordinal) {
139                    serde_json::Number::from_f64(v)
140                        .map(Value::Number)
141                        .unwrap_or(Value::Null)
142                } else if let Ok(v) = row.try_get::<bool, _>(ordinal) {
143                    Value::Bool(v)
144                } else if let Ok(v) = row.try_get::<String, _>(ordinal) {
145                    Value::String(v)
146                } else {
147                    Value::Null
148                }
149            }
150        };
151        map.insert(name, value);
152    }
153    Ok(map)
154}