lutra_runner_duckdb/
lib.rs1mod case;
4mod params;
5mod schema;
6
7pub use lutra_runner::{Run, RunSync};
8
9use lutra_bin::{ir, rr};
10use std::{collections::HashMap, path, sync::Arc};
11use thiserror::Error;
12
13pub struct Runner {
18 conn: duckdb::Connection,
19}
20
21impl Runner {
22 pub fn new(
23 conn: duckdb::Connection,
24 file_system: Option<path::PathBuf>,
25 ) -> Result<Self, Error> {
26 if let Some(fs_path) = file_system {
27 let set_fs_access = format!("SET file_search_path = '{}'", fs_path.display());
28 conn.execute(&set_fs_access, [])?;
29 }
30 Ok(Self { conn })
31 }
32
33 pub fn open(path: &str, file_system: Option<path::PathBuf>) -> Result<Self, Error> {
35 let conn = duckdb::Connection::open(path)?;
36 Self::new(conn, file_system)
37 }
38
39 pub fn in_memory(file_system: Option<path::PathBuf>) -> Result<Self, Error> {
41 let conn = duckdb::Connection::open_in_memory()?;
42 Self::new(conn, file_system)
43 }
44}
45
46#[derive(Clone)]
47pub struct PreparedProgram {
48 program: Arc<rr::SqlProgram>,
49}
50
51impl lutra_runner::RunSync for Runner {
52 type Error = Error;
53 type Prepared = PreparedProgram;
54
55 fn prepare_sync(&mut self, program: rr::Program) -> Result<Self::Prepared, Self::Error> {
56 let program = program
58 .into_sql_duck_db()
59 .map_err(|_| Error::UnsupportedFormat)?;
60
61 Ok(PreparedProgram {
65 program: Arc::from(program),
66 })
67 }
68
69 fn execute_sync(
70 &mut self,
71 handle: &Self::Prepared,
72 input: &[u8],
73 ) -> Result<Vec<u8>, Self::Error> {
74 let ctx = Context::new(&handle.program.defs);
75
76 let args = params::to_sql(input, &handle.program.input_ty, &ctx)?;
78
79 let mut stmt = self.conn.prepare(&handle.program.sql)?;
81 let arrow = stmt.query_arrow(args.as_params())?;
82 let batches: Vec<_> = arrow.collect();
83
84 let output =
86 lutra_arrow::arrow_to_lutra(batches, &handle.program.output_ty, &handle.program.defs)
87 .map_err(|e| Error::ArrowConversion(e.to_string()))?;
88
89 Ok(output.to_vec())
90 }
91
92 fn get_interface_sync(&mut self) -> Result<String, Self::Error> {
93 schema::pull_interface(self)
94 }
95}
96
97#[derive(Error, Debug)]
98pub enum Error {
99 #[error("bad result: {}", .0)]
100 BadDatabaseResponse(&'static str),
101 #[error("duckdb: {}", .0)]
102 DuckDB(#[from] duckdb::Error),
103 #[error("unsupported program format")]
104 UnsupportedFormat,
105 #[error("unsupported data type: {}", .0)]
106 UnsupportedDataType(&'static str),
107 #[error("arrow conversion: {}", .0)]
108 ArrowConversion(String),
109}
110
111pub(crate) struct Context<'a> {
112 pub types: HashMap<&'a ir::Path, &'a ir::Ty>,
113}
114
115impl<'a> Context<'a> {
116 pub fn new(ty_defs: &'a [ir::TyDef]) -> Self {
117 Context {
118 types: ty_defs.iter().map(|def| (&def.name, &def.ty)).collect(),
119 }
120 }
121
122 pub fn get_ty_mat(&self, ty: &'a ir::Ty) -> &'a ir::Ty {
123 match &ty.kind {
124 ir::TyKind::Ident(path) => self.types.get(path).unwrap(),
125 _ => ty,
126 }
127 }
128
129 fn is_option(&self, variants: &[ir::TyEnumVariant]) -> bool {
131 if variants.len() != 2 || !variants[0].ty.is_unit() {
132 return false;
133 }
134 let some_ty = self.get_ty_mat(&variants[1].ty);
135 some_ty.kind.is_primitive() || some_ty.kind.is_array()
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_in_memory() {
145 let _runner = Runner::in_memory(None).unwrap();
146 }
148
149 #[test]
150 fn test_file_based() {
151 let temp = std::env::temp_dir().join("test_lutra_duckdb.duckdb");
152 let _runner = Runner::open(temp.to_str().unwrap(), None).unwrap();
153 std::fs::remove_file(temp).ok();
155 }
156}