Skip to main content

citadel_sql/
prepared.rs

1//! Prepared statements: parse + compile once, execute many times with parameters.
2
3use std::sync::Arc;
4
5use rustc_hash::FxHashMap;
6
7use crate::connection::Connection;
8use crate::error::{Result, SqlError};
9use crate::executor::compile::RowSourceIter;
10use crate::executor::helpers::expr_display_name;
11use crate::executor::{self, CompiledPlan};
12use crate::parser::{QueryBody, SelectColumn, SelectQuery, SelectStmt, Statement};
13use crate::schema::SchemaManager;
14use crate::types::{ExecutionResult, QueryResult, Value};
15
16/// A prepared SQL statement bound to a `Connection`.
17pub struct PreparedStatement<'c, 'db> {
18    conn: &'c Connection<'db>,
19    sql: String,
20    ast: Arc<Statement>,
21    compiled: Option<Arc<dyn CompiledPlan>>,
22    schema_gen: u64,
23    param_count: usize,
24    columns: Vec<String>,
25    column_index: FxHashMap<String, usize>,
26    readonly: bool,
27    is_explain: bool,
28}
29
30struct Compiled {
31    ast: Arc<Statement>,
32    plan: Option<Arc<dyn CompiledPlan>>,
33    schema_gen: u64,
34    param_count: usize,
35    columns: Vec<String>,
36}
37
38impl<'c, 'db> PreparedStatement<'c, 'db> {
39    pub(crate) fn new(conn: &'c Connection<'db>, sql: &str) -> Result<Self> {
40        let c = compile_for_sql(conn, sql)?;
41        let readonly = matches!(*c.ast, Statement::Select(_) | Statement::Explain(_));
42        let is_explain = matches!(*c.ast, Statement::Explain(_));
43        let mut column_index =
44            FxHashMap::with_capacity_and_hasher(c.columns.len(), Default::default());
45        for (i, name) in c.columns.iter().enumerate() {
46            column_index.entry(name.clone()).or_insert(i);
47        }
48        Ok(Self {
49            conn,
50            sql: sql.to_string(),
51            ast: c.ast,
52            compiled: c.plan,
53            schema_gen: c.schema_gen,
54            param_count: c.param_count,
55            columns: c.columns,
56            column_index,
57            readonly,
58            is_explain,
59        })
60    }
61
62    pub fn sql(&self) -> &str {
63        &self.sql
64    }
65
66    /// Number of positional parameters (`$1`, `$2`, ...) this statement expects.
67    pub fn param_count(&self) -> usize {
68        self.param_count
69    }
70
71    /// Alias of [`Self::param_count`] matching rusqlite's name.
72    pub fn parameter_count(&self) -> usize {
73        self.param_count
74    }
75
76    /// Number of output columns. Zero for non-SELECT statements.
77    pub fn column_count(&self) -> usize {
78        self.columns.len()
79    }
80
81    /// Output column names in declaration order.
82    pub fn column_names(&self) -> &[String] {
83        &self.columns
84    }
85
86    /// Output column name at index `i`, if any.
87    pub fn column_name(&self, i: usize) -> Option<&str> {
88        self.columns.get(i).map(|s| s.as_str())
89    }
90
91    /// Position of the column named `name`, if present.
92    pub fn column_index(&self, name: &str) -> Option<usize> {
93        self.column_index.get(name).copied()
94    }
95
96    /// True if the statement is read-only (SELECT or EXPLAIN).
97    pub fn readonly(&self) -> bool {
98        self.readonly
99    }
100
101    /// True if the statement is an EXPLAIN.
102    pub fn is_explain(&self) -> bool {
103        self.is_explain
104    }
105
106    /// Execute the statement; returns rows affected (0 for SELECT/DDL).
107    pub fn execute(&self, params: &[Value]) -> Result<u64> {
108        match self.run(params)? {
109            ExecutionResult::RowsAffected(n) => Ok(n),
110            ExecutionResult::Query(_) | ExecutionResult::Ok => Ok(0),
111        }
112    }
113
114    /// Execute and return a stepping `Rows<'_>` iterator.
115    pub fn query(&self, params: &[Value]) -> Result<Rows<'_>> {
116        if params.len() != self.param_count {
117            return Err(SqlError::ParameterCountMismatch {
118                expected: self.param_count,
119                got: params.len(),
120            });
121        }
122        if self.conn.inner.borrow().schema.generation() == self.schema_gen {
123            if let Some(plan) = &self.compiled {
124                if let Some(stream) = try_stream_via_plan(self, plan.as_ref(), params) {
125                    return Ok(Rows::streaming(stream));
126                }
127            }
128        }
129        let (columns, rows) = match self.run(params)? {
130            ExecutionResult::Query(qr) => (qr.columns, qr.rows),
131            ExecutionResult::RowsAffected(_) | ExecutionResult::Ok => {
132                (self.columns.clone(), Vec::new())
133            }
134        };
135        Ok(Rows::materialized(columns, rows))
136    }
137
138    /// Execute and return the fully-materialized `QueryResult`.
139    pub fn query_collect(&self, params: &[Value]) -> Result<QueryResult> {
140        match self.run(params)? {
141            ExecutionResult::Query(qr) => Ok(qr),
142            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
143                columns: vec!["rows_affected".into()],
144                rows: vec![vec![Value::Integer(n as i64)]],
145            }),
146            ExecutionResult::Ok => Ok(QueryResult {
147                columns: vec![],
148                rows: vec![],
149            }),
150        }
151    }
152
153    /// Run the query and pass the first row to `f`.
154    pub fn query_row<T, F>(&self, params: &[Value], f: F) -> Result<T>
155    where
156        F: FnOnce(&Row<'_>) -> Result<T>,
157    {
158        let mut rows = self.query(params)?;
159        match rows.next()? {
160            Some(row) => f(&row),
161            None => Err(SqlError::QueryReturnedNoRows),
162        }
163    }
164
165    /// True if the query returns at least one row (DML returns `n > 0`).
166    pub fn exists(&self, params: &[Value]) -> Result<bool> {
167        if params.len() != self.param_count {
168            return Err(SqlError::ParameterCountMismatch {
169                expected: self.param_count,
170                got: params.len(),
171            });
172        }
173        if self.conn.inner.borrow().schema.generation() == self.schema_gen {
174            if let Some(plan) = &self.compiled {
175                if let Some(mut stream) = try_stream_via_plan(self, plan.as_ref(), params) {
176                    return Ok(stream.next_row()?.is_some());
177                }
178            }
179        }
180        match self.run(params)? {
181            ExecutionResult::Query(qr) => Ok(!qr.rows.is_empty()),
182            ExecutionResult::RowsAffected(n) => Ok(n > 0),
183            ExecutionResult::Ok => Ok(false),
184        }
185    }
186
187    fn run(&self, params: &[Value]) -> Result<ExecutionResult> {
188        if params.len() != self.param_count {
189            return Err(SqlError::ParameterCountMismatch {
190                expected: self.param_count,
191                got: params.len(),
192            });
193        }
194        let mut inner = self.conn.inner.borrow_mut();
195        if inner.schema.generation() == self.schema_gen {
196            return inner.execute_prepared(self.conn.db, &self.ast, self.compiled.as_ref(), params);
197        }
198        let c = compile_inside(&mut inner, &self.sql)?;
199        if c.param_count != self.param_count {
200            return Err(SqlError::ParameterCountMismatch {
201                expected: self.param_count,
202                got: c.param_count,
203            });
204        }
205        inner.execute_prepared(self.conn.db, &c.ast, c.plan.as_ref(), params)
206    }
207}
208
209/// Stepping iterator over query rows. Obtained from [`PreparedStatement::query`].
210pub struct Rows<'a> {
211    source: RowSource<'a>,
212    columns: Vec<String>,
213    buf: Vec<Value>,
214}
215
216enum RowSource<'a> {
217    Materialized(std::vec::IntoIter<Vec<Value>>),
218    Streaming(Box<dyn RowSourceIter + 'a>),
219}
220
221impl<'a> Rows<'a> {
222    fn materialized(columns: Vec<String>, rows: Vec<Vec<Value>>) -> Self {
223        Self {
224            source: RowSource::Materialized(rows.into_iter()),
225            columns,
226            buf: Vec::new(),
227        }
228    }
229
230    fn streaming(source: Box<dyn RowSourceIter + 'a>) -> Self {
231        let columns = source.columns().to_vec();
232        Self {
233            source: RowSource::Streaming(source),
234            columns,
235            buf: Vec::new(),
236        }
237    }
238
239    /// Step to the next row, if any.
240    #[allow(clippy::should_implement_trait)]
241    pub fn next(&mut self) -> Result<Option<Row<'_>>> {
242        let next: Option<Vec<Value>> = match &mut self.source {
243            RowSource::Materialized(iter) => iter.next(),
244            RowSource::Streaming(stream) => stream.next_row()?,
245        };
246        match next {
247            Some(values) => {
248                self.buf = values;
249                Ok(Some(Row {
250                    columns: &self.columns,
251                    values: &self.buf,
252                }))
253            }
254            None => Ok(None),
255        }
256    }
257
258    pub fn column_count(&self) -> usize {
259        self.columns.len()
260    }
261
262    pub fn column_names(&self) -> &[String] {
263        &self.columns
264    }
265
266    /// Drain all remaining rows into a [`QueryResult`].
267    pub fn collect(mut self) -> Result<QueryResult> {
268        let mut rows = Vec::new();
269        while let Some(row) = self.next()? {
270            rows.push(row.to_vec());
271        }
272        Ok(QueryResult {
273            columns: self.columns,
274            rows,
275        })
276    }
277}
278
279/// A single row produced by [`Rows::next`].
280pub struct Row<'a> {
281    columns: &'a [String],
282    values: &'a [Value],
283}
284
285impl<'a> Row<'a> {
286    /// Value at column index `i`, if present.
287    pub fn get(&self, i: usize) -> Option<&Value> {
288        self.values.get(i)
289    }
290
291    /// Value of the column named `name`, if present.
292    pub fn get_by_name(&self, name: &str) -> Option<&Value> {
293        self.columns
294            .iter()
295            .position(|c| c == name)
296            .and_then(|i| self.values.get(i))
297    }
298
299    pub fn column_count(&self) -> usize {
300        self.values.len()
301    }
302
303    /// Name of the column at index `i`, if present.
304    pub fn column_name(&self, i: usize) -> Option<&str> {
305        self.columns.get(i).map(|s| s.as_str())
306    }
307
308    pub fn as_slice(&self) -> &[Value] {
309        self.values
310    }
311
312    pub fn to_vec(&self) -> Vec<Value> {
313        self.values.to_vec()
314    }
315}
316
317fn compile_for_sql(conn: &Connection<'_>, sql: &str) -> Result<Compiled> {
318    let mut inner = conn.inner.borrow_mut();
319    compile_inside(&mut inner, sql)
320}
321
322fn compile_inside(
323    inner: &mut crate::connection::ConnectionInner<'_>,
324    sql: &str,
325) -> Result<Compiled> {
326    let (ast, param_count) = inner.get_or_parse(sql)?;
327    let schema_gen = inner.schema.generation();
328    let plan = executor::compile(&inner.schema, &ast);
329    if let Some(p) = &plan {
330        if let Some(entry) = inner.stmt_cache.get_mut(sql) {
331            entry.compiled = Some(Arc::clone(p));
332        }
333    }
334    let columns = derive_columns(&ast, &inner.schema);
335    Ok(Compiled {
336        ast,
337        plan,
338        schema_gen,
339        param_count,
340        columns,
341    })
342}
343
344fn derive_columns(stmt: &Statement, schema: &SchemaManager) -> Vec<String> {
345    match stmt {
346        Statement::Select(sq) => derive_select_columns(sq, schema),
347        Statement::Explain(_) => vec!["plan".into()],
348        _ => Vec::new(),
349    }
350}
351
352fn derive_select_columns(sq: &SelectQuery, schema: &SchemaManager) -> Vec<String> {
353    derive_body_columns(&sq.body, schema)
354}
355
356fn derive_body_columns(body: &QueryBody, schema: &SchemaManager) -> Vec<String> {
357    match body {
358        QueryBody::Select(sel) => derive_from_select_stmt(sel, schema),
359        QueryBody::Compound(cs) => derive_body_columns(&cs.left, schema),
360        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => Vec::new(),
361    }
362}
363
364fn try_stream_via_plan<'db>(
365    stmt: &PreparedStatement<'_, 'db>,
366    plan: &dyn CompiledPlan,
367    params: &[Value],
368) -> Option<Box<dyn RowSourceIter + 'db>> {
369    let inner = stmt.conn.inner.borrow();
370    if inner.active_txn_is_some() {
371        return None;
372    }
373    plan.try_stream(stmt.conn.db, &inner.schema, &stmt.ast, params)
374}
375
376fn derive_from_select_stmt(sel: &SelectStmt, schema: &SchemaManager) -> Vec<String> {
377    let lower = sel.from.to_ascii_lowercase();
378    let table_columns = schema.get(&lower).map(|ts| ts.columns.as_slice());
379    let mut out = Vec::new();
380    for col in &sel.columns {
381        match col {
382            SelectColumn::AllColumns | SelectColumn::AllFromOld | SelectColumn::AllFromNew => {
383                if let Some(cols) = table_columns {
384                    for c in cols {
385                        out.push(c.name.clone());
386                    }
387                }
388            }
389            SelectColumn::Expr { alias: Some(a), .. } => out.push(a.clone()),
390            SelectColumn::Expr { expr, alias: None } => out.push(expr_display_name(expr)),
391        }
392    }
393    out
394}