use fsqlite_error::FrankenError;
use fsqlite_types::value::SqliteValue;
use crate::{Connection, Row};
use super::params::ParamValue;
pub trait ConnectionExt {
fn query_row_map<T, F>(
&self,
sql: &str,
params: &[ParamValue],
f: F,
) -> Result<T, FrankenError>
where
F: FnOnce(&Row) -> Result<T, FrankenError>;
fn query_map_collect<T, F>(
&self,
sql: &str,
params: &[ParamValue],
f: F,
) -> Result<Vec<T>, FrankenError>
where
F: FnMut(&Row) -> Result<T, FrankenError>;
fn execute_compat(&self, sql: &str, params: &[ParamValue]) -> Result<usize, FrankenError>;
}
impl ConnectionExt for Connection {
fn query_row_map<T, F>(&self, sql: &str, params: &[ParamValue], f: F) -> Result<T, FrankenError>
where
F: FnOnce(&Row) -> Result<T, FrankenError>,
{
let values: Vec<SqliteValue> = params.iter().map(|p| p.0.clone()).collect();
let row = self.query_row_with_params(sql, &values)?;
f(&row)
}
fn query_map_collect<T, F>(
&self,
sql: &str,
params: &[ParamValue],
mut f: F,
) -> Result<Vec<T>, FrankenError>
where
F: FnMut(&Row) -> Result<T, FrankenError>,
{
let values: Vec<SqliteValue> = params.iter().map(|p| p.0.clone()).collect();
let mut mapped = Vec::new();
self.query_with_params_for_each(sql, &values, |row| {
mapped.push(f(row)?);
Ok(())
})?;
Ok(mapped)
}
fn execute_compat(&self, sql: &str, params: &[ParamValue]) -> Result<usize, FrankenError> {
let values: Vec<SqliteValue> = params.iter().map(|p| p.0.clone()).collect();
self.execute_with_params(sql, &values)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compat::RowExt;
use crate::compat::{OpenFlags, open_with_flags};
use rusqlite::params;
#[test]
fn query_row_map_returns_value() {
let conn = Connection::open(":memory:").unwrap();
let result: i64 = conn
.query_row_map("SELECT 42", &[], |row| row.get_typed(0))
.unwrap();
assert_eq!(result, 42);
}
#[test]
fn query_row_map_with_params() {
let conn = Connection::open(":memory:").unwrap();
let p = [ParamValue::from(10_i64), ParamValue::from(32_i64)];
let result: i64 = conn
.query_row_map("SELECT ?1 + ?2", &p, |row| row.get_typed(0))
.unwrap();
assert_eq!(result, 42);
}
#[test]
fn query_map_collect_returns_vec() {
let conn = Connection::open(":memory:").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)")
.unwrap();
conn.execute("INSERT INTO t (val) VALUES ('a')").unwrap();
conn.execute("INSERT INTO t (val) VALUES ('b')").unwrap();
conn.execute("INSERT INTO t (val) VALUES ('c')").unwrap();
let results: Vec<String> = conn
.query_map_collect("SELECT val FROM t ORDER BY id", &[], |row| row.get_typed(0))
.unwrap();
assert_eq!(results, vec!["a", "b", "c"]);
}
#[test]
fn query_map_collect_supports_side_effect_only_row_processing() {
let conn = Connection::open(":memory:").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)")
.unwrap();
conn.execute("INSERT INTO t (val) VALUES ('a')").unwrap();
conn.execute("INSERT INTO t (val) VALUES ('b')").unwrap();
conn.execute("INSERT INTO t (val) VALUES ('c')").unwrap();
let mut seen = Vec::new();
let results: Vec<()> = conn
.query_map_collect("SELECT val FROM t ORDER BY id", &[], |row| {
seen.push(row.get_typed::<String>(0)?);
Ok(())
})
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(seen, vec!["a", "b", "c"]);
}
#[test]
fn query_map_collect_supports_explain_statements() {
let conn = Connection::open(":memory:").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)")
.unwrap();
conn.execute("INSERT INTO t (val) VALUES ('a')").unwrap();
conn.execute("INSERT INTO t (val) VALUES ('b')").unwrap();
let opcodes: Vec<String> = conn
.query_map_collect(
"EXPLAIN SELECT val FROM t WHERE id = ?1",
&[ParamValue::from(1_i64)],
|row| row.get_typed(1),
)
.unwrap();
assert!(!opcodes.is_empty());
assert!(opcodes.iter().any(|opcode| opcode == "OpenRead"));
}
#[test]
fn execute_params_with_values() {
let conn = Connection::open(":memory:").unwrap();
conn.execute("CREATE TABLE t (id INTEGER, name TEXT)")
.unwrap();
let p = [ParamValue::from(1_i64), ParamValue::from("alice")];
let affected = conn
.execute_compat("INSERT INTO t VALUES (?1, ?2)", &p)
.unwrap();
assert_eq!(affected, 1);
}
#[test]
fn query_map_collect_composite_unique_index_returns_only_matching_duplicate_run() {
let dir = tempfile::tempdir().expect("create temp dir");
let db_path = dir.path().join("messages.db");
{
let mut conn = rusqlite::Connection::open(&db_path).expect("open sqlite db");
conn.execute_batch(
"CREATE TABLE messages (
id INTEGER PRIMARY KEY,
conversation_id INTEGER NOT NULL,
idx INTEGER NOT NULL,
role TEXT,
author TEXT,
created_at INTEGER,
content TEXT,
UNIQUE(conversation_id, idx)
);",
)
.expect("create schema");
let tx = conn.transaction().expect("begin tx");
tx.execute(
"INSERT INTO messages (id, conversation_id, idx, role, author, created_at, content)
VALUES (1, 1, 0, 'user', 'u', 1000, 'first')",
[],
)
.expect("insert first");
tx.execute(
"INSERT INTO messages (id, conversation_id, idx, role, author, created_at, content)
VALUES (2, 1, 1, 'assistant', 'a', 1001, 'second')",
[],
)
.expect("insert second");
for (next_id, conversation_id) in (3_i64..).zip(2_i64..=25_000_i64) {
tx.execute(
"INSERT INTO messages (id, conversation_id, idx, role, author, created_at, content)
VALUES (?1, ?2, 0, 'assistant', 'bulk', ?3, ?4)",
params![
next_id,
conversation_id,
1_700_000_000_i64 + conversation_id,
format!("bulk-{conversation_id}")
],
)
.expect("insert bulk row");
}
tx.commit().expect("commit fixture");
}
let conn = Connection::open(db_path.to_str().expect("utf8 path")).expect("open fsqlite db");
let rows: Vec<(i64, i64, String)> = conn
.query_map_collect(
"SELECT id, idx, content
FROM messages INDEXED BY sqlite_autoindex_messages_1
WHERE conversation_id = ?1
ORDER BY idx",
&[ParamValue::from(1_i64)],
|row| Ok((row.get_typed(0)?, row.get_typed(1)?, row.get_typed(2)?)),
)
.expect("query composite unique index");
assert_eq!(
rows,
vec![(1, 0, "first".to_owned()), (2, 1, "second".to_owned()),],
"indexed equality scan should stay within the conversation_id=1 duplicate run",
);
let readonly = open_with_flags(
db_path.to_str().expect("utf8 path"),
OpenFlags::SQLITE_OPEN_READ_ONLY,
)
.expect("open readonly fsqlite db");
let readonly_rows: Vec<(i64, i64, String)> = readonly
.query_map_collect(
"SELECT id, idx, content
FROM messages INDEXED BY sqlite_autoindex_messages_1
WHERE conversation_id = ?1
ORDER BY idx",
&[ParamValue::from(1_i64)],
|row| Ok((row.get_typed(0)?, row.get_typed(1)?, row.get_typed(2)?)),
)
.expect("query composite unique index via readonly path");
assert_eq!(
readonly_rows,
vec![(1, 0, "first".to_owned()), (2, 1, "second".to_owned()),],
"readonly indexed equality scan should stay within the conversation_id=1 duplicate run",
);
}
#[test]
#[ignore = "machine-local cass repro; run with FSQLITE_REAL_DB=/path/to/agent_search.db"]
fn query_map_collect_real_cass_db_repro() {
let db_path = std::env::var("FSQLITE_REAL_DB").expect("FSQLITE_REAL_DB must be set");
let conn = open_with_flags(&db_path, OpenFlags::SQLITE_OPEN_READ_ONLY)
.expect("open readonly real cass db");
let query = "SELECT id, idx, content
FROM messages INDEXED BY sqlite_autoindex_messages_1
WHERE conversation_id = ?1
ORDER BY idx
LIMIT 20";
let stmt = conn.prepare(query).expect("prepare real cass query");
eprintln!("real_cass_query_explain:\n{}", stmt.explain());
let rows: Vec<(i64, i64, String)> = conn
.query_map_collect(query, &[ParamValue::from(1_i64)], |row| {
Ok((row.get_typed(0)?, row.get_typed(1)?, row.get_typed(2)?))
})
.expect("query real cass db");
assert_eq!(
rows.len(),
2,
"conversation_id=1 should only have two rows in the canonical cass db"
);
assert_eq!(rows[0].0, 1);
assert_eq!(rows[1].0, 2);
}
#[test]
#[ignore = "machine-local cass repro; run with FSQLITE_REAL_DB=/path/to/agent_search.db"]
fn query_rowid_lookup_real_cass_db_repro() {
let db_path = std::env::var("FSQLITE_REAL_DB").expect("FSQLITE_REAL_DB must be set");
let conn = open_with_flags(&db_path, OpenFlags::SQLITE_OPEN_READ_ONLY)
.expect("open readonly real cass db");
let query = "SELECT id, conversation_id, idx, content
FROM messages
WHERE id = ?1";
let stmt = conn.prepare(query).expect("prepare real cass rowid query");
eprintln!("real_cass_rowid_query_explain:\n{}", stmt.explain());
let rows: Vec<(i64, i64, i64, String)> = conn
.query_map_collect(query, &[ParamValue::from(1_i64)], |row| {
Ok((
row.get_typed(0)?,
row.get_typed(1)?,
row.get_typed(2)?,
row.get_typed(3)?,
))
})
.expect("query real cass db by rowid");
assert_eq!(rows, vec![(1, 1, 0, "hello".to_owned())]);
}
}