1use serde_json::{json, Map, Value as JsonValue};
2use sqlx::{AnyPool, Column, Row};
3use std::env;
4use std::sync::Once;
5use std::time::Instant;
6
7static SQLX_ANY_DRIVERS: Once = Once::new();
8
9pub fn query(args: &JsonValue) -> JsonValue {
10 let connection = match required_string(args, "connection") {
11 Ok(v) => v,
12 Err(e) => return error_payload("validation_error", "missing_connection", &e),
13 };
14 let sql = match required_string(args, "sql") {
15 Ok(v) => v,
16 Err(e) => return error_payload("validation_error", "missing_sql", &e),
17 };
18
19 let params = match optional_params(args) {
20 Ok(v) => v,
21 Err(e) => return error_payload("validation_error", "sql_params_invalid", &e),
22 };
23 let max_rows = optional_u64(args, "max_rows");
24 let max_payload_bytes = optional_u64(args, "max_payload_bytes");
25
26 let resolved = match resolve_connection(&connection) {
27 Ok(v) => v,
28 Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
29 };
30
31 let started = Instant::now();
32 let rt = match tokio::runtime::Builder::new_current_thread()
33 .enable_all()
34 .build()
35 {
36 Ok(rt) => rt,
37 Err(e) => {
38 return error_payload("runtime_error", "tokio_runtime_init_failed", &e.to_string())
39 }
40 };
41
42 let result: Result<JsonValue, String> = rt.block_on(async {
43 ensure_any_drivers();
44 let pool = AnyPool::connect(&resolved)
45 .await
46 .map_err(|e| format!("connect failed: {e}"))?;
47
48 let query = match bind_params(sqlx::query(&sql), ¶ms) {
49 Ok(query) => query,
50 Err(e) => return Err(e),
51 };
52
53 let rows = query
54 .fetch_all(&pool)
55 .await
56 .map_err(|e| format!("query failed: {e}"))?;
57
58 if let Some(max_rows) = max_rows {
59 if rows.len() > max_rows as usize {
60 return Err(format!(
61 "row count {} exceeds max_rows {}",
62 rows.len(),
63 max_rows
64 ));
65 }
66 }
67
68 let mut out_rows = Vec::with_capacity(rows.len());
69 for row in rows {
70 let mut obj = Map::new();
71 for (idx, col) in row.columns().iter().enumerate() {
72 let key = col.name().to_string();
73 let value = decode_row_value(&row, idx);
74 obj.insert(key, value);
75 }
76 out_rows.push(JsonValue::Object(obj));
77 }
78
79 if let Some(max_payload_bytes) = max_payload_bytes {
80 let payload_bytes = serde_json::to_vec(&out_rows)
81 .map(|v| v.len())
82 .map_err(|e| format!("serialize query rows failed: {e}"))?;
83 if payload_bytes > max_payload_bytes as usize {
84 return Err(format!(
85 "payload size {} exceeds max_payload_bytes {}",
86 payload_bytes, max_payload_bytes
87 ));
88 }
89 }
90
91 Ok(json!({
92 "ok": true,
93 "connection": connection,
94 "row_count": out_rows.len(),
95 "rows": out_rows,
96 "elapsed_ms": started.elapsed().as_millis() as u64,
97 }))
98 });
99
100 match result {
101 Ok(v) => v,
102 Err(e) => error_payload("query_error", "sql_query_failed", &e),
103 }
104}
105
106pub fn execute(args: &JsonValue) -> JsonValue {
107 let connection = match required_string(args, "connection") {
108 Ok(v) => v,
109 Err(e) => return error_payload("validation_error", "missing_connection", &e),
110 };
111 let sql = match required_string(args, "sql") {
112 Ok(v) => v,
113 Err(e) => return error_payload("validation_error", "missing_sql", &e),
114 };
115
116 let params = match optional_params(args) {
117 Ok(v) => v,
118 Err(e) => return error_payload("validation_error", "sql_params_invalid", &e),
119 };
120
121 let resolved = match resolve_connection(&connection) {
122 Ok(v) => v,
123 Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
124 };
125
126 let started = Instant::now();
127 let rt = match tokio::runtime::Builder::new_current_thread()
128 .enable_all()
129 .build()
130 {
131 Ok(rt) => rt,
132 Err(e) => {
133 return error_payload("runtime_error", "tokio_runtime_init_failed", &e.to_string())
134 }
135 };
136
137 let result: Result<JsonValue, String> = rt.block_on(async {
138 ensure_any_drivers();
139 let pool = AnyPool::connect(&resolved)
140 .await
141 .map_err(|e| format!("connect failed: {e}"))?;
142
143 let query = match bind_params(sqlx::query(&sql), ¶ms) {
144 Ok(query) => query,
145 Err(e) => return Err(e),
146 };
147
148 let outcome = query
149 .execute(&pool)
150 .await
151 .map_err(|e| format!("execute failed: {e}"))?;
152
153 Ok(json!({
154 "ok": true,
155 "connection": connection,
156 "rows_affected": outcome.rows_affected(),
157 "elapsed_ms": started.elapsed().as_millis() as u64,
158 }))
159 });
160
161 match result {
162 Ok(v) => v,
163 Err(e) => error_payload("query_error", "sql_execute_failed", &e),
164 }
165}
166
167pub fn health(args: &JsonValue) -> JsonValue {
168 let connection = match required_string(args, "connection") {
169 Ok(v) => v,
170 Err(e) => return error_payload("validation_error", "missing_connection", &e),
171 };
172
173 let resolved = match resolve_connection(&connection) {
174 Ok(v) => v,
175 Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
176 };
177
178 let started = Instant::now();
179 let rt = match tokio::runtime::Builder::new_current_thread()
180 .enable_all()
181 .build()
182 {
183 Ok(rt) => rt,
184 Err(e) => {
185 return error_payload("runtime_error", "tokio_runtime_init_failed", &e.to_string())
186 }
187 };
188
189 let result: Result<JsonValue, String> = rt.block_on(async {
190 ensure_any_drivers();
191 let pool = AnyPool::connect(&resolved)
192 .await
193 .map_err(|e| format!("connect failed: {e}"))?;
194
195 sqlx::query("select 1")
196 .execute(&pool)
197 .await
198 .map_err(|e| format!("health check failed: {e}"))?;
199
200 Ok(json!({
201 "ok": true,
202 "connection": connection,
203 "latency_ms": started.elapsed().as_millis() as u64,
204 }))
205 });
206
207 match result {
208 Ok(v) => v,
209 Err(e) => error_payload("connection_error", "sql_health_failed", &e),
210 }
211}
212
213pub fn transaction(args: &JsonValue) -> JsonValue {
214 let connection = match required_string(args, "connection") {
215 Ok(v) => v,
216 Err(e) => return error_payload("validation_error", "missing_connection", &e),
217 };
218
219 let steps = match parse_transaction_steps(args) {
220 Ok(v) => v,
221 Err(e) => return error_payload("validation_error", "sql_transaction_invalid", &e),
222 };
223
224 let resolved = match resolve_connection(&connection) {
225 Ok(v) => v,
226 Err(e) => return error_payload("connection_error", "sql_connection_unresolved", &e),
227 };
228
229 let started = Instant::now();
230 let rt = match tokio::runtime::Builder::new_current_thread()
231 .enable_all()
232 .build()
233 {
234 Ok(rt) => rt,
235 Err(e) => {
236 return error_payload("runtime_error", "tokio_runtime_init_failed", &e.to_string())
237 }
238 };
239
240 let result: Result<JsonValue, JsonValue> = rt.block_on(async {
241 ensure_any_drivers();
242 let pool = AnyPool::connect(&resolved).await.map_err(|e| {
243 error_payload(
244 "connection_error",
245 "sql_connect_failed",
246 &format!("connect failed: {e}"),
247 )
248 })?;
249
250 let mut tx = pool.begin().await.map_err(|e| {
251 error_payload(
252 "query_error",
253 "sql_transaction_begin_failed",
254 &format!("begin failed: {e}"),
255 )
256 })?;
257
258 let mut results = Vec::with_capacity(steps.len());
259
260 for (idx, step) in steps.iter().enumerate() {
261 let query = bind_params(sqlx::query(&step.sql), &step.params).map_err(|e| {
262 transaction_failure_payload(
263 &connection,
264 &results,
265 idx,
266 started,
267 "sql_params_invalid",
268 &e,
269 )
270 })?;
271
272 if step.mode == TransactionStepMode::Query {
273 let rows = query.fetch_all(&mut *tx).await.map_err(|e| {
274 transaction_failure_payload(
275 &connection,
276 &results,
277 idx,
278 started,
279 "sql_transaction_step_failed",
280 &format!("query step failed: {e}"),
281 )
282 })?;
283
284 let mut out_rows = Vec::with_capacity(rows.len());
285 for row in rows {
286 let mut obj = Map::new();
287 for (col_idx, col) in row.columns().iter().enumerate() {
288 obj.insert(col.name().to_string(), decode_row_value(&row, col_idx));
289 }
290 out_rows.push(JsonValue::Object(obj));
291 }
292
293 results.push(json!({
294 "mode": "query",
295 "row_count": out_rows.len(),
296 "rows": out_rows,
297 }));
298 } else {
299 let outcome = query.execute(&mut *tx).await.map_err(|e| {
300 transaction_failure_payload(
301 &connection,
302 &results,
303 idx,
304 started,
305 "sql_transaction_step_failed",
306 &format!("execute step failed: {e}"),
307 )
308 })?;
309
310 results.push(json!({
311 "mode": "execute",
312 "rows_affected": outcome.rows_affected(),
313 }));
314 }
315 }
316
317 tx.commit().await.map_err(|e| {
318 error_payload(
319 "query_error",
320 "sql_transaction_commit_failed",
321 &format!("commit failed: {e}"),
322 )
323 })?;
324
325 Ok(json!({
326 "ok": true,
327 "connection": connection,
328 "committed": true,
329 "results": results,
330 "elapsed_ms": started.elapsed().as_millis() as u64,
331 }))
332 });
333
334 match result {
335 Ok(v) => v,
336 Err(v) => v,
337 }
338}
339
340fn required_string(args: &JsonValue, key: &str) -> Result<String, String> {
341 args.get(key)
342 .and_then(|v| v.as_str())
343 .map(ToOwned::to_owned)
344 .or_else(|| {
345 args.get("__input")
346 .and_then(|v| v.as_object())
347 .and_then(|obj| obj.get(key))
348 .and_then(|v| v.as_str())
349 .map(ToOwned::to_owned)
350 })
351 .ok_or_else(|| format!("missing required '{}'", key))
352}
353
354fn optional_params(args: &JsonValue) -> Result<Vec<JsonValue>, String> {
355 let candidate = args.get("params").cloned().or_else(|| {
356 args.get("__input")
357 .and_then(|v| v.as_object())
358 .and_then(|obj| obj.get("params").cloned())
359 });
360
361 match candidate {
362 None => Ok(Vec::new()),
363 Some(JsonValue::Array(items)) => Ok(items),
364 Some(_) => Err("params must be an array".to_string()),
365 }
366}
367
368fn optional_u64(args: &JsonValue, key: &str) -> Option<u64> {
369 args.get(key)
370 .and_then(|v| {
371 v.as_u64()
372 .or_else(|| v.as_str().and_then(|s| s.parse::<u64>().ok()))
373 })
374 .or_else(|| {
375 args.get("__input")
376 .and_then(|v| v.as_object())
377 .and_then(|obj| obj.get(key))
378 .and_then(|v| {
379 v.as_u64()
380 .or_else(|| v.as_str().and_then(|s| s.parse::<u64>().ok()))
381 })
382 })
383}
384
385#[derive(Debug, Clone, Copy, PartialEq, Eq)]
386enum TransactionStepMode {
387 Query,
388 Execute,
389}
390
391#[derive(Debug, Clone)]
392struct TransactionStep {
393 mode: TransactionStepMode,
394 sql: String,
395 params: Vec<JsonValue>,
396}
397
398fn parse_transaction_steps(args: &JsonValue) -> Result<Vec<TransactionStep>, String> {
399 let raw_steps = args
400 .get("steps")
401 .cloned()
402 .or_else(|| {
403 args.get("__input")
404 .and_then(|v| v.as_object())
405 .and_then(|obj| obj.get("steps").cloned())
406 })
407 .ok_or_else(|| "missing required 'steps'".to_string())?;
408
409 let list = raw_steps
410 .as_array()
411 .ok_or_else(|| "steps must be an array".to_string())?;
412
413 if list.is_empty() {
414 return Err("steps must contain at least one entry".to_string());
415 }
416
417 let mut out = Vec::with_capacity(list.len());
418 for (idx, value) in list.iter().enumerate() {
419 let obj = value
420 .as_object()
421 .ok_or_else(|| format!("step[{idx}] must be an object"))?;
422
423 let sql = obj
424 .get("sql")
425 .and_then(|v| v.as_str())
426 .map(ToOwned::to_owned)
427 .ok_or_else(|| format!("step[{idx}] missing required 'sql'"))?;
428
429 let mode = match obj.get("mode").and_then(|v| v.as_str()) {
430 Some("query") => TransactionStepMode::Query,
431 Some("execute") | None => TransactionStepMode::Execute,
432 Some(other) => {
433 return Err(format!(
434 "step[{idx}] has invalid mode '{}', expected query|execute",
435 other
436 ))
437 }
438 };
439
440 let params = match obj.get("params") {
441 None => Vec::new(),
442 Some(JsonValue::Array(items)) => items.clone(),
443 Some(_) => return Err(format!("step[{idx}] params must be an array")),
444 };
445
446 out.push(TransactionStep { mode, sql, params });
447 }
448
449 Ok(out)
450}
451
452fn bind_params<'q>(
453 mut query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
454 params: &[JsonValue],
455) -> Result<sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>, String> {
456 for value in params {
457 query = match value {
458 JsonValue::Null => query.bind(Option::<String>::None),
459 JsonValue::Bool(v) => query.bind(*v),
460 JsonValue::Number(n) => {
461 if let Some(v) = n.as_i64() {
462 query.bind(v)
463 } else if let Some(v) = n.as_u64() {
464 if let Ok(as_i64) = i64::try_from(v) {
465 query.bind(as_i64)
466 } else {
467 query.bind(v as f64)
468 }
469 } else if let Some(v) = n.as_f64() {
470 query.bind(v)
471 } else {
472 return Err("unsupported numeric param representation".to_string());
473 }
474 }
475 JsonValue::String(v) => query.bind(v.clone()),
476 JsonValue::Array(_) | JsonValue::Object(_) => {
477 return Err("only scalar params are supported (null/bool/number/string)".to_string())
478 }
479 };
480 }
481
482 Ok(query)
483}
484
485fn resolve_connection(connection: &str) -> Result<String, String> {
486 if connection.contains("://") || connection.starts_with("sqlite:") {
487 return Ok(connection.to_string());
488 }
489
490 let env_key = format!(
491 "GRAPHEME_SQL_CONNECTION_{}",
492 connection
493 .chars()
494 .map(|c| if c.is_ascii_alphanumeric() {
495 c.to_ascii_uppercase()
496 } else {
497 '_'
498 })
499 .collect::<String>()
500 );
501
502 if let Ok(url) = env::var(&env_key) {
503 if !url.trim().is_empty() {
504 return Ok(url);
505 }
506 }
507
508 if let Ok(map_raw) = env::var("GRAPHEME_SQL_CONNECTIONS") {
509 if let Ok(map_json) = serde_json::from_str::<JsonValue>(&map_raw) {
510 if let Some(url) = map_json
511 .get(connection)
512 .and_then(|v| v.as_str())
513 .map(ToOwned::to_owned)
514 {
515 return Ok(url);
516 }
517 }
518 }
519
520 Err(format!(
521 "connection '{}' is unresolved; set {} or GRAPHEME_SQL_CONNECTIONS",
522 connection, env_key
523 ))
524}
525
526fn decode_row_value(row: &sqlx::any::AnyRow, idx: usize) -> JsonValue {
527 if let Ok(v) = row.try_get::<Option<i64>, _>(idx) {
528 return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
529 }
530 if let Ok(v) = row.try_get::<Option<f64>, _>(idx) {
531 return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
532 }
533 if let Ok(v) = row.try_get::<Option<bool>, _>(idx) {
534 return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
535 }
536 if let Ok(v) = row.try_get::<Option<String>, _>(idx) {
537 return v.map(JsonValue::from).unwrap_or(JsonValue::Null);
538 }
539
540 JsonValue::Null
541}
542
543fn error_payload(kind: &str, code: &str, message: &str) -> JsonValue {
544 json!({
545 "ok": false,
546 "error": {
547 "kind": kind,
548 "code": code,
549 "message": message,
550 "retryable": false
551 }
552 })
553}
554
555fn transaction_failure_payload(
556 connection: &str,
557 results: &[JsonValue],
558 failed_step: usize,
559 started: Instant,
560 code: &str,
561 message: &str,
562) -> JsonValue {
563 json!({
564 "ok": false,
565 "connection": connection,
566 "committed": false,
567 "failed_step": failed_step,
568 "results": results,
569 "elapsed_ms": started.elapsed().as_millis() as u64,
570 "error": {
571 "kind": "query_error",
572 "code": code,
573 "message": message,
574 "retryable": false,
575 }
576 })
577}
578
579fn ensure_any_drivers() {
580 SQLX_ANY_DRIVERS.call_once(sqlx::any::install_default_drivers);
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586 use serde_json::json;
587 use std::fs;
588 use std::time::{SystemTime, UNIX_EPOCH};
589
590 fn sqlite_temp_connection(tag: &str) -> (String, std::path::PathBuf) {
591 let mut path = std::env::temp_dir();
592 let ts = SystemTime::now()
593 .duration_since(UNIX_EPOCH)
594 .expect("system clock")
595 .as_nanos();
596 path.push(format!("grapheme-sql-{tag}-{ts}.db"));
597 (format!("sqlite://{}?mode=rwc", path.display()), path)
598 }
599
600 #[test]
601 fn health_accepts_direct_sqlite_url_connection() {
602 let out = health(&json!({ "connection": "sqlite::memory:" }));
603 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
604 }
605
606 #[test]
607 fn query_returns_rows_for_basic_select() {
608 let out = query(&json!({
609 "connection": "sqlite::memory:",
610 "sql": "select 1 as ok"
611 }));
612 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
613 assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(1));
614 }
615
616 #[test]
617 fn execute_reports_rows_affected() {
618 let out = execute(&json!({
619 "connection": "sqlite::memory:",
620 "sql": "create table if not exists t (id integer)"
621 }));
622 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
623 assert!(out.get("rows_affected").and_then(|v| v.as_u64()).is_some());
624 }
625
626 #[test]
627 fn query_reports_unresolved_connection_id() {
628 let out = query(&json!({
629 "connection": "missing_conn",
630 "sql": "select 1"
631 }));
632 assert_eq!(
633 out.get("error")
634 .and_then(|v| v.get("code"))
635 .and_then(|v| v.as_str()),
636 Some("sql_connection_unresolved")
637 );
638 }
639
640 #[test]
641 fn query_supports_scalar_positional_params() {
642 let out = query(&json!({
643 "connection": "sqlite::memory:",
644 "sql": "select ?1 as n, ?2 as t, ?3 as b, ?4 as z",
645 "params": [42, "hello", true, null]
646 }));
647
648 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
649 let rows = out
650 .get("rows")
651 .and_then(|v| v.as_array())
652 .expect("rows should be present");
653 assert_eq!(rows.len(), 1);
654
655 let row = rows
656 .first()
657 .and_then(|v| v.as_object())
658 .expect("row object");
659 assert_eq!(row.get("n").and_then(|v| v.as_i64()), Some(42));
660 assert_eq!(row.get("t").and_then(|v| v.as_str()), Some("hello"));
661 let b = row.get("b").cloned().unwrap_or(JsonValue::Null);
662 assert!(matches!(b, JsonValue::Bool(true) | JsonValue::Number(_)));
663 if let JsonValue::Number(n) = b {
664 assert_eq!(n.as_i64(), Some(1));
665 }
666 assert_eq!(row.get("z"), Some(&JsonValue::Null));
667 }
668
669 #[test]
670 fn execute_supports_positional_params() {
671 let out = execute(&json!({
672 "connection": "sqlite::memory:",
673 "sql": "select ?1",
674 "params": [7]
675 }));
676
677 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
678 }
679
680 #[test]
681 fn query_rejects_non_array_params() {
682 let out = query(&json!({
683 "connection": "sqlite::memory:",
684 "sql": "select 1",
685 "params": {"a": 1}
686 }));
687
688 assert_eq!(
689 out.get("error")
690 .and_then(|v| v.get("code"))
691 .and_then(|v| v.as_str()),
692 Some("sql_params_invalid")
693 );
694 }
695
696 #[test]
697 fn query_enforces_max_rows_limit() {
698 let out = query(&json!({
699 "connection": "sqlite::memory:",
700 "sql": "select 1 as n union all select 2 as n",
701 "max_rows": 1
702 }));
703
704 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
705 assert_eq!(
706 out.get("error")
707 .and_then(|v| v.get("code"))
708 .and_then(|v| v.as_str()),
709 Some("sql_query_failed")
710 );
711 assert!(out
712 .get("error")
713 .and_then(|v| v.get("message"))
714 .and_then(|v| v.as_str())
715 .unwrap_or_default()
716 .contains("exceeds max_rows"));
717 }
718
719 #[test]
720 fn query_enforces_max_payload_bytes_limit() {
721 let out = query(&json!({
722 "connection": "sqlite::memory:",
723 "sql": "select 'abcdefghijklmnopqrstuvwxyz' as payload",
724 "max_payload_bytes": 8
725 }));
726
727 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
728 assert_eq!(
729 out.get("error")
730 .and_then(|v| v.get("code"))
731 .and_then(|v| v.as_str()),
732 Some("sql_query_failed")
733 );
734 assert!(out
735 .get("error")
736 .and_then(|v| v.get("message"))
737 .and_then(|v| v.as_str())
738 .unwrap_or_default()
739 .contains("exceeds max_payload_bytes"));
740 }
741
742 #[test]
743 fn query_handles_high_row_count_when_within_limit() {
744 let out = query(&json!({
745 "connection": "sqlite::memory:",
746 "sql": "with recursive cnt(x) as (select 1 union all select x + 1 from cnt where x < 128) select x from cnt",
747 "max_rows": 128
748 }));
749
750 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
751 assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(128));
752 }
753
754 #[test]
755 fn query_handles_high_payload_near_limit_boundary() {
756 let out = query(&json!({
757 "connection": "sqlite::memory:",
758 "sql": "with recursive cnt(x) as (select 1 union all select x + 1 from cnt where x < 64) select x, 'aaaaaaaaaaaaaaaa' as payload from cnt",
759 "max_rows": 64,
760 "max_payload_bytes": 4096
761 }));
762
763 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
764 assert_eq!(out.get("row_count").and_then(|v| v.as_u64()), Some(64));
765 }
766
767 #[test]
768 fn transaction_runs_execute_and_query_steps() {
769 let out = transaction(&json!({
770 "connection": "sqlite::memory:",
771 "steps": [
772 {
773 "sql": "create table if not exists t (id integer, label text)",
774 "mode": "execute"
775 },
776 {
777 "sql": "insert into t (id, label) values (?1, ?2)",
778 "mode": "execute",
779 "params": [1, "ok"]
780 },
781 {
782 "sql": "select label from t where id = ?1",
783 "mode": "query",
784 "params": [1]
785 }
786 ]
787 }));
788
789 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(true));
790 assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(true));
791 let results = out
792 .get("results")
793 .and_then(|v| v.as_array())
794 .expect("results array");
795 assert_eq!(results.len(), 3);
796 let query_result_rows = results[2]
797 .get("rows")
798 .and_then(|v| v.as_array())
799 .expect("query rows");
800 assert_eq!(query_result_rows.len(), 1);
801 }
802
803 #[test]
804 fn transaction_rolls_back_on_step_failure() {
805 let out = transaction(&json!({
806 "connection": "sqlite::memory:",
807 "steps": [
808 {
809 "sql": "create table if not exists t (id integer)",
810 "mode": "execute"
811 },
812 {
813 "sql": "this is invalid sql",
814 "mode": "execute"
815 }
816 ]
817 }));
818
819 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
820 assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(false));
821 assert_eq!(out.get("failed_step").and_then(|v| v.as_u64()), Some(1));
822 assert_eq!(
823 out.get("error")
824 .and_then(|v| v.get("code"))
825 .and_then(|v| v.as_str()),
826 Some("sql_transaction_step_failed")
827 );
828 }
829
830 #[test]
831 fn transaction_rollback_is_deterministic_for_persisted_connection() {
832 let (connection, path) = sqlite_temp_connection("rollback-deterministic");
833
834 let setup = execute(&json!({
835 "connection": connection,
836 "sql": "create table if not exists t (id integer, label text)"
837 }));
838 assert_eq!(setup.get("ok").and_then(|v| v.as_bool()), Some(true));
839
840 let out = transaction(&json!({
841 "connection": connection,
842 "steps": [
843 {
844 "sql": "insert into t (id, label) values (?1, ?2)",
845 "mode": "execute",
846 "params": [1, "should_rollback"]
847 },
848 {
849 "sql": "this is invalid sql",
850 "mode": "execute"
851 }
852 ]
853 }));
854
855 assert_eq!(out.get("ok").and_then(|v| v.as_bool()), Some(false));
856 assert_eq!(out.get("committed").and_then(|v| v.as_bool()), Some(false));
857
858 let verify = query(&json!({
859 "connection": connection,
860 "sql": "select count(*) as count from t"
861 }));
862 assert_eq!(verify.get("ok").and_then(|v| v.as_bool()), Some(true));
863 let rows = verify
864 .get("rows")
865 .and_then(|v| v.as_array())
866 .expect("rows array");
867 let count = rows
868 .first()
869 .and_then(|v| v.get("count"))
870 .and_then(|v| v.as_i64());
871 assert_eq!(count, Some(0));
872
873 let _ = fs::remove_file(path);
874 }
875}