#![deny(missing_docs)]
pub use spg_engine::{Engine, EngineError, QueryResult};
pub use spg_storage::Value;
use std::collections::BTreeMap;
use std::fs::{File, OpenOptions};
use std::io::{Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread::{self, JoinHandle};
use std::time::Duration;
use spg_manifest::{CatalogManifest, ColdSegmentEntry, manifest_path as spg_manifest_path};
const WAL_V2_SENTINEL: u32 = 0x8000_0000;
const WAL_V3_FLAG: u32 = 0x4000_0000;
const WAL_V3_TYPE_AUTO_COMMIT_SQL: u8 = 0x01;
fn default_checkpoint_threshold_bytes() -> u64 {
std::env::var("SPG_EMBEDDED_CHECKPOINT_BYTES")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|&n| n > 0)
.unwrap_or(4 * 1024 * 1024)
}
fn encode_v3_auto_commit(sql: &str) -> Vec<u8> {
let payload = sql.as_bytes();
let mut crc_buf = Vec::with_capacity(1 + payload.len());
crc_buf.push(WAL_V3_TYPE_AUTO_COMMIT_SQL);
crc_buf.extend_from_slice(payload);
let crc = spg_crypto::crc32::crc32(&crc_buf);
let header = ((payload.len() as u32) | WAL_V2_SENTINEL | WAL_V3_FLAG).to_le_bytes();
let mut out = Vec::with_capacity(4 + 4 + 1 + payload.len());
out.extend_from_slice(&header);
out.extend_from_slice(&crc.to_le_bytes());
out.push(WAL_V3_TYPE_AUTO_COMMIT_SQL);
out.extend_from_slice(payload);
out
}
fn replay_wal_into_engine(wal_bytes: &[u8], engine: &mut Engine) -> Result<usize, String> {
let mut applied = 0usize;
let mut cur = 0usize;
while cur < wal_bytes.len() {
if wal_bytes.len() - cur < 4 {
break;
}
let raw_len = u32::from_le_bytes(wal_bytes[cur..cur + 4].try_into().unwrap());
let is_v2 = raw_len & WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(WAL_V2_SENTINEL | WAL_V3_FLAG)
} else {
!WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
if wal_bytes.len() - cur < header_len + rec_len {
break;
}
if is_v3 {
let type_byte = wal_bytes[cur + 8];
match type_byte {
WAL_V3_TYPE_AUTO_COMMIT_SQL => {}
0x02 => {
cur += header_len + rec_len;
continue;
}
other => {
return Err(format!(
"WAL replay: unknown v3 type byte {other:#04x} at offset {cur}"
));
}
}
}
let sql_bytes = &wal_bytes[cur + header_len..cur + header_len + rec_len];
let sql = std::str::from_utf8(sql_bytes).map_err(|e| format!("WAL replay: non-UTF-8 SQL at offset {cur}: {e}"))?;
engine
.execute(sql)
.map_err(|e| format!("WAL replay: apply {sql:?} at offset {cur} rejected: {e:?}"))?;
applied += 1;
cur += header_len + rec_len;
}
Ok(applied)
}
fn sql_is_read_only(sql: &str) -> bool {
let t = sql.trim_start();
let head = t
.split(|c: char| c.is_whitespace() || c == ';' || c == '(')
.next()
.unwrap_or("");
matches!(
head.to_ascii_lowercase().as_str(),
"select"
| "show"
| "explain"
| "begin"
| "commit"
| "rollback"
| "checkpoint"
| "compact"
| "wait"
| "with"
)
}
#[derive(Debug)]
pub struct Database {
engine: Engine,
persistence: Option<PersistenceCtx>,
}
#[derive(Debug)]
#[allow(dead_code)] struct PersistenceCtx {
db_path: PathBuf,
wal_path: PathBuf,
wal: File,
wal_len: u64,
checkpoint_threshold_bytes: u64,
cold_segments_dir: PathBuf,
cold_segment_paths: BTreeMap<u32, PathBuf>,
}
impl Database {
#[must_use]
pub fn open_in_memory() -> Self {
Self {
engine: Engine::new(),
persistence: None,
}
}
pub fn open_path(db_path: impl AsRef<Path>) -> Result<Self, EngineError> {
let db_path = db_path.as_ref().to_path_buf();
let wal_path = {
let mut p = db_path.clone();
let name = p
.file_name()
.map(|n| {
let mut s = n.to_os_string();
s.push(".wal");
s
})
.unwrap_or_else(|| std::ffi::OsString::from(".wal"));
p.set_file_name(name);
p
};
if let Some(parent) = db_path.parent()
&& !parent.as_os_str().is_empty()
{
std::fs::create_dir_all(parent).map_err(io_err)?;
}
let mut engine = if db_path.exists() {
let bytes = std::fs::read(&db_path).map_err(io_err)?;
let engine = Engine::restore_envelope(&bytes).map_err(|e| {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!(
"restore from {}: {e}",
db_path.display()
)))
})?;
engine
} else {
Engine::new()
};
let cold_segments_dir = {
let parent = db_path.parent().unwrap_or_else(|| Path::new("."));
let stem = db_path
.file_stem()
.unwrap_or_else(|| std::ffi::OsStr::new("db"))
.to_string_lossy()
.into_owned();
parent.join(format!("{stem}.spg")).join("segments")
};
let mut cold_segment_paths: BTreeMap<u32, PathBuf> = BTreeMap::new();
let manifest_pth = spg_manifest_path(&db_path);
if manifest_pth.exists() && db_path.exists() {
let m_bytes = std::fs::read(&manifest_pth).map_err(io_err)?;
if let Ok(m) = CatalogManifest::deserialize(&m_bytes) {
let snap_bytes = std::fs::read(&db_path).map_err(io_err)?;
let snap_crc = spg_crypto::crc32::crc32(&snap_bytes);
if snap_crc == m.catalog_crc32 {
for entry in &m.cold_segments {
if let Ok(seg_bytes) = std::fs::read(&entry.path) {
let computed = spg_crypto::crc32::crc32(&seg_bytes);
if computed != entry.crc32 {
eprintln!(
"spg-embedded: manifest skip segment {}: CRC mismatch",
entry.segment_id
);
continue;
}
if engine
.catalog()
.cold_segment(entry.segment_id)
.is_some()
{
continue;
}
let mut new_cat = engine.catalog().clone();
if let Err(e) = new_cat
.load_segment_bytes_at(entry.segment_id, seg_bytes)
{
eprintln!(
"spg-embedded: manifest load segment {} failed: {e}",
entry.segment_id
);
continue;
}
engine.replace_catalog(new_cat);
cold_segment_paths
.insert(entry.segment_id, entry.path.clone());
} else {
eprintln!(
"spg-embedded: manifest skip segment {}: file unreadable",
entry.segment_id
);
}
}
}
}
}
if wal_path.exists() {
let wal_bytes = std::fs::read(&wal_path).map_err(io_err)?;
if !wal_bytes.is_empty() {
replay_wal_into_engine(&wal_bytes, &mut engine)
.map_err(|m| EngineError::Storage(spg_storage::StorageError::Corrupt(m)))?;
}
}
let wal = OpenOptions::new()
.create(true)
.append(true)
.read(true)
.open(&wal_path)
.map_err(io_err)?;
let wal_len = wal.metadata().map_err(io_err)?.len();
Ok(Self {
engine,
persistence: Some(PersistenceCtx {
db_path,
wal_path,
wal,
wal_len,
checkpoint_threshold_bytes: default_checkpoint_threshold_bytes(),
cold_segments_dir,
cold_segment_paths,
}),
})
}
pub fn freeze_oldest_to_cold(
&mut self,
table_name: &str,
index_name: &str,
max_rows: usize,
) -> Result<spg_storage::FreezeReport, EngineError> {
let report = self
.engine
.freeze_oldest_to_cold(table_name, index_name, max_rows)?;
if let Some(p) = &mut self.persistence {
std::fs::create_dir_all(&p.cold_segments_dir).map_err(io_err)?;
let final_path = p
.cold_segments_dir
.join(format!("seg_{}.spg", report.segment_id));
let tmp_path = p
.cold_segments_dir
.join(format!("seg_{}.spg.tmp", report.segment_id));
std::fs::write(&tmp_path, &report.segment_bytes).map_err(io_err)?;
std::fs::rename(&tmp_path, &final_path).map_err(io_err)?;
p.cold_segment_paths.insert(report.segment_id, final_path);
}
Ok(report)
}
pub fn set_checkpoint_threshold_bytes(&mut self, bytes: u64) {
if let Some(p) = &mut self.persistence {
p.checkpoint_threshold_bytes = bytes.max(1);
}
}
pub fn checkpoint(&mut self) -> Result<(), EngineError> {
let snapshot = self.engine.snapshot();
let Some(p) = &mut self.persistence else {
return Ok(());
};
let tmp = {
let mut t = p.db_path.clone();
let mut name = t
.file_name()
.map(std::ffi::OsStr::to_os_string)
.unwrap_or_default();
name.push(".tmp");
t.set_file_name(name);
t
};
std::fs::write(&tmp, &snapshot).map_err(io_err)?;
std::fs::rename(&tmp, &p.db_path).map_err(io_err)?;
if !p.cold_segment_paths.is_empty() {
let snap_crc = spg_crypto::crc32::crc32(&snapshot);
let entries: Vec<ColdSegmentEntry> = p
.cold_segment_paths
.iter()
.filter_map(|(&segment_id, path)| {
let bytes = std::fs::read(path).ok()?;
Some(ColdSegmentEntry {
segment_id,
path: path.clone(),
crc32: spg_crypto::crc32::crc32(&bytes),
})
})
.collect();
let manifest = CatalogManifest {
catalog_crc32: snap_crc,
cold_segments: entries,
wal_baseline_offset: 0,
};
let m_bytes = manifest.serialize();
let m_path = spg_manifest_path(&p.db_path);
if let Some(dir) = m_path.parent() {
std::fs::create_dir_all(dir).map_err(io_err)?;
}
let m_tmp = {
let mut t = m_path.clone();
let mut name = t
.file_name()
.map(std::ffi::OsStr::to_os_string)
.unwrap_or_default();
name.push(".tmp");
t.set_file_name(name);
t
};
std::fs::write(&m_tmp, &m_bytes).map_err(io_err)?;
std::fs::rename(&m_tmp, &m_path).map_err(io_err)?;
}
p.wal.set_len(0).map_err(io_err)?;
p.wal.seek(SeekFrom::Start(0)).map_err(io_err)?;
p.wal.sync_data().map_err(io_err)?;
p.wal_len = 0;
Ok(())
}
pub fn restore(snapshot: &[u8]) -> Result<Self, EngineError> {
let engine = Engine::restore_envelope(snapshot)
.map_err(|e| EngineError::Storage(spg_storage::StorageError::Corrupt(format!("restore: {e}"))))?;
Ok(Self {
engine,
persistence: None,
})
}
#[must_use]
pub fn snapshot(&self) -> Vec<u8> {
self.engine.snapshot()
}
pub fn execute(&mut self, sql: &str) -> Result<QueryResult, EngineError> {
let result = self.engine.execute(sql)?;
if self.persistence.is_some()
&& !sql_is_read_only(sql)
&& matches!(&result, QueryResult::CommandOk { modified_catalog: true, .. })
{
let record = encode_v3_auto_commit(sql);
let p = self.persistence.as_mut().expect("checked above");
p.wal.write_all(&record).map_err(io_err)?;
p.wal.sync_data().map_err(io_err)?;
p.wal_len = p.wal_len.saturating_add(record.len() as u64);
if p.wal_len >= p.checkpoint_threshold_bytes {
self.checkpoint()?;
}
}
Ok(result)
}
pub fn query_typed<T: FromSpgRow>(&mut self, sql: &str) -> Result<Vec<T>, EngineError> {
let rows = self.query(sql)?;
rows.into_iter()
.map(|r| T::from_spg_row(&r))
.collect()
}
pub fn query(&mut self, sql: &str) -> Result<Vec<Vec<Value>>, EngineError> {
match self.engine.execute(sql)? {
QueryResult::Rows { rows, .. } => Ok(rows.into_iter().map(|r| r.values).collect()),
QueryResult::CommandOk { .. } => Err(EngineError::Unsupported(
"query() expects a SELECT — use execute() for DML/DDL".into(),
)),
_ => Err(EngineError::Unsupported(
"query() expects a SELECT — use execute() for DML/DDL".into(),
)),
}
}
#[must_use]
pub const fn engine(&self) -> &Engine {
&self.engine
}
pub const fn engine_mut(&mut self) -> &mut Engine {
&mut self.engine
}
pub fn with_transaction<R, F>(&mut self, body: F) -> Result<R, EngineError>
where
F: FnOnce(&mut Self) -> Result<R, EngineError>,
{
self.execute("BEGIN")?;
match body(self) {
Ok(value) => {
self.execute("COMMIT")?;
Ok(value)
}
Err(e) => {
let _ = self.execute("ROLLBACK");
Err(e)
}
}
}
}
impl Default for Database {
fn default() -> Self {
Self::open_in_memory()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub struct EmbeddedMetrics {
pub hot_rows: u64,
pub hot_bytes: u64,
pub cold_segments: u64,
pub tables: u64,
pub wal_bytes: u64,
pub persistent: bool,
}
#[must_use = "the background freezer keeps running until this handle is dropped"]
#[derive(Debug)]
pub struct FreezerHandle {
shutdown: Arc<AtomicBool>,
join: Option<JoinHandle<()>>,
}
impl FreezerHandle {
pub fn stop(&mut self) {
self.shutdown.store(true, Ordering::Release);
if let Some(h) = self.join.take() {
let _ = h.join();
}
}
}
impl Drop for FreezerHandle {
fn drop(&mut self) {
self.stop();
}
}
#[derive(Debug, Clone)]
pub struct FreezerOptions {
pub tick: Duration,
pub hot_tier_bytes: u64,
pub batch_rows: usize,
pub compact_when_segments_exceed: usize,
pub compact_target_bytes: u64,
}
impl Default for FreezerOptions {
fn default() -> Self {
Self {
tick: Duration::from_secs(1),
hot_tier_bytes: 4 * 1024 * 1024 * 1024,
batch_rows: 1000,
compact_when_segments_exceed: 64,
compact_target_bytes: 64 * 1024 * 1024,
}
}
}
impl Database {
#[must_use]
pub fn cold_segment_count(&self) -> usize {
self.engine.catalog().cold_segment_count()
}
#[must_use]
pub fn metrics(&self) -> EmbeddedMetrics {
let cat = self.engine.catalog();
let mut hot_rows: u64 = 0;
let mut hot_bytes: u64 = 0;
for name in cat.table_names() {
if let Some(t) = cat.get(&name) {
hot_rows = hot_rows.saturating_add(t.row_count() as u64);
hot_bytes = hot_bytes.saturating_add(t.hot_bytes());
}
}
let (wal_bytes, persistent) = match &self.persistence {
Some(p) => (p.wal_len, true),
None => (0, false),
};
EmbeddedMetrics {
hot_rows,
hot_bytes,
cold_segments: cat.cold_segment_count() as u64,
tables: cat.table_count() as u64,
wal_bytes,
persistent,
}
}
pub fn spawn_background_freezer(
db: Arc<Mutex<Database>>,
opts: FreezerOptions,
) -> FreezerHandle {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_for_thread = Arc::clone(&shutdown);
let join = thread::Builder::new()
.name("spg-embedded-freezer".into())
.spawn(move || {
background_freezer_loop(db, opts, shutdown_for_thread);
})
.expect("spawn background freezer thread");
FreezerHandle {
shutdown,
join: Some(join),
}
}
}
fn background_freezer_loop(
db: Arc<Mutex<Database>>,
opts: FreezerOptions,
shutdown: Arc<AtomicBool>,
) {
let slice = Duration::from_millis(50.min(opts.tick.as_millis() as u64));
let mut last_tick = std::time::Instant::now();
loop {
if shutdown.load(Ordering::Acquire) {
return;
}
thread::sleep(slice);
if last_tick.elapsed() < opts.tick {
continue;
}
last_tick = std::time::Instant::now();
let Ok(mut guard) = db.lock() else {
return;
};
if guard.engine.catalog().hot_tier_bytes() <= opts.hot_tier_bytes {
continue;
}
let Some((table, index)) = pick_freeze_target(&guard) else {
continue;
};
let row_count = guard
.engine
.catalog()
.get(&table)
.map_or(0, spg_storage::Table::row_count);
let to_freeze = opts.batch_rows.min(row_count);
if to_freeze == 0 {
continue;
}
if let Err(e) = guard.freeze_oldest_to_cold(&table, &index, to_freeze) {
eprintln!(
"spg-embedded: background freeze on {table}.{index} failed: {e:?}"
);
continue;
}
let count = guard.engine.catalog().cold_segment_count();
if count > opts.compact_when_segments_exceed {
if let Err(e) = guard
.engine
.compact_cold_segments_with_target(opts.compact_target_bytes)
{
eprintln!(
"spg-embedded: background compact failed (segments={count}, \
threshold={}): {e:?}",
opts.compact_when_segments_exceed,
);
}
}
}
}
fn pick_freeze_target(db: &Database) -> Option<(String, String)> {
let cat = db.engine.catalog();
let mut best: Option<(String, String, u64)> = None;
for name in cat.table_names() {
let Some(t) = cat.get(&name) else { continue };
if t.row_count() == 0 {
continue;
}
let cols = &t.schema().columns;
let Some(idx) = t.indices().iter().find(|i| {
matches!(i.kind, spg_storage::IndexKind::BTree(_))
&& i.column_position < cols.len()
&& matches!(
cols[i.column_position].ty,
spg_storage::DataType::SmallInt
| spg_storage::DataType::Int
| spg_storage::DataType::BigInt
)
}) else {
continue;
};
let hot = t.hot_bytes();
match best {
None => best = Some((name, idx.name.clone(), hot)),
Some((_, _, best_hot)) if hot > best_hot => {
best = Some((name, idx.name.clone(), hot));
}
_ => {}
}
}
best.map(|(t, i, _)| (t, i))
}
pub fn revert_wal_to_seq(
wal_path: impl AsRef<Path>,
to_seq: u64,
out_db_path: impl AsRef<Path>,
) -> Result<u64, EngineError> {
let wal_bytes = std::fs::read(wal_path.as_ref()).map_err(io_err)?;
let mut engine = Engine::new();
let mut applied = 0u64;
let mut cur = 0usize;
while cur < wal_bytes.len() && applied < to_seq {
let (sql_bytes, total) = decode_wal_record(&wal_bytes[cur..])?;
cur += total;
if sql_bytes.is_empty() {
continue;
}
let sql = core::str::from_utf8(&sql_bytes).map_err(|e| {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!(
"WAL record at offset {cur}: non-UTF-8 SQL: {e}"
)))
})?;
engine.execute(sql)?;
applied += 1;
}
let snapshot = engine.snapshot();
std::fs::write(out_db_path.as_ref(), &snapshot).map_err(io_err)?;
Ok(applied)
}
fn decode_wal_record(tail: &[u8]) -> Result<(Vec<u8>, usize), EngineError> {
if tail.len() < 4 {
return Err(EngineError::Storage(spg_storage::StorageError::Corrupt(
format!("WAL truncated record: {} < 4 header bytes", tail.len()),
)));
}
let raw_len = u32::from_le_bytes(tail[..4].try_into().unwrap());
let is_v2 = raw_len & WAL_V2_SENTINEL != 0;
let is_v3 = is_v2 && (raw_len & WAL_V3_FLAG != 0);
let len_mask = if is_v3 {
!(WAL_V2_SENTINEL | WAL_V3_FLAG)
} else {
!WAL_V2_SENTINEL
};
let rec_len = (raw_len & len_mask) as usize;
let header_len = if is_v3 {
9
} else if is_v2 {
8
} else {
4
};
if tail.len() < header_len + rec_len {
return Err(EngineError::Storage(spg_storage::StorageError::Corrupt(
format!(
"WAL truncated record: header+payload {} > available {}",
header_len + rec_len,
tail.len()
),
)));
}
let payload = &tail[header_len..header_len + rec_len];
let sql_bytes = if is_v3 {
let type_byte = tail[8];
if type_byte == WAL_V3_TYPE_AUTO_COMMIT_SQL {
payload.to_vec()
} else {
Vec::new()
}
} else {
payload.to_vec()
};
Ok((sql_bytes, header_len + rec_len))
}
impl Drop for Database {
fn drop(&mut self) {
if self.persistence.is_some() {
if let Err(e) = self.checkpoint() {
eprintln!(
"spg-embedded: final checkpoint on Drop failed: {e:?} \
(WAL is intact; next open_path will replay)"
);
}
}
}
}
fn io_err(e: std::io::Error) -> EngineError {
EngineError::Storage(spg_storage::StorageError::Corrupt(format!("io: {e}")))
}
#[allow(dead_code)]
fn _database_is_send() {
fn assert_send<T: Send>() {}
assert_send::<Database>();
}
pub trait FromSpgRow: Sized {
fn from_spg_row(row: &[Value]) -> Result<Self, EngineError>;
}
#[macro_export]
macro_rules! spg_row {
(
$(#[$meta:meta])*
$vis:vis struct $name:ident {
$(
$(#[$fmeta:meta])*
$fvis:vis $field:ident : $ty:ty,
)*
}
) => {
$(#[$meta])*
#[derive(Debug, Clone)]
$vis struct $name {
$(
$(#[$fmeta])*
$fvis $field : $ty,
)*
}
impl $crate::FromSpgRow for $name {
fn from_spg_row(row: &[$crate::Value]) -> ::core::result::Result<Self, $crate::EngineError> {
let mut __spg_row_iter = row.iter();
$(
let $field: $ty = {
let v = __spg_row_iter
.next()
.ok_or_else(|| $crate::EngineError::Unsupported(
::std::format!(
"spg_row! {}: missing column for field `{}`",
::core::stringify!($name),
::core::stringify!($field)
)
))?;
<$ty as $crate::FromSpgValue>::from_spg_value(v)
.map_err(|e| $crate::EngineError::Unsupported(
::std::format!(
"spg_row! {}: column `{}`: {}",
::core::stringify!($name),
::core::stringify!($field),
e
)
))?
};
)*
Ok(Self { $($field,)* })
}
}
};
}
pub trait FromSpgValue: Sized {
fn from_spg_value(v: &Value) -> Result<Self, &'static str>;
}
macro_rules! impl_from_value_int {
($($t:ty),* $(,)?) => {
$(
impl FromSpgValue for $t {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::SmallInt(n) => <$t>::try_from(*n).map_err(|_| "SmallInt does not fit target int type"),
Value::Int(n) => <$t>::try_from(*n).map_err(|_| "Int does not fit target int type"),
Value::BigInt(n) => <$t>::try_from(*n).map_err(|_| "BigInt does not fit target int type"),
Value::Null => Err("NULL in non-Option int column"),
_ => Err("non-integer value in int column"),
}
}
}
)*
};
}
impl_from_value_int!(i16, i32, i64);
impl FromSpgValue for f32 {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Float(f) => Ok(*f as f32),
Value::Null => Err("NULL in non-Option float column"),
_ => Err("non-float value in float column"),
}
}
}
impl FromSpgValue for f64 {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Float(f) => Ok(*f),
Value::Null => Err("NULL in non-Option float column"),
_ => Err("non-float value in float column"),
}
}
}
impl FromSpgValue for bool {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Bool(b) => Ok(*b),
Value::Null => Err("NULL in non-Option bool column"),
_ => Err("non-bool value in bool column"),
}
}
}
impl FromSpgValue for String {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Text(s) => Ok(s.clone()),
Value::Null => Err("NULL in non-Option text column"),
_ => Err("non-text value in String column"),
}
}
}
impl FromSpgValue for Vec<f32> {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Vector(xs) => Ok(xs.clone()),
Value::Null => Err("NULL in non-Option vector column"),
_ => Err("non-vector value in Vec<f32> column"),
}
}
}
impl<T: FromSpgValue> FromSpgValue for Option<T> {
fn from_spg_value(v: &Value) -> Result<Self, &'static str> {
match v {
Value::Null => Ok(None),
other => T::from_spg_value(other).map(Some),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn in_memory_create_insert_select() {
let mut db = Database::open_in_memory();
db.execute("CREATE TABLE t (id INT NOT NULL, name TEXT)").unwrap();
db.execute("INSERT INTO t VALUES (1, 'alice')").unwrap();
db.execute("INSERT INTO t VALUES (2, 'bob')").unwrap();
let rows = db.query("SELECT id FROM t WHERE id = 1").unwrap();
assert_eq!(rows.len(), 1);
match &rows[0][0] {
Value::Int(1) => {}
other => panic!("expected Int(1), got {other:?}"),
}
}
#[test]
fn query_on_non_select_errors() {
let mut db = Database::open_in_memory();
db.execute("CREATE TABLE t (id INT)").unwrap();
let r = db.query("INSERT INTO t VALUES (1)");
assert!(r.is_err(), "query() on INSERT must error");
}
#[test]
fn snapshot_roundtrip() {
let mut db = Database::open_in_memory();
db.execute("CREATE TABLE t (id INT NOT NULL)").unwrap();
db.execute("INSERT INTO t VALUES (42)").unwrap();
let bytes = db.snapshot();
let mut restored = Database::restore(&bytes).unwrap();
let rows = restored.query("SELECT id FROM t WHERE id = 42").unwrap();
assert_eq!(rows.len(), 1);
match &rows[0][0] {
Value::Int(42) => {}
other => panic!("expected Int(42), got {other:?}"),
}
}
#[test]
fn from_spg_row_trait_shape() {
struct User {
_id: i32,
}
impl FromSpgRow for User {
fn from_spg_row(row: &[Value]) -> Result<Self, EngineError> {
match row.first() {
Some(Value::Int(n)) => Ok(Self { _id: *n }),
_ => Err(EngineError::Unsupported("bad id".into())),
}
}
}
let row = vec![Value::Int(7)];
let _u = User::from_spg_row(&row).unwrap();
}
}