Skip to main content

lutra_runner_duckdb/
lib.rs

1//! DuckDB Lutra runner
2
3mod case;
4mod params;
5mod schema;
6
7pub use lutra_runner::{Run, RunSync};
8
9use lutra_bin::{ir, rr};
10use std::{collections::HashMap, path, sync::Arc};
11use thiserror::Error;
12
13/// DuckDB runner for executing Lutra programs
14///
15/// Uses the synchronous duckdb crate. DuckDB operations are CPU-bound
16/// and don't involve actual async I/O, making this suitable for RunSync.
17pub struct Runner {
18    conn: duckdb::Connection,
19}
20
21impl Runner {
22    pub fn new(
23        conn: duckdb::Connection,
24        file_system: Option<path::PathBuf>,
25    ) -> Result<Self, Error> {
26        if let Some(fs_path) = file_system {
27            let set_fs_access = format!("SET file_search_path = '{}'", fs_path.display());
28            conn.execute(&set_fs_access, [])?;
29        }
30        Ok(Self { conn })
31    }
32
33    /// Open a file-based DuckDB database
34    pub fn open(path: &str, file_system: Option<path::PathBuf>) -> Result<Self, Error> {
35        let conn = duckdb::Connection::open(path)?;
36        Self::new(conn, file_system)
37    }
38
39    /// Create an in-memory DuckDB database
40    pub fn in_memory(file_system: Option<path::PathBuf>) -> Result<Self, Error> {
41        let conn = duckdb::Connection::open_in_memory()?;
42        Self::new(conn, file_system)
43    }
44}
45
46#[derive(Clone)]
47pub struct PreparedProgram {
48    program: Arc<rr::SqlProgram>,
49}
50
51impl lutra_runner::RunSync for Runner {
52    type Error = Error;
53    type Prepared = PreparedProgram;
54
55    fn prepare_sync(&mut self, program: rr::Program) -> Result<Self::Prepared, Self::Error> {
56        // Accept both SqlDuckDB (preferred) and SqlPg (backward compatibility)
57        let program = program
58            .into_sql_duck_db()
59            .map_err(|_| Error::UnsupportedFormat)?;
60
61        // Don't prepare a statement, because we cannot cache it for later anyway.
62        // That's because statement borrows connection, which means we cannot use it for other queries.
63
64        Ok(PreparedProgram {
65            program: Arc::from(program),
66        })
67    }
68
69    fn execute_sync(
70        &mut self,
71        handle: &Self::Prepared,
72        input: &[u8],
73    ) -> Result<Vec<u8>, Self::Error> {
74        let ctx = Context::new(&handle.program.defs);
75
76        // Convert input to SQL params
77        let args = params::to_sql(input, &handle.program.input_ty, &ctx)?;
78
79        // Execute query and get Arrow RecordBatches
80        let mut stmt = self.conn.prepare(&handle.program.sql)?;
81        let arrow = stmt.query_arrow(args.as_params())?;
82        let batches: Vec<_> = arrow.collect();
83
84        // Convert Arrow to Lutra format
85        let output =
86            lutra_arrow::arrow_to_lutra(batches, &handle.program.output_ty, &handle.program.defs)
87                .map_err(|e| Error::ArrowConversion(e.to_string()))?;
88
89        Ok(output.to_vec())
90    }
91
92    fn get_interface_sync(&mut self) -> Result<String, Self::Error> {
93        schema::pull_interface(self)
94    }
95}
96
97#[derive(Error, Debug)]
98pub enum Error {
99    #[error("bad result: {}", .0)]
100    BadDatabaseResponse(&'static str),
101    #[error("duckdb: {}", .0)]
102    DuckDB(#[from] duckdb::Error),
103    #[error("unsupported program format")]
104    UnsupportedFormat,
105    #[error("unsupported data type: {}", .0)]
106    UnsupportedDataType(&'static str),
107    #[error("arrow conversion: {}", .0)]
108    ArrowConversion(String),
109}
110
111pub(crate) struct Context<'a> {
112    pub types: HashMap<&'a ir::Path, &'a ir::Ty>,
113}
114
115impl<'a> Context<'a> {
116    pub fn new(ty_defs: &'a [ir::TyDef]) -> Self {
117        Context {
118            types: ty_defs.iter().map(|def| (&def.name, &def.ty)).collect(),
119        }
120    }
121
122    pub fn get_ty_mat(&self, ty: &'a ir::Ty) -> &'a ir::Ty {
123        match &ty.kind {
124            ir::TyKind::Ident(path) => self.types.get(path).unwrap(),
125            _ => ty,
126        }
127    }
128
129    /// Checks if an enum is an "option" enum. Must match [lutra_compiler::sql::utils::is_option].
130    fn is_option(&self, variants: &[ir::TyEnumVariant]) -> bool {
131        if variants.len() != 2 || !variants[0].ty.is_unit() {
132            return false;
133        }
134        let some_ty = self.get_ty_mat(&variants[1].ty);
135        some_ty.kind.is_primitive() || some_ty.kind.is_array()
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn test_in_memory() {
145        let _runner = Runner::in_memory(None).unwrap();
146        // Basic smoke test
147    }
148
149    #[test]
150    fn test_file_based() {
151        let temp = std::env::temp_dir().join("test_lutra_duckdb.duckdb");
152        let _runner = Runner::open(temp.to_str().unwrap(), None).unwrap();
153        // Test persistence
154        std::fs::remove_file(temp).ok();
155    }
156}