Skip to main content

grapheme_stdlib/
sql.rs

1use serde_json::{json, Map, Value as JsonValue};
2use sqlx::{AnyPool, Column, Row};
3use std::env;
4use std::sync::Once;
5use std::time::Instant;
6
7static SQLX_ANY_DRIVERS: Once = Once::new();
8
9pub fn query(args: &JsonValue) -> JsonValue {
10    let connection = match required_string(args, "connection") {
11        Ok(v) => v,
12        Err(e) => return error_payload("validation_error", "missing_connection", &e),
13    };
14    let sql = match required_string(args, "sql") {
15        Ok(v) => v,
16        Err(e) => return error_payload("validation_error", "missing_sql", &e),
17    };
18
19    let params = match optional_params(args) {
20        Ok(v) => v,
21        Err(e) => return error_payload("validation_error", "sql_params_invalid", &e),
22    };
23    let max_rows = optional_u64(args, "max_rows");
24    let max_payload_bytes = optional_u64(args, "max_payload_bytes");
25
26    let resolved = match resolve_connection(&connection) {
27        Ok(v) => v,
28        Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
29    };
30
31    let started = Instant::now();
32    let rt = match tokio::runtime::Builder::new_current_thread()
33        .enable_all()
34        .build()
35    {
36        Ok(rt) => rt,
37        Err(e) => {
38            return error_payload(
39                "runtime_error",
40                "tokio_runtime_init_failed",
41                &e.to_string(),
42            )
43        }
44    };
45
46    let result: Result<JsonValue, String> = rt.block_on(async {
47        ensure_any_drivers();
48        let pool = AnyPool::connect(&resolved)
49            .await
50            .map_err(|e| format!("connect failed: {e}"))?;
51
52        let query = match bind_params(sqlx::query(&sql), &params) {
53            Ok(query) => query,
54            Err(e) => return Err(e),
55        };
56
57        let rows = query
58            .fetch_all(&pool)
59            .await
60            .map_err(|e| format!("query failed: {e}"))?;
61
62        if let Some(max_rows) = max_rows {
63            if rows.len() > max_rows as usize {
64                return Err(format!(
65                    "row count {} exceeds max_rows {}",
66                    rows.len(),
67                    max_rows
68                ));
69            }
70        }
71
72        let mut out_rows = Vec::with_capacity(rows.len());
73        for row in rows {
74            let mut obj = Map::new();
75            for (idx, col) in row.columns().iter().enumerate() {
76                let key = col.name().to_string();
77                let value = decode_row_value(&row, idx);
78                obj.insert(key, value);
79            }
80            out_rows.push(JsonValue::Object(obj));
81        }
82
83        if let Some(max_payload_bytes) = max_payload_bytes {
84            let payload_bytes = serde_json::to_vec(&out_rows)
85                .map(|v| v.len())
86                .map_err(|e| format!("serialize query rows failed: {e}"))?;
87            if payload_bytes > max_payload_bytes as usize {
88                return Err(format!(
89                    "payload size {} exceeds max_payload_bytes {}",
90                    payload_bytes,
91                    max_payload_bytes
92                ));
93            }
94        }
95
96        Ok(json!({
97            "ok": true,
98            "connection": connection,
99            "row_count": out_rows.len(),
100            "rows": out_rows,
101            "elapsed_ms": started.elapsed().as_millis() as u64,
102        }))
103    });
104
105    match result {
106        Ok(v) => v,
107        Err(e) => error_payload("query_error", "sql_query_failed", &e),
108    }
109}
110
111pub fn execute(args: &JsonValue) -> JsonValue {
112    let connection = match required_string(args, "connection") {
113        Ok(v) => v,
114        Err(e) => return error_payload("validation_error", "missing_connection", &e),
115    };
116    let sql = match required_string(args, "sql") {
117        Ok(v) => v,
118        Err(e) => return error_payload("validation_error", "missing_sql", &e),
119    };
120
121    let params = match optional_params(args) {
122        Ok(v) => v,
123        Err(e) => return error_payload("validation_error", "sql_params_invalid", &e),
124    };
125
126    let resolved = match resolve_connection(&connection) {
127        Ok(v) => v,
128        Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
129    };
130
131    let started = Instant::now();
132    let rt = match tokio::runtime::Builder::new_current_thread()
133        .enable_all()
134        .build()
135    {
136        Ok(rt) => rt,
137        Err(e) => {
138            return error_payload(
139                "runtime_error",
140                "tokio_runtime_init_failed",
141                &e.to_string(),
142            )
143        }
144    };
145
146    let result: Result<JsonValue, String> = rt.block_on(async {
147        ensure_any_drivers();
148        let pool = AnyPool::connect(&resolved)
149            .await
150            .map_err(|e| format!("connect failed: {e}"))?;
151
152        let query = match bind_params(sqlx::query(&sql), &params) {
153            Ok(query) => query,
154            Err(e) => return Err(e),
155        };
156
157        let outcome = query
158            .execute(&pool)
159            .await
160            .map_err(|e| format!("execute failed: {e}"))?;
161
162        Ok(json!({
163            "ok": true,
164            "connection": connection,
165            "rows_affected": outcome.rows_affected(),
166            "elapsed_ms": started.elapsed().as_millis() as u64,
167        }))
168    });
169
170    match result {
171        Ok(v) => v,
172        Err(e) => error_payload("query_error", "sql_execute_failed", &e),
173    }
174}
175
176pub fn health(args: &JsonValue) -> JsonValue {
177    let connection = match required_string(args, "connection") {
178        Ok(v) => v,
179        Err(e) => return error_payload("validation_error", "missing_connection", &e),
180    };
181
182    let resolved = match resolve_connection(&connection) {
183        Ok(v) => v,
184        Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
185    };
186
187    let started = Instant::now();
188    let rt = match tokio::runtime::Builder::new_current_thread()
189        .enable_all()
190        .build()
191    {
192        Ok(rt) => rt,
193        Err(e) => {
194            return error_payload(
195                "runtime_error",
196                "tokio_runtime_init_failed",
197                &e.to_string(),
198            )
199        }
200    };
201
202    let result: Result<JsonValue, String> = rt.block_on(async {
203        ensure_any_drivers();
204        let pool = AnyPool::connect(&resolved)
205            .await
206            .map_err(|e| format!("connect failed: {e}"))?;
207
208        sqlx::query("select 1")
209            .execute(&pool)
210            .await
211            .map_err(|e| format!("health check failed: {e}"))?;
212
213        Ok(json!({
214            "ok": true,
215            "connection": connection,
216            "latency_ms": started.elapsed().as_millis() as u64,
217        }))
218    });
219
220    match result {
221        Ok(v) => v,
222        Err(e) => error_payload("connection_error", "sql_health_failed", &e),
223    }
224}
225
226pub fn transaction(args: &JsonValue) -> JsonValue {
227    let connection = match required_string(args, "connection") {
228        Ok(v) => v,
229        Err(e) => return error_payload("validation_error", "missing_connection", &e),
230    };
231
232    let steps = match parse_transaction_steps(args) {
233        Ok(v) => v,
234        Err(e) => return error_payload("validation_error", "sql_transaction_invalid", &e),
235    };
236
237    let resolved = match resolve_connection(&connection) {
238        Ok(v) => v,
239        Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
240    };
241
242    let started = Instant::now();
243    let rt = match tokio::runtime::Builder::new_current_thread()
244        .enable_all()
245        .build()
246    {
247        Ok(rt) => rt,
248        Err(e) => {
249            return error_payload(
250                "runtime_error",
251                "tokio_runtime_init_failed",
252                &e.to_string(),
253            )
254        }
255    };
256
257    let result: Result<JsonValue, JsonValue> = rt.block_on(async {
258        ensure_any_drivers();
259        let pool = AnyPool::connect(&resolved)
260            .await
261            .map_err(|e| error_payload("connection_error", "sql_connect_failed", &format!("connect failed: {e}")))?;
262
263        let mut tx = pool
264            .begin()
265            .await
266            .map_err(|e| error_payload("query_error", "sql_transaction_begin_failed", &format!("begin failed: {e}")))?;
267
268        let mut results = Vec::with_capacity(steps.len());
269
270        for (idx, step) in steps.iter().enumerate() {
271            let query = bind_params(sqlx::query(&step.sql), &step.params)
272                .map_err(|e| transaction_failure_payload(&connection, &results, idx, started, "sql_params_invalid", &e))?;
273
274            if step.mode == TransactionStepMode::Query {
275                let rows = query
276                    .fetch_all(&mut *tx)
277                    .await
278                    .map_err(|e| transaction_failure_payload(&connection, &results, idx, started, "sql_transaction_step_failed", &format!("query step failed: {e}")))?;
279
280                let mut out_rows = Vec::with_capacity(rows.len());
281                for row in rows {
282                    let mut obj = Map::new();
283                    for (col_idx, col) in row.columns().iter().enumerate() {
284                        obj.insert(col.name().to_string(), decode_row_value(&row, col_idx));
285                    }
286                    out_rows.push(JsonValue::Object(obj));
287                }
288
289                results.push(json!({
290                    "mode": "query",
291                    "row_count": out_rows.len(),
292                    "rows": out_rows,
293                }));
294            } else {
295                let outcome = query
296                    .execute(&mut *tx)
297                    .await
298                    .map_err(|e| transaction_failure_payload(&connection, &results, idx, started, "sql_transaction_step_failed", &format!("execute step failed: {e}")))?;
299
300                results.push(json!({
301                    "mode": "execute",
302                    "rows_affected": outcome.rows_affected(),
303                }));
304            }
305        }
306
307        tx.commit()
308            .await
309            .map_err(|e| error_payload("query_error", "sql_transaction_commit_failed", &format!("commit failed: {e}")))?;
310
311        Ok(json!({
312            "ok": true,
313            "connection": connection,
314            "committed": true,
315            "results": results,
316            "elapsed_ms": started.elapsed().as_millis() as u64,
317        }))
318    });
319
320    match result {
321        Ok(v) => v,
322        Err(v) => v,
323    }
324}
325
326fn required_string(args: &JsonValue, key: &str) -> Result<String, String> {
327    args.get(key)
328        .and_then(|v| v.as_str())
329        .map(ToOwned::to_owned)
330        .or_else(|| {
331            args.get("__input")
332                .and_then(|v| v.as_object())
333                .and_then(|obj| obj.get(key))
334                .and_then(|v| v.as_str())
335                .map(ToOwned::to_owned)
336        })
337        .ok_or_else(|| format!("missing required '{}'", key))
338}
339
340fn optional_params(args: &JsonValue) -> Result<Vec<JsonValue>, String> {
341    let candidate = args
342        .get("params")
343        .cloned()
344        .or_else(|| {
345            args.get("__input")
346                .and_then(|v| v.as_object())
347                .and_then(|obj| obj.get("params").cloned())
348        });
349
350    match candidate {
351        None => Ok(Vec::new()),
352        Some(JsonValue::Array(items)) => Ok(items),
353        Some(_) => Err("params must be an array".to_string()),
354    }
355}
356
357fn optional_u64(args: &JsonValue, key: &str) -> Option<u64> {
358    args.get(key)
359        .and_then(|v| v.as_u64().or_else(|| v.as_str().and_then(|s| s.parse::<u64>().ok())))
360        .or_else(|| {
361            args.get("__input")
362                .and_then(|v| v.as_object())
363                .and_then(|obj| obj.get(key))
364                .and_then(|v| {
365                    v.as_u64()
366                        .or_else(|| v.as_str().and_then(|s| s.parse::<u64>().ok()))
367                })
368        })
369}
370
371#[derive(Debug, Clone, Copy, PartialEq, Eq)]
372enum TransactionStepMode {
373    Query,
374    Execute,
375}
376
377#[derive(Debug, Clone)]
378struct TransactionStep {
379    mode: TransactionStepMode,
380    sql: String,
381    params: Vec<JsonValue>,
382}
383
384fn parse_transaction_steps(args: &JsonValue) -> Result<Vec<TransactionStep>, String> {
385    let raw_steps = args
386        .get("steps")
387        .cloned()
388        .or_else(|| {
389            args.get("__input")
390                .and_then(|v| v.as_object())
391                .and_then(|obj| obj.get("steps").cloned())
392        })
393        .ok_or_else(|| "missing required 'steps'".to_string())?;
394
395    let list = raw_steps
396        .as_array()
397        .ok_or_else(|| "steps must be an array".to_string())?;
398
399    if list.is_empty() {
400        return Err("steps must contain at least one entry".to_string());
401    }
402
403    let mut out = Vec::with_capacity(list.len());
404    for (idx, value) in list.iter().enumerate() {
405        let obj = value
406            .as_object()
407            .ok_or_else(|| format!("step[{idx}] must be an object"))?;
408
409        let sql = obj
410            .get("sql")
411            .and_then(|v| v.as_str())
412            .map(ToOwned::to_owned)
413            .ok_or_else(|| format!("step[{idx}] missing required 'sql'"))?;
414
415        let mode = match obj.get("mode").and_then(|v| v.as_str()) {
416            Some("query") => TransactionStepMode::Query,
417            Some("execute") | None => TransactionStepMode::Execute,
418            Some(other) => {
419                return Err(format!(
420                    "step[{idx}] has invalid mode '{}', expected query|execute",
421                    other
422                ))
423            }
424        };
425
426        let params = match obj.get("params") {
427            None => Vec::new(),
428            Some(JsonValue::Array(items)) => items.clone(),
429            Some(_) => return Err(format!("step[{idx}] params must be an array")),
430        };
431
432        out.push(TransactionStep { mode, sql, params });
433    }
434
435    Ok(out)
436}
437
438fn bind_params<'q>(
439    mut query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
440    params: &[JsonValue],
441) -> Result<sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>, String> {
442    for value in params {
443        query = match value {
444            JsonValue::Null => query.bind(Option::<String>::None),
445            JsonValue::Bool(v) => query.bind(*v),
446            JsonValue::Number(n) => {
447                if let Some(v) = n.as_i64() {
448                    query.bind(v)
449                } else if let Some(v) = n.as_u64() {
450                    if let Ok(as_i64) = i64::try_from(v) {
451                        query.bind(as_i64)
452                    } else {
453                        query.bind(v as f64)
454                    }
455                } else if let Some(v) = n.as_f64() {
456                    query.bind(v)
457                } else {
458                    return Err("unsupported numeric param representation".to_string());
459                }
460            }
461            JsonValue::String(v) => query.bind(v.clone()),
462            JsonValue::Array(_) | JsonValue::Object(_) => {
463                return Err("only scalar params are supported (null/bool/number/string)".to_string())
464            }
465        };
466    }
467
468    Ok(query)
469}
470
471fn resolve_connection(connection: &str) -> Result<String, String> {
472    if connection.contains("://") || connection.starts_with("sqlite:") {
473        return Ok(connection.to_string());
474    }
475
476    let env_key = format!(
477        "GRAPHEME_SQL_CONNECTION_{}",
478        connection
479            .chars()
480            .map(|c| if c.is_ascii_alphanumeric() { c.to_ascii_uppercase() } else { '_' })
481            .collect::<String>()
482    );
483
484    if let Ok(url) = env::var(&env_key) {
485        if !url.trim().is_empty() {
486            return Ok(url);
487        }
488    }
489
490    if let Ok(map_raw) = env::var("GRAPHEME_SQL_CONNECTIONS") {
491        if let Ok(map_json) = serde_json::from_str::<JsonValue>(&map_raw) {
492            if let Some(url) = map_json
493                .get(connection)
494                .and_then(|v| v.as_str())
495                .map(ToOwned::to_owned)
496            {
497                return Ok(url);
498            }
499        }
500    }
501
502    Err(format!(
503        "connection '{}' is unresolved; set {} or GRAPHEME_SQL_CONNECTIONS",
504        connection, env_key
505    ))
506}
507
508fn decode_row_value(row: &sqlx::any::AnyRow, idx: usize) -> JsonValue {
509    if let Ok(v) = row.try_get::<Option<i64>, _>(idx) {
510        return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
511    }
512    if let Ok(v) = row.try_get::<Option<f64>, _>(idx) {
513        return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
514    }
515    if let Ok(v) = row.try_get::<Option<bool>, _>(idx) {
516        return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
517    }
518    if let Ok(v) = row.try_get::<Option<String>, _>(idx) {
519        return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
520    }
521
522    JsonValue::Null
523}
524
525fn error_payload(kind: &str, code: &str, message: &str) -> JsonValue {
526    json!({
527        "ok": false,
528        "error": {
529            "kind": kind,
530            "code": code,
531            "message": message,
532            "retryable": false
533        }
534    })
535}
536
537fn transaction_failure_payload(
538    connection: &str,
539    results: &[JsonValue],
540    failed_step: usize,
541    started: Instant,
542    code: &str,
543    message: &str,
544) -> JsonValue {
545    json!({
546        "ok": false,
547        "connection": connection,
548        "committed": false,
549        "failed_step": failed_step,
550        "results": results,
551        "elapsed_ms": started.elapsed().as_millis() as u64,
552        "error": {
553            "kind": "query_error",
554            "code": code,
555            "message": message,
556            "retryable": false,
557        }
558    })
559}
560
561fn ensure_any_drivers() {
562    SQLX_ANY_DRIVERS.call_once(sqlx::any::install_default_drivers);
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use serde_json::json;
569    use std::fs;
570    use std::time::{SystemTime, UNIX_EPOCH};
571
572    fn sqlite_temp_connection(tag: &str) -> (String, std::path::PathBuf) {
573        let mut path = std::env::temp_dir();
574        let ts = SystemTime::now()
575            .duration_since(UNIX_EPOCH)
576            .expect("system clock")
577            .as_nanos();
578        path.push(format!("grapheme-sql-{tag}-{ts}.db"));
579        (format!("sqlite://{}?mode=rwc", path.display()), path)
580    }
581
582    #[test]
583    fn health_accepts_direct_sqlite_url_connection() {
584        let out = health(&json!({ "connection": "sqlite::memory:" }));
585        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
586    }
587
588    #[test]
589    fn query_returns_rows_for_basic_select() {
590        let out = query(&json!({
591            "connection": "sqlite::memory:",
592            "sql": "select 1 as ok"
593        }));
594        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
595        assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(1));
596    }
597
598    #[test]
599    fn execute_reports_rows_affected() {
600        let out = execute(&json!({
601            "connection": "sqlite::memory:",
602            "sql": "create table if not exists t (id integer)"
603        }));
604        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
605        assert!(out.get("rows_affected").and_then(|v| v.as_u64()).is_some());
606    }
607
608    #[test]
609    fn query_reports_unresolved_connection_id() {
610        let out = query(&json!({
611            "connection": "missing_conn",
612            "sql": "select 1"
613        }));
614        assert_eq!(
615            out.get("error")
616                .and_then(|v| v.get("code"))
617                .and_then(|v| v.as_str()),
618            Some("sql_connection_unresolved")
619        );
620    }
621
622    #[test]
623    fn query_supports_scalar_positional_params() {
624        let out = query(&json!({
625            "connection": "sqlite::memory:",
626            "sql": "select ?1 as n, ?2 as t, ?3 as b, ?4 as z",
627            "params": [42, "hello", true, null]
628        }));
629
630        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
631        let rows = out
632            .get("rows")
633            .and_then(|v| v.as_array())
634            .expect("rows should be present");
635        assert_eq!(rows.len(), 1);
636
637        let row = rows.first().and_then(|v| v.as_object()).expect("row object");
638        assert_eq!(row.get("n").and_then(|v| v.as_i64()), Some(42));
639        assert_eq!(row.get("t").and_then(|v| v.as_str()), Some("hello"));
640        let b = row.get("b").cloned().unwrap_or(JsonValue::Null);
641        assert!(matches!(b, JsonValue::Bool(true) | JsonValue::Number(_)));
642        if let JsonValue::Number(n) = b {
643            assert_eq!(n.as_i64(), Some(1));
644        }
645        assert_eq!(row.get("z"), Some(&JsonValue::Null));
646    }
647
648    #[test]
649    fn execute_supports_positional_params() {
650        let out = execute(&json!({
651            "connection": "sqlite::memory:",
652            "sql": "select ?1",
653            "params": [7]
654        }));
655
656        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
657    }
658
659    #[test]
660    fn query_rejects_non_array_params() {
661        let out = query(&json!({
662            "connection": "sqlite::memory:",
663            "sql": "select 1",
664            "params": {"a": 1}
665        }));
666
667        assert_eq!(
668            out.get("error")
669                .and_then(|v| v.get("code"))
670                .and_then(|v| v.as_str()),
671            Some("sql_params_invalid")
672        );
673    }
674
675    #[test]
676    fn query_enforces_max_rows_limit() {
677        let out = query(&json!({
678            "connection": "sqlite::memory:",
679            "sql": "select 1 as n union all select 2 as n",
680            "max_rows": 1
681        }));
682
683        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
684        assert_eq!(
685            out.get("error")
686                .and_then(|v| v.get("code"))
687                .and_then(|v| v.as_str()),
688            Some("sql_query_failed")
689        );
690        assert!(out
691            .get("error")
692            .and_then(|v| v.get("message"))
693            .and_then(|v| v.as_str())
694            .unwrap_or_default()
695            .contains("exceeds max_rows"));
696    }
697
698    #[test]
699    fn query_enforces_max_payload_bytes_limit() {
700        let out = query(&json!({
701            "connection": "sqlite::memory:",
702            "sql": "select 'abcdefghijklmnopqrstuvwxyz' as payload",
703            "max_payload_bytes": 8
704        }));
705
706        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
707        assert_eq!(
708            out.get("error")
709                .and_then(|v| v.get("code"))
710                .and_then(|v| v.as_str()),
711            Some("sql_query_failed")
712        );
713        assert!(out
714            .get("error")
715            .and_then(|v| v.get("message"))
716            .and_then(|v| v.as_str())
717            .unwrap_or_default()
718            .contains("exceeds max_payload_bytes"));
719    }
720
721    #[test]
722    fn query_handles_high_row_count_when_within_limit() {
723        let out = query(&json!({
724            "connection": "sqlite::memory:",
725            "sql": "with recursive cnt(x) as (select 1 union all select x + 1 from cnt where x < 128) select x from cnt",
726            "max_rows": 128
727        }));
728
729        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
730        assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(128));
731    }
732
733    #[test]
734    fn query_handles_high_payload_near_limit_boundary() {
735        let out = query(&json!({
736            "connection": "sqlite::memory:",
737            "sql": "with recursive cnt(x) as (select 1 union all select x + 1 from cnt where x < 64) select x, 'aaaaaaaaaaaaaaaa' as payload from cnt",
738            "max_rows": 64,
739            "max_payload_bytes": 4096
740        }));
741
742        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
743        assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(64));
744    }
745
746    #[test]
747    fn transaction_runs_execute_and_query_steps() {
748        let out = transaction(&json!({
749            "connection": "sqlite::memory:",
750            "steps": [
751                {
752                    "sql": "create table if not exists t (id integer, label text)",
753                    "mode": "execute"
754                },
755                {
756                    "sql": "insert into t (id, label) values (?1, ?2)",
757                    "mode": "execute",
758                    "params": [1, "ok"]
759                },
760                {
761                    "sql": "select label from t where id = ?1",
762                    "mode": "query",
763                    "params": [1]
764                }
765            ]
766        }));
767
768        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
769        assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(true));
770        let results = out
771            .get("results")
772            .and_then(|v| v.as_array())
773            .expect("results array");
774        assert_eq!(results.len(), 3);
775        let query_result_rows = results[2]
776            .get("rows")
777            .and_then(|v| v.as_array())
778            .expect("query rows");
779        assert_eq!(query_result_rows.len(), 1);
780    }
781
782    #[test]
783    fn transaction_rolls_back_on_step_failure() {
784        let out = transaction(&json!({
785            "connection": "sqlite::memory:",
786            "steps": [
787                {
788                    "sql": "create table if not exists t (id integer)",
789                    "mode": "execute"
790                },
791                {
792                    "sql": "this is invalid sql",
793                    "mode": "execute"
794                }
795            ]
796        }));
797
798        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
799        assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(false));
800        assert_eq!(out.get("failed_step").and_then(|v| v.as_u64()), Some(1));
801        assert_eq!(
802            out.get("error")
803                .and_then(|v| v.get("code"))
804                .and_then(|v| v.as_str()),
805            Some("sql_transaction_step_failed")
806        );
807    }
808
809    #[test]
810    fn transaction_rollback_is_deterministic_for_persisted_connection() {
811        let (connection, path) = sqlite_temp_connection("rollback-deterministic");
812
813        let setup = execute(&json!({
814            "connection": connection,
815            "sql": "create table if not exists t (id integer, label text)"
816        }));
817        assert_eq!(setup.get("ok").and_then(|v| v.as_bool()), Some(true));
818
819        let out = transaction(&json!({
820            "connection": connection,
821            "steps": [
822                {
823                    "sql": "insert into t (id, label) values (?1, ?2)",
824                    "mode": "execute",
825                    "params": [1, "should_rollback"]
826                },
827                {
828                    "sql": "this is invalid sql",
829                    "mode": "execute"
830                }
831            ]
832        }));
833
834        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
835        assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(false));
836
837        let verify = query(&json!({
838            "connection": connection,
839            "sql": "select count(*) as count from t"
840        }));
841        assert_eq!(verify.get("ok").and_then(|v| v.as_bool()), Some(true));
842        let rows = verify
843            .get("rows")
844            .and_then(|v| v.as_array())
845            .expect("rows array");
846        let count = rows
847            .first()
848            .and_then(|v| v.get("count"))
849            .and_then(|v| v.as_i64());
850        assert_eq!(count, Some(0));
851
852        let _ = fs::remove_file(path);
853    }
854}