use sqlite3_ext::{vtab::*, *};
const COLUMN_START: i32 = 1;
const COLUMN_STOP: i32 = 2;
const COLUMN_STEP: i32 = 3;
#[sqlite3_ext_vtab(EponymousModule)]
struct GenerateSeries {}
impl VTab<'_> for GenerateSeries {
type Aux = ();
type Cursor = Cursor;
fn connect(db: &VTabConnection, _aux: &Self::Aux, _args: &[&str]) -> Result<(String, Self)> {
db.set_risk_level(RiskLevel::Innocuous);
Ok((
"CREATE TABLE x ( value, start HIDDEN, stop HIDDEN, step HIDDEN )".to_owned(),
GenerateSeries {},
))
}
fn best_index(&self, index_info: &mut IndexInfo) -> Result<()> {
let mut query_plan: i32 = 0;
let mut unusable_mask: i32 = 0;
let mut has_start = false;
let mut arg_index: [Option<usize>; 3] = [None, None, None];
for (i, constraint) in index_info.constraints().enumerate() {
if constraint.column() < COLUMN_START {
continue;
}
let bit = (constraint.column() - COLUMN_START) as usize;
assert!(bit <= 2);
if constraint.column() == COLUMN_START {
has_start = true;
}
if !constraint.usable() {
unusable_mask |= 1 << bit;
continue;
} else if constraint.op() == ConstraintOp::Eq {
query_plan |= 1 << bit;
arg_index[bit] = Some(i);
}
}
let mut constraints: Vec<_> = index_info.constraints().collect();
let mut next_idx = 0;
for i in 0..3 {
if let Some(j) = arg_index[i] {
constraints[j].set_argv_index(Some(next_idx));
constraints[j].set_omit(true);
next_idx += 1;
}
}
if !has_start {
return Err(Error::Module(
"first argument to \"generate_series()\" missing or unusable".to_owned(),
));
}
if (unusable_mask & !query_plan) != 0 {
return Err(SQLITE_CONSTRAINT);
}
if (query_plan & 3) == 3 {
index_info.set_estimated_cost((2 - ((query_plan & 4) != 0) as isize) as f64);
index_info.set_estimated_rows(1000);
if let Some(order) = index_info.order_by().next() {
if order.column() == 0 {
if order.desc() {
query_plan |= 8;
} else {
query_plan |= 16;
}
index_info.set_order_by_consumed(true);
}
}
} else {
index_info.set_estimated_rows(i64::MAX / 2);
}
index_info.set_index_num(query_plan);
Ok(())
}
fn open(&self) -> Result<Self::Cursor> {
Ok(Cursor::default())
}
}
#[derive(Default, Debug)]
struct Cursor {
desc: bool,
rowid: i64,
value: i64,
min_value: i64,
max_value: i64,
step: i64,
}
impl VTabCursor for Cursor {
fn filter(
&mut self,
query_plan: i32,
_: Option<&str>,
args: &mut [&mut ValueRef],
) -> Result<()> {
let mut query_plan = query_plan;
for a in args.iter() {
if a.is_null() {
self.max_value = -1;
return Ok(());
}
}
let mut args = args.into_iter();
if query_plan & 1 != 0 {
self.min_value = (*args.next().unwrap()).get_i64();
}
if query_plan & 2 != 0 {
self.max_value = (*args.next().unwrap()).get_i64();
} else {
self.max_value = i64::MAX;
}
if query_plan & 4 != 0 {
self.step = (*args.next().unwrap()).get_i64();
if self.step == 0 {
self.step = 1;
} else if self.step < 0 {
self.step = -self.step;
if query_plan & 16 == 0 {
query_plan |= 8;
}
}
} else {
self.step = 1;
}
if query_plan & 8 != 0 {
self.desc = true;
self.value = self.max_value;
if self.step > 0 {
self.value -= (self.max_value - self.min_value) % self.step;
}
} else {
self.desc = false;
self.value = self.min_value;
}
self.rowid = 1;
Ok(())
}
fn next(&mut self) -> Result<()> {
if self.desc {
self.value -= self.step;
} else {
self.value += self.step;
}
self.rowid += 1;
Ok(())
}
fn eof(&mut self) -> bool {
if self.desc {
self.value < self.min_value
} else {
self.value > self.max_value
}
}
fn column(&mut self, idx: usize, c: &ColumnContext) -> Result<()> {
let ret = match idx as _ {
COLUMN_START => self.min_value,
COLUMN_STOP => self.max_value,
COLUMN_STEP => self.step,
_ => self.value,
};
c.set_result(ret)
}
fn rowid(&mut self) -> Result<i64> {
Ok(self.rowid)
}
}
#[sqlite3_ext_main]
fn init(db: &Connection) -> Result<()> {
db.create_module("generate_series", GenerateSeries::module(), ())?;
Ok(())
}
#[cfg(all(test, feature = "static"))]
mod test {
use super::*;
fn setup() -> Result<Database> {
let conn = Database::open(":memory:")?;
init(&conn)?;
Ok(conn)
}
#[test]
fn example() -> Result<()> {
let conn = setup()?;
let results: Vec<i64> = conn
.prepare("SELECT value FROM generate_series(5, 100, 5)")?
.query(())?
.map(|row| Ok(row[0].get_i64()))
.collect()?;
assert_eq!(
results,
vec![5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100]
);
let results: Vec<i64> = conn
.prepare("SELECT value FROM generate_series(20) LIMIT 10")?
.query(())?
.map(|row| Ok(row[0].get_i64()))
.collect()?;
assert_eq!(results, vec![20, 21, 22, 23, 24, 25, 26, 27, 28, 29]);
Ok(())
}
macro_rules! case {
($test_name:ident { sql: $sql:expr, expected: $expected:expr, }) => {
#[test]
fn $test_name() -> Result<()> {
let conn = setup()?;
let results = conn
.prepare($sql)?
.query(())?
.map(|row| Ok(row[0].get_i64()))
.collect();
assert_eq!(results, $expected);
Ok(())
}
};
}
case!(max_lt_min {
sql: "SELECT value FROM generate_series(10, 5)",
expected: Ok(vec![]),
});
case!(max_eq_min {
sql: "SELECT value FROM generate_series(17, 17)",
expected: Ok(vec![17]),
});
case!(negative_step {
sql: "SELECT value FROM generate_series(5, 10, -2)",
expected: Ok(vec![9, 7, 5]),
});
case!(order_desc {
sql: "SELECT value FROM generate_series(5, 10) ORDER BY value DESC",
expected: Ok(vec![10, 9, 8, 7, 6, 5]),
});
case!(order_asc {
sql: "SELECT value FROM generate_series(5, 10, -2) ORDER BY value",
expected: Ok(vec![5, 7, 9]),
});
case!(null_arg {
sql: "SELECT value FROM generate_series(5, 10, NULL)",
expected: Ok(vec![]),
});
case!(only_limit {
sql: "SELECT value FROM generate_series(1) LIMIT 5",
expected: Ok(vec![1, 2, 3, 4, 5]),
});
}