1use std::{cell::RefCell, collections::HashMap};
2
3use aiscript_arena::{Gc, GcRefLock, RefLock};
4use sqlx::{Column, Postgres, Row, TypeInfo, ValueRef};
5
6use tokio::runtime::Handle;
7
8use crate::{
9 NativeFn, Value, VmError,
10 module::ModuleKind,
11 object::{Class, Instance, Object},
12 vm::{Context, State},
13};
14
15thread_local! {
16 static ACTIVE_TRANSACTION: RefCell<Option<sqlx::Transaction<'static, Postgres>>> = const { RefCell::new(None) };
17}
18
19pub fn create_pg_module(ctx: Context) -> ModuleKind {
21 let name = ctx.intern(b"std.db.pg");
22
23 let exports = [
24 ("query", Value::NativeFunction(NativeFn(pg_query))),
25 ("query_as", Value::NativeFunction(NativeFn(pg_query_as))),
26 (
27 "begin_transaction",
28 Value::NativeFunction(NativeFn(transaction::begin_transaction)),
29 ),
30 ]
31 .into_iter()
32 .map(|(name, f)| (ctx.intern_static(name), f))
33 .collect();
34
35 ModuleKind::Native { name, exports }
36}
37
38fn column_to_value<'gc>(
39 ctx: Context<'gc>,
40 row: &sqlx::postgres::PgRow,
41 i: usize,
42 type_info: &sqlx::postgres::PgTypeInfo,
43) -> Result<Value<'gc>, VmError> {
44 if row.try_get_raw(i).map_or(true, |v| v.is_null()) {
46 return Ok(Value::Nil);
47 }
48
49 let value = match type_info.name() {
50 "INT2" | "SMALLINT" => row.try_get::<i16, _>(i).map(|v| Value::Number(v as f64)),
52 "INT4" | "INTEGER" => row.try_get::<i32, _>(i).map(|v| Value::Number(v as f64)),
53 "INT8" | "BIGINT" => row.try_get::<i64, _>(i).map(|v| Value::Number(v as f64)),
54
55 "SERIAL2" | "SMALLSERIAL" => row.try_get::<i16, _>(i).map(|v| Value::Number(v as f64)),
57 "SERIAL4" | "SERIAL" => row.try_get::<i32, _>(i).map(|v| Value::Number(v as f64)),
58 "SERIAL8" | "BIGSERIAL" => row.try_get::<i64, _>(i).map(|v| Value::Number(v as f64)),
59
60 "FLOAT4" | "REAL" => row.try_get::<f32, _>(i).map(|v| Value::Number(v as f64)),
62 "FLOAT8" | "DOUBLE PRECISION" => row.try_get::<f64, _>(i).map(Value::Number),
63
64 "VARCHAR" | "CHAR" | "TEXT" | "BPCHAR" | "NAME" => row
71 .try_get::<String, _>(i)
72 .map(|v| Value::String(ctx.intern(v.as_bytes()))),
73
74 "BOOL" | "BOOLEAN" => row.try_get::<bool, _>(i).map(Value::Boolean),
76
77 "UUID" => row
79 .try_get::<sqlx::types::Uuid, _>(i)
80 .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
81
82 "DATE" => row
84 .try_get::<sqlx::types::chrono::NaiveDate, _>(i)
85 .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
86 "TIME" => row
87 .try_get::<sqlx::types::chrono::NaiveTime, _>(i)
88 .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
89 "TIMESTAMP" => row
90 .try_get::<sqlx::types::chrono::NaiveDateTime, _>(i)
91 .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
92 "TIMESTAMPTZ" => row
93 .try_get::<sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>, _>(i)
94 .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
95
96 "JSON" | "JSONB" => row
98 .try_get::<serde_json::Value, _>(i)
99 .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
100
101 t if t.starts_with("_") => {
103 match &t[1..] {
104 "INT2" | "SMALLINT" => row.try_get::<Vec<i16>, _>(i).map(|v| {
106 Value::array(
107 &ctx,
108 v.into_iter().map(|n| Value::Number(n as f64)).collect(),
109 )
110 }),
111 "INT4" | "INTEGER" => row.try_get::<Vec<i32>, _>(i).map(|v| {
112 Value::array(
113 &ctx,
114 v.into_iter().map(|n| Value::Number(n as f64)).collect(),
115 )
116 }),
117 "INT8" | "BIGINT" => row.try_get::<Vec<i64>, _>(i).map(|v| {
118 Value::array(
119 &ctx,
120 v.into_iter().map(|n| Value::Number(n as f64)).collect(),
121 )
122 }),
123
124 "FLOAT4" | "REAL" => row.try_get::<Vec<f32>, _>(i).map(|v| {
126 Value::array(
127 &ctx,
128 v.into_iter().map(|n| Value::Number(n as f64)).collect(),
129 )
130 }),
131 "FLOAT8" | "DOUBLE PRECISION" => row
132 .try_get::<Vec<f64>, _>(i)
133 .map(|v| Value::array(&ctx, v.into_iter().map(Value::Number).collect())),
134
135 "VARCHAR" | "TEXT" => row.try_get::<Vec<String>, _>(i).map(|v| {
137 Value::array(
138 &ctx,
139 v.into_iter()
140 .map(|s| Value::String(ctx.intern(s.as_bytes())))
141 .collect(),
142 )
143 }),
144
145 "BOOL" | "BOOLEAN" => row
147 .try_get::<Vec<bool>, _>(i)
148 .map(|v| Value::array(&ctx, v.into_iter().map(Value::Boolean).collect())),
149
150 _ => row.try_get::<Vec<String>, _>(i).map(|v| {
152 Value::array(
153 &ctx,
154 v.into_iter()
155 .map(|s| Value::String(ctx.intern(s.as_bytes())))
156 .collect(),
157 )
158 }),
159 }
160 }
161
162 "BYTEA" => row
164 .try_get::<Vec<u8>, _>(i)
165 .map(|v| Value::String(ctx.intern(&v))),
166
167 _ => row
169 .try_get::<String, _>(i)
170 .map(|v| Value::String(ctx.intern(v.as_bytes()))),
171 }
172 .unwrap_or_else(|_| {
173 row.try_get::<String, _>(i)
175 .map(|v| Value::String(ctx.intern(v.as_bytes())))
176 .unwrap_or(Value::Nil)
177 });
178 Ok(value)
179}
180
181fn row_to_object<'gc>(ctx: Context<'gc>, row: &sqlx::postgres::PgRow) -> Value<'gc> {
183 let mut obj = Object::default();
184
185 for (i, column) in row.columns().iter().enumerate() {
186 let column_name = ctx.intern(column.name().as_bytes());
187 let value = column_to_value(ctx, row, i, column.type_info()).unwrap_or(Value::Nil);
188 obj.fields.insert(column_name, value);
189 }
190
191 Value::Object(Gc::new(&ctx, RefLock::new(obj)))
192}
193
194fn execute_query<'a, E>(
195 executor: E,
196 query: &str,
197 bindings: Vec<Value<'_>>,
198) -> Result<Vec<sqlx::postgres::PgRow>, VmError>
199where
200 E: sqlx::Executor<'a, Database = sqlx::Postgres>,
201{
202 Handle::current()
203 .block_on(async {
204 let mut query_builder = sqlx::query(query);
205
206 for value in bindings {
208 match value {
209 Value::Number(n) => {
210 query_builder = query_builder.bind(n);
211 }
212 Value::String(s) => {
213 let s_str = s.to_str().unwrap();
214 if let Ok(uuid) = sqlx::types::Uuid::parse_str(s_str) {
216 query_builder = query_builder.bind(uuid);
217 } else if let Ok(date) =
218 sqlx::types::chrono::NaiveDate::parse_from_str(s_str, "%Y-%m-%d")
219 {
220 query_builder = query_builder.bind(date);
221 } else if let Ok(datetime) =
222 sqlx::types::chrono::NaiveDateTime::parse_from_str(
223 s_str,
224 "%Y-%m-%dT%H:%M:%S",
225 )
226 {
227 query_builder = query_builder.bind(datetime);
228 } else {
229 query_builder = query_builder.bind(s_str);
230 }
231 }
232 Value::Boolean(b) => {
233 query_builder = query_builder.bind(b);
234 }
235 Value::Nil => {
236 query_builder = query_builder.bind(Option::<String>::None);
237 }
238 Value::List(arr) => {
239 let arr = &arr.borrow().data;
240 if let Some(first) = arr.first() {
241 match first {
242 Value::Number(_) => {
243 let nums: Vec<f64> = arr
244 .iter()
245 .filter_map(|v| match v {
246 Value::Number(n) => Some(*n),
247 _ => None,
248 })
249 .collect();
250 query_builder = query_builder.bind(nums);
251 }
252 Value::String(_) => {
253 let strings: Vec<String> = arr
254 .iter()
255 .filter_map(|v| match v {
256 Value::String(s) => {
257 Some(s.to_str().unwrap().to_string())
258 }
259 _ => None,
260 })
261 .collect();
262 query_builder = query_builder.bind(strings);
263 }
264 Value::Boolean(_) => {
265 let bools: Vec<bool> = arr
266 .iter()
267 .filter_map(|v| match v {
268 Value::Boolean(b) => Some(*b),
269 _ => None,
270 })
271 .collect();
272 query_builder = query_builder.bind(bools);
273 }
274 _ => {
275 return Err(sqlx::Error::Protocol(
276 "Unsupported array element type".into(),
277 ));
278 }
279 }
280 } else {
281 query_builder = query_builder.bind::<Vec<String>>(vec![]);
282 }
283 }
284 _ => return Err(sqlx::Error::Protocol("Unsupported parameter type".into())),
285 }
286 }
287
288 query_builder.fetch_all(executor).await
289 })
290 .map_err(|e| VmError::RuntimeError(format!("Database query error: {}", e)))
291}
292
293fn execute_typed_query<'gc, 'a, E>(
294 ctx: Context<'gc>,
295 executor: E,
296 class: GcRefLock<'gc, Class<'gc>>,
297 query: &str,
298 bindings: Vec<Value<'gc>>,
299) -> Result<Value<'gc>, VmError>
300where
301 E: sqlx::Executor<'a, Database = sqlx::Postgres>,
302{
303 let rows = execute_query(executor, query, bindings)?;
305
306 let mut results = Vec::new();
313 for row in rows {
314 let mut instance = Instance::new(class);
316
317 for (i, column) in row.columns().iter().enumerate() {
319 let field_name = ctx.intern(column.name().as_bytes());
320 let value = column_to_value(ctx, &row, i, column.type_info())?;
321 instance.fields.insert(field_name, value);
322 }
323
324 results.push(Value::Instance(Gc::new(&ctx, RefLock::new(instance))));
325 }
326
327 Ok(Value::array(&ctx, results))
328}
329
330fn pg_query<'gc>(state: &mut State<'gc>, args: Vec<Value<'gc>>) -> Result<Value<'gc>, VmError> {
332 if args.is_empty() {
333 return Err(VmError::RuntimeError(
334 "query() requires at least a SQL query string.".into(),
335 ));
336 }
337
338 let sql = args[0].as_string()?;
339 let ctx = state.get_context();
340 let conn = state.pg_connection.as_ref().unwrap();
341 let rows = execute_query(
343 conn,
344 sql.to_str().unwrap(),
345 args.into_iter().skip(1).collect(),
346 )?;
347
348 let mut results = Vec::new();
350 for row in rows {
351 results.push(row_to_object(ctx, &row));
352 }
353
354 Ok(Value::array(state, results))
355}
356
357fn pg_query_as<'gc>(state: &mut State<'gc>, args: Vec<Value<'gc>>) -> Result<Value<'gc>, VmError> {
358 if args.len() < 2 {
359 return Err(VmError::RuntimeError(
360 "query_as() requires a class and SQL query string.".into(),
361 ));
362 }
363
364 let class = match args[0] {
366 Value::Class(class) => class,
367 _ => {
368 return Err(VmError::RuntimeError(
369 "First argument to query_as() must be a class.".into(),
370 ));
371 }
372 };
373
374 let sql = args[1].as_string()?;
375 let ctx = state.get_context();
376 let conn = state.pg_connection.as_ref().unwrap();
377
378 execute_typed_query(
379 ctx,
380 conn,
381 class,
382 sql.to_str().unwrap(),
383 args.into_iter().skip(2).collect(),
384 )
385}
386
387mod transaction {
388 use super::*;
389
390 fn create_transaction_class(ctx: Context) -> Gc<RefLock<Class>> {
391 let methods = [
392 (ctx.intern(b"query"), Value::NativeFunction(NativeFn(query))),
393 (
394 ctx.intern(b"query_as"),
395 Value::NativeFunction(NativeFn(query_as)),
396 ),
397 (
398 ctx.intern(b"commit"),
399 Value::NativeFunction(NativeFn(commit)),
400 ),
401 (
402 ctx.intern(b"rollback"),
403 Value::NativeFunction(NativeFn(rollback)),
404 ),
405 ]
406 .into_iter()
407 .collect();
408 Gc::new(
409 &ctx,
410 RefLock::new(Class {
411 name: ctx.intern(b"Transaction"),
412 methods,
413 static_methods: HashMap::default(),
414 }),
415 )
416 }
417
418 pub(super) fn begin_transaction<'gc>(
419 state: &mut State<'gc>,
420 _args: Vec<Value<'gc>>,
421 ) -> Result<Value<'gc>, VmError> {
422 let has_active = ACTIVE_TRANSACTION.with(|tx| tx.borrow().is_some());
424 if has_active {
425 return Err(VmError::RuntimeError("Transaction already active".into()));
426 }
427
428 let ctx = state.get_context();
429 let conn = state.pg_connection.as_ref().unwrap();
430 let tx = Handle::current()
431 .block_on(async move { conn.begin().await })
432 .map_err(|e| VmError::RuntimeError(format!("Failed to begin transaction: {}", e)))?;
433
434 ACTIVE_TRANSACTION.with(|cell| {
436 *cell.borrow_mut() = Some(tx);
437 });
438
439 let instance = Instance::new(create_transaction_class(ctx));
441 Ok(Value::Instance(Gc::new(&ctx, RefLock::new(instance))))
442 }
443
444 fn query<'gc>(state: &mut State<'gc>, args: Vec<Value<'gc>>) -> Result<Value<'gc>, VmError> {
445 if args.is_empty() {
446 return Err(VmError::RuntimeError(
447 "query() requires a SQL query string.".into(),
448 ));
449 }
450
451 let query = args[0].as_string()?;
452 let ctx = state.get_context();
453
454 let result = ACTIVE_TRANSACTION.with(|cell| {
456 if let Some(tx) = (*cell.borrow_mut()).as_mut() {
457 let rows = execute_query(
458 &mut **tx,
459 query.to_str().unwrap(),
460 args.into_iter().skip(1).collect(),
461 );
462 Some(rows)
463 } else {
464 None
465 }
466 });
467
468 match result {
469 Some(Ok(rows)) => {
470 let mut results = Vec::new();
472 for row in rows {
473 results.push(row_to_object(ctx, &row));
474 }
475 Ok(Value::array(&ctx, results))
476 }
477 Some(Err(e)) => Err(VmError::RuntimeError(format!("Database error: {e}"))),
478 None => Err(VmError::RuntimeError("No active transaction".into())),
479 }
480 }
481
482 fn query_as<'gc>(state: &mut State<'gc>, args: Vec<Value<'gc>>) -> Result<Value<'gc>, VmError> {
483 if args.len() < 2 {
484 return Err(VmError::RuntimeError(
485 "query_as() requires a class and SQL query string.".into(),
486 ));
487 }
488
489 let class = match args[0] {
491 Value::Class(class) => class,
492 _ => {
493 return Err(VmError::RuntimeError(
494 "First argument to query_as() must be a class.".into(),
495 ));
496 }
497 };
498
499 let query = args[1].as_string()?;
500 let ctx = state.get_context();
501
502 let result = ACTIVE_TRANSACTION.with(|cell| {
504 if let Some(tx) = (*cell.borrow_mut()).as_mut() {
505 let bindings = args.into_iter().skip(2).collect();
506 Some(execute_typed_query(
507 ctx,
508 &mut **tx,
509 class,
510 query.to_str().unwrap(),
511 bindings,
512 ))
513 } else {
514 None
515 }
516 });
517
518 match result {
519 Some(result) => result,
520 None => Err(VmError::RuntimeError("No active transaction".into())),
521 }
522 }
523
524 fn commit<'gc>(_state: &mut State<'gc>, _args: Vec<Value<'gc>>) -> Result<Value<'gc>, VmError> {
525 let result = ACTIVE_TRANSACTION.with(|cell| {
526 cell.borrow_mut()
527 .take() .map(|tx| Handle::current().block_on(async { tx.commit().await }))
529 });
530
531 match result {
532 Some(Ok(())) => Ok(Value::Nil),
533 Some(Err(e)) => Err(VmError::RuntimeError(format!(
534 "Failed to commit transaction: {e}"
535 ))),
536 None => Err(VmError::RuntimeError("No active transaction".into())),
537 }
538 }
539
540 fn rollback<'gc>(
541 _state: &mut State<'gc>,
542 _args: Vec<Value<'gc>>,
543 ) -> Result<Value<'gc>, VmError> {
544 let result = ACTIVE_TRANSACTION.with(|cell| {
545 cell.borrow_mut()
546 .take() .map(|tx| Handle::current().block_on(async { tx.rollback().await }))
548 });
549
550 match result {
551 Some(Ok(())) => Ok(Value::Nil),
552 Some(Err(e)) => Err(VmError::RuntimeError(format!(
553 "Failed to rollback transaction: {e}"
554 ))),
555 None => Err(VmError::RuntimeError("No active transaction".into())),
556 }
557 }
558}