Skip to main content

earl_protocol_sql/
sandbox.rs

1use std::sync::Once;
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use serde_json::{Map, Value};
6use sqlx::{Column, Row, TypeInfo};
7
8static INSTALL_DRIVERS: Once = Once::new();
9
10/// Execute a SQL query against the given connection URL and return the result rows as JSON objects.
11pub async fn execute_query(
12    connection_url: &str,
13    query: &str,
14    params: &[Value],
15    read_only: bool,
16    max_rows: usize,
17    timeout: Duration,
18) -> Result<Vec<Map<String, Value>>> {
19    INSTALL_DRIVERS.call_once(|| {
20        sqlx::any::install_default_drivers();
21    });
22
23    // For SQLite, enforce read-only mode via the connection URL before connecting.
24    let connection_url: std::borrow::Cow<'_, str> = if read_only
25        && connection_url.to_ascii_lowercase().starts_with("sqlite")
26        && !connection_url.to_ascii_lowercase().contains("mode=ro")
27    {
28        let separator = if connection_url.contains('?') {
29            "&"
30        } else {
31            "?"
32        };
33        std::borrow::Cow::Owned(format!("{connection_url}{separator}mode=ro"))
34    } else {
35        std::borrow::Cow::Borrowed(connection_url)
36    };
37
38    let pool = sqlx::any::AnyPoolOptions::new()
39        .max_connections(1)
40        .acquire_timeout(timeout)
41        .connect(&connection_url)
42        .await
43        .context("failed connecting to SQL database")?;
44
45    // For PostgreSQL and MySQL, enforce read-only mode via explicit read-only
46    // transactions instead of SET commands (which can be overridden by users).
47    // SQLite is already handled via mode=ro connection parameter above.
48    let url_lower = connection_url.to_ascii_lowercase();
49    let use_read_only_transaction =
50        read_only && (url_lower.starts_with("postgres") || url_lower.starts_with("mysql"));
51
52    if use_read_only_transaction {
53        let begin_stmt = if url_lower.starts_with("postgres") {
54            "BEGIN READ ONLY"
55        } else {
56            "START TRANSACTION READ ONLY"
57        };
58        sqlx::query(begin_stmt)
59            .execute(&pool)
60            .await
61            .context("failed starting read-only transaction")?;
62    }
63
64    let mut sqlx_query = sqlx::query(query);
65    for param in params {
66        sqlx_query = bind_json_param(sqlx_query, param);
67    }
68
69    let query_result = tokio::time::timeout(timeout, sqlx_query.fetch_all(&pool))
70        .await
71        .map_err(|_| anyhow::anyhow!("SQL query timed out after {timeout:?}"));
72
73    // Always rollback the read-only transaction, even on error.
74    if use_read_only_transaction {
75        let _ = sqlx::query("ROLLBACK").execute(&pool).await;
76    }
77
78    let rows = query_result?.context("SQL query execution failed")?;
79
80    let mut results = Vec::with_capacity(rows.len().min(max_rows));
81    for row in rows.iter().take(max_rows) {
82        results.push(row_to_json(row)?);
83    }
84
85    pool.close().await;
86
87    Ok(results)
88}
89
90fn bind_json_param<'q>(
91    query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
92    value: &'q Value,
93) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
94    match value {
95        Value::Null => query.bind(None::<String>),
96        Value::Bool(b) => query.bind(*b),
97        Value::Number(n) => {
98            if let Some(i) = n.as_i64() {
99                query.bind(i)
100            } else if let Some(f) = n.as_f64() {
101                query.bind(f)
102            } else {
103                query.bind(n.to_string())
104            }
105        }
106        Value::String(s) => query.bind(s.as_str()),
107        _ => query.bind(serde_json::to_string(value).unwrap_or_default()),
108    }
109}
110
111fn row_to_json(row: &sqlx::any::AnyRow) -> Result<Map<String, Value>> {
112    let mut map = Map::new();
113    for col in row.columns() {
114        let name = col.name().to_string();
115        let type_name = col.type_info().name();
116        let value = match type_name {
117            "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" => row
118                .try_get::<i64, _>(col.ordinal())
119                .map(|v| Value::Number(v.into()))
120                .unwrap_or(Value::Null),
121            "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "NUMERIC" => row
122                .try_get::<f64, _>(col.ordinal())
123                .ok()
124                .and_then(|v| serde_json::Number::from_f64(v).map(Value::Number))
125                .unwrap_or(Value::Null),
126            "BOOLEAN" | "BOOL" => row
127                .try_get::<bool, _>(col.ordinal())
128                .map(Value::Bool)
129                .unwrap_or(Value::Null),
130            _ => {
131                // For unknown types (e.g. SQLite "NULL" for literal expressions),
132                // try decoding as each type in order: i64 -> f64 -> bool -> String.
133                let ordinal = col.ordinal();
134                if let Ok(v) = row.try_get::<i64, _>(ordinal) {
135                    Value::Number(v.into())
136                } else if let Ok(v) = row.try_get::<f64, _>(ordinal) {
137                    serde_json::Number::from_f64(v)
138                        .map(Value::Number)
139                        .unwrap_or(Value::Null)
140                } else if let Ok(v) = row.try_get::<bool, _>(ordinal) {
141                    Value::Bool(v)
142                } else if let Ok(v) = row.try_get::<String, _>(ordinal) {
143                    Value::String(v)
144                } else {
145                    Value::Null
146                }
147            }
148        };
149        map.insert(name, value);
150    }
151    Ok(map)
152}