1use serde_json::{json, Map, Value as JsonValue};
2use sqlx::{AnyPool, Column, Row};
3use std::env;
4use std::sync::Once;
5use std::time::Instant;
6
7static SQLX_ANY_DRIVERS: Once = Once::new();
8
9pub fn query(args: &JsonValue) -> JsonValue {
10 let connection = match required_string(args, "connection") {
11 Ok(v) => v,
12 Err(e) => return error_payload("validation_error", "missing_connection", &e),
13 };
14 let sql = match required_string(args, "sql") {
15 Ok(v) => v,
16 Err(e) => return error_payload("validation_error", "missing_sql", &e),
17 };
18
19 let params = match optional_params(args) {
20 Ok(v) => v,
21 Err(e) => return error_payload("validation_error", "sql_params_invalid", &e),
22 };
23 let max_rows = optional_u64(args, "max_rows");
24 let max_payload_bytes = optional_u64(args, "max_payload_bytes");
25
26 let resolved = match resolve_connection(&connection) {
27 Ok(v) => v,
28 Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
29 };
30
31 let started = Instant::now();
32 let rt = match tokio::runtime::Builder::new_current_thread()
33 .enable_all()
34 .build()
35 {
36 Ok(rt) => rt,
37 Err(e) => {
38 return error_payload(
39 "runtime_error",
40 "tokio_runtime_init_failed",
41 &e.to_string(),
42 )
43 }
44 };
45
46 let result: Result<JsonValue, String> = rt.block_on(async {
47 ensure_any_drivers();
48 let pool = AnyPool::connect(&resolved)
49 .await
50 .map_err(|e| format!("connect failed: {e}"))?;
51
52 let query = match bind_params(sqlx::query(&sql), ¶ms) {
53 Ok(query) => query,
54 Err(e) => return Err(e),
55 };
56
57 let rows = query
58 .fetch_all(&pool)
59 .await
60 .map_err(|e| format!("query failed: {e}"))?;
61
62 if let Some(max_rows) = max_rows {
63 if rows.len() > max_rows as usize {
64 return Err(format!(
65 "row count {} exceeds max_rows {}",
66 rows.len(),
67 max_rows
68 ));
69 }
70 }
71
72 let mut out_rows = Vec::with_capacity(rows.len());
73 for row in rows {
74 let mut obj = Map::new();
75 for (idx, col) in row.columns().iter().enumerate() {
76 let key = col.name().to_string();
77 let value = decode_row_value(&row, idx);
78 obj.insert(key, value);
79 }
80 out_rows.push(JsonValue::Object(obj));
81 }
82
83 if let Some(max_payload_bytes) = max_payload_bytes {
84 let payload_bytes = serde_json::to_vec(&out_rows)
85 .map(|v| v.len())
86 .map_err(|e| format!("serialize query rows failed: {e}"))?;
87 if payload_bytes > max_payload_bytes as usize {
88 return Err(format!(
89 "payload size {} exceeds max_payload_bytes {}",
90 payload_bytes,
91 max_payload_bytes
92 ));
93 }
94 }
95
96 Ok(json!({
97 "ok": true,
98 "connection": connection,
99 "row_count": out_rows.len(),
100 "rows": out_rows,
101 "elapsed_ms": started.elapsed().as_millis() as u64,
102 }))
103 });
104
105 match result {
106 Ok(v) => v,
107 Err(e) => error_payload("query_error", "sql_query_failed", &e),
108 }
109}
110
111pub fn execute(args: &JsonValue) -> JsonValue {
112 let connection = match required_string(args, "connection") {
113 Ok(v) => v,
114 Err(e) => return error_payload("validation_error", "missing_connection", &e),
115 };
116 let sql = match required_string(args, "sql") {
117 Ok(v) => v,
118 Err(e) => return error_payload("validation_error", "missing_sql", &e),
119 };
120
121 let params = match optional_params(args) {
122 Ok(v) => v,
123 Err(e) => return error_payload("validation_error", "sql_params_invalid", &e),
124 };
125
126 let resolved = match resolve_connection(&connection) {
127 Ok(v) => v,
128 Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
129 };
130
131 let started = Instant::now();
132 let rt = match tokio::runtime::Builder::new_current_thread()
133 .enable_all()
134 .build()
135 {
136 Ok(rt) => rt,
137 Err(e) => {
138 return error_payload(
139 "runtime_error",
140 "tokio_runtime_init_failed",
141 &e.to_string(),
142 )
143 }
144 };
145
146 let result: Result<JsonValue, String> = rt.block_on(async {
147 ensure_any_drivers();
148 let pool = AnyPool::connect(&resolved)
149 .await
150 .map_err(|e| format!("connect failed: {e}"))?;
151
152 let query = match bind_params(sqlx::query(&sql), ¶ms) {
153 Ok(query) => query,
154 Err(e) => return Err(e),
155 };
156
157 let outcome = query
158 .execute(&pool)
159 .await
160 .map_err(|e| format!("execute failed: {e}"))?;
161
162 Ok(json!({
163 "ok": true,
164 "connection": connection,
165 "rows_affected": outcome.rows_affected(),
166 "elapsed_ms": started.elapsed().as_millis() as u64,
167 }))
168 });
169
170 match result {
171 Ok(v) => v,
172 Err(e) => error_payload("query_error", "sql_execute_failed", &e),
173 }
174}
175
176pub fn health(args: &JsonValue) -> JsonValue {
177 let connection = match required_string(args, "connection") {
178 Ok(v) => v,
179 Err(e) => return error_payload("validation_error", "missing_connection", &e),
180 };
181
182 let resolved = match resolve_connection(&connection) {
183 Ok(v) => v,
184 Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
185 };
186
187 let started = Instant::now();
188 let rt = match tokio::runtime::Builder::new_current_thread()
189 .enable_all()
190 .build()
191 {
192 Ok(rt) => rt,
193 Err(e) => {
194 return error_payload(
195 "runtime_error",
196 "tokio_runtime_init_failed",
197 &e.to_string(),
198 )
199 }
200 };
201
202 let result: Result<JsonValue, String> = rt.block_on(async {
203 ensure_any_drivers();
204 let pool = AnyPool::connect(&resolved)
205 .await
206 .map_err(|e| format!("connect failed: {e}"))?;
207
208 sqlx::query("select 1")
209 .execute(&pool)
210 .await
211 .map_err(|e| format!("health check failed: {e}"))?;
212
213 Ok(json!({
214 "ok": true,
215 "connection": connection,
216 "latency_ms": started.elapsed().as_millis() as u64,
217 }))
218 });
219
220 match result {
221 Ok(v) => v,
222 Err(e) => error_payload("connection_error", "sql_health_failed", &e),
223 }
224}
225
226pub fn transaction(args: &JsonValue) -> JsonValue {
227 let connection = match required_string(args, "connection") {
228 Ok(v) => v,
229 Err(e) => return error_payload("validation_error", "missing_connection", &e),
230 };
231
232 let steps = match parse_transaction_steps(args) {
233 Ok(v) => v,
234 Err(e) => return error_payload("validation_error", "sql_transaction_invalid", &e),
235 };
236
237 let resolved = match resolve_connection(&connection) {
238 Ok(v) => v,
239 Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
240 };
241
242 let started = Instant::now();
243 let rt = match tokio::runtime::Builder::new_current_thread()
244 .enable_all()
245 .build()
246 {
247 Ok(rt) => rt,
248 Err(e) => {
249 return error_payload(
250 "runtime_error",
251 "tokio_runtime_init_failed",
252 &e.to_string(),
253 )
254 }
255 };
256
257 let result: Result<JsonValue, JsonValue> = rt.block_on(async {
258 ensure_any_drivers();
259 let pool = AnyPool::connect(&resolved)
260 .await
261 .map_err(|e| error_payload("connection_error", "sql_connect_failed", &format!("connect failed: {e}")))?;
262
263 let mut tx = pool
264 .begin()
265 .await
266 .map_err(|e| error_payload("query_error", "sql_transaction_begin_failed", &format!("begin failed: {e}")))?;
267
268 let mut results = Vec::with_capacity(steps.len());
269
270 for (idx, step) in steps.iter().enumerate() {
271 let query = bind_params(sqlx::query(&step.sql), &step.params)
272 .map_err(|e| transaction_failure_payload(&connection, &results, idx, started, "sql_params_invalid", &e))?;
273
274 if step.mode == TransactionStepMode::Query {
275 let rows = query
276 .fetch_all(&mut *tx)
277 .await
278 .map_err(|e| transaction_failure_payload(&connection, &results, idx, started, "sql_transaction_step_failed", &format!("query step failed: {e}")))?;
279
280 let mut out_rows = Vec::with_capacity(rows.len());
281 for row in rows {
282 let mut obj = Map::new();
283 for (col_idx, col) in row.columns().iter().enumerate() {
284 obj.insert(col.name().to_string(), decode_row_value(&row, col_idx));
285 }
286 out_rows.push(JsonValue::Object(obj));
287 }
288
289 results.push(json!({
290 "mode": "query",
291 "row_count": out_rows.len(),
292 "rows": out_rows,
293 }));
294 } else {
295 let outcome = query
296 .execute(&mut *tx)
297 .await
298 .map_err(|e| transaction_failure_payload(&connection, &results, idx, started, "sql_transaction_step_failed", &format!("execute step failed: {e}")))?;
299
300 results.push(json!({
301 "mode": "execute",
302 "rows_affected": outcome.rows_affected(),
303 }));
304 }
305 }
306
307 tx.commit()
308 .await
309 .map_err(|e| error_payload("query_error", "sql_transaction_commit_failed", &format!("commit failed: {e}")))?;
310
311 Ok(json!({
312 "ok": true,
313 "connection": connection,
314 "committed": true,
315 "results": results,
316 "elapsed_ms": started.elapsed().as_millis() as u64,
317 }))
318 });
319
320 match result {
321 Ok(v) => v,
322 Err(v) => v,
323 }
324}
325
326fn required_string(args: &JsonValue, key: &str) -> Result<String, String> {
327 args.get(key)
328 .and_then(|v| v.as_str())
329 .map(ToOwned::to_owned)
330 .or_else(|| {
331 args.get("__input")
332 .and_then(|v| v.as_object())
333 .and_then(|obj| obj.get(key))
334 .and_then(|v| v.as_str())
335 .map(ToOwned::to_owned)
336 })
337 .ok_or_else(|| format!("missing required '{}'", key))
338}
339
340fn optional_params(args: &JsonValue) -> Result<Vec<JsonValue>, String> {
341 let candidate = args
342 .get("params")
343 .cloned()
344 .or_else(|| {
345 args.get("__input")
346 .and_then(|v| v.as_object())
347 .and_then(|obj| obj.get("params").cloned())
348 });
349
350 match candidate {
351 None => Ok(Vec::new()),
352 Some(JsonValue::Array(items)) => Ok(items),
353 Some(_) => Err("params must be an array".to_string()),
354 }
355}
356
357fn optional_u64(args: &JsonValue, key: &str) -> Option<u64> {
358 args.get(key)
359 .and_then(|v| v.as_u64().or_else(|| v.as_str().and_then(|s| s.parse::<u64>().ok())))
360 .or_else(|| {
361 args.get("__input")
362 .and_then(|v| v.as_object())
363 .and_then(|obj| obj.get(key))
364 .and_then(|v| {
365 v.as_u64()
366 .or_else(|| v.as_str().and_then(|s| s.parse::<u64>().ok()))
367 })
368 })
369}
370
371#[derive(Debug, Clone, Copy, PartialEq, Eq)]
372enum TransactionStepMode {
373 Query,
374 Execute,
375}
376
377#[derive(Debug, Clone)]
378struct TransactionStep {
379 mode: TransactionStepMode,
380 sql: String,
381 params: Vec<JsonValue>,
382}
383
384fn parse_transaction_steps(args: &JsonValue) -> Result<Vec<TransactionStep>, String> {
385 let raw_steps = args
386 .get("steps")
387 .cloned()
388 .or_else(|| {
389 args.get("__input")
390 .and_then(|v| v.as_object())
391 .and_then(|obj| obj.get("steps").cloned())
392 })
393 .ok_or_else(|| "missing required 'steps'".to_string())?;
394
395 let list = raw_steps
396 .as_array()
397 .ok_or_else(|| "steps must be an array".to_string())?;
398
399 if list.is_empty() {
400 return Err("steps must contain at least one entry".to_string());
401 }
402
403 let mut out = Vec::with_capacity(list.len());
404 for (idx, value) in list.iter().enumerate() {
405 let obj = value
406 .as_object()
407 .ok_or_else(|| format!("step[{idx}] must be an object"))?;
408
409 let sql = obj
410 .get("sql")
411 .and_then(|v| v.as_str())
412 .map(ToOwned::to_owned)
413 .ok_or_else(|| format!("step[{idx}] missing required 'sql'"))?;
414
415 let mode = match obj.get("mode").and_then(|v| v.as_str()) {
416 Some("query") => TransactionStepMode::Query,
417 Some("execute") | None => TransactionStepMode::Execute,
418 Some(other) => {
419 return Err(format!(
420 "step[{idx}] has invalid mode '{}', expected query|execute",
421 other
422 ))
423 }
424 };
425
426 let params = match obj.get("params") {
427 None => Vec::new(),
428 Some(JsonValue::Array(items)) => items.clone(),
429 Some(_) => return Err(format!("step[{idx}] params must be an array")),
430 };
431
432 out.push(TransactionStep { mode, sql, params });
433 }
434
435 Ok(out)
436}
437
438fn bind_params<'q>(
439 mut query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
440 params: &[JsonValue],
441) -> Result<sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>, String> {
442 for value in params {
443 query = match value {
444 JsonValue::Null => query.bind(Option::<String>::None),
445 JsonValue::Bool(v) => query.bind(*v),
446 JsonValue::Number(n) => {
447 if let Some(v) = n.as_i64() {
448 query.bind(v)
449 } else if let Some(v) = n.as_u64() {
450 if let Ok(as_i64) = i64::try_from(v) {
451 query.bind(as_i64)
452 } else {
453 query.bind(v as f64)
454 }
455 } else if let Some(v) = n.as_f64() {
456 query.bind(v)
457 } else {
458 return Err("unsupported numeric param representation".to_string());
459 }
460 }
461 JsonValue::String(v) => query.bind(v.clone()),
462 JsonValue::Array(_) | JsonValue::Object(_) => {
463 return Err("only scalar params are supported (null/bool/number/string)".to_string())
464 }
465 };
466 }
467
468 Ok(query)
469}
470
471fn resolve_connection(connection: &str) -> Result<String, String> {
472 if connection.contains("://") || connection.starts_with("sqlite:") {
473 return Ok(connection.to_string());
474 }
475
476 let env_key = format!(
477 "GRAPHEME_SQL_CONNECTION_{}",
478 connection
479 .chars()
480 .map(|c| if c.is_ascii_alphanumeric() { c.to_ascii_uppercase() } else { '_' })
481 .collect::<String>()
482 );
483
484 if let Ok(url) = env::var(&env_key) {
485 if !url.trim().is_empty() {
486 return Ok(url);
487 }
488 }
489
490 if let Ok(map_raw) = env::var("GRAPHEME_SQL_CONNECTIONS") {
491 if let Ok(map_json) = serde_json::from_str::<JsonValue>(&map_raw) {
492 if let Some(url) = map_json
493 .get(connection)
494 .and_then(|v| v.as_str())
495 .map(ToOwned::to_owned)
496 {
497 return Ok(url);
498 }
499 }
500 }
501
502 Err(format!(
503 "connection '{}' is unresolved; set {} or GRAPHEME_SQL_CONNECTIONS",
504 connection, env_key
505 ))
506}
507
508fn decode_row_value(row: &sqlx::any::AnyRow, idx: usize) -> JsonValue {
509 if let Ok(v) = row.try_get::<Option<i64>, _>(idx) {
510 return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
511 }
512 if let Ok(v) = row.try_get::<Option<f64>, _>(idx) {
513 return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
514 }
515 if let Ok(v) = row.try_get::<Option<bool>, _>(idx) {
516 return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
517 }
518 if let Ok(v) = row.try_get::<Option<String>, _>(idx) {
519 return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
520 }
521
522 JsonValue::Null
523}
524
525fn error_payload(kind: &str, code: &str, message: &str) -> JsonValue {
526 json!({
527 "ok": false,
528 "error": {
529 "kind": kind,
530 "code": code,
531 "message": message,
532 "retryable": false
533 }
534 })
535}
536
537fn transaction_failure_payload(
538 connection: &str,
539 results: &[JsonValue],
540 failed_step: usize,
541 started: Instant,
542 code: &str,
543 message: &str,
544) -> JsonValue {
545 json!({
546 "ok": false,
547 "connection": connection,
548 "committed": false,
549 "failed_step": failed_step,
550 "results": results,
551 "elapsed_ms": started.elapsed().as_millis() as u64,
552 "error": {
553 "kind": "query_error",
554 "code": code,
555 "message": message,
556 "retryable": false,
557 }
558 })
559}
560
561fn ensure_any_drivers() {
562 SQLX_ANY_DRIVERS.call_once(sqlx::any::install_default_drivers);
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use serde_json::json;
569 use std::fs;
570 use std::time::{SystemTime, UNIX_EPOCH};
571
572 fn sqlite_temp_connection(tag: &str) -> (String, std::path::PathBuf) {
573 let mut path = std::env::temp_dir();
574 let ts = SystemTime::now()
575 .duration_since(UNIX_EPOCH)
576 .expect("system clock")
577 .as_nanos();
578 path.push(format!("grapheme-sql-{tag}-{ts}.db"));
579 (format!("sqlite://{}?mode=rwc", path.display()), path)
580 }
581
582 #[test]
583 fn health_accepts_direct_sqlite_url_connection() {
584 let out = health(&json!({ "connection": "sqlite::memory:" }));
585 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
586 }
587
588 #[test]
589 fn query_returns_rows_for_basic_select() {
590 let out = query(&json!({
591 "connection": "sqlite::memory:",
592 "sql": "select 1 as ok"
593 }));
594 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
595 assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(1));
596 }
597
598 #[test]
599 fn execute_reports_rows_affected() {
600 let out = execute(&json!({
601 "connection": "sqlite::memory:",
602 "sql": "create table if not exists t (id integer)"
603 }));
604 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
605 assert!(out.get("rows_affected").and_then(|v| v.as_u64()).is_some());
606 }
607
608 #[test]
609 fn query_reports_unresolved_connection_id() {
610 let out = query(&json!({
611 "connection": "missing_conn",
612 "sql": "select 1"
613 }));
614 assert_eq!(
615 out.get("error")
616 .and_then(|v| v.get("code"))
617 .and_then(|v| v.as_str()),
618 Some("sql_connection_unresolved")
619 );
620 }
621
622 #[test]
623 fn query_supports_scalar_positional_params() {
624 let out = query(&json!({
625 "connection": "sqlite::memory:",
626 "sql": "select ?1 as n, ?2 as t, ?3 as b, ?4 as z",
627 "params": [42, "hello", true, null]
628 }));
629
630 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
631 let rows = out
632 .get("rows")
633 .and_then(|v| v.as_array())
634 .expect("rows should be present");
635 assert_eq!(rows.len(), 1);
636
637 let row = rows.first().and_then(|v| v.as_object()).expect("row object");
638 assert_eq!(row.get("n").and_then(|v| v.as_i64()), Some(42));
639 assert_eq!(row.get("t").and_then(|v| v.as_str()), Some("hello"));
640 let b = row.get("b").cloned().unwrap_or(JsonValue::Null);
641 assert!(matches!(b, JsonValue::Bool(true) | JsonValue::Number(_)));
642 if let JsonValue::Number(n) = b {
643 assert_eq!(n.as_i64(), Some(1));
644 }
645 assert_eq!(row.get("z"), Some(&JsonValue::Null));
646 }
647
648 #[test]
649 fn execute_supports_positional_params() {
650 let out = execute(&json!({
651 "connection": "sqlite::memory:",
652 "sql": "select ?1",
653 "params": [7]
654 }));
655
656 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
657 }
658
659 #[test]
660 fn query_rejects_non_array_params() {
661 let out = query(&json!({
662 "connection": "sqlite::memory:",
663 "sql": "select 1",
664 "params": {"a": 1}
665 }));
666
667 assert_eq!(
668 out.get("error")
669 .and_then(|v| v.get("code"))
670 .and_then(|v| v.as_str()),
671 Some("sql_params_invalid")
672 );
673 }
674
675 #[test]
676 fn query_enforces_max_rows_limit() {
677 let out = query(&json!({
678 "connection": "sqlite::memory:",
679 "sql": "select 1 as n union all select 2 as n",
680 "max_rows": 1
681 }));
682
683 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
684 assert_eq!(
685 out.get("error")
686 .and_then(|v| v.get("code"))
687 .and_then(|v| v.as_str()),
688 Some("sql_query_failed")
689 );
690 assert!(out
691 .get("error")
692 .and_then(|v| v.get("message"))
693 .and_then(|v| v.as_str())
694 .unwrap_or_default()
695 .contains("exceeds max_rows"));
696 }
697
698 #[test]
699 fn query_enforces_max_payload_bytes_limit() {
700 let out = query(&json!({
701 "connection": "sqlite::memory:",
702 "sql": "select 'abcdefghijklmnopqrstuvwxyz' as payload",
703 "max_payload_bytes": 8
704 }));
705
706 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
707 assert_eq!(
708 out.get("error")
709 .and_then(|v| v.get("code"))
710 .and_then(|v| v.as_str()),
711 Some("sql_query_failed")
712 );
713 assert!(out
714 .get("error")
715 .and_then(|v| v.get("message"))
716 .and_then(|v| v.as_str())
717 .unwrap_or_default()
718 .contains("exceeds max_payload_bytes"));
719 }
720
721 #[test]
722 fn query_handles_high_row_count_when_within_limit() {
723 let out = query(&json!({
724 "connection": "sqlite::memory:",
725 "sql": "with recursive cnt(x) as (select 1 union all select x + 1 from cnt where x < 128) select x from cnt",
726 "max_rows": 128
727 }));
728
729 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
730 assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(128));
731 }
732
733 #[test]
734 fn query_handles_high_payload_near_limit_boundary() {
735 let out = query(&json!({
736 "connection": "sqlite::memory:",
737 "sql": "with recursive cnt(x) as (select 1 union all select x + 1 from cnt where x < 64) select x, 'aaaaaaaaaaaaaaaa' as payload from cnt",
738 "max_rows": 64,
739 "max_payload_bytes": 4096
740 }));
741
742 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
743 assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(64));
744 }
745
746 #[test]
747 fn transaction_runs_execute_and_query_steps() {
748 let out = transaction(&json!({
749 "connection": "sqlite::memory:",
750 "steps": [
751 {
752 "sql": "create table if not exists t (id integer, label text)",
753 "mode": "execute"
754 },
755 {
756 "sql": "insert into t (id, label) values (?1, ?2)",
757 "mode": "execute",
758 "params": [1, "ok"]
759 },
760 {
761 "sql": "select label from t where id = ?1",
762 "mode": "query",
763 "params": [1]
764 }
765 ]
766 }));
767
768 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
769 assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(true));
770 let results = out
771 .get("results")
772 .and_then(|v| v.as_array())
773 .expect("results array");
774 assert_eq!(results.len(), 3);
775 let query_result_rows = results[2]
776 .get("rows")
777 .and_then(|v| v.as_array())
778 .expect("query rows");
779 assert_eq!(query_result_rows.len(), 1);
780 }
781
782 #[test]
783 fn transaction_rolls_back_on_step_failure() {
784 let out = transaction(&json!({
785 "connection": "sqlite::memory:",
786 "steps": [
787 {
788 "sql": "create table if not exists t (id integer)",
789 "mode": "execute"
790 },
791 {
792 "sql": "this is invalid sql",
793 "mode": "execute"
794 }
795 ]
796 }));
797
798 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
799 assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(false));
800 assert_eq!(out.get("failed_step").and_then(|v| v.as_u64()), Some(1));
801 assert_eq!(
802 out.get("error")
803 .and_then(|v| v.get("code"))
804 .and_then(|v| v.as_str()),
805 Some("sql_transaction_step_failed")
806 );
807 }
808
809 #[test]
810 fn transaction_rollback_is_deterministic_for_persisted_connection() {
811 let (connection, path) = sqlite_temp_connection("rollback-deterministic");
812
813 let setup = execute(&json!({
814 "connection": connection,
815 "sql": "create table if not exists t (id integer, label text)"
816 }));
817 assert_eq!(setup.get("ok").and_then(|v| v.as_bool()), Some(true));
818
819 let out = transaction(&json!({
820 "connection": connection,
821 "steps": [
822 {
823 "sql": "insert into t (id, label) values (?1, ?2)",
824 "mode": "execute",
825 "params": [1, "should_rollback"]
826 },
827 {
828 "sql": "this is invalid sql",
829 "mode": "execute"
830 }
831 ]
832 }));
833
834 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
835 assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(false));
836
837 let verify = query(&json!({
838 "connection": connection,
839 "sql": "select count(*) as count from t"
840 }));
841 assert_eq!(verify.get("ok").and_then(|v| v.as_bool()), Some(true));
842 let rows = verify
843 .get("rows")
844 .and_then(|v| v.as_array())
845 .expect("rows array");
846 let count = rows
847 .first()
848 .and_then(|v| v.get("count"))
849 .and_then(|v| v.as_i64());
850 assert_eq!(count, Some(0));
851
852 let _ = fs::remove_file(path);
853 }
854}