earl_protocol_sql/
sandbox.rs1use std::sync::Once;
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use serde_json::{Map, Value};
6use sqlx::{Column, Row, TypeInfo};
7
8static INSTALL_DRIVERS: Once = Once::new();
9
10pub async fn execute_query(
12 connection_url: &str,
13 query: &str,
14 params: &[Value],
15 read_only: bool,
16 max_rows: usize,
17 timeout: Duration,
18) -> Result<Vec<Map<String, Value>>> {
19 INSTALL_DRIVERS.call_once(|| {
20 sqlx::any::install_default_drivers();
21 });
22
23 let connection_url: std::borrow::Cow<'_, str> = if read_only
25 && connection_url.to_ascii_lowercase().starts_with("sqlite")
26 && !connection_url.to_ascii_lowercase().contains("mode=ro")
27 {
28 let separator = if connection_url.contains('?') {
29 "&"
30 } else {
31 "?"
32 };
33 std::borrow::Cow::Owned(format!("{connection_url}{separator}mode=ro"))
34 } else {
35 std::borrow::Cow::Borrowed(connection_url)
36 };
37
38 let pool = sqlx::any::AnyPoolOptions::new()
39 .max_connections(1)
40 .acquire_timeout(timeout)
41 .connect(&connection_url)
42 .await
43 .context("failed connecting to SQL database")?;
44
45 let url_lower = connection_url.to_ascii_lowercase();
49 let use_read_only_transaction =
50 read_only && (url_lower.starts_with("postgres") || url_lower.starts_with("mysql"));
51
52 if use_read_only_transaction {
53 let begin_stmt = if url_lower.starts_with("postgres") {
54 "BEGIN READ ONLY"
55 } else {
56 "START TRANSACTION READ ONLY"
57 };
58 sqlx::query(begin_stmt)
59 .execute(&pool)
60 .await
61 .context("failed starting read-only transaction")?;
62 }
63
64 let mut sqlx_query = sqlx::query(query);
67 for param in params {
68 sqlx_query = bind_json_param(sqlx_query, param);
69 }
70
71 let query_result = tokio::time::timeout(timeout, sqlx_query.fetch_all(&pool))
72 .await
73 .map_err(|_| anyhow::anyhow!("SQL query timed out after {timeout:?}"));
74
75 if use_read_only_transaction {
77 let _ = sqlx::query("ROLLBACK").execute(&pool).await;
78 }
79
80 let rows = query_result?.context("SQL query execution failed")?;
81
82 let mut results = Vec::with_capacity(rows.len().min(max_rows));
83 for row in rows.iter().take(max_rows) {
84 results.push(row_to_json(row)?);
85 }
86
87 pool.close().await;
88
89 Ok(results)
90}
91
92fn bind_json_param<'q>(
93 query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
94 value: &'q Value,
95) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
96 match value {
97 Value::Null => query.bind(None::<String>),
98 Value::Bool(b) => query.bind(*b),
99 Value::Number(n) => {
100 if let Some(i) = n.as_i64() {
101 query.bind(i)
102 } else if let Some(f) = n.as_f64() {
103 query.bind(f)
104 } else {
105 query.bind(n.to_string())
106 }
107 }
108 Value::String(s) => query.bind(s.as_str()),
109 _ => query.bind(serde_json::to_string(value).unwrap_or_default()),
110 }
111}
112
113fn row_to_json(row: &sqlx::any::AnyRow) -> Result<Map<String, Value>> {
114 let mut map = Map::new();
115 for col in row.columns() {
116 let name = col.name().to_string();
117 let type_name = col.type_info().name();
118 let value = match type_name {
119 "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" => row
120 .try_get::<i64, _>(col.ordinal())
121 .map(|v| Value::Number(v.into()))
122 .unwrap_or(Value::Null),
123 "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "NUMERIC" => row
124 .try_get::<f64, _>(col.ordinal())
125 .ok()
126 .and_then(|v| serde_json::Number::from_f64(v).map(Value::Number))
127 .unwrap_or(Value::Null),
128 "BOOLEAN" | "BOOL" => row
129 .try_get::<bool, _>(col.ordinal())
130 .map(Value::Bool)
131 .unwrap_or(Value::Null),
132 _ => {
133 let ordinal = col.ordinal();
136 if let Ok(v) = row.try_get::<i64, _>(ordinal) {
137 Value::Number(v.into())
138 } else if let Ok(v) = row.try_get::<f64, _>(ordinal) {
139 serde_json::Number::from_f64(v)
140 .map(Value::Number)
141 .unwrap_or(Value::Null)
142 } else if let Ok(v) = row.try_get::<bool, _>(ordinal) {
143 Value::Bool(v)
144 } else if let Ok(v) = row.try_get::<String, _>(ordinal) {
145 Value::String(v)
146 } else {
147 Value::Null
148 }
149 }
150 };
151 map.insert(name, value);
152 }
153 Ok(map)
154}