earl_protocol_sql/
sandbox.rs1use std::sync::Once;
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use serde_json::{Map, Value};
6use sqlx::{Column, Row, TypeInfo};
7
8static INSTALL_DRIVERS: Once = Once::new();
9
10pub async fn execute_query(
12 connection_url: &str,
13 query: &str,
14 params: &[Value],
15 read_only: bool,
16 max_rows: usize,
17 timeout: Duration,
18) -> Result<Vec<Map<String, Value>>> {
19 INSTALL_DRIVERS.call_once(|| {
20 sqlx::any::install_default_drivers();
21 });
22
23 let connection_url: std::borrow::Cow<'_, str> = if read_only
25 && connection_url.to_ascii_lowercase().starts_with("sqlite")
26 && !connection_url.to_ascii_lowercase().contains("mode=ro")
27 {
28 let separator = if connection_url.contains('?') {
29 "&"
30 } else {
31 "?"
32 };
33 std::borrow::Cow::Owned(format!("{connection_url}{separator}mode=ro"))
34 } else {
35 std::borrow::Cow::Borrowed(connection_url)
36 };
37
38 let pool = sqlx::any::AnyPoolOptions::new()
39 .max_connections(1)
40 .acquire_timeout(timeout)
41 .connect(&connection_url)
42 .await
43 .context("failed connecting to SQL database")?;
44
45 let url_lower = connection_url.to_ascii_lowercase();
49 let use_read_only_transaction =
50 read_only && (url_lower.starts_with("postgres") || url_lower.starts_with("mysql"));
51
52 if use_read_only_transaction {
53 let begin_stmt = if url_lower.starts_with("postgres") {
54 "BEGIN READ ONLY"
55 } else {
56 "START TRANSACTION READ ONLY"
57 };
58 sqlx::query(begin_stmt)
59 .execute(&pool)
60 .await
61 .context("failed starting read-only transaction")?;
62 }
63
64 let mut sqlx_query = sqlx::query(query);
65 for param in params {
66 sqlx_query = bind_json_param(sqlx_query, param);
67 }
68
69 let query_result = tokio::time::timeout(timeout, sqlx_query.fetch_all(&pool))
70 .await
71 .map_err(|_| anyhow::anyhow!("SQL query timed out after {timeout:?}"));
72
73 if use_read_only_transaction {
75 let _ = sqlx::query("ROLLBACK").execute(&pool).await;
76 }
77
78 let rows = query_result?.context("SQL query execution failed")?;
79
80 let mut results = Vec::with_capacity(rows.len().min(max_rows));
81 for row in rows.iter().take(max_rows) {
82 results.push(row_to_json(row)?);
83 }
84
85 pool.close().await;
86
87 Ok(results)
88}
89
90fn bind_json_param<'q>(
91 query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
92 value: &'q Value,
93) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
94 match value {
95 Value::Null => query.bind(None::<String>),
96 Value::Bool(b) => query.bind(*b),
97 Value::Number(n) => {
98 if let Some(i) = n.as_i64() {
99 query.bind(i)
100 } else if let Some(f) = n.as_f64() {
101 query.bind(f)
102 } else {
103 query.bind(n.to_string())
104 }
105 }
106 Value::String(s) => query.bind(s.as_str()),
107 _ => query.bind(serde_json::to_string(value).unwrap_or_default()),
108 }
109}
110
111fn row_to_json(row: &sqlx::any::AnyRow) -> Result<Map<String, Value>> {
112 let mut map = Map::new();
113 for col in row.columns() {
114 let name = col.name().to_string();
115 let type_name = col.type_info().name();
116 let value = match type_name {
117 "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" => row
118 .try_get::<i64, _>(col.ordinal())
119 .map(|v| Value::Number(v.into()))
120 .unwrap_or(Value::Null),
121 "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "NUMERIC" => row
122 .try_get::<f64, _>(col.ordinal())
123 .ok()
124 .and_then(|v| serde_json::Number::from_f64(v).map(Value::Number))
125 .unwrap_or(Value::Null),
126 "BOOLEAN" | "BOOL" => row
127 .try_get::<bool, _>(col.ordinal())
128 .map(Value::Bool)
129 .unwrap_or(Value::Null),
130 _ => {
131 let ordinal = col.ordinal();
134 if let Ok(v) = row.try_get::<i64, _>(ordinal) {
135 Value::Number(v.into())
136 } else if let Ok(v) = row.try_get::<f64, _>(ordinal) {
137 serde_json::Number::from_f64(v)
138 .map(Value::Number)
139 .unwrap_or(Value::Null)
140 } else if let Ok(v) = row.try_get::<bool, _>(ordinal) {
141 Value::Bool(v)
142 } else if let Ok(v) = row.try_get::<String, _>(ordinal) {
143 Value::String(v)
144 } else {
145 Value::Null
146 }
147 }
148 };
149 map.insert(name, value);
150 }
151 Ok(map)
152}