1use std::sync::Arc;
4
5use rustc_hash::FxHashMap;
6
7use crate::connection::Connection;
8use crate::error::{Result, SqlError};
9use crate::executor::compile::RowSourceIter;
10use crate::executor::helpers::expr_display_name;
11use crate::executor::{self, CompiledPlan};
12use crate::parser::{QueryBody, SelectColumn, SelectQuery, SelectStmt, Statement};
13use crate::schema::SchemaManager;
14use crate::types::{ExecutionResult, QueryResult, Value};
15
16pub struct PreparedStatement<'c, 'db> {
18 conn: &'c Connection<'db>,
19 sql: String,
20 ast: Arc<Statement>,
21 compiled: Option<Arc<dyn CompiledPlan>>,
22 schema_gen: u64,
23 param_count: usize,
24 columns: Vec<String>,
25 column_index: FxHashMap<String, usize>,
26 readonly: bool,
27 is_explain: bool,
28}
29
30struct Compiled {
31 ast: Arc<Statement>,
32 plan: Option<Arc<dyn CompiledPlan>>,
33 schema_gen: u64,
34 param_count: usize,
35 columns: Vec<String>,
36}
37
38impl<'c, 'db> PreparedStatement<'c, 'db> {
39 pub(crate) fn new(conn: &'c Connection<'db>, sql: &str) -> Result<Self> {
40 let c = compile_for_sql(conn, sql)?;
41 let readonly = matches!(*c.ast, Statement::Select(_) | Statement::Explain(_));
42 let is_explain = matches!(*c.ast, Statement::Explain(_));
43 let mut column_index =
44 FxHashMap::with_capacity_and_hasher(c.columns.len(), Default::default());
45 for (i, name) in c.columns.iter().enumerate() {
46 column_index.entry(name.clone()).or_insert(i);
47 }
48 Ok(Self {
49 conn,
50 sql: sql.to_string(),
51 ast: c.ast,
52 compiled: c.plan,
53 schema_gen: c.schema_gen,
54 param_count: c.param_count,
55 columns: c.columns,
56 column_index,
57 readonly,
58 is_explain,
59 })
60 }
61
62 pub fn sql(&self) -> &str {
64 &self.sql
65 }
66
67 pub fn param_count(&self) -> usize {
69 self.param_count
70 }
71
72 pub fn parameter_count(&self) -> usize {
74 self.param_count
75 }
76
77 pub fn column_count(&self) -> usize {
79 self.columns.len()
80 }
81
82 pub fn column_names(&self) -> &[String] {
84 &self.columns
85 }
86
87 pub fn column_name(&self, i: usize) -> Option<&str> {
89 self.columns.get(i).map(|s| s.as_str())
90 }
91
92 pub fn column_index(&self, name: &str) -> Option<usize> {
94 self.column_index.get(name).copied()
95 }
96
97 pub fn readonly(&self) -> bool {
99 self.readonly
100 }
101
102 pub fn is_explain(&self) -> bool {
104 self.is_explain
105 }
106
107 pub fn execute(&self, params: &[Value]) -> Result<u64> {
109 match self.run(params)? {
110 ExecutionResult::RowsAffected(n) => Ok(n),
111 ExecutionResult::Query(_) | ExecutionResult::Ok => Ok(0),
112 }
113 }
114
115 pub fn query(&self, params: &[Value]) -> Result<Rows<'_>> {
124 if params.len() != self.param_count {
125 return Err(SqlError::ParameterCountMismatch {
126 expected: self.param_count,
127 got: params.len(),
128 });
129 }
130 if self.conn.inner.borrow().schema.generation() == self.schema_gen {
131 if let Some(plan) = &self.compiled {
132 if let Some(stream) = try_stream_via_plan(self, plan.as_ref(), params) {
133 return Ok(Rows::streaming(stream));
134 }
135 }
136 }
137 let (columns, rows) = match self.run(params)? {
138 ExecutionResult::Query(qr) => (qr.columns, qr.rows),
139 ExecutionResult::RowsAffected(_) | ExecutionResult::Ok => {
140 (self.columns.clone(), Vec::new())
141 }
142 };
143 Ok(Rows::materialized(columns, rows))
144 }
145
146 pub fn query_collect(&self, params: &[Value]) -> Result<QueryResult> {
150 match self.run(params)? {
151 ExecutionResult::Query(qr) => Ok(qr),
152 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
153 columns: vec!["rows_affected".into()],
154 rows: vec![vec![Value::Integer(n as i64)]],
155 }),
156 ExecutionResult::Ok => Ok(QueryResult {
157 columns: vec![],
158 rows: vec![],
159 }),
160 }
161 }
162
163 pub fn query_row<T, F>(&self, params: &[Value], f: F) -> Result<T>
168 where
169 F: FnOnce(&Row<'_>) -> Result<T>,
170 {
171 let mut rows = self.query(params)?;
172 match rows.next()? {
173 Some(row) => f(&row),
174 None => Err(SqlError::QueryReturnedNoRows),
175 }
176 }
177
178 pub fn exists(&self, params: &[Value]) -> Result<bool> {
183 if params.len() != self.param_count {
184 return Err(SqlError::ParameterCountMismatch {
185 expected: self.param_count,
186 got: params.len(),
187 });
188 }
189 if self.conn.inner.borrow().schema.generation() == self.schema_gen {
190 if let Some(plan) = &self.compiled {
191 if let Some(mut stream) = try_stream_via_plan(self, plan.as_ref(), params) {
192 return Ok(stream.next_row()?.is_some());
193 }
194 }
195 }
196 match self.run(params)? {
197 ExecutionResult::Query(qr) => Ok(!qr.rows.is_empty()),
198 ExecutionResult::RowsAffected(n) => Ok(n > 0),
199 ExecutionResult::Ok => Ok(false),
200 }
201 }
202
203 fn run(&self, params: &[Value]) -> Result<ExecutionResult> {
204 if params.len() != self.param_count {
205 return Err(SqlError::ParameterCountMismatch {
206 expected: self.param_count,
207 got: params.len(),
208 });
209 }
210 let mut inner = self.conn.inner.borrow_mut();
211 if inner.schema.generation() == self.schema_gen {
212 return inner.execute_prepared(self.conn.db, &self.ast, self.compiled.as_ref(), params);
213 }
214 let c = compile_inside(&mut inner, &self.sql)?;
215 if c.param_count != self.param_count {
216 return Err(SqlError::ParameterCountMismatch {
217 expected: self.param_count,
218 got: c.param_count,
219 });
220 }
221 inner.execute_prepared(self.conn.db, &c.ast, c.plan.as_ref(), params)
222 }
223}
224
225pub struct Rows<'a> {
232 source: RowSource<'a>,
233 columns: Vec<String>,
234 buf: Vec<Value>,
235}
236
237enum RowSource<'a> {
238 Materialized(std::vec::IntoIter<Vec<Value>>),
239 Streaming(Box<dyn RowSourceIter + 'a>),
240}
241
242impl<'a> Rows<'a> {
243 fn materialized(columns: Vec<String>, rows: Vec<Vec<Value>>) -> Self {
244 Self {
245 source: RowSource::Materialized(rows.into_iter()),
246 columns,
247 buf: Vec::new(),
248 }
249 }
250
251 fn streaming(source: Box<dyn RowSourceIter + 'a>) -> Self {
252 let columns = source.columns().to_vec();
253 Self {
254 source: RowSource::Streaming(source),
255 columns,
256 buf: Vec::new(),
257 }
258 }
259
260 #[allow(clippy::should_implement_trait)]
262 pub fn next(&mut self) -> Result<Option<Row<'_>>> {
263 let next: Option<Vec<Value>> = match &mut self.source {
264 RowSource::Materialized(iter) => iter.next(),
265 RowSource::Streaming(stream) => stream.next_row()?,
266 };
267 match next {
268 Some(values) => {
269 self.buf = values;
270 Ok(Some(Row {
271 columns: &self.columns,
272 values: &self.buf,
273 }))
274 }
275 None => Ok(None),
276 }
277 }
278
279 pub fn column_count(&self) -> usize {
281 self.columns.len()
282 }
283
284 pub fn column_names(&self) -> &[String] {
286 &self.columns
287 }
288
289 pub fn collect(mut self) -> Result<QueryResult> {
291 let mut rows = Vec::new();
292 while let Some(row) = self.next()? {
293 rows.push(row.to_vec());
294 }
295 Ok(QueryResult {
296 columns: self.columns,
297 rows,
298 })
299 }
300}
301
302pub struct Row<'a> {
307 columns: &'a [String],
308 values: &'a [Value],
309}
310
311impl<'a> Row<'a> {
312 pub fn get(&self, i: usize) -> Option<&Value> {
314 self.values.get(i)
315 }
316
317 pub fn get_by_name(&self, name: &str) -> Option<&Value> {
319 self.columns
320 .iter()
321 .position(|c| c == name)
322 .and_then(|i| self.values.get(i))
323 }
324
325 pub fn column_count(&self) -> usize {
327 self.values.len()
328 }
329
330 pub fn column_name(&self, i: usize) -> Option<&str> {
332 self.columns.get(i).map(|s| s.as_str())
333 }
334
335 pub fn as_slice(&self) -> &[Value] {
337 self.values
338 }
339
340 pub fn to_vec(&self) -> Vec<Value> {
342 self.values.to_vec()
343 }
344}
345
346fn compile_for_sql(conn: &Connection<'_>, sql: &str) -> Result<Compiled> {
347 let mut inner = conn.inner.borrow_mut();
348 compile_inside(&mut inner, sql)
349}
350
351fn compile_inside(
352 inner: &mut crate::connection::ConnectionInner<'_>,
353 sql: &str,
354) -> Result<Compiled> {
355 let (ast, param_count) = inner.get_or_parse(sql)?;
356 let schema_gen = inner.schema.generation();
357 let plan = executor::compile(&inner.schema, &ast);
358 if let Some(p) = &plan {
359 if let Some(entry) = inner.stmt_cache.get_mut(sql) {
360 entry.compiled = Some(Arc::clone(p));
361 }
362 }
363 let columns = derive_columns(&ast, &inner.schema);
364 Ok(Compiled {
365 ast,
366 plan,
367 schema_gen,
368 param_count,
369 columns,
370 })
371}
372
373fn derive_columns(stmt: &Statement, schema: &SchemaManager) -> Vec<String> {
374 match stmt {
375 Statement::Select(sq) => derive_select_columns(sq, schema),
376 Statement::Explain(_) => vec!["plan".into()],
377 _ => Vec::new(),
378 }
379}
380
381fn derive_select_columns(sq: &SelectQuery, schema: &SchemaManager) -> Vec<String> {
382 derive_body_columns(&sq.body, schema)
383}
384
385fn derive_body_columns(body: &QueryBody, schema: &SchemaManager) -> Vec<String> {
386 match body {
387 QueryBody::Select(sel) => derive_from_select_stmt(sel, schema),
388 QueryBody::Compound(cs) => derive_body_columns(&cs.left, schema),
389 }
390}
391
392fn try_stream_via_plan<'db>(
393 stmt: &PreparedStatement<'_, 'db>,
394 plan: &dyn CompiledPlan,
395 params: &[Value],
396) -> Option<Box<dyn RowSourceIter + 'db>> {
397 let inner = stmt.conn.inner.borrow();
398 if inner.active_txn_is_some() {
399 return None;
400 }
401 plan.try_stream(stmt.conn.db, &inner.schema, &stmt.ast, params)
402}
403
404fn derive_from_select_stmt(sel: &SelectStmt, schema: &SchemaManager) -> Vec<String> {
405 let lower = sel.from.to_ascii_lowercase();
406 let table_columns = schema.get(&lower).map(|ts| ts.columns.as_slice());
407 let mut out = Vec::new();
408 for col in &sel.columns {
409 match col {
410 SelectColumn::AllColumns => {
411 if let Some(cols) = table_columns {
412 for c in cols {
413 out.push(c.name.clone());
414 }
415 }
416 }
417 SelectColumn::Expr { alias: Some(a), .. } => out.push(a.clone()),
418 SelectColumn::Expr { expr, alias: None } => out.push(expr_display_name(expr)),
419 }
420 }
421 out
422}