Skip to main content

citadel_sql/
prepared.rs

1//! Prepared statements: parse + compile once, execute many times with parameters.
2
3use std::sync::Arc;
4
5use rustc_hash::FxHashMap;
6
7use crate::connection::Connection;
8use crate::error::{Result, SqlError};
9use crate::executor::compile::RowSourceIter;
10use crate::executor::helpers::expr_display_name;
11use crate::executor::{self, CompiledPlan};
12use crate::parser::{QueryBody, SelectColumn, SelectQuery, SelectStmt, Statement};
13use crate::schema::SchemaManager;
14use crate::types::{ExecutionResult, QueryResult, Value};
15
16/// A prepared SQL statement bound to a `Connection`.
17pub struct PreparedStatement<'c, 'db> {
18    conn: &'c Connection<'db>,
19    sql: String,
20    ast: Arc<Statement>,
21    compiled: Option<Arc<dyn CompiledPlan>>,
22    schema_gen: u64,
23    param_count: usize,
24    columns: Vec<String>,
25    column_index: FxHashMap<String, usize>,
26    readonly: bool,
27    is_explain: bool,
28}
29
30struct Compiled {
31    ast: Arc<Statement>,
32    plan: Option<Arc<dyn CompiledPlan>>,
33    schema_gen: u64,
34    param_count: usize,
35    columns: Vec<String>,
36}
37
38impl<'c, 'db> PreparedStatement<'c, 'db> {
39    pub(crate) fn new(conn: &'c Connection<'db>, sql: &str) -> Result<Self> {
40        let c = compile_for_sql(conn, sql)?;
41        let readonly = matches!(*c.ast, Statement::Select(_) | Statement::Explain(_));
42        let is_explain = matches!(*c.ast, Statement::Explain(_));
43        let mut column_index =
44            FxHashMap::with_capacity_and_hasher(c.columns.len(), Default::default());
45        for (i, name) in c.columns.iter().enumerate() {
46            column_index.entry(name.clone()).or_insert(i);
47        }
48        Ok(Self {
49            conn,
50            sql: sql.to_string(),
51            ast: c.ast,
52            compiled: c.plan,
53            schema_gen: c.schema_gen,
54            param_count: c.param_count,
55            columns: c.columns,
56            column_index,
57            readonly,
58            is_explain,
59        })
60    }
61
62    /// The original SQL text.
63    pub fn sql(&self) -> &str {
64        &self.sql
65    }
66
67    /// Number of positional parameters (`$1`, `$2`, ...) this statement expects.
68    pub fn param_count(&self) -> usize {
69        self.param_count
70    }
71
72    /// Alias of [`Self::param_count`] matching rusqlite's name.
73    pub fn parameter_count(&self) -> usize {
74        self.param_count
75    }
76
77    /// Number of output columns. Zero for non-SELECT statements.
78    pub fn column_count(&self) -> usize {
79        self.columns.len()
80    }
81
82    /// Output column names in declaration order.
83    pub fn column_names(&self) -> &[String] {
84        &self.columns
85    }
86
87    /// Output column name at index `i`, if any.
88    pub fn column_name(&self, i: usize) -> Option<&str> {
89        self.columns.get(i).map(|s| s.as_str())
90    }
91
92    /// Position of the column named `name`, if present.
93    pub fn column_index(&self, name: &str) -> Option<usize> {
94        self.column_index.get(name).copied()
95    }
96
97    /// True if the statement is read-only (SELECT or EXPLAIN).
98    pub fn readonly(&self) -> bool {
99        self.readonly
100    }
101
102    /// True if the statement is an EXPLAIN.
103    pub fn is_explain(&self) -> bool {
104        self.is_explain
105    }
106
107    /// Execute the statement; returns rows affected (0 for SELECT/DDL).
108    pub fn execute(&self, params: &[Value]) -> Result<u64> {
109        match self.run(params)? {
110            ExecutionResult::RowsAffected(n) => Ok(n),
111            ExecutionResult::Query(_) | ExecutionResult::Ok => Ok(0),
112        }
113    }
114
115    /// Execute and return a stepping `Rows<'_>` iterator.
116    ///
117    /// Streams rows directly from the B+ tree for simple `SELECT [cols] FROM t`
118    /// shapes (no WHERE/ORDER BY/aggregate/join). Materializes internally for
119    /// everything else — same user-visible API either way.
120    ///
121    /// DML statements execute the mutation and yield an immediately-exhausted `Rows`
122    /// (matching rusqlite semantics).
123    pub fn query(&self, params: &[Value]) -> Result<Rows<'_>> {
124        if params.len() != self.param_count {
125            return Err(SqlError::ParameterCountMismatch {
126                expected: self.param_count,
127                got: params.len(),
128            });
129        }
130        if self.conn.inner.borrow().schema.generation() == self.schema_gen {
131            if let Some(plan) = &self.compiled {
132                if let Some(stream) = try_stream_via_plan(self, plan.as_ref(), params) {
133                    return Ok(Rows::streaming(stream));
134                }
135            }
136        }
137        let (columns, rows) = match self.run(params)? {
138            ExecutionResult::Query(qr) => (qr.columns, qr.rows),
139            ExecutionResult::RowsAffected(_) | ExecutionResult::Ok => {
140                (self.columns.clone(), Vec::new())
141            }
142        };
143        Ok(Rows::materialized(columns, rows))
144    }
145
146    /// Execute and return the fully-materialized `QueryResult`.
147    ///
148    /// Equivalent to `self.query(params)?.collect()` but slightly more direct.
149    pub fn query_collect(&self, params: &[Value]) -> Result<QueryResult> {
150        match self.run(params)? {
151            ExecutionResult::Query(qr) => Ok(qr),
152            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
153                columns: vec!["rows_affected".into()],
154                rows: vec![vec![Value::Integer(n as i64)]],
155            }),
156            ExecutionResult::Ok => Ok(QueryResult {
157                columns: vec![],
158                rows: vec![],
159            }),
160        }
161    }
162
163    /// Run the query and pass the first row to `f`.
164    ///
165    /// Returns `SqlError::QueryReturnedNoRows` if the query produced zero rows.
166    /// Extra rows after the first are ignored (matches rusqlite's `query_row`).
167    pub fn query_row<T, F>(&self, params: &[Value], f: F) -> Result<T>
168    where
169        F: FnOnce(&Row<'_>) -> Result<T>,
170    {
171        let mut rows = self.query(params)?;
172        match rows.next()? {
173            Some(row) => f(&row),
174            None => Err(SqlError::QueryReturnedNoRows),
175        }
176    }
177
178    /// True if the query returns at least one row (DML returns `n > 0`).
179    ///
180    /// For streamable SELECTs this short-circuits on the first matching row
181    /// without materializing the rest.
182    pub fn exists(&self, params: &[Value]) -> Result<bool> {
183        if params.len() != self.param_count {
184            return Err(SqlError::ParameterCountMismatch {
185                expected: self.param_count,
186                got: params.len(),
187            });
188        }
189        if self.conn.inner.borrow().schema.generation() == self.schema_gen {
190            if let Some(plan) = &self.compiled {
191                if let Some(mut stream) = try_stream_via_plan(self, plan.as_ref(), params) {
192                    return Ok(stream.next_row()?.is_some());
193                }
194            }
195        }
196        match self.run(params)? {
197            ExecutionResult::Query(qr) => Ok(!qr.rows.is_empty()),
198            ExecutionResult::RowsAffected(n) => Ok(n > 0),
199            ExecutionResult::Ok => Ok(false),
200        }
201    }
202
203    fn run(&self, params: &[Value]) -> Result<ExecutionResult> {
204        if params.len() != self.param_count {
205            return Err(SqlError::ParameterCountMismatch {
206                expected: self.param_count,
207                got: params.len(),
208            });
209        }
210        let mut inner = self.conn.inner.borrow_mut();
211        if inner.schema.generation() == self.schema_gen {
212            return inner.execute_prepared(self.conn.db, &self.ast, self.compiled.as_ref(), params);
213        }
214        let c = compile_inside(&mut inner, &self.sql)?;
215        if c.param_count != self.param_count {
216            return Err(SqlError::ParameterCountMismatch {
217                expected: self.param_count,
218                got: c.param_count,
219            });
220        }
221        inner.execute_prepared(self.conn.db, &c.ast, c.plan.as_ref(), params)
222    }
223}
224
225/// Stepping iterator over query rows. Obtained from [`PreparedStatement::query`].
226///
227/// Uses a lending-iterator pattern: `next()` returns `Result<Option<Row<'_>>>`
228/// where the `Row` borrows from `&mut self`. Incompatible with `std::iter::Iterator`
229/// because the row's lifetime is tied to the stepper — same design as rusqlite's
230/// [`Rows`](https://docs.rs/rusqlite/latest/rusqlite/struct.Rows.html).
231pub struct Rows<'a> {
232    source: RowSource<'a>,
233    columns: Vec<String>,
234    buf: Vec<Value>,
235}
236
237enum RowSource<'a> {
238    Materialized(std::vec::IntoIter<Vec<Value>>),
239    Streaming(Box<dyn RowSourceIter + 'a>),
240}
241
242impl<'a> Rows<'a> {
243    fn materialized(columns: Vec<String>, rows: Vec<Vec<Value>>) -> Self {
244        Self {
245            source: RowSource::Materialized(rows.into_iter()),
246            columns,
247            buf: Vec::new(),
248        }
249    }
250
251    fn streaming(source: Box<dyn RowSourceIter + 'a>) -> Self {
252        let columns = source.columns().to_vec();
253        Self {
254            source: RowSource::Streaming(source),
255            columns,
256            buf: Vec::new(),
257        }
258    }
259
260    /// Step to the next row, if any.
261    #[allow(clippy::should_implement_trait)]
262    pub fn next(&mut self) -> Result<Option<Row<'_>>> {
263        let next: Option<Vec<Value>> = match &mut self.source {
264            RowSource::Materialized(iter) => iter.next(),
265            RowSource::Streaming(stream) => stream.next_row()?,
266        };
267        match next {
268            Some(values) => {
269                self.buf = values;
270                Ok(Some(Row {
271                    columns: &self.columns,
272                    values: &self.buf,
273                }))
274            }
275            None => Ok(None),
276        }
277    }
278
279    /// Number of output columns.
280    pub fn column_count(&self) -> usize {
281        self.columns.len()
282    }
283
284    /// Output column names.
285    pub fn column_names(&self) -> &[String] {
286        &self.columns
287    }
288
289    /// Drain all remaining rows into a [`QueryResult`].
290    pub fn collect(mut self) -> Result<QueryResult> {
291        let mut rows = Vec::new();
292        while let Some(row) = self.next()? {
293            rows.push(row.to_vec());
294        }
295        Ok(QueryResult {
296            columns: self.columns,
297            rows,
298        })
299    }
300}
301
302/// A single row produced by [`Rows::next`].
303///
304/// Borrows column metadata and values from the parent `Rows`. Drop before calling
305/// `Rows::next` again.
306pub struct Row<'a> {
307    columns: &'a [String],
308    values: &'a [Value],
309}
310
311impl<'a> Row<'a> {
312    /// Value at column index `i`, if present.
313    pub fn get(&self, i: usize) -> Option<&Value> {
314        self.values.get(i)
315    }
316
317    /// Value of the column named `name`, if present.
318    pub fn get_by_name(&self, name: &str) -> Option<&Value> {
319        self.columns
320            .iter()
321            .position(|c| c == name)
322            .and_then(|i| self.values.get(i))
323    }
324
325    /// Number of columns in this row.
326    pub fn column_count(&self) -> usize {
327        self.values.len()
328    }
329
330    /// Name of the column at index `i`, if present.
331    pub fn column_name(&self, i: usize) -> Option<&str> {
332        self.columns.get(i).map(|s| s.as_str())
333    }
334
335    /// Borrow the raw values slice.
336    pub fn as_slice(&self) -> &[Value] {
337        self.values
338    }
339
340    /// Clone into an owned `Vec<Value>`.
341    pub fn to_vec(&self) -> Vec<Value> {
342        self.values.to_vec()
343    }
344}
345
346fn compile_for_sql(conn: &Connection<'_>, sql: &str) -> Result<Compiled> {
347    let mut inner = conn.inner.borrow_mut();
348    compile_inside(&mut inner, sql)
349}
350
351fn compile_inside(
352    inner: &mut crate::connection::ConnectionInner<'_>,
353    sql: &str,
354) -> Result<Compiled> {
355    let (ast, param_count) = inner.get_or_parse(sql)?;
356    let schema_gen = inner.schema.generation();
357    let plan = executor::compile(&inner.schema, &ast);
358    if let Some(p) = &plan {
359        if let Some(entry) = inner.stmt_cache.get_mut(sql) {
360            entry.compiled = Some(Arc::clone(p));
361        }
362    }
363    let columns = derive_columns(&ast, &inner.schema);
364    Ok(Compiled {
365        ast,
366        plan,
367        schema_gen,
368        param_count,
369        columns,
370    })
371}
372
373fn derive_columns(stmt: &Statement, schema: &SchemaManager) -> Vec<String> {
374    match stmt {
375        Statement::Select(sq) => derive_select_columns(sq, schema),
376        Statement::Explain(_) => vec!["plan".into()],
377        _ => Vec::new(),
378    }
379}
380
381fn derive_select_columns(sq: &SelectQuery, schema: &SchemaManager) -> Vec<String> {
382    derive_body_columns(&sq.body, schema)
383}
384
385fn derive_body_columns(body: &QueryBody, schema: &SchemaManager) -> Vec<String> {
386    match body {
387        QueryBody::Select(sel) => derive_from_select_stmt(sel, schema),
388        QueryBody::Compound(cs) => derive_body_columns(&cs.left, schema),
389    }
390}
391
392fn try_stream_via_plan<'db>(
393    stmt: &PreparedStatement<'_, 'db>,
394    plan: &dyn CompiledPlan,
395    params: &[Value],
396) -> Option<Box<dyn RowSourceIter + 'db>> {
397    let inner = stmt.conn.inner.borrow();
398    if inner.active_txn_is_some() {
399        return None;
400    }
401    plan.try_stream(stmt.conn.db, &inner.schema, &stmt.ast, params)
402}
403
404fn derive_from_select_stmt(sel: &SelectStmt, schema: &SchemaManager) -> Vec<String> {
405    let lower = sel.from.to_ascii_lowercase();
406    let table_columns = schema.get(&lower).map(|ts| ts.columns.as_slice());
407    let mut out = Vec::new();
408    for col in &sel.columns {
409        match col {
410            SelectColumn::AllColumns => {
411                if let Some(cols) = table_columns {
412                    for c in cols {
413                        out.push(c.name.clone());
414                    }
415                }
416            }
417            SelectColumn::Expr { alias: Some(a), .. } => out.push(a.clone()),
418            SelectColumn::Expr { expr, alias: None } => out.push(expr_display_name(expr)),
419        }
420    }
421    out
422}