aiscript_vm/stdlib/db/
pg.rs

1use std::{cell::RefCell, collections::HashMap};
2
3use aiscript_arena::{Gc, GcRefLock, RefLock};
4use sqlx::{Column, Postgres, Row, TypeInfo, ValueRef};
5
6use tokio::runtime::Handle;
7
8use crate::{
9    NativeFn, Value, VmError,
10    module::ModuleKind,
11    object::{Class, Instance, Object},
12    vm::{Context, State},
13};
14
15thread_local! {
16    static ACTIVE_TRANSACTION: RefCell<Option<sqlx::Transaction<'static, Postgres>>> = const { RefCell::new(None) };
17}
18
19// Create the PostgreSQL module with native functions
20pub fn create_pg_module(ctx: Context) -> ModuleKind {
21    let name = ctx.intern(b"std.db.pg");
22
23    let exports = [
24        ("query", Value::NativeFunction(NativeFn(pg_query))),
25        ("query_as", Value::NativeFunction(NativeFn(pg_query_as))),
26        (
27            "begin_transaction",
28            Value::NativeFunction(NativeFn(transaction::begin_transaction)),
29        ),
30    ]
31    .into_iter()
32    .map(|(name, f)| (ctx.intern_static(name), f))
33    .collect();
34
35    ModuleKind::Native { name, exports }
36}
37
38fn column_to_value<'gc>(
39    ctx: Context<'gc>,
40    row: &sqlx::postgres::PgRow,
41    i: usize,
42    type_info: &sqlx::postgres::PgTypeInfo,
43) -> Result<Value<'gc>, VmError> {
44    // Handle NULL values first
45    if row.try_get_raw(i).map_or(true, |v| v.is_null()) {
46        return Ok(Value::Nil);
47    }
48
49    let value = match type_info.name() {
50        // Integer types
51        "INT2" | "SMALLINT" => row.try_get::<i16, _>(i).map(|v| Value::Number(v as f64)),
52        "INT4" | "INTEGER" => row.try_get::<i32, _>(i).map(|v| Value::Number(v as f64)),
53        "INT8" | "BIGINT" => row.try_get::<i64, _>(i).map(|v| Value::Number(v as f64)),
54
55        // Serial types (same as integer types)
56        "SERIAL2" | "SMALLSERIAL" => row.try_get::<i16, _>(i).map(|v| Value::Number(v as f64)),
57        "SERIAL4" | "SERIAL" => row.try_get::<i32, _>(i).map(|v| Value::Number(v as f64)),
58        "SERIAL8" | "BIGSERIAL" => row.try_get::<i64, _>(i).map(|v| Value::Number(v as f64)),
59
60        // Floating-point types
61        "FLOAT4" | "REAL" => row.try_get::<f32, _>(i).map(|v| Value::Number(v as f64)),
62        "FLOAT8" | "DOUBLE PRECISION" => row.try_get::<f64, _>(i).map(Value::Number),
63
64        // Decimal/numeric types
65        // "NUMERIC" | "DECIMAL" => row
66        //     .try_get::<sqlx::types::Decimal, _>(i)
67        //     .map(|v| Value::Number(v.to_string().parse::<f64>().unwrap_or(0.0))),
68
69        // Character types
70        "VARCHAR" | "CHAR" | "TEXT" | "BPCHAR" | "NAME" => row
71            .try_get::<String, _>(i)
72            .map(|v| Value::String(ctx.intern(v.as_bytes()))),
73
74        // Boolean type
75        "BOOL" | "BOOLEAN" => row.try_get::<bool, _>(i).map(Value::Boolean),
76
77        // UUID type
78        "UUID" => row
79            .try_get::<sqlx::types::Uuid, _>(i)
80            .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
81
82        // Date/Time types
83        "DATE" => row
84            .try_get::<sqlx::types::chrono::NaiveDate, _>(i)
85            .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
86        "TIME" => row
87            .try_get::<sqlx::types::chrono::NaiveTime, _>(i)
88            .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
89        "TIMESTAMP" => row
90            .try_get::<sqlx::types::chrono::NaiveDateTime, _>(i)
91            .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
92        "TIMESTAMPTZ" => row
93            .try_get::<sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>, _>(i)
94            .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
95
96        // JSON types
97        "JSON" | "JSONB" => row
98            .try_get::<serde_json::Value, _>(i)
99            .map(|v| Value::String(ctx.intern(v.to_string().as_bytes()))),
100
101        // Array types
102        t if t.starts_with("_") => {
103            match &t[1..] {
104                // Integer arrays
105                "INT2" | "SMALLINT" => row.try_get::<Vec<i16>, _>(i).map(|v| {
106                    Value::array(
107                        &ctx,
108                        v.into_iter().map(|n| Value::Number(n as f64)).collect(),
109                    )
110                }),
111                "INT4" | "INTEGER" => row.try_get::<Vec<i32>, _>(i).map(|v| {
112                    Value::array(
113                        &ctx,
114                        v.into_iter().map(|n| Value::Number(n as f64)).collect(),
115                    )
116                }),
117                "INT8" | "BIGINT" => row.try_get::<Vec<i64>, _>(i).map(|v| {
118                    Value::array(
119                        &ctx,
120                        v.into_iter().map(|n| Value::Number(n as f64)).collect(),
121                    )
122                }),
123
124                // Float arrays
125                "FLOAT4" | "REAL" => row.try_get::<Vec<f32>, _>(i).map(|v| {
126                    Value::array(
127                        &ctx,
128                        v.into_iter().map(|n| Value::Number(n as f64)).collect(),
129                    )
130                }),
131                "FLOAT8" | "DOUBLE PRECISION" => row
132                    .try_get::<Vec<f64>, _>(i)
133                    .map(|v| Value::array(&ctx, v.into_iter().map(Value::Number).collect())),
134
135                // Text arrays
136                "VARCHAR" | "TEXT" => row.try_get::<Vec<String>, _>(i).map(|v| {
137                    Value::array(
138                        &ctx,
139                        v.into_iter()
140                            .map(|s| Value::String(ctx.intern(s.as_bytes())))
141                            .collect(),
142                    )
143                }),
144
145                // Boolean arrays
146                "BOOL" | "BOOLEAN" => row
147                    .try_get::<Vec<bool>, _>(i)
148                    .map(|v| Value::array(&ctx, v.into_iter().map(Value::Boolean).collect())),
149
150                // Default to string representation for unknown array types
151                _ => row.try_get::<Vec<String>, _>(i).map(|v| {
152                    Value::array(
153                        &ctx,
154                        v.into_iter()
155                            .map(|s| Value::String(ctx.intern(s.as_bytes())))
156                            .collect(),
157                    )
158                }),
159            }
160        }
161
162        // Binary data
163        "BYTEA" => row
164            .try_get::<Vec<u8>, _>(i)
165            .map(|v| Value::String(ctx.intern(&v))),
166
167        // Default to string for unknown types
168        _ => row
169            .try_get::<String, _>(i)
170            .map(|v| Value::String(ctx.intern(v.as_bytes()))),
171    }
172    .unwrap_or_else(|_| {
173        // If conversion fails, try to get as string
174        row.try_get::<String, _>(i)
175            .map(|v| Value::String(ctx.intern(v.as_bytes())))
176            .unwrap_or(Value::Nil)
177    });
178    Ok(value)
179}
180
181// Convert database row to AIScript object
182fn row_to_object<'gc>(ctx: Context<'gc>, row: &sqlx::postgres::PgRow) -> Value<'gc> {
183    let mut obj = Object::default();
184
185    for (i, column) in row.columns().iter().enumerate() {
186        let column_name = ctx.intern(column.name().as_bytes());
187        let value = column_to_value(ctx, row, i, column.type_info()).unwrap_or(Value::Nil);
188        obj.fields.insert(column_name, value);
189    }
190
191    Value::Object(Gc::new(&ctx, RefLock::new(obj)))
192}
193
194fn execute_query<'a, E>(
195    executor: E,
196    query: &str,
197    bindings: Vec<Value<'_>>,
198) -> Result<Vec<sqlx::postgres::PgRow>, VmError>
199where
200    E: sqlx::Executor<'a, Database = sqlx::Postgres>,
201{
202    Handle::current()
203        .block_on(async {
204            let mut query_builder = sqlx::query(query);
205
206            // Bind parameters
207            for value in bindings {
208                match value {
209                    Value::Number(n) => {
210                        query_builder = query_builder.bind(n);
211                    }
212                    Value::String(s) => {
213                        let s_str = s.to_str().unwrap();
214                        // Try to parse special types from string
215                        if let Ok(uuid) = sqlx::types::Uuid::parse_str(s_str) {
216                            query_builder = query_builder.bind(uuid);
217                        } else if let Ok(date) =
218                            sqlx::types::chrono::NaiveDate::parse_from_str(s_str, "%Y-%m-%d")
219                        {
220                            query_builder = query_builder.bind(date);
221                        } else if let Ok(datetime) =
222                            sqlx::types::chrono::NaiveDateTime::parse_from_str(
223                                s_str,
224                                "%Y-%m-%dT%H:%M:%S",
225                            )
226                        {
227                            query_builder = query_builder.bind(datetime);
228                        } else {
229                            query_builder = query_builder.bind(s_str);
230                        }
231                    }
232                    Value::Boolean(b) => {
233                        query_builder = query_builder.bind(b);
234                    }
235                    Value::Nil => {
236                        query_builder = query_builder.bind(Option::<String>::None);
237                    }
238                    Value::List(arr) => {
239                        let arr = &arr.borrow().data;
240                        if let Some(first) = arr.first() {
241                            match first {
242                                Value::Number(_) => {
243                                    let nums: Vec<f64> = arr
244                                        .iter()
245                                        .filter_map(|v| match v {
246                                            Value::Number(n) => Some(*n),
247                                            _ => None,
248                                        })
249                                        .collect();
250                                    query_builder = query_builder.bind(nums);
251                                }
252                                Value::String(_) => {
253                                    let strings: Vec<String> = arr
254                                        .iter()
255                                        .filter_map(|v| match v {
256                                            Value::String(s) => {
257                                                Some(s.to_str().unwrap().to_string())
258                                            }
259                                            _ => None,
260                                        })
261                                        .collect();
262                                    query_builder = query_builder.bind(strings);
263                                }
264                                Value::Boolean(_) => {
265                                    let bools: Vec<bool> = arr
266                                        .iter()
267                                        .filter_map(|v| match v {
268                                            Value::Boolean(b) => Some(*b),
269                                            _ => None,
270                                        })
271                                        .collect();
272                                    query_builder = query_builder.bind(bools);
273                                }
274                                _ => {
275                                    return Err(sqlx::Error::Protocol(
276                                        "Unsupported array element type".into(),
277                                    ));
278                                }
279                            }
280                        } else {
281                            query_builder = query_builder.bind::<Vec<String>>(vec![]);
282                        }
283                    }
284                    _ => return Err(sqlx::Error::Protocol("Unsupported parameter type".into())),
285                }
286            }
287
288            query_builder.fetch_all(executor).await
289        })
290        .map_err(|e| VmError::RuntimeError(format!("Database query error: {}", e)))
291}
292
293fn execute_typed_query<'gc, 'a, E>(
294    ctx: Context<'gc>,
295    executor: E,
296    class: GcRefLock<'gc, Class<'gc>>,
297    query: &str,
298    bindings: Vec<Value<'gc>>,
299) -> Result<Value<'gc>, VmError>
300where
301    E: sqlx::Executor<'a, Database = sqlx::Postgres>,
302{
303    // Execute the query
304    let rows = execute_query(executor, query, bindings)?;
305
306    // TODO: Validate first row's columns against class fields?
307    // if let Some(first_row) = rows.first() {
308    //     validate_query_columns(ctx, class, first_row)?;
309    // }
310
311    // Convert rows to class instances
312    let mut results = Vec::new();
313    for row in rows {
314        // Create new instance
315        let mut instance = Instance::new(class);
316
317        // Set fields from row data
318        for (i, column) in row.columns().iter().enumerate() {
319            let field_name = ctx.intern(column.name().as_bytes());
320            let value = column_to_value(ctx, &row, i, column.type_info())?;
321            instance.fields.insert(field_name, value);
322        }
323
324        results.push(Value::Instance(Gc::new(&ctx, RefLock::new(instance))));
325    }
326
327    Ok(Value::array(&ctx, results))
328}
329
330// Native function implementations
331fn pg_query<'gc>(state: &mut State<'gc>, args: Vec<Value<'gc>>) -> Result<Value<'gc>, VmError> {
332    if args.is_empty() {
333        return Err(VmError::RuntimeError(
334            "query() requires at least a SQL query string.".into(),
335        ));
336    }
337
338    let sql = args[0].as_string()?;
339    let ctx = state.get_context();
340    let conn = state.pg_connection.as_ref().unwrap();
341    // Execute query in runtime
342    let rows = execute_query(
343        conn,
344        sql.to_str().unwrap(),
345        args.into_iter().skip(1).collect(),
346    )?;
347
348    // Convert rows to array of objects
349    let mut results = Vec::new();
350    for row in rows {
351        results.push(row_to_object(ctx, &row));
352    }
353
354    Ok(Value::array(state, results))
355}
356
357fn pg_query_as<'gc>(state: &mut State<'gc>, args: Vec<Value<'gc>>) -> Result<Value<'gc>, VmError> {
358    if args.len() < 2 {
359        return Err(VmError::RuntimeError(
360            "query_as() requires a class and SQL query string.".into(),
361        ));
362    }
363
364    // First argument should be a class
365    let class = match args[0] {
366        Value::Class(class) => class,
367        _ => {
368            return Err(VmError::RuntimeError(
369                "First argument to query_as() must be a class.".into(),
370            ));
371        }
372    };
373
374    let sql = args[1].as_string()?;
375    let ctx = state.get_context();
376    let conn = state.pg_connection.as_ref().unwrap();
377
378    execute_typed_query(
379        ctx,
380        conn,
381        class,
382        sql.to_str().unwrap(),
383        args.into_iter().skip(2).collect(),
384    )
385}
386
387mod transaction {
388    use super::*;
389
390    fn create_transaction_class(ctx: Context) -> Gc<RefLock<Class>> {
391        let methods = [
392            (ctx.intern(b"query"), Value::NativeFunction(NativeFn(query))),
393            (
394                ctx.intern(b"query_as"),
395                Value::NativeFunction(NativeFn(query_as)),
396            ),
397            (
398                ctx.intern(b"commit"),
399                Value::NativeFunction(NativeFn(commit)),
400            ),
401            (
402                ctx.intern(b"rollback"),
403                Value::NativeFunction(NativeFn(rollback)),
404            ),
405        ]
406        .into_iter()
407        .collect();
408        Gc::new(
409            &ctx,
410            RefLock::new(Class {
411                name: ctx.intern(b"Transaction"),
412                methods,
413                static_methods: HashMap::default(),
414            }),
415        )
416    }
417
418    pub(super) fn begin_transaction<'gc>(
419        state: &mut State<'gc>,
420        _args: Vec<Value<'gc>>,
421    ) -> Result<Value<'gc>, VmError> {
422        // Check if there's already an active transaction
423        let has_active = ACTIVE_TRANSACTION.with(|tx| tx.borrow().is_some());
424        if has_active {
425            return Err(VmError::RuntimeError("Transaction already active".into()));
426        }
427
428        let ctx = state.get_context();
429        let conn = state.pg_connection.as_ref().unwrap();
430        let tx = Handle::current()
431            .block_on(async move { conn.begin().await })
432            .map_err(|e| VmError::RuntimeError(format!("Failed to begin transaction: {}", e)))?;
433
434        // Store transaction in thread local
435        ACTIVE_TRANSACTION.with(|cell| {
436            *cell.borrow_mut() = Some(tx);
437        });
438
439        // Create and return new instance
440        let instance = Instance::new(create_transaction_class(ctx));
441        Ok(Value::Instance(Gc::new(&ctx, RefLock::new(instance))))
442    }
443
444    fn query<'gc>(state: &mut State<'gc>, args: Vec<Value<'gc>>) -> Result<Value<'gc>, VmError> {
445        if args.is_empty() {
446            return Err(VmError::RuntimeError(
447                "query() requires a SQL query string.".into(),
448            ));
449        }
450
451        let query = args[0].as_string()?;
452        let ctx = state.get_context();
453
454        // Execute query with the active transaction
455        let result = ACTIVE_TRANSACTION.with(|cell| {
456            if let Some(tx) = (*cell.borrow_mut()).as_mut() {
457                let rows = execute_query(
458                    &mut **tx,
459                    query.to_str().unwrap(),
460                    args.into_iter().skip(1).collect(),
461                );
462                Some(rows)
463            } else {
464                None
465            }
466        });
467
468        match result {
469            Some(Ok(rows)) => {
470                // Convert rows to array of objects
471                let mut results = Vec::new();
472                for row in rows {
473                    results.push(row_to_object(ctx, &row));
474                }
475                Ok(Value::array(&ctx, results))
476            }
477            Some(Err(e)) => Err(VmError::RuntimeError(format!("Database error: {e}"))),
478            None => Err(VmError::RuntimeError("No active transaction".into())),
479        }
480    }
481
482    fn query_as<'gc>(state: &mut State<'gc>, args: Vec<Value<'gc>>) -> Result<Value<'gc>, VmError> {
483        if args.len() < 2 {
484            return Err(VmError::RuntimeError(
485                "query_as() requires a class and SQL query string.".into(),
486            ));
487        }
488
489        // First argument should be a class
490        let class = match args[0] {
491            Value::Class(class) => class,
492            _ => {
493                return Err(VmError::RuntimeError(
494                    "First argument to query_as() must be a class.".into(),
495                ));
496            }
497        };
498
499        let query = args[1].as_string()?;
500        let ctx = state.get_context();
501
502        // Execute query using the active transaction
503        let result = ACTIVE_TRANSACTION.with(|cell| {
504            if let Some(tx) = (*cell.borrow_mut()).as_mut() {
505                let bindings = args.into_iter().skip(2).collect();
506                Some(execute_typed_query(
507                    ctx,
508                    &mut **tx,
509                    class,
510                    query.to_str().unwrap(),
511                    bindings,
512                ))
513            } else {
514                None
515            }
516        });
517
518        match result {
519            Some(result) => result,
520            None => Err(VmError::RuntimeError("No active transaction".into())),
521        }
522    }
523
524    fn commit<'gc>(_state: &mut State<'gc>, _args: Vec<Value<'gc>>) -> Result<Value<'gc>, VmError> {
525        let result = ACTIVE_TRANSACTION.with(|cell| {
526            cell.borrow_mut()
527                .take() // Set ACTIVE_TRANSACTION to None
528                .map(|tx| Handle::current().block_on(async { tx.commit().await }))
529        });
530
531        match result {
532            Some(Ok(())) => Ok(Value::Nil),
533            Some(Err(e)) => Err(VmError::RuntimeError(format!(
534                "Failed to commit transaction: {e}"
535            ))),
536            None => Err(VmError::RuntimeError("No active transaction".into())),
537        }
538    }
539
540    fn rollback<'gc>(
541        _state: &mut State<'gc>,
542        _args: Vec<Value<'gc>>,
543    ) -> Result<Value<'gc>, VmError> {
544        let result = ACTIVE_TRANSACTION.with(|cell| {
545            cell.borrow_mut()
546                .take() // Set ACTIVE_TRANSACTION to None
547                .map(|tx| Handle::current().block_on(async { tx.rollback().await }))
548        });
549
550        match result {
551            Some(Ok(())) => Ok(Value::Nil),
552            Some(Err(e)) => Err(VmError::RuntimeError(format!(
553                "Failed to rollback transaction: {e}"
554            ))),
555            None => Err(VmError::RuntimeError("No active transaction".into())),
556        }
557    }
558}