1use std::sync::Arc;
4
5use rustc_hash::FxHashMap;
6
7use crate::connection::Connection;
8use crate::error::{Result, SqlError};
9use crate::executor::compile::RowSourceIter;
10use crate::executor::helpers::expr_display_name;
11use crate::executor::{self, CompiledPlan};
12use crate::parser::{QueryBody, SelectColumn, SelectQuery, SelectStmt, Statement};
13use crate::schema::SchemaManager;
14use crate::types::{ExecutionResult, QueryResult, Value};
15
16pub struct PreparedStatement<'c, 'db> {
18 conn: &'c Connection<'db>,
19 sql: String,
20 ast: Arc<Statement>,
21 compiled: Option<Arc<dyn CompiledPlan>>,
22 schema_gen: u64,
23 param_count: usize,
24 columns: Vec<String>,
25 column_index: FxHashMap<String, usize>,
26 readonly: bool,
27 is_explain: bool,
28}
29
30struct Compiled {
31 ast: Arc<Statement>,
32 plan: Option<Arc<dyn CompiledPlan>>,
33 schema_gen: u64,
34 param_count: usize,
35 columns: Vec<String>,
36}
37
38impl<'c, 'db> PreparedStatement<'c, 'db> {
39 pub(crate) fn new(conn: &'c Connection<'db>, sql: &str) -> Result<Self> {
40 let c = compile_for_sql(conn, sql)?;
41 let readonly = matches!(*c.ast, Statement::Select(_) | Statement::Explain(_));
42 let is_explain = matches!(*c.ast, Statement::Explain(_));
43 let mut column_index =
44 FxHashMap::with_capacity_and_hasher(c.columns.len(), Default::default());
45 for (i, name) in c.columns.iter().enumerate() {
46 column_index.entry(name.clone()).or_insert(i);
47 }
48 Ok(Self {
49 conn,
50 sql: sql.to_string(),
51 ast: c.ast,
52 compiled: c.plan,
53 schema_gen: c.schema_gen,
54 param_count: c.param_count,
55 columns: c.columns,
56 column_index,
57 readonly,
58 is_explain,
59 })
60 }
61
62 pub fn sql(&self) -> &str {
63 &self.sql
64 }
65
66 pub fn param_count(&self) -> usize {
68 self.param_count
69 }
70
71 pub fn parameter_count(&self) -> usize {
73 self.param_count
74 }
75
76 pub fn column_count(&self) -> usize {
78 self.columns.len()
79 }
80
81 pub fn column_names(&self) -> &[String] {
83 &self.columns
84 }
85
86 pub fn column_name(&self, i: usize) -> Option<&str> {
88 self.columns.get(i).map(|s| s.as_str())
89 }
90
91 pub fn column_index(&self, name: &str) -> Option<usize> {
93 self.column_index.get(name).copied()
94 }
95
96 pub fn readonly(&self) -> bool {
98 self.readonly
99 }
100
101 pub fn is_explain(&self) -> bool {
103 self.is_explain
104 }
105
106 pub fn execute(&self, params: &[Value]) -> Result<u64> {
108 match self.run(params)? {
109 ExecutionResult::RowsAffected(n) => Ok(n),
110 ExecutionResult::Query(_) | ExecutionResult::Ok => Ok(0),
111 }
112 }
113
114 pub fn query(&self, params: &[Value]) -> Result<Rows<'_>> {
116 if self.readonly {
118 if let Some(stream) = self.stream_fast_path(params)? {
119 return Ok(Rows::streaming(stream));
120 }
121 }
122 let (columns, rows) = match self.run(params)? {
123 ExecutionResult::Query(qr) => (qr.columns, qr.rows),
124 ExecutionResult::RowsAffected(_) | ExecutionResult::Ok => {
125 (self.columns.clone(), Vec::new())
126 }
127 };
128 Ok(Rows::materialized(columns, rows))
129 }
130
131 pub fn query_collect(&self, params: &[Value]) -> Result<QueryResult> {
133 if self.readonly {
134 if let Some(qr) = self.collect_fast_path(params)? {
135 return Ok(qr);
136 }
137 if let Some(mut stream) = self.stream_fast_path(params)? {
138 let columns = stream.columns().to_vec();
139 let mut rows = Vec::with_capacity(stream.size_hint());
140 while let Some(row) = stream.next_row()? {
141 rows.push(row);
142 }
143 return Ok(QueryResult { columns, rows });
144 }
145 }
146 match self.run(params)? {
147 ExecutionResult::Query(qr) => Ok(qr),
148 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
149 columns: vec!["rows_affected".into()],
150 rows: vec![vec![Value::Integer(n as i64)]],
151 }),
152 ExecutionResult::Ok => Ok(QueryResult {
153 columns: vec![],
154 rows: vec![],
155 }),
156 }
157 }
158
159 pub fn query_row<T, F>(&self, params: &[Value], f: F) -> Result<T>
161 where
162 F: FnOnce(&Row<'_>) -> Result<T>,
163 {
164 let mut rows = self.query(params)?;
165 match rows.next()? {
166 Some(row) => f(&row),
167 None => Err(SqlError::QueryReturnedNoRows),
168 }
169 }
170
171 pub fn exists(&self, params: &[Value]) -> Result<bool> {
173 if self.readonly {
174 if let Some(mut stream) = self.stream_fast_path(params)? {
175 return Ok(stream.next_row()?.is_some());
176 }
177 }
178 match self.run(params)? {
179 ExecutionResult::Query(qr) => Ok(!qr.rows.is_empty()),
180 ExecutionResult::RowsAffected(n) => Ok(n > 0),
181 ExecutionResult::Ok => Ok(false),
182 }
183 }
184
185 fn stream_fast_path(&self, params: &[Value]) -> Result<Option<Box<dyn RowSourceIter + 'db>>> {
186 if params.len() != self.param_count {
187 return Err(SqlError::ParameterCountMismatch {
188 expected: self.param_count,
189 got: params.len(),
190 });
191 }
192 if self.conn.inner.borrow().schema.generation() != self.schema_gen {
193 return Ok(None);
194 }
195 match &self.compiled {
196 Some(plan) => Ok(try_stream_via_plan(self, plan.as_ref(), params)),
197 None => Ok(None),
198 }
199 }
200
201 fn collect_fast_path(&self, params: &[Value]) -> Result<Option<QueryResult>> {
203 if params.len() != self.param_count {
204 return Err(SqlError::ParameterCountMismatch {
205 expected: self.param_count,
206 got: params.len(),
207 });
208 }
209 let inner = self.conn.inner.borrow();
210 if inner.schema.generation() != self.schema_gen || inner.active_txn_is_some() {
211 return Ok(None);
212 }
213 match &self.compiled {
214 Some(plan) => plan
215 .try_collect(self.conn.db, &inner.schema, &self.ast, params)
216 .transpose(),
217 None => Ok(None),
218 }
219 }
220
221 fn run(&self, params: &[Value]) -> Result<ExecutionResult> {
222 if params.len() != self.param_count {
223 return Err(SqlError::ParameterCountMismatch {
224 expected: self.param_count,
225 got: params.len(),
226 });
227 }
228 let mut inner = self.conn.inner.borrow_mut();
229 if inner.schema.generation() == self.schema_gen {
230 return inner.execute_prepared(self.conn.db, &self.ast, self.compiled.as_ref(), params);
231 }
232 let c = compile_inside(&mut inner, &self.sql)?;
233 if c.param_count != self.param_count {
234 return Err(SqlError::ParameterCountMismatch {
235 expected: self.param_count,
236 got: c.param_count,
237 });
238 }
239 inner.execute_prepared(self.conn.db, &c.ast, c.plan.as_ref(), params)
240 }
241}
242
243pub struct Rows<'a> {
245 source: RowSource<'a>,
246 columns: Vec<String>,
247 buf: Vec<Value>,
248}
249
250enum RowSource<'a> {
251 Materialized(std::vec::IntoIter<Vec<Value>>),
252 Streaming(Box<dyn RowSourceIter + 'a>),
253}
254
255impl<'a> Rows<'a> {
256 fn materialized(columns: Vec<String>, rows: Vec<Vec<Value>>) -> Self {
257 Self {
258 source: RowSource::Materialized(rows.into_iter()),
259 columns,
260 buf: Vec::new(),
261 }
262 }
263
264 fn streaming(source: Box<dyn RowSourceIter + 'a>) -> Self {
265 let columns = source.columns().to_vec();
266 Self {
267 source: RowSource::Streaming(source),
268 columns,
269 buf: Vec::new(),
270 }
271 }
272
273 #[allow(clippy::should_implement_trait)]
275 pub fn next(&mut self) -> Result<Option<Row<'_>>> {
276 let next: Option<Vec<Value>> = match &mut self.source {
277 RowSource::Materialized(iter) => iter.next(),
278 RowSource::Streaming(stream) => stream.next_row()?,
279 };
280 match next {
281 Some(values) => {
282 self.buf = values;
283 Ok(Some(Row {
284 columns: &self.columns,
285 values: &self.buf,
286 }))
287 }
288 None => Ok(None),
289 }
290 }
291
292 pub fn column_count(&self) -> usize {
293 self.columns.len()
294 }
295
296 pub fn column_names(&self) -> &[String] {
297 &self.columns
298 }
299
300 pub fn collect(mut self) -> Result<QueryResult> {
302 let mut rows = Vec::new();
303 while let Some(row) = self.next()? {
304 rows.push(row.to_vec());
305 }
306 Ok(QueryResult {
307 columns: self.columns,
308 rows,
309 })
310 }
311}
312
313pub struct Row<'a> {
315 columns: &'a [String],
316 values: &'a [Value],
317}
318
319impl<'a> Row<'a> {
320 pub fn get(&self, i: usize) -> Option<&Value> {
322 self.values.get(i)
323 }
324
325 pub fn get_by_name(&self, name: &str) -> Option<&Value> {
327 self.columns
328 .iter()
329 .position(|c| c == name)
330 .and_then(|i| self.values.get(i))
331 }
332
333 pub fn column_count(&self) -> usize {
334 self.values.len()
335 }
336
337 pub fn column_name(&self, i: usize) -> Option<&str> {
339 self.columns.get(i).map(|s| s.as_str())
340 }
341
342 pub fn as_slice(&self) -> &[Value] {
343 self.values
344 }
345
346 pub fn to_vec(&self) -> Vec<Value> {
347 self.values.to_vec()
348 }
349}
350
351fn compile_for_sql(conn: &Connection<'_>, sql: &str) -> Result<Compiled> {
352 let mut inner = conn.inner.borrow_mut();
353 compile_inside(&mut inner, sql)
354}
355
356fn compile_inside(
357 inner: &mut crate::connection::ConnectionInner<'_>,
358 sql: &str,
359) -> Result<Compiled> {
360 let (ast, param_count) = inner.get_or_parse(sql)?;
361 let schema_gen = inner.schema.generation();
362 let plan = executor::compile(&inner.schema, &ast);
363 if let Some(p) = &plan {
364 if let Some(entry) = inner.stmt_cache.get_mut(sql) {
365 entry.compiled = Some(Arc::clone(p));
366 }
367 }
368 let columns = derive_columns(&ast, &inner.schema);
369 Ok(Compiled {
370 ast,
371 plan,
372 schema_gen,
373 param_count,
374 columns,
375 })
376}
377
378fn derive_columns(stmt: &Statement, schema: &SchemaManager) -> Vec<String> {
379 match stmt {
380 Statement::Select(sq) => derive_select_columns(sq, schema),
381 Statement::Explain(_) => vec!["plan".into()],
382 _ => Vec::new(),
383 }
384}
385
386fn derive_select_columns(sq: &SelectQuery, schema: &SchemaManager) -> Vec<String> {
387 derive_body_columns(&sq.body, schema)
388}
389
390fn derive_body_columns(body: &QueryBody, schema: &SchemaManager) -> Vec<String> {
391 match body {
392 QueryBody::Select(sel) => derive_from_select_stmt(sel, schema),
393 QueryBody::Compound(cs) => derive_body_columns(&cs.left, schema),
394 QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => Vec::new(),
395 }
396}
397
398fn try_stream_via_plan<'db>(
399 stmt: &PreparedStatement<'_, 'db>,
400 plan: &dyn CompiledPlan,
401 params: &[Value],
402) -> Option<Box<dyn RowSourceIter + 'db>> {
403 let inner = stmt.conn.inner.borrow();
404 if inner.active_txn_is_some() {
405 return None;
406 }
407 plan.try_stream(stmt.conn.db, &inner.schema, &stmt.ast, params)
408}
409
410fn derive_from_select_stmt(sel: &SelectStmt, schema: &SchemaManager) -> Vec<String> {
411 let lower = sel.from.to_ascii_lowercase();
412 let table_columns = schema.get(&lower).map(|ts| ts.columns.as_slice());
413 let mut out = Vec::new();
414 for col in &sel.columns {
415 match col {
416 SelectColumn::AllColumns | SelectColumn::AllFromOld | SelectColumn::AllFromNew => {
417 if let Some(cols) = table_columns {
418 for c in cols {
419 out.push(c.name.clone());
420 }
421 }
422 }
423 SelectColumn::Expr { alias: Some(a), .. } => out.push(a.clone()),
424 SelectColumn::Expr { expr, alias: None } => out.push(expr_display_name(expr)),
425 }
426 }
427 out
428}