pub mod eval;
pub mod func;
use crate::btree::{create_table_root, delete_table, insert_table, TableCursor};
use crate::error::{Error, Result};
use crate::format::record::{decode_record, encode_record};
use crate::pager::{PageSource, WritePager};
use crate::schema::Schema;
use crate::sql::ast::*;
use crate::sql::{self};
use crate::value::Value;
use crate::vfs::{OpenFlags, Vfs};
use alloc::boxed::Box;
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use eval::{ColumnInfo, EvalCtx, Params};
#[derive(Debug, Clone, PartialEq)]
pub struct QueryResult {
pub columns: Vec<String>,
pub rows: Vec<Vec<Value>>,
}
enum Backend {
Write(WritePager),
Read(Box<dyn PageSource>),
}
impl Backend {
fn source(&self) -> &dyn PageSource {
match self {
Backend::Write(w) => w,
Backend::Read(r) => r.as_ref(),
}
}
fn writer(&mut self) -> Result<&mut WritePager> {
match self {
Backend::Write(w) => Ok(w),
Backend::Read(_) => Err(Error::Error("database is read-only".into())),
}
}
}
pub struct Connection {
backend: Backend,
schema: Schema,
in_tx: bool,
}
impl Connection {
fn from_pager(db: WritePager) -> Result<Connection> {
let backend = Backend::Write(db);
let schema = Schema::read(backend.source())?;
Ok(Connection {
backend,
schema,
in_tx: false,
})
}
fn from_read_backend(backend: Box<dyn PageSource>) -> Result<Connection> {
let backend = Backend::Read(backend);
let schema = Schema::read(backend.source())?;
Ok(Connection {
backend,
schema,
in_tx: false,
})
}
pub fn open_vfs(vfs: &dyn Vfs, path: &str) -> Result<Connection> {
let main = vfs.open(path, OpenFlags::READ_WRITE)?;
let journal = vfs.open(&journal_path(path), OpenFlags::READ_WRITE_CREATE)?;
Connection::from_pager(WritePager::open(main, Some(journal))?)
}
pub fn open_readonly_vfs(vfs: &dyn Vfs, path: &str) -> Result<Connection> {
let main = vfs.open(path, OpenFlags::READ_ONLY)?;
let wal_path = wal_path(path);
if vfs.exists(&wal_path)? {
let mut wal = vfs.open(&wal_path, OpenFlags::READ_ONLY)?;
let reader = crate::pager::WalReader::open(main, wal.as_mut())?;
return Connection::from_read_backend(Box::new(reader));
}
Connection::from_read_backend(Box::new(WritePager::open(main, None)?))
}
pub fn create_vfs(vfs: &dyn Vfs, path: &str, page_size: u32) -> Result<Connection> {
let main = vfs.open(path, OpenFlags::READ_WRITE_CREATE)?;
let journal = vfs.open(&journal_path(path), OpenFlags::READ_WRITE_CREATE)?;
let mut db = WritePager::create(main, Some(journal), page_size)?;
db.commit()?;
Connection::from_pager(db)
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn open(path: &str) -> Result<Connection> {
Connection::open_vfs(&crate::vfs::std_file::StdVfs::new(), path)
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn open_readonly(path: &str) -> Result<Connection> {
Connection::open_readonly_vfs(&crate::vfs::std_file::StdVfs::new(), path)
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn create(path: &str) -> Result<Connection> {
Connection::create_vfs(&crate::vfs::std_file::StdVfs::new(), path, 4096)
}
pub fn open_memory() -> Result<Connection> {
let vfs = crate::vfs::memory::MemoryVfs::new();
let main = vfs.open("main", OpenFlags::READ_WRITE_CREATE)?;
let mut db = WritePager::create(main, None, 4096)?;
db.commit()?;
Connection::from_pager(db)
}
pub fn schema(&self) -> &Schema {
&self.schema
}
pub fn query(&self, sql: &str) -> Result<QueryResult> {
self.query_params(sql, &Params::default())
}
pub fn query_params(&self, sql: &str, params: &Params) -> Result<QueryResult> {
match sql::parse_one(sql)? {
Statement::Select(sel) => self.run_select(&sel, params),
Statement::Pragma(p) => self.run_pragma(&p),
_ => Err(Error::Unsupported(
"use execute() for non-SELECT statements",
)),
}
}
fn run_pragma(&self, p: &Pragma) -> Result<QueryResult> {
let name = p.name.to_ascii_lowercase();
let header = self.backend.source().header();
let single = |col: &str, v: Value| QueryResult {
columns: alloc::vec![String::from(col)],
rows: alloc::vec![alloc::vec![v]],
};
match name.as_str() {
"page_size" => Ok(single("page_size", Value::Integer(header.page_size as i64))),
"page_count" => Ok(single(
"page_count",
Value::Integer(self.backend.source().page_count() as i64),
)),
"user_version" => Ok(single(
"user_version",
Value::Integer(header.user_version as i64),
)),
"schema_version" => Ok(single(
"schema_version",
Value::Integer(header.schema_cookie as i64),
)),
"encoding" => Ok(single(
"encoding",
Value::Text(
match header.text_encoding {
crate::format::TextEncoding::Utf8 => "UTF-8",
crate::format::TextEncoding::Utf16Le => "UTF-16le",
crate::format::TextEncoding::Utf16Be => "UTF-16be",
}
.into(),
),
)),
"table_info" => self.pragma_table_info(p),
_ => Err(Error::Unsupported("this PRAGMA")),
}
}
fn pragma_table_info(&self, p: &Pragma) -> Result<QueryResult> {
let table = match &p.value {
Some(Expr::Column { column, .. }) => column.clone(),
Some(Expr::Literal(Literal::Str(s))) => s.clone(),
_ => {
return Err(Error::Error(
"PRAGMA table_info requires a table name".into(),
))
}
};
let obj = self
.schema
.table(&table)
.ok_or_else(|| Error::Error(format!("no such table: {table}")))?;
let sql = obj.sql.as_deref().unwrap_or("");
let Statement::CreateTable(ct) = sql::parse_one(sql)? else {
return Err(Error::Corrupt("schema sql is not CREATE TABLE".into()));
};
let ipk = find_integer_primary_key(&ct);
let mut rows = Vec::new();
for (i, col) in ct.columns.iter().enumerate() {
let notnull = col
.constraints
.iter()
.any(|c| matches!(c, ColumnConstraint::NotNull))
|| Some(i) == ipk;
let dflt = col.constraints.iter().find_map(|c| match c {
ColumnConstraint::Default(Expr::Literal(l)) => Some(literal_text(l)),
_ => None,
});
let pk = if Some(i) == ipk
|| col
.constraints
.iter()
.any(|c| matches!(c, ColumnConstraint::PrimaryKey { .. }))
{
1
} else {
0
};
rows.push(alloc::vec![
Value::Integer(i as i64),
Value::Text(col.name.clone()),
Value::Text(col.type_name.clone().unwrap_or_default()),
Value::Integer(notnull as i64),
dflt.map(Value::Text).unwrap_or(Value::Null),
Value::Integer(pk),
]);
}
Ok(QueryResult {
columns: ["cid", "name", "type", "notnull", "dflt_value", "pk"]
.iter()
.map(|s| String::from(*s))
.collect(),
rows,
})
}
pub fn execute(&mut self, sql: &str) -> Result<usize> {
self.execute_params(sql, &Params::default())
}
pub fn execute_params(&mut self, sql: &str, params: &Params) -> Result<usize> {
let stmt = sql::parse_one(sql)?;
match &stmt {
Statement::Begin => {
self.in_tx = true;
return Ok(0);
}
Statement::Commit => {
self.backend.writer()?.commit()?;
self.in_tx = false;
return Ok(0);
}
Statement::Rollback => {
self.backend.writer()?.rollback();
self.in_tx = false;
self.schema = Schema::read(self.backend.source())?;
return Ok(0);
}
_ => {}
}
let affected = match stmt {
Statement::CreateTable(ct) => {
self.exec_create_table(&ct, sql.trim())?;
0
}
Statement::Insert(ins) => self.exec_insert(&ins, params)?,
Statement::Delete(del) => self.exec_delete(&del, params)?,
Statement::Update(upd) => self.exec_update(&upd, params)?,
Statement::Pragma(_) => 0, Statement::Select(_) => return Err(Error::Unsupported("use query() for SELECT")),
Statement::CreateIndex(_) | Statement::Drop(_) => {
return Err(Error::Unsupported("CREATE INDEX / DROP (Phase 7/9)"))
}
Statement::Begin | Statement::Commit | Statement::Rollback => unreachable!(),
};
if !self.in_tx {
self.backend.writer()?.commit()?;
self.schema = Schema::read(self.backend.source())?;
}
Ok(affected)
}
fn exec_create_table(&mut self, ct: &CreateTable, sql_text: &str) -> Result<()> {
if self.schema.table(&ct.name).is_some() {
if ct.if_not_exists {
return Ok(());
}
return Err(Error::Error(format!("table {} already exists", ct.name)));
}
if ct.without_rowid {
return Err(Error::Unsupported("WITHOUT ROWID tables"));
}
let root = create_table_root(self.backend.writer()?)?;
let next = self.next_rowid(crate::schema::SCHEMA_ROOT_PAGE)?;
let row = encode_record(&[
Value::Text("table".into()),
Value::Text(ct.name.clone()),
Value::Text(ct.name.clone()),
Value::Integer(root as i64),
Value::Text(sql_text.into()),
]);
insert_table(
self.backend.writer()?,
crate::schema::SCHEMA_ROOT_PAGE,
next,
&row,
)?;
let cookie = self
.backend
.writer()?
.header()
.schema_cookie
.wrapping_add(1);
self.backend.writer()?.header_mut().schema_cookie = cookie;
self.schema = Schema::read(self.backend.source())?;
Ok(())
}
fn exec_insert(&mut self, ins: &Insert, params: &Params) -> Result<usize> {
let rows = match &ins.source {
InsertSource::Values(rows) => rows.clone(),
InsertSource::DefaultValues => alloc::vec![Vec::new()],
InsertSource::Select(_) => return Err(Error::Unsupported("INSERT ... SELECT")),
};
let meta = self.table_meta(&ins.table, None)?;
let n_cols = meta.columns.len();
let target: Vec<usize> = if ins.columns.is_empty() {
(0..n_cols).collect()
} else {
let mut t = Vec::new();
for name in &ins.columns {
let pos = meta
.columns
.iter()
.position(|c| c.name.eq_ignore_ascii_case(name))
.ok_or_else(|| Error::Error(format!("no such column: {name}")))?;
t.push(pos);
}
t
};
let mut next_auto = self.next_rowid(meta.root)?;
let mut affected = 0;
for row_exprs in &rows {
if !ins.columns.is_empty() && row_exprs.len() != target.len() {
return Err(Error::Error("INSERT column/value count mismatch".into()));
}
let ctx = EvalCtx::rowless(params);
let mut values: Vec<Value> = meta
.defaults
.iter()
.map(|d| match d {
Some(e) => eval::eval(e, &ctx),
None => Ok(Value::Null),
})
.collect::<Result<_>>()?;
for (i, e) in row_exprs.iter().enumerate() {
values[target[i]] = eval::eval(e, &ctx)?;
}
let rowid = match meta.ipk {
Some(ipk) if !matches!(values[ipk], Value::Null) => {
let r = eval::to_i64(&values[ipk]);
next_auto = next_auto.max(r + 1);
r
}
_ => {
let r = next_auto;
next_auto += 1;
r
}
};
if let Some(ipk) = meta.ipk {
values[ipk] = Value::Null;
}
let record = encode_record(&values);
insert_table(self.backend.writer()?, meta.root, rowid, &record)?;
affected += 1;
}
Ok(affected)
}
fn exec_delete(&mut self, del: &Delete, params: &Params) -> Result<usize> {
let meta = self.table_meta(&del.table, None)?;
let victims = self.matching_rowids(&meta, del.where_clause.as_ref(), params)?;
for rowid in &victims {
delete_table(self.backend.writer()?, meta.root, *rowid)?;
}
Ok(victims.len())
}
fn exec_update(&mut self, upd: &Update, params: &Params) -> Result<usize> {
let meta = self.table_meta(&upd.table, None)?;
let mut targets: Vec<(i64, Vec<Value>)> = Vec::new();
{
let mut cur = TableCursor::new(self.backend.source(), meta.root);
let encoding = self.backend.source().header().text_encoding;
let mut ok = cur.first()?;
while ok {
let rowid = cur.rowid()?;
let mut values = decode_record(&cur.payload()?, encoding)?;
values.resize(meta.columns.len(), Value::Null);
if let Some(ipk) = meta.ipk {
values[ipk] = Value::Integer(rowid);
}
let matches = match &upd.where_clause {
Some(p) => {
let ctx = row_ctx(&values, &meta.columns, Some(rowid), params);
eval::truth(&eval::eval(p, &ctx)?) == Some(true)
}
None => true,
};
if matches {
targets.push((rowid, values));
}
ok = cur.next()?;
}
}
let mut affected = 0;
for (rowid, mut values) in targets {
for (col, expr) in &upd.assignments {
let pos = meta
.columns
.iter()
.position(|c| c.name.eq_ignore_ascii_case(col))
.ok_or_else(|| Error::Error(format!("no such column: {col}")))?;
let ctx = row_ctx(&values, &meta.columns, Some(rowid), params);
values[pos] = eval::eval(expr, &ctx)?;
}
let new_rowid = match meta.ipk {
Some(ipk) => eval::to_i64(&values[ipk]),
None => rowid,
};
if let Some(ipk) = meta.ipk {
values[ipk] = Value::Null;
}
let record = encode_record(&values);
delete_table(self.backend.writer()?, meta.root, rowid)?;
insert_table(self.backend.writer()?, meta.root, new_rowid, &record)?;
affected += 1;
}
Ok(affected)
}
fn matching_rowids(
&self,
meta: &TableMeta,
pred: Option<&Expr>,
params: &Params,
) -> Result<Vec<i64>> {
let mut out = Vec::new();
let mut cur = TableCursor::new(self.backend.source(), meta.root);
let encoding = self.backend.source().header().text_encoding;
let mut ok = cur.first()?;
while ok {
let rowid = cur.rowid()?;
let mut values = decode_record(&cur.payload()?, encoding)?;
values.resize(meta.columns.len(), Value::Null);
if let Some(ipk) = meta.ipk {
values[ipk] = Value::Integer(rowid);
}
let keep = match pred {
Some(p) => {
let ctx = row_ctx(&values, &meta.columns, Some(rowid), params);
eval::truth(&eval::eval(p, &ctx)?) == Some(true)
}
None => true,
};
if keep {
out.push(rowid);
}
ok = cur.next()?;
}
Ok(out)
}
fn next_rowid(&self, root: u32) -> Result<i64> {
let mut cur = TableCursor::new(self.backend.source(), root);
if cur.last()? {
Ok(cur.rowid()? + 1)
} else {
Ok(1)
}
}
fn run_select(&self, sel: &Select, params: &Params) -> Result<QueryResult> {
let (columns, input_rows) = self.scan_source(sel, params)?;
let mut rows: Vec<InputRow> = Vec::new();
for r in input_rows {
if let Some(pred) = &sel.where_clause {
let ctx = r.ctx(&columns, params);
if eval::truth(&eval::eval(pred, &ctx)?) != Some(true) {
continue;
}
}
rows.push(r);
}
let aggregated = !sel.group_by.is_empty() || self.has_aggregate(sel);
let (out_labels, mut out) = if aggregated {
self.eval_aggregated(sel, &columns, rows, params)?
} else {
self.eval_simple(sel, &columns, rows, params)?
};
if sel.distinct {
let mut seen: Vec<Vec<Value>> = Vec::new();
out.retain(|row| {
if seen.iter().any(|s| rows_equal(s, &row.values)) {
false
} else {
seen.push(row.values.clone());
true
}
});
}
if !sel.order_by.is_empty() {
out.sort_by(|a, b| {
for (i, term) in sel.order_by.iter().enumerate() {
let ord = eval::compare(&a.sort_keys[i], &b.sort_keys[i]);
let ord = if term.descending { ord.reverse() } else { ord };
if ord != core::cmp::Ordering::Equal {
return ord;
}
}
core::cmp::Ordering::Equal
});
}
let offset = match &sel.offset {
Some(e) => eval::to_i64(&eval::eval(e, &EvalCtx::rowless(params))?).max(0) as usize,
None => 0,
};
let limit = match &sel.limit {
Some(e) => {
Some(eval::to_i64(&eval::eval(e, &EvalCtx::rowless(params))?).max(0) as usize)
}
None => None,
};
let mut final_rows: Vec<Vec<Value>> =
out.into_iter().skip(offset).map(|r| r.values).collect();
if let Some(n) = limit {
final_rows.truncate(n);
}
Ok(QueryResult {
columns: out_labels,
rows: final_rows,
})
}
fn scan_source(
&self,
sel: &Select,
params: &Params,
) -> Result<(Vec<ColumnInfo>, Vec<InputRow>)> {
let Some(from) = &sel.from else {
return Ok((
Vec::new(),
alloc::vec![InputRow {
values: Vec::new(),
rowid: None
}],
));
};
let first_meta = self.table_meta(&from.first.name, from.first.alias.as_deref())?;
let first_rows = self.scan_table(&first_meta)?;
if from.joins.is_empty() {
let input_rows = first_rows
.into_iter()
.map(|(rowid, values)| InputRow {
values,
rowid: Some(rowid),
})
.collect();
return Ok((first_meta.columns, input_rows));
}
let mut columns = first_meta.columns.clone();
let mut rows: Vec<Vec<Value>> = first_rows.into_iter().map(|(_, v)| v).collect();
for join in &from.joins {
let jmeta = self.table_meta(&join.table.name, join.table.alias.as_deref())?;
let jrows: Vec<Vec<Value>> = self
.scan_table(&jmeta)?
.into_iter()
.map(|(_, v)| v)
.collect();
let mut new_columns = columns.clone();
new_columns.extend(jmeta.columns.iter().cloned());
let n_jcols = jmeta.columns.len();
let mut joined: Vec<Vec<Value>> = Vec::new();
for left in &rows {
let mut matched = false;
for right in &jrows {
let mut combined = left.clone();
combined.extend(right.iter().cloned());
let keep = match &join.on {
Some(on) => {
let ctx = row_ctx(&combined, &new_columns, None, params);
eval::truth(&eval::eval(on, &ctx)?) == Some(true)
}
None => true, };
if keep {
joined.push(combined);
matched = true;
}
}
if !matched && join.kind == JoinKind::Left {
let mut combined = left.clone();
combined.extend(core::iter::repeat_n(Value::Null, n_jcols));
joined.push(combined);
}
}
columns = new_columns;
rows = joined;
}
let input_rows = rows
.into_iter()
.map(|values| InputRow {
values,
rowid: None, })
.collect();
Ok((columns, input_rows))
}
fn scan_table(&self, meta: &TableMeta) -> Result<Vec<(i64, Vec<Value>)>> {
let encoding = self.backend.source().header().text_encoding;
let mut rows = Vec::new();
let mut cur = TableCursor::new(self.backend.source(), meta.root);
let mut ok = cur.first()?;
while ok {
let rowid = cur.rowid()?;
let mut values = decode_record(&cur.payload()?, encoding)?;
values.resize(meta.columns.len(), Value::Null);
if let Some(ipk) = meta.ipk {
values[ipk] = Value::Integer(rowid);
}
rows.push((rowid, values));
ok = cur.next()?;
}
Ok(rows)
}
fn eval_simple(
&self,
sel: &Select,
columns: &[ColumnInfo],
rows: Vec<InputRow>,
params: &Params,
) -> Result<(Vec<String>, Vec<OutRow>)> {
let labels = self.output_labels(sel, columns);
let mut out = Vec::with_capacity(rows.len());
for r in &rows {
let ctx = r.ctx(columns, params);
let mut values = Vec::new();
for col in &sel.columns {
project_column(col, columns, &ctx, &mut values)?;
}
let mut sort_keys = Vec::new();
for term in &sel.order_by {
match resolve_order_index(&term.expr, &labels, values.len()) {
Some(idx) => sort_keys.push(values[idx].clone()),
None => sort_keys.push(eval::eval(&term.expr, &ctx)?),
}
}
out.push(OutRow { values, sort_keys });
}
Ok((labels, out))
}
fn eval_aggregated(
&self,
sel: &Select,
columns: &[ColumnInfo],
rows: Vec<InputRow>,
params: &Params,
) -> Result<(Vec<String>, Vec<OutRow>)> {
if sel
.columns
.iter()
.any(|c| matches!(c, ResultColumn::Wildcard | ResultColumn::TableWildcard(_)))
{
return Err(Error::Unsupported("'*' with aggregation"));
}
let mut group_keys: Vec<Vec<Value>> = Vec::new();
let mut groups: Vec<Vec<usize>> = Vec::new();
for (i, r) in rows.iter().enumerate() {
let ctx = r.ctx(columns, params);
let mut key = Vec::new();
for g in &sel.group_by {
key.push(eval::eval(g, &ctx)?);
}
match group_keys.iter().position(|k| rows_equal(k, &key)) {
Some(idx) => groups[idx].push(i),
None => {
group_keys.push(key);
groups.push(alloc::vec![i]);
}
}
}
if sel.group_by.is_empty() {
groups = alloc::vec![(0..rows.len()).collect()];
}
let labels = self.output_labels(sel, columns);
let mut out = Vec::new();
for group in &groups {
let repr = group.first().map(|&i| &rows[i]);
let empty = InputRow {
values: alloc::vec![Value::Null; columns.len()],
rowid: None,
};
let repr_ctx = repr.unwrap_or(&empty).ctx(columns, params);
let mut values = Vec::new();
for col in &sel.columns {
let ResultColumn::Expr { expr, .. } = col else {
unreachable!("wildcards rejected above")
};
let substituted =
self.substitute_aggregates(expr, columns, &rows, group, params)?;
values.push(eval::eval(&substituted, &repr_ctx)?);
}
if let Some(having) = &sel.having {
let h = self.substitute_aggregates(having, columns, &rows, group, params)?;
if eval::truth(&eval::eval(&h, &repr_ctx)?) != Some(true) {
continue;
}
}
let mut sort_keys = Vec::new();
for term in &sel.order_by {
if let Some(idx) = resolve_order_index(&term.expr, &labels, values.len()) {
sort_keys.push(values[idx].clone());
} else {
let s =
self.substitute_aggregates(&term.expr, columns, &rows, group, params)?;
sort_keys.push(eval::eval(&s, &repr_ctx)?);
}
}
out.push(OutRow { values, sort_keys });
}
Ok((labels, out))
}
fn substitute_aggregates(
&self,
expr: &Expr,
columns: &[ColumnInfo],
rows: &[InputRow],
group: &[usize],
params: &Params,
) -> Result<Expr> {
Ok(match expr {
Expr::Function {
name,
distinct,
args,
star,
} if func::is_aggregate(name) => {
let v = self.compute_aggregate(
name, *distinct, args, *star, columns, rows, group, params,
)?;
Expr::Literal(value_to_literal(v))
}
Expr::Function {
name,
distinct,
args,
star,
} => {
let mut new_args = Vec::with_capacity(args.len());
for a in args {
new_args.push(self.substitute_aggregates(a, columns, rows, group, params)?);
}
Expr::Function {
name: name.clone(),
distinct: *distinct,
args: new_args,
star: *star,
}
}
Expr::Binary { op, left, right } => Expr::Binary {
op: *op,
left: Box::new(self.substitute_aggregates(left, columns, rows, group, params)?),
right: Box::new(self.substitute_aggregates(right, columns, rows, group, params)?),
},
Expr::Unary { op, expr } => Expr::Unary {
op: *op,
expr: Box::new(self.substitute_aggregates(expr, columns, rows, group, params)?),
},
Expr::Paren(e) => Expr::Paren(Box::new(
self.substitute_aggregates(e, columns, rows, group, params)?,
)),
other => other.clone(),
})
}
#[allow(clippy::too_many_arguments)]
fn compute_aggregate(
&self,
name: &str,
distinct: bool,
args: &[Expr],
star: bool,
columns: &[ColumnInfo],
rows: &[InputRow],
group: &[usize],
params: &Params,
) -> Result<Value> {
let lname = name.to_ascii_lowercase();
let mut vals: Vec<Value> = Vec::new();
let mut count_rows = 0usize; for &i in group {
count_rows += 1;
if star {
continue;
}
let ctx = rows[i].ctx(columns, params);
let v = eval::eval(&args[0], &ctx)?;
if !matches!(v, Value::Null) {
vals.push(v);
}
}
if distinct {
dedup_values(&mut vals);
}
Ok(match lname.as_str() {
"count" => {
if star {
Value::Integer(count_rows as i64)
} else {
Value::Integer(vals.len() as i64)
}
}
"sum" => {
if vals.is_empty() {
Value::Null
} else if vals.iter().all(|v| matches!(v, Value::Integer(_))) {
let mut acc: i64 = 0;
let mut overflow = false;
for v in &vals {
if let Value::Integer(i) = v {
match acc.checked_add(*i) {
Some(s) => acc = s,
None => {
overflow = true;
break;
}
}
}
}
if overflow {
Value::Real(vals.iter().map(eval::to_f64).sum())
} else {
Value::Integer(acc)
}
} else {
Value::Real(vals.iter().map(eval::to_f64).sum())
}
}
"total" => Value::Real(vals.iter().map(eval::to_f64).sum()),
"avg" => {
if vals.is_empty() {
Value::Null
} else {
let sum: f64 = vals.iter().map(eval::to_f64).sum();
Value::Real(sum / vals.len() as f64)
}
}
"min" => vals
.into_iter()
.reduce(|a, b| {
if eval::compare(&b, &a) == core::cmp::Ordering::Less {
b
} else {
a
}
})
.unwrap_or(Value::Null),
"max" => vals
.into_iter()
.reduce(|a, b| {
if eval::compare(&b, &a) == core::cmp::Ordering::Greater {
b
} else {
a
}
})
.unwrap_or(Value::Null),
"group_concat" => {
if vals.is_empty() {
Value::Null
} else {
let sep = if args.len() >= 2 {
let ctx = EvalCtx::rowless(params);
eval::to_text(&eval::eval(&args[1], &ctx)?)
} else {
",".to_string()
};
let parts: Vec<String> = vals.iter().map(eval::to_text).collect();
Value::Text(parts.join(&sep))
}
}
_ => return Err(Error::Unsupported("aggregate function")),
})
}
fn has_aggregate(&self, sel: &Select) -> bool {
sel.columns.iter().any(|c| match c {
ResultColumn::Expr { expr, .. } => contains_aggregate(expr),
_ => false,
}) || sel.having.as_ref().is_some_and(contains_aggregate)
}
fn output_labels(&self, sel: &Select, columns: &[ColumnInfo]) -> Vec<String> {
let mut labels = Vec::new();
for col in &sel.columns {
match col {
ResultColumn::Wildcard | ResultColumn::TableWildcard(_) => {
for c in columns {
labels.push(c.name.clone());
}
}
ResultColumn::Expr { expr, alias } => {
labels.push(alias.clone().unwrap_or_else(|| expr_label(expr)));
}
}
}
labels
}
fn table_meta(&self, name: &str, alias: Option<&str>) -> Result<TableMeta> {
let obj = self
.schema
.table(name)
.ok_or_else(|| Error::Error(alloc::format!("no such table: {name}")))?;
let sql = obj
.sql
.as_ref()
.ok_or_else(|| Error::Corrupt("table has no CREATE statement".into()))?;
let Statement::CreateTable(ct) = sql::parse_one(sql)? else {
return Err(Error::Corrupt("schema sql is not CREATE TABLE".into()));
};
if ct.without_rowid {
return Err(Error::Unsupported("WITHOUT ROWID tables"));
}
let table_label = alias.unwrap_or(name).to_string();
let columns: Vec<ColumnInfo> = ct
.columns
.iter()
.map(|c| ColumnInfo {
name: c.name.clone(),
table: table_label.clone(),
})
.collect();
let defaults: Vec<Option<Expr>> = ct
.columns
.iter()
.map(|c| {
c.constraints.iter().find_map(|k| match k {
ColumnConstraint::Default(e) => Some(e.clone()),
_ => None,
})
})
.collect();
let ipk = find_integer_primary_key(&ct);
Ok(TableMeta {
root: obj.rootpage,
columns,
defaults,
ipk,
})
}
}
struct TableMeta {
root: u32,
columns: Vec<ColumnInfo>,
defaults: Vec<Option<Expr>>,
ipk: Option<usize>,
}
struct InputRow {
values: Vec<Value>,
rowid: Option<i64>,
}
impl InputRow {
fn ctx<'a>(&'a self, columns: &'a [ColumnInfo], params: &'a Params) -> EvalCtx<'a> {
EvalCtx {
row: &self.values,
columns,
rowid: self.rowid,
params,
anon_counter: core::cell::Cell::new(0),
}
}
}
fn row_ctx<'a>(
values: &'a [Value],
columns: &'a [ColumnInfo],
rowid: Option<i64>,
params: &'a Params,
) -> EvalCtx<'a> {
EvalCtx {
row: values,
columns,
rowid,
params,
anon_counter: core::cell::Cell::new(0),
}
}
fn journal_path(path: &str) -> String {
let mut p = String::from(path);
p.push_str("-journal");
p
}
fn wal_path(path: &str) -> String {
let mut p = String::from(path);
p.push_str("-wal");
p
}
struct OutRow {
values: Vec<Value>,
sort_keys: Vec<Value>,
}
fn project_column(
col: &ResultColumn,
columns: &[ColumnInfo],
ctx: &EvalCtx,
out: &mut Vec<Value>,
) -> Result<()> {
match col {
ResultColumn::Wildcard => {
for v in ctx.row {
out.push(v.clone());
}
}
ResultColumn::TableWildcard(table) => {
for (i, c) in columns.iter().enumerate() {
if c.table.eq_ignore_ascii_case(table) {
out.push(ctx.row[i].clone());
}
}
}
ResultColumn::Expr { expr, .. } => {
out.push(eval::eval(expr, ctx)?);
}
}
Ok(())
}
fn resolve_order_index(expr: &Expr, labels: &[String], ncols: usize) -> Option<usize> {
match expr {
Expr::Literal(Literal::Integer(n)) => {
let idx = (*n as usize).checked_sub(1)?;
(idx < ncols).then_some(idx)
}
Expr::Column {
table: None,
column,
} => labels.iter().position(|l| l.eq_ignore_ascii_case(column)),
_ => None,
}
}
fn contains_aggregate(expr: &Expr) -> bool {
match expr {
Expr::Function { name, args, .. } => {
func::is_aggregate(name) || args.iter().any(contains_aggregate)
}
Expr::Binary { left, right, .. } => contains_aggregate(left) || contains_aggregate(right),
Expr::Unary { expr, .. } | Expr::Paren(expr) => contains_aggregate(expr),
Expr::IsNull { expr, .. } => contains_aggregate(expr),
Expr::Between {
expr, low, high, ..
} => contains_aggregate(expr) || contains_aggregate(low) || contains_aggregate(high),
Expr::InList { expr, list, .. } => {
contains_aggregate(expr) || list.iter().any(contains_aggregate)
}
Expr::Case {
operand,
when_then,
else_result,
} => {
operand.as_deref().is_some_and(contains_aggregate)
|| when_then
.iter()
.any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
|| else_result.as_deref().is_some_and(contains_aggregate)
}
Expr::Cast { expr, .. } => contains_aggregate(expr),
_ => false,
}
}
fn rows_equal(a: &[Value], b: &[Value]) -> bool {
a.len() == b.len()
&& a.iter()
.zip(b)
.all(|(x, y)| eval::compare(x, y) == core::cmp::Ordering::Equal)
}
fn dedup_values(vals: &mut Vec<Value>) {
let mut seen: Vec<Value> = Vec::new();
vals.retain(|v| {
if seen
.iter()
.any(|s| eval::compare(s, v) == core::cmp::Ordering::Equal)
{
false
} else {
seen.push(v.clone());
true
}
});
}
fn value_to_literal(v: Value) -> Literal {
match v {
Value::Null => Literal::Null,
Value::Integer(i) => Literal::Integer(i),
Value::Real(r) => Literal::Real(r),
Value::Text(s) => Literal::Str(s),
Value::Blob(b) => Literal::Blob(b),
}
}
fn literal_text(l: &Literal) -> String {
eval::to_text(&match l {
Literal::Null => Value::Null,
Literal::Integer(i) => Value::Integer(*i),
Literal::Real(r) => Value::Real(*r),
Literal::Str(s) => Value::Text(s.clone()),
Literal::Blob(b) => Value::Blob(b.clone()),
Literal::Boolean(b) => Value::Integer(*b as i64),
})
}
fn expr_label(expr: &Expr) -> String {
match expr {
Expr::Column { column, .. } => column.clone(),
Expr::Literal(Literal::Integer(i)) => i.to_string(),
Expr::Literal(Literal::Str(s)) => s.clone(),
Expr::Function { name, .. } => name.clone(),
Expr::Paren(e) => expr_label(e),
_ => "expr".to_string(),
}
}
fn find_integer_primary_key(ct: &CreateTable) -> Option<usize> {
for (i, c) in ct.columns.iter().enumerate() {
let is_integer = c
.type_name
.as_deref()
.is_some_and(|t| t.eq_ignore_ascii_case("integer"));
let is_pk = c
.constraints
.iter()
.any(|k| matches!(k, ColumnConstraint::PrimaryKey { .. }));
if is_integer && is_pk {
return Some(i);
}
}
for tc in &ct.constraints {
if let TableConstraint::PrimaryKey(cols) = tc {
if cols.len() == 1 {
if let Some(i) = ct.columns.iter().position(|c| c.name == cols[0]) {
if ct.columns[i]
.type_name
.as_deref()
.is_some_and(|t| t.eq_ignore_ascii_case("integer"))
{
return Some(i);
}
}
}
}
}
None
}