use duckdb::{Connection, types::ToSql};
use oxide_sql_core::builder::value::SqlValue;
use oxide_sql_core::builder::{Delete, Insert, Select, Update, col};
use oxide_sql_core::migrations::dialect::DuckDbDialect;
use oxide_sql_core::migrations::dialect::MigrationDialect;
use oxide_sql_core::migrations::{CreateTableBuilder, integer, varchar};
use oxide_sql_derive::Table;
#[allow(dead_code)]
#[derive(Debug, Clone, Table)]
#[table(name = "items")]
pub struct Item {
#[column(primary_key)]
pub id: i64,
pub name: String,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Table)]
#[table(name = "counters")]
pub struct Counter {
#[column(primary_key)]
pub id: i64,
pub label: String,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Table)]
#[table(name = "excluded_domains")]
pub struct ExcludedDomain {
#[column(primary_key)]
pub id: i64,
pub domain: String,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Table)]
#[table(name = "excluded_ips")]
pub struct ExcludedIp {
#[column(primary_key)]
pub id: i64,
pub cidr: String,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Table)]
#[table(name = "products")]
pub struct Product {
#[column(primary_key)]
pub id: i64,
pub name: String,
pub price: i64,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Table)]
#[table(name = "users")]
pub struct User {
#[column(primary_key)]
pub id: i64,
pub name: String,
}
fn to_duckdb_params(values: &[SqlValue]) -> Vec<Box<dyn ToSql>> {
values
.iter()
.map(|v| -> Box<dyn ToSql> {
match v {
SqlValue::Null => Box::new(duckdb::types::Null),
SqlValue::Bool(b) => Box::new(*b),
SqlValue::Int(i) => Box::new(*i),
SqlValue::Float(f) => Box::new(*f),
SqlValue::Text(s) => Box::new(s.clone()),
SqlValue::Blob(b) => Box::new(b.clone()),
}
})
.collect()
}
fn execute_sql(conn: &Connection, sql: &str, params: &[SqlValue]) -> duckdb::Result<usize> {
let boxed = to_duckdb_params(params);
let refs: Vec<&dyn ToSql> = boxed.iter().map(|b| b.as_ref()).collect();
conn.execute(sql, refs.as_slice())
}
fn execute_batch(conn: &Connection, sql: &str) -> duckdb::Result<()> {
conn.execute_batch(sql)
}
fn query_id_str(conn: &Connection, sql: &str, params: &[SqlValue]) -> Vec<(i64, String)> {
let boxed = to_duckdb_params(params);
let refs: Vec<&dyn ToSql> = boxed.iter().map(|b| b.as_ref()).collect();
let mut stmt = conn.prepare(sql).unwrap();
stmt.query_map(refs.as_slice(), |row| {
Ok((
row.get::<_, i64>(0).unwrap(),
row.get::<_, String>(1).unwrap(),
))
})
.unwrap()
.map(Result::unwrap)
.collect()
}
#[test]
fn test_create_table_and_insert() {
let conn = Connection::open_in_memory().unwrap();
let dialect = DuckDbDialect::new();
let create_op = CreateTableBuilder::new()
.name("items")
.column(integer("id").primary_key().autoincrement().build())
.column(varchar("name", 255).not_null().build())
.build();
execute_batch(&conn, &dialect.create_table(&create_op)).unwrap();
let (sql, params) = Insert::<ItemTable, _>::new()
.set(Item::name(), "widget")
.build();
execute_sql(&conn, &sql, ¶ms).unwrap();
let (sql, params) = Select::<ItemTable, _, _>::new()
.select_all()
.from_table()
.build();
let rows = query_id_str(&conn, &sql, ¶ms);
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].0, 1);
assert_eq!(rows[0].1, "widget");
}
#[test]
fn test_autoincrement_sequence() {
let conn = Connection::open_in_memory().unwrap();
let dialect = DuckDbDialect::new();
let create_op = CreateTableBuilder::new()
.name("counters")
.column(integer("id").primary_key().autoincrement().build())
.column(varchar("label", 100).not_null().build())
.build();
execute_batch(&conn, &dialect.create_table(&create_op)).unwrap();
for label in &["alpha", "beta", "gamma"] {
let (sql, params) = Insert::<CounterTable, _>::new()
.set(Counter::label(), *label)
.build();
execute_sql(&conn, &sql, ¶ms).unwrap();
}
let (sql, params) = Select::<CounterTable, _, _>::new()
.select_all()
.from_table()
.order_by(Counter::id(), true)
.build();
let rows = query_id_str(&conn, &sql, ¶ms);
assert_eq!(rows.len(), 3);
assert_eq!(rows[0], (1, "alpha".to_string()));
assert_eq!(rows[1], (2, "beta".to_string()));
assert_eq!(rows[2], (3, "gamma".to_string()));
}
#[test]
fn test_consumer_scenario() {
let conn = Connection::open_in_memory().unwrap();
let dialect = DuckDbDialect::new();
let op1 = CreateTableBuilder::new()
.if_not_exists()
.name("excluded_domains")
.column(integer("id").primary_key().autoincrement().build())
.column(varchar("domain", 255).not_null().unique().build())
.build();
execute_batch(&conn, &dialect.create_table(&op1)).unwrap();
let op2 = CreateTableBuilder::new()
.if_not_exists()
.name("excluded_ips")
.column(integer("id").primary_key().autoincrement().build())
.column(varchar("cidr", 255).not_null().unique().build())
.build();
execute_batch(&conn, &dialect.create_table(&op2)).unwrap();
for domain in &["example.com", "test.org"] {
let (sql, params) = Insert::<ExcludedDomainTable, _>::new()
.set(ExcludedDomain::domain(), *domain)
.build();
execute_sql(&conn, &sql, ¶ms).unwrap();
}
for cidr in &["10.0.0.0/8", "192.168.0.0/16", "172.16.0.0/12"] {
let (sql, params) = Insert::<ExcludedIpTable, _>::new()
.set(ExcludedIp::cidr(), *cidr)
.build();
execute_sql(&conn, &sql, ¶ms).unwrap();
}
let (sql, params) = Select::<ExcludedDomainTable, _, _>::new()
.select_all()
.from_table()
.order_by(ExcludedDomain::id(), true)
.build();
let domains = query_id_str(&conn, &sql, ¶ms);
assert_eq!(domains.len(), 2);
assert_eq!(domains[0].1, "example.com");
assert_eq!(domains[1].1, "test.org");
let (sql, params) = Select::<ExcludedIpTable, _, _>::new()
.select_all()
.from_table()
.order_by(ExcludedIp::id(), true)
.build();
let ips = query_id_str(&conn, &sql, ¶ms);
assert_eq!(ips.len(), 3);
assert_eq!(ips[0].1, "10.0.0.0/8");
assert_eq!(ips[1].1, "192.168.0.0/16");
assert_eq!(ips[2].1, "172.16.0.0/12");
}
#[test]
fn test_select_with_where() {
let conn = Connection::open_in_memory().unwrap();
let dialect = DuckDbDialect::new();
let create_op = CreateTableBuilder::new()
.name("products")
.column(integer("id").primary_key().autoincrement().build())
.column(varchar("name", 255).not_null().build())
.column(integer("price").not_null().build())
.build();
execute_batch(&conn, &dialect.create_table(&create_op)).unwrap();
let products = vec![
("apple", 100_i64),
("banana", 50),
("cherry", 200),
("date", 150),
];
for (name, price) in &products {
let (sql, params) = Insert::<ProductTable, _>::new()
.set(Product::name(), *name)
.set(Product::price(), *price)
.build();
execute_sql(&conn, &sql, ¶ms).unwrap();
}
let (sql, params) = Select::<ProductTable, _, _>::new()
.select_all()
.from_table()
.where_col(Product::price(), col(Product::price()).gt(100_i64))
.order_by(Product::price(), true)
.build();
let boxed = to_duckdb_params(¶ms);
let refs: Vec<&dyn ToSql> = boxed.iter().map(|b| b.as_ref()).collect();
let mut stmt = conn.prepare(&sql).unwrap();
let rows: Vec<(i64, String, i64)> = stmt
.query_map(refs.as_slice(), |row| {
Ok((
row.get::<_, i64>(0).unwrap(),
row.get::<_, String>(1).unwrap(),
row.get::<_, i64>(2).unwrap(),
))
})
.unwrap()
.map(Result::unwrap)
.collect();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].1, "date");
assert_eq!(rows[0].2, 150);
assert_eq!(rows[1].1, "cherry");
assert_eq!(rows[1].2, 200);
}
#[test]
fn test_update_and_delete() {
let conn = Connection::open_in_memory().unwrap();
let dialect = DuckDbDialect::new();
let create_op = CreateTableBuilder::new()
.name("users")
.column(integer("id").primary_key().autoincrement().build())
.column(varchar("name", 255).not_null().build())
.build();
execute_batch(&conn, &dialect.create_table(&create_op)).unwrap();
for name in &["alice", "bob", "charlie"] {
let (sql, params) = Insert::<UserTable, _>::new()
.set(User::name(), *name)
.build();
execute_sql(&conn, &sql, ¶ms).unwrap();
}
let (sql, params) = Update::<UserTable, _>::new()
.set(User::name(), "robert")
.where_col(User::name(), col(User::name()).eq("bob"))
.build();
let affected = execute_sql(&conn, &sql, ¶ms).unwrap();
assert_eq!(affected, 1);
let (sql, params) = Select::<UserTable, _, _>::new()
.select_all()
.from_table()
.order_by(User::id(), true)
.build();
let rows = query_id_str(&conn, &sql, ¶ms);
assert_eq!(rows[1].1, "robert");
let (sql, params) = Delete::<UserTable>::new()
.where_col(User::id(), col(User::id()).eq(3_i64))
.build();
let affected = execute_sql(&conn, &sql, ¶ms).unwrap();
assert_eq!(affected, 1);
let (sql, params) = Select::<UserTable, _, _>::new()
.select_all()
.from_table()
.order_by(User::id(), true)
.build();
let rows = query_id_str(&conn, &sql, ¶ms);
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].0, 1);
assert_eq!(rows[1].0, 2);
}