Skip to main content

citadel_sql/
prepared.rs

1//! Prepared statements: parse + compile once, execute many times with parameters.
2
3use std::sync::Arc;
4
5use rustc_hash::FxHashMap;
6
7use crate::connection::Connection;
8use crate::error::{Result, SqlError};
9use crate::executor::compile::RowSourceIter;
10use crate::executor::helpers::expr_display_name;
11use crate::executor::{self, CompiledPlan};
12use crate::parser::{QueryBody, SelectColumn, SelectQuery, SelectStmt, Statement};
13use crate::schema::SchemaManager;
14use crate::types::{ExecutionResult, QueryResult, Value};
15
16/// A prepared SQL statement bound to a `Connection`.
17pub struct PreparedStatement<'c, 'db> {
18    conn: &'c Connection<'db>,
19    sql: String,
20    ast: Arc<Statement>,
21    compiled: Option<Arc<dyn CompiledPlan>>,
22    schema_gen: u64,
23    param_count: usize,
24    columns: Vec<String>,
25    column_index: FxHashMap<String, usize>,
26    readonly: bool,
27    is_explain: bool,
28}
29
30struct Compiled {
31    ast: Arc<Statement>,
32    plan: Option<Arc<dyn CompiledPlan>>,
33    schema_gen: u64,
34    param_count: usize,
35    columns: Vec<String>,
36}
37
38impl<'c, 'db> PreparedStatement<'c, 'db> {
39    pub(crate) fn new(conn: &'c Connection<'db>, sql: &str) -> Result<Self> {
40        let c = compile_for_sql(conn, sql)?;
41        let readonly = matches!(*c.ast, Statement::Select(_) | Statement::Explain(_));
42        let is_explain = matches!(*c.ast, Statement::Explain(_));
43        let mut column_index =
44            FxHashMap::with_capacity_and_hasher(c.columns.len(), Default::default());
45        for (i, name) in c.columns.iter().enumerate() {
46            column_index.entry(name.clone()).or_insert(i);
47        }
48        Ok(Self {
49            conn,
50            sql: sql.to_string(),
51            ast: c.ast,
52            compiled: c.plan,
53            schema_gen: c.schema_gen,
54            param_count: c.param_count,
55            columns: c.columns,
56            column_index,
57            readonly,
58            is_explain,
59        })
60    }
61
62    pub fn sql(&self) -> &str {
63        &self.sql
64    }
65
66    /// Number of positional parameters (`$1`, `$2`, ...) this statement expects.
67    pub fn param_count(&self) -> usize {
68        self.param_count
69    }
70
71    /// Alias of [`Self::param_count`] matching rusqlite's name.
72    pub fn parameter_count(&self) -> usize {
73        self.param_count
74    }
75
76    /// Number of output columns. Zero for non-SELECT statements.
77    pub fn column_count(&self) -> usize {
78        self.columns.len()
79    }
80
81    /// Output column names in declaration order.
82    pub fn column_names(&self) -> &[String] {
83        &self.columns
84    }
85
86    /// Output column name at index `i`, if any.
87    pub fn column_name(&self, i: usize) -> Option<&str> {
88        self.columns.get(i).map(|s| s.as_str())
89    }
90
91    /// Position of the column named `name`, if present.
92    pub fn column_index(&self, name: &str) -> Option<usize> {
93        self.column_index.get(name).copied()
94    }
95
96    /// True if the statement is read-only (SELECT or EXPLAIN).
97    pub fn readonly(&self) -> bool {
98        self.readonly
99    }
100
101    /// True if the statement is an EXPLAIN.
102    pub fn is_explain(&self) -> bool {
103        self.is_explain
104    }
105
106    /// Execute the statement; returns rows affected (0 for SELECT/DDL).
107    pub fn execute(&self, params: &[Value]) -> Result<u64> {
108        match self.run(params)? {
109            ExecutionResult::RowsAffected(n) => Ok(n),
110            ExecutionResult::Query(_) | ExecutionResult::Ok => Ok(0),
111        }
112    }
113
114    /// Execute and return a stepping `Rows<'_>` iterator.
115    pub fn query(&self, params: &[Value]) -> Result<Rows<'_>> {
116        if let Some(stream) = self.stream_fast_path(params)? {
117            return Ok(Rows::streaming(stream));
118        }
119        let (columns, rows) = match self.run(params)? {
120            ExecutionResult::Query(qr) => (qr.columns, qr.rows),
121            ExecutionResult::RowsAffected(_) | ExecutionResult::Ok => {
122                (self.columns.clone(), Vec::new())
123            }
124        };
125        Ok(Rows::materialized(columns, rows))
126    }
127
128    /// Execute and return the fully-materialized `QueryResult`.
129    pub fn query_collect(&self, params: &[Value]) -> Result<QueryResult> {
130        if let Some(qr) = self.collect_fast_path(params)? {
131            return Ok(qr);
132        }
133        if let Some(mut stream) = self.stream_fast_path(params)? {
134            let columns = stream.columns().to_vec();
135            let mut rows = Vec::with_capacity(stream.size_hint());
136            while let Some(row) = stream.next_row()? {
137                rows.push(row);
138            }
139            return Ok(QueryResult { columns, rows });
140        }
141        match self.run(params)? {
142            ExecutionResult::Query(qr) => Ok(qr),
143            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
144                columns: vec!["rows_affected".into()],
145                rows: vec![vec![Value::Integer(n as i64)]],
146            }),
147            ExecutionResult::Ok => Ok(QueryResult {
148                columns: vec![],
149                rows: vec![],
150            }),
151        }
152    }
153
154    /// Run the query and pass the first row to `f`.
155    pub fn query_row<T, F>(&self, params: &[Value], f: F) -> Result<T>
156    where
157        F: FnOnce(&Row<'_>) -> Result<T>,
158    {
159        let mut rows = self.query(params)?;
160        match rows.next()? {
161            Some(row) => f(&row),
162            None => Err(SqlError::QueryReturnedNoRows),
163        }
164    }
165
166    /// True if the query returns at least one row (DML returns `n > 0`).
167    pub fn exists(&self, params: &[Value]) -> Result<bool> {
168        if let Some(mut stream) = self.stream_fast_path(params)? {
169            return Ok(stream.next_row()?.is_some());
170        }
171        match self.run(params)? {
172            ExecutionResult::Query(qr) => Ok(!qr.rows.is_empty()),
173            ExecutionResult::RowsAffected(n) => Ok(n > 0),
174            ExecutionResult::Ok => Ok(false),
175        }
176    }
177
178    fn stream_fast_path(&self, params: &[Value]) -> Result<Option<Box<dyn RowSourceIter + 'db>>> {
179        if params.len() != self.param_count {
180            return Err(SqlError::ParameterCountMismatch {
181                expected: self.param_count,
182                got: params.len(),
183            });
184        }
185        if self.conn.inner.borrow().schema.generation() != self.schema_gen {
186            return Ok(None);
187        }
188        match &self.compiled {
189            Some(plan) => Ok(try_stream_via_plan(self, plan.as_ref(), params)),
190            None => Ok(None),
191        }
192    }
193
194    /// Zero-copy collect for full scans; `None` falls through to the slower paths.
195    fn collect_fast_path(&self, params: &[Value]) -> Result<Option<QueryResult>> {
196        if params.len() != self.param_count {
197            return Err(SqlError::ParameterCountMismatch {
198                expected: self.param_count,
199                got: params.len(),
200            });
201        }
202        let inner = self.conn.inner.borrow();
203        if inner.schema.generation() != self.schema_gen || inner.active_txn_is_some() {
204            return Ok(None);
205        }
206        match &self.compiled {
207            Some(plan) => plan
208                .try_collect(self.conn.db, &inner.schema, &self.ast, params)
209                .transpose(),
210            None => Ok(None),
211        }
212    }
213
214    fn run(&self, params: &[Value]) -> Result<ExecutionResult> {
215        if params.len() != self.param_count {
216            return Err(SqlError::ParameterCountMismatch {
217                expected: self.param_count,
218                got: params.len(),
219            });
220        }
221        let mut inner = self.conn.inner.borrow_mut();
222        if inner.schema.generation() == self.schema_gen {
223            return inner.execute_prepared(self.conn.db, &self.ast, self.compiled.as_ref(), params);
224        }
225        let c = compile_inside(&mut inner, &self.sql)?;
226        if c.param_count != self.param_count {
227            return Err(SqlError::ParameterCountMismatch {
228                expected: self.param_count,
229                got: c.param_count,
230            });
231        }
232        inner.execute_prepared(self.conn.db, &c.ast, c.plan.as_ref(), params)
233    }
234}
235
236/// Stepping iterator over query rows. Obtained from [`PreparedStatement::query`].
237pub struct Rows<'a> {
238    source: RowSource<'a>,
239    columns: Vec<String>,
240    buf: Vec<Value>,
241}
242
243enum RowSource<'a> {
244    Materialized(std::vec::IntoIter<Vec<Value>>),
245    Streaming(Box<dyn RowSourceIter + 'a>),
246}
247
248impl<'a> Rows<'a> {
249    fn materialized(columns: Vec<String>, rows: Vec<Vec<Value>>) -> Self {
250        Self {
251            source: RowSource::Materialized(rows.into_iter()),
252            columns,
253            buf: Vec::new(),
254        }
255    }
256
257    fn streaming(source: Box<dyn RowSourceIter + 'a>) -> Self {
258        let columns = source.columns().to_vec();
259        Self {
260            source: RowSource::Streaming(source),
261            columns,
262            buf: Vec::new(),
263        }
264    }
265
266    /// Step to the next row, if any.
267    #[allow(clippy::should_implement_trait)]
268    pub fn next(&mut self) -> Result<Option<Row<'_>>> {
269        let next: Option<Vec<Value>> = match &mut self.source {
270            RowSource::Materialized(iter) => iter.next(),
271            RowSource::Streaming(stream) => stream.next_row()?,
272        };
273        match next {
274            Some(values) => {
275                self.buf = values;
276                Ok(Some(Row {
277                    columns: &self.columns,
278                    values: &self.buf,
279                }))
280            }
281            None => Ok(None),
282        }
283    }
284
285    pub fn column_count(&self) -> usize {
286        self.columns.len()
287    }
288
289    pub fn column_names(&self) -> &[String] {
290        &self.columns
291    }
292
293    /// Drain all remaining rows into a [`QueryResult`].
294    pub fn collect(mut self) -> Result<QueryResult> {
295        let mut rows = Vec::new();
296        while let Some(row) = self.next()? {
297            rows.push(row.to_vec());
298        }
299        Ok(QueryResult {
300            columns: self.columns,
301            rows,
302        })
303    }
304}
305
306/// A single row produced by [`Rows::next`].
307pub struct Row<'a> {
308    columns: &'a [String],
309    values: &'a [Value],
310}
311
312impl<'a> Row<'a> {
313    /// Value at column index `i`, if present.
314    pub fn get(&self, i: usize) -> Option<&Value> {
315        self.values.get(i)
316    }
317
318    /// Value of the column named `name`, if present.
319    pub fn get_by_name(&self, name: &str) -> Option<&Value> {
320        self.columns
321            .iter()
322            .position(|c| c == name)
323            .and_then(|i| self.values.get(i))
324    }
325
326    pub fn column_count(&self) -> usize {
327        self.values.len()
328    }
329
330    /// Name of the column at index `i`, if present.
331    pub fn column_name(&self, i: usize) -> Option<&str> {
332        self.columns.get(i).map(|s| s.as_str())
333    }
334
335    pub fn as_slice(&self) -> &[Value] {
336        self.values
337    }
338
339    pub fn to_vec(&self) -> Vec<Value> {
340        self.values.to_vec()
341    }
342}
343
344fn compile_for_sql(conn: &Connection<'_>, sql: &str) -> Result<Compiled> {
345    let mut inner = conn.inner.borrow_mut();
346    compile_inside(&mut inner, sql)
347}
348
349fn compile_inside(
350    inner: &mut crate::connection::ConnectionInner<'_>,
351    sql: &str,
352) -> Result<Compiled> {
353    let (ast, param_count) = inner.get_or_parse(sql)?;
354    let schema_gen = inner.schema.generation();
355    let plan = executor::compile(&inner.schema, &ast);
356    if let Some(p) = &plan {
357        if let Some(entry) = inner.stmt_cache.get_mut(sql) {
358            entry.compiled = Some(Arc::clone(p));
359        }
360    }
361    let columns = derive_columns(&ast, &inner.schema);
362    Ok(Compiled {
363        ast,
364        plan,
365        schema_gen,
366        param_count,
367        columns,
368    })
369}
370
371fn derive_columns(stmt: &Statement, schema: &SchemaManager) -> Vec<String> {
372    match stmt {
373        Statement::Select(sq) => derive_select_columns(sq, schema),
374        Statement::Explain(_) => vec!["plan".into()],
375        _ => Vec::new(),
376    }
377}
378
379fn derive_select_columns(sq: &SelectQuery, schema: &SchemaManager) -> Vec<String> {
380    derive_body_columns(&sq.body, schema)
381}
382
383fn derive_body_columns(body: &QueryBody, schema: &SchemaManager) -> Vec<String> {
384    match body {
385        QueryBody::Select(sel) => derive_from_select_stmt(sel, schema),
386        QueryBody::Compound(cs) => derive_body_columns(&cs.left, schema),
387        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => Vec::new(),
388    }
389}
390
391fn try_stream_via_plan<'db>(
392    stmt: &PreparedStatement<'_, 'db>,
393    plan: &dyn CompiledPlan,
394    params: &[Value],
395) -> Option<Box<dyn RowSourceIter + 'db>> {
396    let inner = stmt.conn.inner.borrow();
397    if inner.active_txn_is_some() {
398        return None;
399    }
400    plan.try_stream(stmt.conn.db, &inner.schema, &stmt.ast, params)
401}
402
403fn derive_from_select_stmt(sel: &SelectStmt, schema: &SchemaManager) -> Vec<String> {
404    let lower = sel.from.to_ascii_lowercase();
405    let table_columns = schema.get(&lower).map(|ts| ts.columns.as_slice());
406    let mut out = Vec::new();
407    for col in &sel.columns {
408        match col {
409            SelectColumn::AllColumns | SelectColumn::AllFromOld | SelectColumn::AllFromNew => {
410                if let Some(cols) = table_columns {
411                    for c in cols {
412                        out.push(c.name.clone());
413                    }
414                }
415            }
416            SelectColumn::Expr { alias: Some(a), .. } => out.push(a.clone()),
417            SelectColumn::Expr { expr, alias: None } => out.push(expr_display_name(expr)),
418        }
419    }
420    out
421}