Skip to main content

grapheme_stdlib/
sql.rs

1use serde_json::{json, Map, Value as JsonValue};
2use sqlx::{AnyPool, Column, Row};
3use std::env;
4use std::sync::Once;
5use std::time::Instant;
6
7static SQLX_ANY_DRIVERS: Once = Once::new();
8
9pub fn query(args: &JsonValue) -> JsonValue {
10    let connection = match required_string(args, "connection") {
11        Ok(v) => v,
12        Err(e) => return error_payload("validation_error", "missing_connection", &e),
13    };
14    let sql = match required_string(args, "sql") {
15        Ok(v) => v,
16        Err(e) => return error_payload("validation_error", "missing_sql", &e),
17    };
18
19    let params = match optional_params(args) {
20        Ok(v) => v,
21        Err(e) => return error_payload("validation_error", "sql_params_invalid", &e),
22    };
23    let max_rows = optional_u64(args, "max_rows");
24    let max_payload_bytes = optional_u64(args, "max_payload_bytes");
25
26    let resolved = match resolve_connection(&connection) {
27        Ok(v) => v,
28        Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
29    };
30
31    let started = Instant::now();
32    let rt = match tokio::runtime::Builder::new_current_thread()
33        .enable_all()
34        .build()
35    {
36        Ok(rt) => rt,
37        Err(e) => {
38            return error_payload("runtime_error", "tokio_runtime_init_failed", &e.to_string())
39        }
40    };
41
42    let result: Result<JsonValue, String> = rt.block_on(async {
43        ensure_any_drivers();
44        let pool = AnyPool::connect(&resolved)
45            .await
46            .map_err(|e| format!("connect failed: {e}"))?;
47
48        let query = match bind_params(sqlx::query(&sql), &params) {
49            Ok(query) => query,
50            Err(e) => return Err(e),
51        };
52
53        let rows = query
54            .fetch_all(&pool)
55            .await
56            .map_err(|e| format!("query failed: {e}"))?;
57
58        if let Some(max_rows) = max_rows {
59            if rows.len() > max_rows as usize {
60                return Err(format!(
61                    "row count {} exceeds max_rows {}",
62                    rows.len(),
63                    max_rows
64                ));
65            }
66        }
67
68        let mut out_rows = Vec::with_capacity(rows.len());
69        for row in rows {
70            let mut obj = Map::new();
71            for (idx, col) in row.columns().iter().enumerate() {
72                let key = col.name().to_string();
73                let value = decode_row_value(&row, idx);
74                obj.insert(key, value);
75            }
76            out_rows.push(JsonValue::Object(obj));
77        }
78
79        if let Some(max_payload_bytes) = max_payload_bytes {
80            let payload_bytes = serde_json::to_vec(&out_rows)
81                .map(|v| v.len())
82                .map_err(|e| format!("serialize query rows failed: {e}"))?;
83            if payload_bytes > max_payload_bytes as usize {
84                return Err(format!(
85                    "payload size {} exceeds max_payload_bytes {}",
86                    payload_bytes, max_payload_bytes
87                ));
88            }
89        }
90
91        Ok(json!({
92            "ok": true,
93            "connection": connection,
94            "row_count": out_rows.len(),
95            "rows": out_rows,
96            "elapsed_ms": started.elapsed().as_millis() as u64,
97        }))
98    });
99
100    match result {
101        Ok(v) => v,
102        Err(e) => error_payload("query_error", "sql_query_failed", &e),
103    }
104}
105
106pub fn execute(args: &JsonValue) -> JsonValue {
107    let connection = match required_string(args, "connection") {
108        Ok(v) => v,
109        Err(e) => return error_payload("validation_error", "missing_connection", &e),
110    };
111    let sql = match required_string(args, "sql") {
112        Ok(v) => v,
113        Err(e) => return error_payload("validation_error", "missing_sql", &e),
114    };
115
116    let params = match optional_params(args) {
117        Ok(v) => v,
118        Err(e) => return error_payload("validation_error", "sql_params_invalid", &e),
119    };
120
121    let resolved = match resolve_connection(&connection) {
122        Ok(v) => v,
123        Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
124    };
125
126    let started = Instant::now();
127    let rt = match tokio::runtime::Builder::new_current_thread()
128        .enable_all()
129        .build()
130    {
131        Ok(rt) => rt,
132        Err(e) => {
133            return error_payload("runtime_error", "tokio_runtime_init_failed", &e.to_string())
134        }
135    };
136
137    let result: Result<JsonValue, String> = rt.block_on(async {
138        ensure_any_drivers();
139        let pool = AnyPool::connect(&resolved)
140            .await
141            .map_err(|e| format!("connect failed: {e}"))?;
142
143        let query = match bind_params(sqlx::query(&sql), &params) {
144            Ok(query) => query,
145            Err(e) => return Err(e),
146        };
147
148        let outcome = query
149            .execute(&pool)
150            .await
151            .map_err(|e| format!("execute failed: {e}"))?;
152
153        Ok(json!({
154            "ok": true,
155            "connection": connection,
156            "rows_affected": outcome.rows_affected(),
157            "elapsed_ms": started.elapsed().as_millis() as u64,
158        }))
159    });
160
161    match result {
162        Ok(v) => v,
163        Err(e) => error_payload("query_error", "sql_execute_failed", &e),
164    }
165}
166
167pub fn health(args: &JsonValue) -> JsonValue {
168    let connection = match required_string(args, "connection") {
169        Ok(v) => v,
170        Err(e) => return error_payload("validation_error", "missing_connection", &e),
171    };
172
173    let resolved = match resolve_connection(&connection) {
174        Ok(v) => v,
175        Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
176    };
177
178    let started = Instant::now();
179    let rt = match tokio::runtime::Builder::new_current_thread()
180        .enable_all()
181        .build()
182    {
183        Ok(rt) => rt,
184        Err(e) => {
185            return error_payload("runtime_error", "tokio_runtime_init_failed", &e.to_string())
186        }
187    };
188
189    let result: Result<JsonValue, String> = rt.block_on(async {
190        ensure_any_drivers();
191        let pool = AnyPool::connect(&resolved)
192            .await
193            .map_err(|e| format!("connect failed: {e}"))?;
194
195        sqlx::query("select 1")
196            .execute(&pool)
197            .await
198            .map_err(|e| format!("health check failed: {e}"))?;
199
200        Ok(json!({
201            "ok": true,
202            "connection": connection,
203            "latency_ms": started.elapsed().as_millis() as u64,
204        }))
205    });
206
207    match result {
208        Ok(v) => v,
209        Err(e) => error_payload("connection_error", "sql_health_failed", &e),
210    }
211}
212
213pub fn transaction(args: &JsonValue) -> JsonValue {
214    let connection = match required_string(args, "connection") {
215        Ok(v) => v,
216        Err(e) => return error_payload("validation_error", "missing_connection", &e),
217    };
218
219    let steps = match parse_transaction_steps(args) {
220        Ok(v) => v,
221        Err(e) => return error_payload("validation_error", "sql_transaction_invalid", &e),
222    };
223
224    let resolved = match resolve_connection(&connection) {
225        Ok(v) => v,
226        Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
227    };
228
229    let started = Instant::now();
230    let rt = match tokio::runtime::Builder::new_current_thread()
231        .enable_all()
232        .build()
233    {
234        Ok(rt) => rt,
235        Err(e) => {
236            return error_payload("runtime_error", "tokio_runtime_init_failed", &e.to_string())
237        }
238    };
239
240    let result: Result<JsonValue, JsonValue> = rt.block_on(async {
241        ensure_any_drivers();
242        let pool = AnyPool::connect(&resolved).await.map_err(|e| {
243            error_payload(
244                "connection_error",
245                "sql_connect_failed",
246                &format!("connect failed: {e}"),
247            )
248        })?;
249
250        let mut tx = pool.begin().await.map_err(|e| {
251            error_payload(
252                "query_error",
253                "sql_transaction_begin_failed",
254                &format!("begin failed: {e}"),
255            )
256        })?;
257
258        let mut results = Vec::with_capacity(steps.len());
259
260        for (idx, step) in steps.iter().enumerate() {
261            let query = bind_params(sqlx::query(&step.sql), &step.params).map_err(|e| {
262                transaction_failure_payload(
263                    &connection,
264                    &results,
265                    idx,
266                    started,
267                    "sql_params_invalid",
268                    &e,
269                )
270            })?;
271
272            if step.mode == TransactionStepMode::Query {
273                let rows = query.fetch_all(&mut *tx).await.map_err(|e| {
274                    transaction_failure_payload(
275                        &connection,
276                        &results,
277                        idx,
278                        started,
279                        "sql_transaction_step_failed",
280                        &format!("query step failed: {e}"),
281                    )
282                })?;
283
284                let mut out_rows = Vec::with_capacity(rows.len());
285                for row in rows {
286                    let mut obj = Map::new();
287                    for (col_idx, col) in row.columns().iter().enumerate() {
288                        obj.insert(col.name().to_string(), decode_row_value(&row, col_idx));
289                    }
290                    out_rows.push(JsonValue::Object(obj));
291                }
292
293                results.push(json!({
294                    "mode": "query",
295                    "row_count": out_rows.len(),
296                    "rows": out_rows,
297                }));
298            } else {
299                let outcome = query.execute(&mut *tx).await.map_err(|e| {
300                    transaction_failure_payload(
301                        &connection,
302                        &results,
303                        idx,
304                        started,
305                        "sql_transaction_step_failed",
306                        &format!("execute step failed: {e}"),
307                    )
308                })?;
309
310                results.push(json!({
311                    "mode": "execute",
312                    "rows_affected": outcome.rows_affected(),
313                }));
314            }
315        }
316
317        tx.commit().await.map_err(|e| {
318            error_payload(
319                "query_error",
320                "sql_transaction_commit_failed",
321                &format!("commit failed: {e}"),
322            )
323        })?;
324
325        Ok(json!({
326            "ok": true,
327            "connection": connection,
328            "committed": true,
329            "results": results,
330            "elapsed_ms": started.elapsed().as_millis() as u64,
331        }))
332    });
333
334    match result {
335        Ok(v) => v,
336        Err(v) => v,
337    }
338}
339
340fn required_string(args: &JsonValue, key: &str) -> Result<String, String> {
341    args.get(key)
342        .and_then(|v| v.as_str())
343        .map(ToOwned::to_owned)
344        .or_else(|| {
345            args.get("__input")
346                .and_then(|v| v.as_object())
347                .and_then(|obj| obj.get(key))
348                .and_then(|v| v.as_str())
349                .map(ToOwned::to_owned)
350        })
351        .ok_or_else(|| format!("missing required '{}'", key))
352}
353
354fn optional_params(args: &JsonValue) -> Result<Vec<JsonValue>, String> {
355    let candidate = args.get("params").cloned().or_else(|| {
356        args.get("__input")
357            .and_then(|v| v.as_object())
358            .and_then(|obj| obj.get("params").cloned())
359    });
360
361    match candidate {
362        None => Ok(Vec::new()),
363        Some(JsonValue::Array(items)) => Ok(items),
364        Some(_) => Err("params must be an array".to_string()),
365    }
366}
367
368fn optional_u64(args: &JsonValue, key: &str) -> Option<u64> {
369    args.get(key)
370        .and_then(|v| {
371            v.as_u64()
372                .or_else(|| v.as_str().and_then(|s| s.parse::<u64>().ok()))
373        })
374        .or_else(|| {
375            args.get("__input")
376                .and_then(|v| v.as_object())
377                .and_then(|obj| obj.get(key))
378                .and_then(|v| {
379                    v.as_u64()
380                        .or_else(|| v.as_str().and_then(|s| s.parse::<u64>().ok()))
381                })
382        })
383}
384
385#[derive(Debug, Clone, Copy, PartialEq, Eq)]
386enum TransactionStepMode {
387    Query,
388    Execute,
389}
390
391#[derive(Debug, Clone)]
392struct TransactionStep {
393    mode: TransactionStepMode,
394    sql: String,
395    params: Vec<JsonValue>,
396}
397
398fn parse_transaction_steps(args: &JsonValue) -> Result<Vec<TransactionStep>, String> {
399    let raw_steps = args
400        .get("steps")
401        .cloned()
402        .or_else(|| {
403            args.get("__input")
404                .and_then(|v| v.as_object())
405                .and_then(|obj| obj.get("steps").cloned())
406        })
407        .ok_or_else(|| "missing required 'steps'".to_string())?;
408
409    let list = raw_steps
410        .as_array()
411        .ok_or_else(|| "steps must be an array".to_string())?;
412
413    if list.is_empty() {
414        return Err("steps must contain at least one entry".to_string());
415    }
416
417    let mut out = Vec::with_capacity(list.len());
418    for (idx, value) in list.iter().enumerate() {
419        let obj = value
420            .as_object()
421            .ok_or_else(|| format!("step[{idx}] must be an object"))?;
422
423        let sql = obj
424            .get("sql")
425            .and_then(|v| v.as_str())
426            .map(ToOwned::to_owned)
427            .ok_or_else(|| format!("step[{idx}] missing required 'sql'"))?;
428
429        let mode = match obj.get("mode").and_then(|v| v.as_str()) {
430            Some("query") => TransactionStepMode::Query,
431            Some("execute") | None => TransactionStepMode::Execute,
432            Some(other) => {
433                return Err(format!(
434                    "step[{idx}] has invalid mode '{}', expected query|execute",
435                    other
436                ))
437            }
438        };
439
440        let params = match obj.get("params") {
441            None => Vec::new(),
442            Some(JsonValue::Array(items)) => items.clone(),
443            Some(_) => return Err(format!("step[{idx}] params must be an array")),
444        };
445
446        out.push(TransactionStep { mode, sql, params });
447    }
448
449    Ok(out)
450}
451
452fn bind_params<'q>(
453    mut query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
454    params: &[JsonValue],
455) -> Result<sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>, String> {
456    for value in params {
457        query = match value {
458            JsonValue::Null => query.bind(Option::<String>::None),
459            JsonValue::Bool(v) => query.bind(*v),
460            JsonValue::Number(n) => {
461                if let Some(v) = n.as_i64() {
462                    query.bind(v)
463                } else if let Some(v) = n.as_u64() {
464                    if let Ok(as_i64) = i64::try_from(v) {
465                        query.bind(as_i64)
466                    } else {
467                        query.bind(v as f64)
468                    }
469                } else if let Some(v) = n.as_f64() {
470                    query.bind(v)
471                } else {
472                    return Err("unsupported numeric param representation".to_string());
473                }
474            }
475            JsonValue::String(v) => query.bind(v.clone()),
476            JsonValue::Array(_) | JsonValue::Object(_) => {
477                return Err("only scalar params are supported (null/bool/number/string)".to_string())
478            }
479        };
480    }
481
482    Ok(query)
483}
484
485fn resolve_connection(connection: &str) -> Result<String, String> {
486    if connection.contains("://") || connection.starts_with("sqlite:") {
487        return Ok(connection.to_string());
488    }
489
490    let env_key = format!(
491        "GRAPHEME_SQL_CONNECTION_{}",
492        connection
493            .chars()
494            .map(|c| if c.is_ascii_alphanumeric() {
495                c.to_ascii_uppercase()
496            } else {
497                '_'
498            })
499            .collect::<String>()
500    );
501
502    if let Ok(url) = env::var(&env_key) {
503        if !url.trim().is_empty() {
504            return Ok(url);
505        }
506    }
507
508    if let Ok(map_raw) = env::var("GRAPHEME_SQL_CONNECTIONS") {
509        if let Ok(map_json) = serde_json::from_str::<JsonValue>(&map_raw) {
510            if let Some(url) = map_json
511                .get(connection)
512                .and_then(|v| v.as_str())
513                .map(ToOwned::to_owned)
514            {
515                return Ok(url);
516            }
517        }
518    }
519
520    Err(format!(
521        "connection '{}' is unresolved; set {} or GRAPHEME_SQL_CONNECTIONS",
522        connection, env_key
523    ))
524}
525
526fn decode_row_value(row: &sqlx::any::AnyRow, idx: usize) -> JsonValue {
527    if let Ok(v) = row.try_get::<Option<i64>, _>(idx) {
528        return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
529    }
530    if let Ok(v) = row.try_get::<Option<f64>, _>(idx) {
531        return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
532    }
533    if let Ok(v) = row.try_get::<Option<bool>, _>(idx) {
534        return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
535    }
536    if let Ok(v) = row.try_get::<Option<String>, _>(idx) {
537        return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
538    }
539
540    JsonValue::Null
541}
542
543fn error_payload(kind: &str, code: &str, message: &str) -> JsonValue {
544    json!({
545        "ok": false,
546        "error": {
547            "kind": kind,
548            "code": code,
549            "message": message,
550            "retryable": false
551        }
552    })
553}
554
555fn transaction_failure_payload(
556    connection: &str,
557    results: &[JsonValue],
558    failed_step: usize,
559    started: Instant,
560    code: &str,
561    message: &str,
562) -> JsonValue {
563    json!({
564        "ok": false,
565        "connection": connection,
566        "committed": false,
567        "failed_step": failed_step,
568        "results": results,
569        "elapsed_ms": started.elapsed().as_millis() as u64,
570        "error": {
571            "kind": "query_error",
572            "code": code,
573            "message": message,
574            "retryable": false,
575        }
576    })
577}
578
579fn ensure_any_drivers() {
580    SQLX_ANY_DRIVERS.call_once(sqlx::any::install_default_drivers);
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586    use serde_json::json;
587    use std::fs;
588    use std::time::{SystemTime, UNIX_EPOCH};
589
590    fn sqlite_temp_connection(tag: &str) -> (String, std::path::PathBuf) {
591        let mut path = std::env::temp_dir();
592        let ts = SystemTime::now()
593            .duration_since(UNIX_EPOCH)
594            .expect("system clock")
595            .as_nanos();
596        path.push(format!("grapheme-sql-{tag}-{ts}.db"));
597        (format!("sqlite://{}?mode=rwc", path.display()), path)
598    }
599
600    #[test]
601    fn health_accepts_direct_sqlite_url_connection() {
602        let out = health(&json!({ "connection": "sqlite::memory:" }));
603        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
604    }
605
606    #[test]
607    fn query_returns_rows_for_basic_select() {
608        let out = query(&json!({
609            "connection": "sqlite::memory:",
610            "sql": "select 1 as ok"
611        }));
612        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
613        assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(1));
614    }
615
616    #[test]
617    fn execute_reports_rows_affected() {
618        let out = execute(&json!({
619            "connection": "sqlite::memory:",
620            "sql": "create table if not exists t (id integer)"
621        }));
622        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
623        assert!(out.get("rows_affected").and_then(|v| v.as_u64()).is_some());
624    }
625
626    #[test]
627    fn query_reports_unresolved_connection_id() {
628        let out = query(&json!({
629            "connection": "missing_conn",
630            "sql": "select 1"
631        }));
632        assert_eq!(
633            out.get("error")
634                .and_then(|v| v.get("code"))
635                .and_then(|v| v.as_str()),
636            Some("sql_connection_unresolved")
637        );
638    }
639
640    #[test]
641    fn query_supports_scalar_positional_params() {
642        let out = query(&json!({
643            "connection": "sqlite::memory:",
644            "sql": "select ?1 as n, ?2 as t, ?3 as b, ?4 as z",
645            "params": [42, "hello", true, null]
646        }));
647
648        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
649        let rows = out
650            .get("rows")
651            .and_then(|v| v.as_array())
652            .expect("rows should be present");
653        assert_eq!(rows.len(), 1);
654
655        let row = rows
656            .first()
657            .and_then(|v| v.as_object())
658            .expect("row object");
659        assert_eq!(row.get("n").and_then(|v| v.as_i64()), Some(42));
660        assert_eq!(row.get("t").and_then(|v| v.as_str()), Some("hello"));
661        let b = row.get("b").cloned().unwrap_or(JsonValue::Null);
662        assert!(matches!(b, JsonValue::Bool(true) | JsonValue::Number(_)));
663        if let JsonValue::Number(n) = b {
664            assert_eq!(n.as_i64(), Some(1));
665        }
666        assert_eq!(row.get("z"), Some(&JsonValue::Null));
667    }
668
669    #[test]
670    fn execute_supports_positional_params() {
671        let out = execute(&json!({
672            "connection": "sqlite::memory:",
673            "sql": "select ?1",
674            "params": [7]
675        }));
676
677        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
678    }
679
680    #[test]
681    fn query_rejects_non_array_params() {
682        let out = query(&json!({
683            "connection": "sqlite::memory:",
684            "sql": "select 1",
685            "params": {"a": 1}
686        }));
687
688        assert_eq!(
689            out.get("error")
690                .and_then(|v| v.get("code"))
691                .and_then(|v| v.as_str()),
692            Some("sql_params_invalid")
693        );
694    }
695
696    #[test]
697    fn query_enforces_max_rows_limit() {
698        let out = query(&json!({
699            "connection": "sqlite::memory:",
700            "sql": "select 1 as n union all select 2 as n",
701            "max_rows": 1
702        }));
703
704        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
705        assert_eq!(
706            out.get("error")
707                .and_then(|v| v.get("code"))
708                .and_then(|v| v.as_str()),
709            Some("sql_query_failed")
710        );
711        assert!(out
712            .get("error")
713            .and_then(|v| v.get("message"))
714            .and_then(|v| v.as_str())
715            .unwrap_or_default()
716            .contains("exceeds max_rows"));
717    }
718
719    #[test]
720    fn query_enforces_max_payload_bytes_limit() {
721        let out = query(&json!({
722            "connection": "sqlite::memory:",
723            "sql": "select 'abcdefghijklmnopqrstuvwxyz' as payload",
724            "max_payload_bytes": 8
725        }));
726
727        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
728        assert_eq!(
729            out.get("error")
730                .and_then(|v| v.get("code"))
731                .and_then(|v| v.as_str()),
732            Some("sql_query_failed")
733        );
734        assert!(out
735            .get("error")
736            .and_then(|v| v.get("message"))
737            .and_then(|v| v.as_str())
738            .unwrap_or_default()
739            .contains("exceeds max_payload_bytes"));
740    }
741
742    #[test]
743    fn query_handles_high_row_count_when_within_limit() {
744        let out = query(&json!({
745            "connection": "sqlite::memory:",
746            "sql": "with recursive cnt(x) as (select 1 union all select x + 1 from cnt where x < 128) select x from cnt",
747            "max_rows": 128
748        }));
749
750        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
751        assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(128));
752    }
753
754    #[test]
755    fn query_handles_high_payload_near_limit_boundary() {
756        let out = query(&json!({
757            "connection": "sqlite::memory:",
758            "sql": "with recursive cnt(x) as (select 1 union all select x + 1 from cnt where x < 64) select x, 'aaaaaaaaaaaaaaaa' as payload from cnt",
759            "max_rows": 64,
760            "max_payload_bytes": 4096
761        }));
762
763        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
764        assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(64));
765    }
766
767    #[test]
768    fn transaction_runs_execute_and_query_steps() {
769        let out = transaction(&json!({
770            "connection": "sqlite::memory:",
771            "steps": [
772                {
773                    "sql": "create table if not exists t (id integer, label text)",
774                    "mode": "execute"
775                },
776                {
777                    "sql": "insert into t (id, label) values (?1, ?2)",
778                    "mode": "execute",
779                    "params": [1, "ok"]
780                },
781                {
782                    "sql": "select label from t where id = ?1",
783                    "mode": "query",
784                    "params": [1]
785                }
786            ]
787        }));
788
789        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
790        assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(true));
791        let results = out
792            .get("results")
793            .and_then(|v| v.as_array())
794            .expect("results array");
795        assert_eq!(results.len(), 3);
796        let query_result_rows = results[2]
797            .get("rows")
798            .and_then(|v| v.as_array())
799            .expect("query rows");
800        assert_eq!(query_result_rows.len(), 1);
801    }
802
803    #[test]
804    fn transaction_rolls_back_on_step_failure() {
805        let out = transaction(&json!({
806            "connection": "sqlite::memory:",
807            "steps": [
808                {
809                    "sql": "create table if not exists t (id integer)",
810                    "mode": "execute"
811                },
812                {
813                    "sql": "this is invalid sql",
814                    "mode": "execute"
815                }
816            ]
817        }));
818
819        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
820        assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(false));
821        assert_eq!(out.get("failed_step").and_then(|v| v.as_u64()), Some(1));
822        assert_eq!(
823            out.get("error")
824                .and_then(|v| v.get("code"))
825                .and_then(|v| v.as_str()),
826            Some("sql_transaction_step_failed")
827        );
828    }
829
830    #[test]
831    fn transaction_rollback_is_deterministic_for_persisted_connection() {
832        let (connection, path) = sqlite_temp_connection("rollback-deterministic");
833
834        let setup = execute(&json!({
835            "connection": connection,
836            "sql": "create table if not exists t (id integer, label text)"
837        }));
838        assert_eq!(setup.get("ok").and_then(|v| v.as_bool()), Some(true));
839
840        let out = transaction(&json!({
841            "connection": connection,
842            "steps": [
843                {
844                    "sql": "insert into t (id, label) values (?1, ?2)",
845                    "mode": "execute",
846                    "params": [1, "should_rollback"]
847                },
848                {
849                    "sql": "this is invalid sql",
850                    "mode": "execute"
851                }
852            ]
853        }));
854
855        assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
856        assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(false));
857
858        let verify = query(&json!({
859            "connection": connection,
860            "sql": "select count(*) as count from t"
861        }));
862        assert_eq!(verify.get("ok").and_then(|v| v.as_bool()), Some(true));
863        let rows = verify
864            .get("rows")
865            .and_then(|v| v.as_array())
866            .expect("rows array");
867        let count = rows
868            .first()
869            .and_then(|v| v.get("count"))
870            .and_then(|v| v.as_i64());
871        assert_eq!(count, Some(0));
872
873        let _ = fs::remove_file(path);
874    }
875}