#[cfg(any(feature = "native-sqlite", feature = "_has-encryption"))]
use crate::errors::{DynoxideError, Result};
#[cfg(any(feature = "native-sqlite", feature = "_has-encryption"))]
use crate::storage_backend::clock::{Clock, SystemClock};
#[cfg(any(feature = "native-sqlite", feature = "_has-encryption"))]
use crate::storage_backend::sql_builders::{self, escape_table_name};
use crate::types::AttributeValue;
#[cfg(any(feature = "native-sqlite", feature = "_has-encryption"))]
use rusqlite::{Connection, params};
#[cfg(any(feature = "native-sqlite", feature = "_has-encryption"))]
use std::{cell::RefCell, collections::HashMap, sync::Arc};
#[cfg(any(feature = "native-sqlite", feature = "_has-encryption"))]
const SCHEMA_VERSION: &str = "8";
pub(crate) const HASH_BUCKETS: u32 = 4096;
pub fn compute_hash_prefix(pk_value: &AttributeValue) -> String {
let key_bytes = match pk_value {
AttributeValue::S(s) => s.as_bytes().to_vec(),
AttributeValue::N(n) => num_to_buffer(n),
AttributeValue::B(b) => b.clone(),
_ => vec![], };
let digest = md5::compute([b"Outliers" as &[u8], &key_bytes].concat());
format!("{:032x}", digest)[..6].to_string()
}
pub fn hash_bucket(hash_prefix: &str) -> u32 {
let prefix_3 = &hash_prefix[..3.min(hash_prefix.len())];
u32::from_str_radix(prefix_3, 16).unwrap_or(0)
}
fn num_to_buffer(num_str: &str) -> Vec<u8> {
let trimmed = num_str.trim();
if trimmed.is_empty() {
return vec![0x80];
}
use bigdecimal::BigDecimal;
use std::str::FromStr;
let bd = match BigDecimal::from_str(trimmed) {
Ok(v) => v,
Err(_) => return vec![0x80],
};
if bd.sign() == bigdecimal::num_bigint::Sign::NoSign {
return vec![0x80];
}
let is_negative = bd.sign() == bigdecimal::num_bigint::Sign::Minus;
let bd_abs = if is_negative { -&bd } else { bd.clone() };
let (mantissa, exponent) = extract_mantissa_and_exponent(&bd_abs);
if mantissa.is_empty() {
return vec![0x80];
}
let append_zero: i64 = if exponent % 2 != 0 { 1 } else { 0 };
let byte_len_no_exp = ((mantissa.len() as i64 + append_zero + 1) / 2) as usize;
let mut byte_array: Vec<u8>;
if byte_len_no_exp < 20 && is_negative {
byte_array = vec![0u8; byte_len_no_exp + 2];
byte_array[byte_len_no_exp + 1] = 102;
} else {
byte_array = vec![0u8; byte_len_no_exp + 1];
}
let exp_sum = exponent + append_zero;
let exp_byte_val = floor_div(exp_sum, 2) - 64;
if is_negative {
byte_array[0] = (exp_byte_val ^ !0i64) as u8;
} else {
byte_array[0] = exp_byte_val as u8;
}
let mut mi: i64 = 0; let mlen = mantissa.len() as i64;
let mut appended_zero = false;
while mi < mlen {
let bai = ((mi + append_zero) / 2 + 1) as usize; if append_zero != 0 && mi == 0 && !appended_zero {
byte_array[bai] = 0;
appended_zero = true;
mi -= 1; } else if (mi + append_zero) % 2 == 0 {
byte_array[bai] = mantissa[mi as usize] * 10;
} else {
byte_array[bai] += mantissa[mi as usize];
}
if ((mi + append_zero) % 2 != 0) || (mi == mlen - 1) {
if is_negative {
byte_array[bai] = 101u8.wrapping_sub(byte_array[bai]);
} else {
byte_array[bai] = byte_array[bai].wrapping_add(1);
}
}
mi += 1; }
byte_array
}
fn floor_div(a: i64, b: i64) -> i64 {
let d = a / b;
let r = a % b;
if (r != 0) && ((r ^ b) < 0) { d - 1 } else { d }
}
fn extract_mantissa_and_exponent(bd: &bigdecimal::BigDecimal) -> (Vec<u8>, i64) {
let normalized = bd.normalized();
let (bigint, scale) = normalized.as_bigint_and_exponent();
let digits_str = bigint.to_string();
let digits_str = digits_str.trim_start_matches('-');
let digits: Vec<u8> = digits_str
.chars()
.map(|c| c.to_digit(10).unwrap() as u8)
.collect();
let exponent = digits.len() as i64 - scale;
(digits, exponent)
}
pub fn hash_in_segment(hash_prefix: &str, segment: u32, total_segments: u32) -> bool {
let bucket = hash_bucket(hash_prefix);
let start = ceiling_div(HASH_BUCKETS * segment, total_segments);
let end = ceiling_div(HASH_BUCKETS * (segment + 1), total_segments) - 1;
bucket >= start && bucket <= end
}
pub(crate) fn ceiling_div(a: u32, b: u32) -> u32 {
a.div_ceil(b)
}
#[derive(Debug, Default)]
pub struct ScanParams<'a> {
pub limit: Option<usize>,
pub exclusive_start_pk: Option<&'a str>,
pub exclusive_start_sk: Option<&'a str>,
pub segment: Option<u32>,
pub total_segments: Option<u32>,
pub exclusive_start_base_pk: Option<&'a str>,
pub exclusive_start_base_sk: Option<&'a str>,
}
#[derive(Debug, Default)]
pub struct CreateTableMetadata<'a> {
pub table_name: &'a str,
pub key_schema: &'a str,
pub attribute_definitions: &'a str,
pub gsi_definitions: Option<&'a str>,
pub lsi_definitions: Option<&'a str>,
pub provisioned_throughput: Option<&'a str>,
pub created_at: i64,
pub sse_specification: Option<&'a str>,
pub table_class: Option<&'a str>,
pub deletion_protection_enabled: bool,
pub billing_mode: Option<&'a str>,
pub on_demand_throughput: Option<&'a str>,
}
#[derive(Debug, Default)]
pub struct QueryParams<'a> {
pub sk_condition: Option<&'a str>,
pub sk_params: &'a [&'a str],
pub forward: bool,
pub limit: Option<usize>,
pub exclusive_start_sk: Option<&'a str>,
pub exclusive_start_base_pk: Option<&'a str>,
pub exclusive_start_base_sk: Option<&'a str>,
}
#[cfg(any(feature = "native-sqlite", feature = "_has-encryption"))]
pub struct Storage {
conn: Connection,
metadata_cache: RefCell<HashMap<String, TableMetadata>>,
clock: Arc<dyn Clock>,
}
#[cfg(any(feature = "native-sqlite", feature = "_has-encryption"))]
impl Storage {
pub fn new(path: &str) -> Result<Self> {
let conn = Connection::open(path)?;
let mut storage = Self {
conn,
metadata_cache: RefCell::new(HashMap::new()),
clock: Arc::new(SystemClock),
};
storage.initialize().map_err(Self::maybe_encrypted_error)?;
Ok(storage)
}
pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
self.clock = clock;
self
}
pub(crate) fn clock(&self) -> &dyn Clock {
self.clock.as_ref()
}
fn maybe_encrypted_error(err: DynoxideError) -> DynoxideError {
if let DynoxideError::SqliteError(ref sqlite_err) = err {
if let Some(rusqlite::ErrorCode::NotADatabase) = sqlite_err.sqlite_error_code() {
return DynoxideError::InternalServerError(
"Database file is encrypted or not a valid SQLite database. \
If encrypted, enable the `encryption` or `encryption-cc` feature \
and use Database::new_encrypted() with the correct key."
.to_string(),
);
}
}
err
}
#[cfg(feature = "_has-encryption")]
pub fn new_encrypted(path: &str, key: &str) -> Result<Self> {
use zeroize::Zeroize;
let conn = Connection::open(path)?;
let mut pragma_val = format!("x'{key}'");
conn.pragma_update(None, "key", &pragma_val)?;
pragma_val.zeroize();
conn.execute_batch("SELECT count(*) FROM sqlite_master;")?;
let mut storage = Self {
conn,
metadata_cache: RefCell::new(HashMap::new()),
clock: Arc::new(SystemClock),
};
storage.initialize()?;
Ok(storage)
}
pub fn memory() -> Result<Self> {
let conn = Connection::open_in_memory()?;
let mut storage = Self {
conn,
metadata_cache: RefCell::new(HashMap::new()),
clock: Arc::new(SystemClock),
};
storage.initialize()?;
Ok(storage)
}
fn initialize(&mut self) -> Result<()> {
self.conn.pragma_update(None, "journal_mode", "WAL")?;
self.conn.create_scalar_function(
"fnv1a_hash",
1,
rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC
| rusqlite::functions::FunctionFlags::SQLITE_UTF8,
|ctx: &rusqlite::functions::Context| -> rusqlite::Result<i64> {
let pk_ref = ctx.get_raw(0);
let pk_bytes = match pk_ref {
rusqlite::types::ValueRef::Text(bytes) => bytes,
_ => {
return Err(rusqlite::Error::InvalidFunctionParameterType(
0,
rusqlite::types::Type::Text,
));
}
};
let mut hash: u32 = 2166136261;
for &byte in pk_bytes {
hash ^= byte as u32;
hash = hash.wrapping_mul(16777619);
}
Ok(hash as i64)
},
)?;
self.conn.execute_batch(sql_builders::INIT_SCHEMA)?;
let _ = self
.conn
.execute_batch("ALTER TABLE _stream_records ADD COLUMN user_identity TEXT");
self.conn.execute(
"INSERT OR IGNORE INTO _config (key, value) VALUES ('schema_version', ?1)",
params![SCHEMA_VERSION],
)?;
let version: i32 = self
.conn
.query_row(
"SELECT value FROM _config WHERE key = 'schema_version'",
[],
|r| r.get::<_, String>(0),
)
.unwrap_or_else(|_| "1".to_string())
.parse()
.unwrap_or(1);
if version < 2 {
self.migrate_v1_to_v2()?;
}
if version < 3 {
self.migrate_v2_to_v3()?;
}
if version < 4 {
self.migrate_v3_to_v4()?;
}
if version < 5 {
self.migrate_v4_to_v5()?;
}
if version < 6 {
self.migrate_v5_to_v6()?;
}
if version < 7 {
self.migrate_v6_to_v7()?;
}
if version < 8 {
self.migrate_v7_to_v8()?;
}
Ok(())
}
fn migrate_v1_to_v2(&self) -> Result<()> {
let mut stmt = self.conn.prepare("SELECT table_name FROM _tables")?;
let table_names: Vec<String> = stmt
.query_map([], |row| row.get(0))?
.collect::<std::result::Result<Vec<_>, _>>()?;
for table_name in &table_names {
let escaped = format!("\"{}\"", table_name.replace('"', "\"\""));
let _ = self.conn.execute(
&format!("ALTER TABLE {escaped} ADD COLUMN cached_at REAL"),
[],
);
}
self.conn.execute(
"INSERT OR REPLACE INTO _config (key, value) VALUES ('schema_version', '2')",
[],
)?;
Ok(())
}
fn migrate_v2_to_v3(&self) -> Result<()> {
let _ = self
.conn
.execute("ALTER TABLE _tables ADD COLUMN tags TEXT", []);
self.conn.execute(
"INSERT OR REPLACE INTO _config (key, value) VALUES ('schema_version', '3')",
[],
)?;
Ok(())
}
fn migrate_v3_to_v4(&self) -> Result<()> {
let _ = self
.conn
.execute("ALTER TABLE _tables ADD COLUMN sse_specification TEXT", []);
let _ = self
.conn
.execute("ALTER TABLE _tables ADD COLUMN table_class TEXT", []);
let _ = self.conn.execute(
"ALTER TABLE _tables ADD COLUMN deletion_protection_enabled INTEGER DEFAULT 0",
[],
);
self.conn.execute(
"INSERT OR REPLACE INTO _config (key, value) VALUES ('schema_version', '4')",
[],
)?;
Ok(())
}
fn migrate_v4_to_v5(&self) -> Result<()> {
let mut stmt = self
.conn
.prepare("SELECT table_name, gsi_definitions, lsi_definitions FROM _tables")?;
let tables: Vec<(String, Option<String>, Option<String>)> = stmt
.query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?
.collect::<std::result::Result<Vec<_>, _>>()?;
for (table_name, gsi_json, lsi_json) in &tables {
if let Some(json) = gsi_json {
if let Ok(gsis) = serde_json::from_str::<Vec<serde_json::Value>>(json) {
for gsi in &gsis {
if let Some(idx) = gsi.get("IndexName").and_then(|v| v.as_str()) {
let gsi_table = escape_table_name(&format!("{table_name}::gsi::{idx}"));
let idx_name =
escape_table_name(&format!("{table_name}::gsi::{idx}::base_key"));
let _ = self.conn.execute_batch(&format!(
"CREATE INDEX IF NOT EXISTS \"{idx_name}\" ON \"{gsi_table}\" (table_pk, table_sk)"
));
}
}
}
}
if let Some(json) = lsi_json {
if let Ok(lsis) = serde_json::from_str::<Vec<serde_json::Value>>(json) {
for lsi in &lsis {
if let Some(idx) = lsi.get("IndexName").and_then(|v| v.as_str()) {
let lsi_table = escape_table_name(&format!("{table_name}::lsi::{idx}"));
let idx_name =
escape_table_name(&format!("{table_name}::lsi::{idx}::base_key"));
let _ = self.conn.execute_batch(&format!(
"CREATE INDEX IF NOT EXISTS \"{idx_name}\" ON \"{lsi_table}\" (base_pk, base_sk)"
));
}
}
}
}
}
self.conn.execute(
"INSERT OR REPLACE INTO _config (key, value) VALUES ('schema_version', '5')",
[],
)?;
Ok(())
}
fn migrate_v5_to_v6(&self) -> Result<()> {
let mut stmt = self.conn.prepare("SELECT table_name FROM _tables")?;
let table_names: Vec<String> = stmt
.query_map([], |row| row.get(0))?
.collect::<std::result::Result<Vec<_>, _>>()?;
for table_name in &table_names {
let escaped = escape_table_name(table_name);
let _ = self.conn.execute(
&format!(
"ALTER TABLE \"{escaped}\" ADD COLUMN hash_prefix TEXT NOT NULL DEFAULT ''"
),
[],
);
}
self.conn.execute(
"INSERT OR REPLACE INTO _config (key, value) VALUES ('schema_version', '6')",
[],
)?;
Ok(())
}
fn migrate_v6_to_v7(&self) -> Result<()> {
let _ = self.conn.execute(
"ALTER TABLE _tables ADD COLUMN on_demand_throughput TEXT",
[],
);
self.conn.execute(
"INSERT OR REPLACE INTO _config (key, value) VALUES ('schema_version', '7')",
[],
)?;
Ok(())
}
fn migrate_v7_to_v8(&self) -> Result<()> {
let _ = self
.conn
.execute("ALTER TABLE _tables ADD COLUMN table_id TEXT", []);
let names: Vec<String> = {
let mut stmt = self
.conn
.prepare("SELECT table_name FROM _tables WHERE table_id IS NULL")?;
let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
rows.collect::<rusqlite::Result<Vec<String>>>()?
};
for name in names {
let id = uuid::Uuid::new_v4().to_string();
self.conn.execute(
"UPDATE _tables SET table_id = ?1 WHERE table_name = ?2",
params![id, name],
)?;
}
self.conn.execute(
"INSERT OR REPLACE INTO _config (key, value) VALUES ('schema_version', '8')",
[],
)?;
Ok(())
}
pub fn conn(&self) -> &Connection {
&self.conn
}
pub fn conn_mut(&mut self) -> &mut Connection {
&mut self.conn
}
pub fn insert_table_metadata(&self, m: &CreateTableMetadata) -> Result<()> {
let table_name = m.table_name;
let (sql, params) = sql_builders::insert_table_metadata(m);
self.conn
.execute(&sql, rusqlite::params_from_iter(params.iter()))?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
pub fn get_table_metadata(&self, table_name: &str) -> Result<Option<TableMetadata>> {
if let Some(cached) = self.metadata_cache.borrow().get(table_name) {
return Ok(Some(cached.clone()));
}
let (sql, params) = sql_builders::get_table_metadata(table_name);
let mut stmt = self.conn.prepare(&sql)?;
let result = stmt.query_row(rusqlite::params_from_iter(params.iter()), row_to_metadata);
match result {
Ok(meta) => {
self.metadata_cache
.borrow_mut()
.insert(table_name.to_string(), meta.clone());
Ok(Some(meta))
}
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(DynoxideError::from(e)),
}
}
pub fn delete_table_metadata(&self, table_name: &str) -> Result<bool> {
let (sql, params) = sql_builders::delete_table_metadata(table_name);
let affected = self
.conn
.execute(&sql, rusqlite::params_from_iter(params.iter()))?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(affected > 0)
}
pub fn update_table_metadata(
&self,
table_name: &str,
attribute_definitions: &str,
gsi_definitions: Option<&str>,
) -> Result<()> {
self.conn.execute(
"UPDATE _tables SET attribute_definitions = ?1, gsi_definitions = ?2 WHERE table_name = ?3",
params![attribute_definitions, gsi_definitions, table_name],
)?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
pub fn update_provisioned_throughput(
&self,
table_name: &str,
provisioned_throughput: &str,
) -> Result<()> {
self.conn.execute(
"UPDATE _tables SET provisioned_throughput = ?1 WHERE table_name = ?2",
params![provisioned_throughput, table_name],
)?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
pub fn clear_provisioned_throughput(&self, table_name: &str) -> Result<()> {
self.conn.execute(
"UPDATE _tables SET provisioned_throughput = NULL WHERE table_name = ?1",
params![table_name],
)?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
pub fn update_billing_mode(&self, table_name: &str, billing_mode: &str) -> Result<()> {
self.conn.execute(
"UPDATE _tables SET billing_mode = ?1 WHERE table_name = ?2",
params![billing_mode, table_name],
)?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
pub fn update_table_class(&self, table_name: &str, table_class: &str) -> Result<()> {
self.conn.execute(
"UPDATE _tables SET table_class = ?1 WHERE table_name = ?2",
params![table_class, table_name],
)?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
pub fn update_on_demand_throughput(
&self,
table_name: &str,
on_demand_throughput: &str,
) -> Result<()> {
self.conn.execute(
"UPDATE _tables SET on_demand_throughput = ?1 WHERE table_name = ?2",
params![on_demand_throughput, table_name],
)?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
pub fn get_tags(&self, table_name: &str) -> Result<Vec<crate::types::Tag>> {
let tags_json: Option<String> = self.conn.query_row(
"SELECT tags FROM _tables WHERE table_name = ?1",
params![table_name],
|row| row.get(0),
)?;
match tags_json {
Some(json) => serde_json::from_str(&json)
.map_err(|e| DynoxideError::InternalServerError(format!("Bad tags JSON: {e}"))),
None => Ok(Vec::new()),
}
}
pub fn set_tags(&self, table_name: &str, new_tags: &[crate::types::Tag]) -> Result<()> {
use std::collections::BTreeMap;
let existing = self.get_tags(table_name)?;
let mut tag_map: BTreeMap<String, String> =
existing.into_iter().map(|t| (t.key, t.value)).collect();
for tag in new_tags {
tag_map.insert(tag.key.clone(), tag.value.clone());
}
if tag_map.len() > 50 {
return Err(DynoxideError::ValidationException(
"One or more parameter values were invalid: \
Too many tags: tag limit is 50"
.to_string(),
));
}
let merged: Vec<crate::types::Tag> = tag_map
.into_iter()
.map(|(k, v)| crate::types::Tag { key: k, value: v })
.collect();
let json = serde_json::to_string(&merged)
.map_err(|e| DynoxideError::InternalServerError(e.to_string()))?;
self.conn.execute(
"UPDATE _tables SET tags = ?1 WHERE table_name = ?2",
params![json, table_name],
)?;
Ok(())
}
pub fn update_deletion_protection(&self, table_name: &str, enabled: bool) -> Result<()> {
self.conn.execute(
"UPDATE _tables SET deletion_protection_enabled = ?1 WHERE table_name = ?2",
params![enabled as i32, table_name],
)?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
pub fn remove_tags(&self, table_name: &str, keys: &[String]) -> Result<()> {
let mut tags = self.get_tags(table_name)?;
tags.retain(|t| !keys.contains(&t.key));
let json = if tags.is_empty() {
None
} else {
Some(
serde_json::to_string(&tags)
.map_err(|e| DynoxideError::InternalServerError(e.to_string()))?,
)
};
self.conn.execute(
"UPDATE _tables SET tags = ?1 WHERE table_name = ?2",
params![json, table_name],
)?;
Ok(())
}
pub fn list_table_names(&self) -> Result<Vec<String>> {
let (sql, params) = sql_builders::list_table_names();
let mut stmt = self.conn.prepare(&sql)?;
let names = stmt
.query_map(rusqlite::params_from_iter(params.iter()), |row| row.get(0))?
.collect::<std::result::Result<Vec<String>, _>>()?;
Ok(names)
}
pub fn table_exists(&self, table_name: &str) -> Result<bool> {
let (sql, params) = sql_builders::table_exists(table_name);
let count: i32 =
self.conn
.query_row(&sql, rusqlite::params_from_iter(params.iter()), |row| {
row.get(0)
})?;
Ok(count > 0)
}
#[allow(dead_code)]
pub(crate) fn invalidate_metadata_cache(&self, table_name: &str) {
self.metadata_cache.borrow_mut().remove(table_name);
}
pub fn create_data_table(&self, table_name: &str) -> Result<()> {
let (sql, params) = sql_builders::create_data_table(table_name);
self.conn
.execute(&sql, rusqlite::params_from_iter(params.iter()))?;
Ok(())
}
pub fn drop_data_table(&self, table_name: &str) -> Result<()> {
let (sql, params) = sql_builders::drop_data_table(table_name);
self.conn
.execute(&sql, rusqlite::params_from_iter(params.iter()))?;
Ok(())
}
pub fn create_gsi_table(&self, table_name: &str, index_name: &str) -> Result<()> {
let (sql, _) = sql_builders::create_gsi_table(table_name, index_name);
self.conn.execute_batch(&sql)?;
Ok(())
}
pub fn drop_gsi_table(&self, table_name: &str, index_name: &str) -> Result<()> {
let (sql, params) = sql_builders::drop_gsi_table(table_name, index_name);
self.conn
.execute(&sql, rusqlite::params_from_iter(params.iter()))?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn insert_gsi_item(
&self,
table_name: &str,
index_name: &str,
gsi_pk: &str,
gsi_sk: &str,
table_pk: &str,
table_sk: &str,
item_json: &str,
) -> Result<()> {
let sql = sql_builders::gsi_insert_sql(table_name, index_name);
let params = sql_builders::gsi_insert_params(gsi_pk, gsi_sk, table_pk, table_sk, item_json);
self.conn
.prepare_cached(&sql)?
.execute(rusqlite::params_from_iter(params.iter()))?;
Ok(())
}
pub fn insert_gsi_items(
&self,
table_name: &str,
index_name: &str,
rows: &[crate::storage_backend::GsiItemRow],
) -> Result<()> {
let sql = sql_builders::gsi_insert_sql(table_name, index_name);
let mut stmt = self.conn.prepare_cached(&sql)?;
for row in rows {
let params = sql_builders::gsi_insert_params(
&row.gsi_pk,
&row.gsi_sk,
&row.table_pk,
&row.table_sk,
&row.item_json,
);
stmt.execute(rusqlite::params_from_iter(params.iter()))?;
}
Ok(())
}
pub fn delete_gsi_item(
&self,
table_name: &str,
index_name: &str,
table_pk: &str,
table_sk: &str,
) -> Result<()> {
let (sql, params) =
sql_builders::delete_gsi_item(table_name, index_name, table_pk, table_sk);
self.conn
.prepare_cached(&sql)?
.execute(rusqlite::params_from_iter(params.iter()))?;
Ok(())
}
pub fn query_gsi_items(
&self,
table_name: &str,
index_name: &str,
gsi_pk: &str,
params: &QueryParams,
) -> Result<Vec<(String, String, String)>> {
let (sql, params_vec) =
sql_builders::query_gsi_items(table_name, index_name, gsi_pk, params);
let mut stmt = self.conn.prepare(&sql)?;
let rows = stmt
.query_map(rusqlite::params_from_iter(params_vec.iter()), |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn scan_gsi_items(
&self,
table_name: &str,
index_name: &str,
params: &ScanParams,
) -> Result<Vec<(String, String, String)>> {
let (sql, params_vec) = sql_builders::scan_gsi_items(table_name, index_name, params);
let mut stmt = self.conn.prepare(&sql)?;
let rows = stmt
.query_map(rusqlite::params_from_iter(params_vec.iter()), |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn create_lsi_table(&self, table_name: &str, index_name: &str) -> Result<()> {
let (sql, _) = sql_builders::create_lsi_table(table_name, index_name);
self.conn.execute_batch(&sql)?;
Ok(())
}
pub fn drop_lsi_table(&self, table_name: &str, index_name: &str) -> Result<()> {
let (sql, params) = sql_builders::drop_lsi_table(table_name, index_name);
self.conn
.execute(&sql, rusqlite::params_from_iter(params.iter()))?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn insert_lsi_item(
&self,
table_name: &str,
index_name: &str,
pk: &str,
sk: &str,
base_pk: &str,
base_sk: &str,
item_json: &str,
) -> Result<()> {
let sql = sql_builders::lsi_insert_sql(table_name, index_name);
let params = sql_builders::lsi_insert_params(pk, sk, base_pk, base_sk, item_json);
self.conn
.prepare_cached(&sql)?
.execute(rusqlite::params_from_iter(params.iter()))?;
Ok(())
}
pub fn delete_lsi_item(
&self,
table_name: &str,
index_name: &str,
base_pk: &str,
base_sk: &str,
) -> Result<()> {
let (sql, params) = sql_builders::delete_lsi_item(table_name, index_name, base_pk, base_sk);
self.conn
.prepare_cached(&sql)?
.execute(rusqlite::params_from_iter(params.iter()))?;
Ok(())
}
pub fn query_lsi_items(
&self,
table_name: &str,
index_name: &str,
pk: &str,
params: &QueryParams,
) -> Result<Vec<(String, String, String)>> {
let (sql, params_vec) = sql_builders::query_lsi_items(table_name, index_name, pk, params);
let mut stmt = self.conn.prepare(&sql)?;
let rows = stmt
.query_map(rusqlite::params_from_iter(params_vec.iter()), |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn scan_lsi_items(
&self,
table_name: &str,
index_name: &str,
params: &ScanParams,
) -> Result<Vec<(String, String, String)>> {
let (sql, params_vec) = sql_builders::scan_lsi_items(table_name, index_name, params);
let mut stmt = self.conn.prepare(&sql)?;
let rows = stmt
.query_map(rusqlite::params_from_iter(params_vec.iter()), |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn begin_transaction(&self) -> Result<()> {
self.conn.execute_batch(sql_builders::BEGIN)?;
Ok(())
}
pub fn commit(&self) -> Result<()> {
self.conn.execute_batch(sql_builders::COMMIT)?;
Ok(())
}
pub fn rollback(&self) -> Result<()> {
self.conn.execute_batch(sql_builders::ROLLBACK)?;
Ok(())
}
pub fn enable_bulk_loading(&self) -> Result<()> {
self.conn.execute_batch(
"PRAGMA synchronous = OFF;
PRAGMA cache_size = -64000;
PRAGMA temp_store = MEMORY;
PRAGMA mmap_size = 268435456;",
)?;
Ok(())
}
pub fn disable_bulk_loading(&self) -> Result<()> {
self.conn.execute_batch(
"PRAGMA synchronous = NORMAL;
PRAGMA cache_size = -2000;
PRAGMA temp_store = DEFAULT;
PRAGMA mmap_size = 0;",
)?;
Ok(())
}
pub fn put_item(
&self,
table_name: &str,
pk: &str,
sk: &str,
item_json: &str,
item_size: usize,
) -> Result<Option<String>> {
self.put_item_with_hash(table_name, pk, sk, item_json, item_size, "")
}
pub fn put_item_with_hash(
&self,
table_name: &str,
pk: &str,
sk: &str,
item_json: &str,
item_size: usize,
hash_prefix: &str,
) -> Result<Option<String>> {
let old_item = self.get_item(table_name, pk, sk)?;
let (sql, params) =
sql_builders::put_item_with_hash(table_name, pk, sk, item_json, item_size, hash_prefix);
self.conn
.execute(&sql, rusqlite::params_from_iter(params.iter()))?;
Ok(old_item)
}
pub fn put_base_items(
&self,
table_name: &str,
rows: &[crate::storage_backend::BaseItemRow],
) -> Result<()> {
let escaped = escape_table_name(table_name);
let sql = format!(
"INSERT OR REPLACE INTO \"{escaped}\" (pk, sk, item_json, item_size, cached_at, hash_prefix) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6)"
);
let mut stmt = self.conn.prepare_cached(&sql)?;
for row in rows {
stmt.execute(params![
row.pk,
row.sk,
row.item_json,
row.item_size as i64,
row.cached_at,
row.hash_prefix
])?;
}
Ok(())
}
pub fn get_item(&self, table_name: &str, pk: &str, sk: &str) -> Result<Option<String>> {
let (sql, params) = sql_builders::get_item(table_name, pk, sk);
let result = self
.conn
.query_row(&sql, rusqlite::params_from_iter(params.iter()), |row| {
row.get(0)
});
match result {
Ok(json) => Ok(Some(json)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(DynoxideError::from(e)),
}
}
pub fn get_partition_size(&self, table_name: &str, pk: &str) -> Result<i64> {
let (sql, params) = sql_builders::get_partition_size(table_name, pk);
let size: i64 =
self.conn
.query_row(&sql, rusqlite::params_from_iter(params.iter()), |row| {
row.get(0)
})?;
Ok(size)
}
pub fn get_lsi_partition_size(
&self,
table_name: &str,
index_name: &str,
pk: &str,
) -> Result<i64> {
let (sql, params) = sql_builders::get_lsi_partition_size(table_name, index_name, pk);
let size: i64 =
self.conn
.query_row(&sql, rusqlite::params_from_iter(params.iter()), |row| {
row.get(0)
})?;
Ok(size)
}
pub fn delete_item(&self, table_name: &str, pk: &str, sk: &str) -> Result<Option<String>> {
let old_item = self.get_item(table_name, pk, sk)?;
let (sql, params) = sql_builders::delete_item(table_name, pk, sk);
self.conn
.execute(&sql, rusqlite::params_from_iter(params.iter()))?;
Ok(old_item)
}
pub fn query_items(
&self,
table_name: &str,
pk: &str,
params: &QueryParams,
) -> Result<Vec<(String, String, String)>> {
let (sql, params_vec) = sql_builders::query_items(table_name, pk, params);
let mut stmt = self.conn.prepare(&sql)?;
let rows = stmt
.query_map(rusqlite::params_from_iter(params_vec.iter()), |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn scan_items(
&self,
table_name: &str,
params: &ScanParams,
) -> Result<Vec<(String, String, String)>> {
let (sql, params_vec) = sql_builders::scan_items(table_name, params);
let mut stmt = self.conn.prepare(&sql)?;
let rows = stmt
.query_map(rusqlite::params_from_iter(params_vec.iter()), |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn count_items(&self, table_name: &str) -> Result<i64> {
let (sql, params) = sql_builders::count_items(table_name);
let count: i64 =
self.conn
.query_row(&sql, rusqlite::params_from_iter(params.iter()), |row| {
row.get(0)
})?;
Ok(count)
}
pub fn db_path(&self) -> Option<String> {
self.conn
.path()
.filter(|p| !p.is_empty())
.map(|p| p.to_owned())
}
pub fn db_size_bytes(&self) -> Result<u64> {
let size: i64 = self.conn.query_row(
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
[],
|row| row.get(0),
)?;
Ok(size as u64)
}
pub fn table_count(&self) -> Result<usize> {
let count: i64 = self
.conn
.query_row("SELECT COUNT(*) FROM _tables", [], |row| row.get(0))?;
Ok(count as usize)
}
pub fn table_stats(&self) -> Result<Vec<TableStats>> {
let table_names = self.list_table_names()?;
let mut stats = Vec::with_capacity(table_names.len());
for name in table_names {
let sql = format!(
"SELECT COUNT(*), COALESCE(SUM(item_size), 0) FROM \"{}\"",
escape_table_name(&name)
);
let (item_count, size_bytes): (i64, i64) = self
.conn
.query_row(&sql, [], |row| Ok((row.get(0)?, row.get(1)?)))?;
stats.push(TableStats {
table_name: name,
item_count,
size_bytes: size_bytes as u64,
});
}
Ok(stats)
}
pub fn database_info(&self) -> Result<DatabaseInfo> {
let path = self.db_path();
let size_bytes = self.db_size_bytes()?;
let table_count = self.table_count()?;
let stats = self.table_stats()?;
let mut table_details = Vec::with_capacity(stats.len());
for s in stats {
let metadata = self.get_table_metadata(&s.table_name)?;
table_details.push(TableInfoEntry { stats: s, metadata });
}
Ok(DatabaseInfo {
path,
size_bytes,
table_count,
tables: table_details,
})
}
pub fn vacuum_into(&self, path: &str) -> Result<()> {
if path.contains('\0') {
return Err(DynoxideError::ValidationException(
"path contains null byte".to_string(),
));
}
self.conn
.execute_batch(&format!("VACUUM INTO '{}'", path.replace('\'', "''")))?;
Ok(())
}
pub fn vacuum(&self) -> Result<()> {
self.conn.execute_batch("VACUUM")?;
Ok(())
}
pub fn restore_from(&mut self, path: &str) -> Result<()> {
let source = Connection::open(path)?;
self.restore_from_connection(&source)
}
pub fn backup_to_memory(&self) -> Result<Connection> {
let mut dest = Connection::open_in_memory()?;
{
let backup = rusqlite::backup::Backup::new(&self.conn, &mut dest)?;
backup.run_to_completion(100, std::time::Duration::from_millis(0), None)?;
}
Ok(dest)
}
pub fn restore_from_connection(&mut self, source: &Connection) -> Result<()> {
let backup = rusqlite::backup::Backup::new(source, &mut self.conn)?;
backup.run_to_completion(100, std::time::Duration::from_millis(0), None)?;
self.metadata_cache.borrow_mut().clear();
Ok(())
}
pub fn connection_size_bytes(conn: &Connection) -> Result<u64> {
let size: i64 = conn.query_row(
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
[],
|row| row.get(0),
)?;
Ok(size as u64)
}
pub fn enable_stream(&self, table_name: &str, view_type: &str, label: &str) -> Result<()> {
self.conn.execute(
"UPDATE _tables SET stream_enabled = 1, stream_view_type = ?1, stream_label = ?2 WHERE table_name = ?3",
params![view_type, label, table_name],
)?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
pub fn disable_stream(&self, table_name: &str) -> Result<()> {
self.conn.execute(
"UPDATE _tables SET stream_enabled = 0 WHERE table_name = ?1",
params![table_name],
)?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn insert_stream_record(
&self,
table_name: &str,
event_name: &str,
keys_json: &str,
new_image: Option<&str>,
old_image: Option<&str>,
sequence_number: &str,
shard_id: &str,
created_at: i64,
) -> Result<()> {
self.insert_stream_record_with_identity(
table_name,
event_name,
keys_json,
new_image,
old_image,
sequence_number,
shard_id,
created_at,
None,
)
}
#[allow(clippy::too_many_arguments)]
pub fn insert_stream_record_with_identity(
&self,
table_name: &str,
event_name: &str,
keys_json: &str,
new_image: Option<&str>,
old_image: Option<&str>,
sequence_number: &str,
shard_id: &str,
created_at: i64,
user_identity: Option<&str>,
) -> Result<()> {
self.conn.execute(
"INSERT INTO _stream_records (table_name, event_name, keys_json, new_image, old_image, sequence_number, shard_id, created_at, user_identity)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
params![table_name, event_name, keys_json, new_image, old_image, sequence_number, shard_id, created_at, user_identity],
)?;
Ok(())
}
pub fn next_stream_sequence_number(&self, table_name: &str) -> Result<i64> {
let result: std::result::Result<i64, _> = self.conn.query_row(
"SELECT COALESCE(MAX(CAST(sequence_number AS INTEGER)), 0) + 1 FROM _stream_records WHERE table_name = ?1",
params![table_name],
|row| row.get(0),
);
match result {
Ok(n) => Ok(n),
Err(_) => Ok(1),
}
}
pub fn get_stream_records(
&self,
table_name: &str,
shard_id: &str,
after_sequence: i64,
limit: usize,
) -> Result<Vec<StreamRecord>> {
let mut stmt = self.conn.prepare(
"SELECT event_name, keys_json, new_image, old_image, sequence_number, created_at, user_identity
FROM _stream_records
WHERE table_name = ?1 AND shard_id = ?2 AND CAST(sequence_number AS INTEGER) > ?3
ORDER BY CAST(sequence_number AS INTEGER) ASC
LIMIT ?4",
)?;
let rows = stmt
.query_map(
params![table_name, shard_id, after_sequence, limit as i64],
|row| {
Ok(StreamRecord {
event_name: row.get(0)?,
keys_json: row.get(1)?,
new_image: row.get(2)?,
old_image: row.get(3)?,
sequence_number: row.get(4)?,
created_at: row.get(5)?,
user_identity: row.get(6)?,
})
},
)?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn list_stream_enabled_tables(&self) -> Result<Vec<TableMetadata>> {
let sql = format!(
"SELECT {} FROM _tables WHERE stream_enabled = 1 ORDER BY table_name",
sql_builders::TABLE_METADATA_COLUMNS
);
let mut stmt = self.conn.prepare(&sql)?;
let rows = stmt
.query_map([], row_to_metadata)?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn update_ttl_config(
&self,
table_name: &str,
attribute_name: Option<&str>,
enabled: bool,
) -> Result<()> {
self.conn.execute(
"UPDATE _tables SET ttl_attribute = ?1, ttl_enabled = ?2 WHERE table_name = ?3",
params![attribute_name, enabled as i32, table_name],
)?;
self.metadata_cache.borrow_mut().remove(table_name);
Ok(())
}
pub fn list_ttl_enabled_tables(&self) -> Result<Vec<TableMetadata>> {
let sql = format!(
"SELECT {} FROM _tables WHERE ttl_enabled = 1 ORDER BY table_name",
sql_builders::TABLE_METADATA_COLUMNS
);
let mut stmt = self.conn.prepare(&sql)?;
let rows = stmt
.query_map([], row_to_metadata)?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
pub fn get_shard_sequence_range(
&self,
table_name: &str,
shard_id: &str,
) -> Result<(Option<String>, Option<String>)> {
let result: std::result::Result<(Option<String>, Option<String>), _> = self.conn.query_row(
"SELECT MIN(sequence_number), MAX(sequence_number) FROM _stream_records WHERE table_name = ?1 AND shard_id = ?2",
params![table_name, shard_id],
|row| Ok((row.get(0)?, row.get(1)?)),
);
match result {
Ok(range) => Ok(range),
Err(_) => Ok((None, None)),
}
}
pub fn touch_cached_at(
&self,
table_name: &str,
pk: &str,
sk: &str,
timestamp: f64,
) -> Result<()> {
let sql = format!(
"UPDATE \"{}\" SET cached_at = ?1 WHERE pk = ?2 AND sk = ?3",
escape_table_name(table_name)
);
self.conn.execute(&sql, params![timestamp, pk, sk])?;
Ok(())
}
pub fn get_lru_items(
&self,
table_name: &str,
limit: usize,
) -> Result<Vec<(String, String, i64)>> {
let sql = format!(
"SELECT pk, sk, item_size FROM \"{}\" WHERE cached_at IS NOT NULL ORDER BY cached_at ASC LIMIT ?1",
escape_table_name(table_name)
);
let mut stmt = self.conn.prepare(&sql)?;
let rows = stmt
.query_map(params![limit as i64], |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(rows)
}
}
#[derive(Debug, Clone)]
pub struct StreamRecord {
pub event_name: String,
pub keys_json: String,
pub new_image: Option<String>,
pub old_image: Option<String>,
pub sequence_number: String,
pub created_at: i64,
pub user_identity: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TableStats {
pub table_name: String,
pub item_count: i64,
pub size_bytes: u64,
}
#[derive(Debug, Clone)]
pub struct DatabaseInfo {
pub path: Option<String>,
pub size_bytes: u64,
pub table_count: usize,
pub tables: Vec<TableInfoEntry>,
}
#[derive(Debug, Clone)]
pub struct TableInfoEntry {
pub stats: TableStats,
pub metadata: Option<TableMetadata>,
}
#[derive(Debug, Clone)]
pub struct TableMetadata {
pub table_name: String,
pub key_schema: String,
pub attribute_definitions: String,
pub gsi_definitions: Option<String>,
pub lsi_definitions: Option<String>,
pub stream_enabled: bool,
pub stream_view_type: Option<String>,
pub stream_label: Option<String>,
pub ttl_attribute: Option<String>,
pub ttl_enabled: bool,
pub created_at: i64,
pub table_status: String,
pub billing_mode: Option<String>,
pub provisioned_throughput: Option<String>,
pub sse_specification: Option<String>,
pub table_class: Option<String>,
pub deletion_protection_enabled: bool,
pub on_demand_throughput: Option<String>,
pub table_id: Option<String>,
}
#[cfg(any(feature = "native-sqlite", feature = "_has-encryption"))]
fn row_to_metadata(row: &rusqlite::Row) -> rusqlite::Result<TableMetadata> {
Ok(TableMetadata {
table_name: row.get(0)?,
key_schema: row.get(1)?,
attribute_definitions: row.get(2)?,
gsi_definitions: row.get(3)?,
lsi_definitions: row.get(4)?,
stream_enabled: row.get::<_, i32>(5)? != 0,
stream_view_type: row.get(6)?,
stream_label: row.get(7)?,
ttl_attribute: row.get(8)?,
ttl_enabled: row.get::<_, i32>(9)? != 0,
created_at: row.get(10)?,
table_status: row.get(11)?,
billing_mode: row.get(12)?,
provisioned_throughput: row.get(13)?,
sse_specification: row.get(14)?,
table_class: row.get(15)?,
deletion_protection_enabled: row.get::<_, i32>(16).unwrap_or(0) != 0,
on_demand_throughput: row.get(17)?,
table_id: row.get(18)?,
})
}
#[cfg(all(test, any(feature = "native-sqlite", feature = "_has-encryption")))]
mod tests {
use super::*;
fn test_storage() -> Storage {
Storage::memory().expect("Failed to create in-memory storage")
}
#[test]
fn test_initialize_creates_metadata_tables() {
let storage = test_storage();
let version: String = storage
.conn()
.query_row(
"SELECT value FROM _config WHERE key = 'schema_version'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(version, SCHEMA_VERSION);
}
#[test]
fn test_migrate_v6_to_v7_adds_on_demand_throughput_column() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let path = tmp.path().to_str().unwrap().to_string();
{
let conn = Connection::open(&path).unwrap();
conn.execute_batch(
"CREATE TABLE _config (key TEXT PRIMARY KEY, value TEXT NOT NULL);
CREATE TABLE _tables (
table_name TEXT PRIMARY KEY,
key_schema TEXT NOT NULL,
attribute_definitions TEXT NOT NULL,
gsi_definitions TEXT,
lsi_definitions TEXT,
stream_enabled INTEGER DEFAULT 0,
stream_view_type TEXT,
stream_label TEXT,
ttl_attribute TEXT,
ttl_enabled INTEGER DEFAULT 0,
created_at INTEGER NOT NULL,
table_status TEXT NOT NULL DEFAULT 'ACTIVE',
billing_mode TEXT DEFAULT 'PAY_PER_REQUEST',
provisioned_throughput TEXT,
tags TEXT,
sse_specification TEXT,
table_class TEXT,
deletion_protection_enabled INTEGER DEFAULT 0
);",
)
.unwrap();
conn.execute(
"INSERT INTO _config (key, value) VALUES ('schema_version', '6')",
[],
)
.unwrap();
conn.execute(
"INSERT INTO _tables (table_name, key_schema, attribute_definitions, created_at) \
VALUES ('LegacyTable', ?1, ?2, 0)",
params![
r#"[{"AttributeName":"pk","KeyType":"HASH"}]"#,
r#"[{"AttributeName":"pk","AttributeType":"S"}]"#,
],
)
.unwrap();
}
let storage = Storage::new(&path).unwrap();
let version: String = storage
.conn()
.query_row(
"SELECT value FROM _config WHERE key = 'schema_version'",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(version, SCHEMA_VERSION);
let meta = storage.get_table_metadata("LegacyTable").unwrap().unwrap();
assert_eq!(meta.table_name, "LegacyTable");
assert!(meta.on_demand_throughput.is_none());
let col: Option<String> = storage
.conn()
.query_row(
"SELECT on_demand_throughput FROM _tables WHERE table_name = 'LegacyTable'",
[],
|r| r.get(0),
)
.unwrap();
assert!(col.is_none());
}
#[test]
fn test_migrate_v7_to_v8_backfills_table_id() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let path = tmp.path().to_str().unwrap().to_string();
{
let conn = Connection::open(&path).unwrap();
conn.execute_batch(
"CREATE TABLE _config (key TEXT PRIMARY KEY, value TEXT NOT NULL);
CREATE TABLE _tables (
table_name TEXT PRIMARY KEY,
key_schema TEXT NOT NULL,
attribute_definitions TEXT NOT NULL,
gsi_definitions TEXT,
lsi_definitions TEXT,
stream_enabled INTEGER DEFAULT 0,
stream_view_type TEXT,
stream_label TEXT,
ttl_attribute TEXT,
ttl_enabled INTEGER DEFAULT 0,
created_at INTEGER NOT NULL,
table_status TEXT NOT NULL DEFAULT 'ACTIVE',
billing_mode TEXT DEFAULT 'PAY_PER_REQUEST',
provisioned_throughput TEXT,
tags TEXT,
sse_specification TEXT,
table_class TEXT,
deletion_protection_enabled INTEGER DEFAULT 0,
on_demand_throughput TEXT
);",
)
.unwrap();
conn.execute(
"INSERT INTO _config (key, value) VALUES ('schema_version', '7')",
[],
)
.unwrap();
conn.execute(
"INSERT INTO _tables (table_name, key_schema, attribute_definitions, created_at) \
VALUES ('LegacyTable', ?1, ?2, 0)",
params![
r#"[{"AttributeName":"pk","KeyType":"HASH"}]"#,
r#"[{"AttributeName":"pk","AttributeType":"S"}]"#,
],
)
.unwrap();
}
let storage = Storage::new(&path).unwrap();
let version: String = storage
.conn()
.query_row(
"SELECT value FROM _config WHERE key = 'schema_version'",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(version, SCHEMA_VERSION);
let meta = storage.get_table_metadata("LegacyTable").unwrap().unwrap();
let id = meta.table_id.expect("legacy table should be backfilled");
assert!(!id.is_empty());
drop(storage);
let storage2 = Storage::new(&path).unwrap();
let meta2 = storage2.get_table_metadata("LegacyTable").unwrap().unwrap();
assert_eq!(meta2.table_id.as_deref(), Some(id.as_str()));
}
#[test]
fn test_wal_mode_enabled() {
let storage = test_storage();
let mode: String = storage
.conn()
.query_row("PRAGMA journal_mode", [], |row| row.get(0))
.unwrap();
assert!(mode == "wal" || mode == "memory", "Got mode: {mode}");
}
#[test]
fn test_table_metadata_crud() {
let storage = test_storage();
assert!(!storage.table_exists("TestTable").unwrap());
assert!(storage.list_table_names().unwrap().is_empty());
storage
.insert_table_metadata(&CreateTableMetadata {
table_name: "TestTable",
key_schema: r#"[{"AttributeName":"pk","KeyType":"HASH"}]"#,
attribute_definitions: r#"[{"AttributeName":"pk","AttributeType":"S"}]"#,
created_at: 1000000,
..Default::default()
})
.unwrap();
assert!(storage.table_exists("TestTable").unwrap());
assert_eq!(storage.list_table_names().unwrap(), vec!["TestTable"]);
let meta = storage.get_table_metadata("TestTable").unwrap().unwrap();
assert_eq!(meta.table_name, "TestTable");
assert_eq!(meta.table_status, "ACTIVE");
assert_eq!(meta.created_at, 1000000);
assert!(storage.delete_table_metadata("TestTable").unwrap());
assert!(!storage.table_exists("TestTable").unwrap());
}
#[test]
fn test_create_and_drop_data_table() {
let storage = test_storage();
storage.create_data_table("MyTable").unwrap();
storage
.put_item("MyTable", "pk1", "", r#"{"pk":{"S":"pk1"}}"#, 10)
.unwrap();
let item = storage.get_item("MyTable", "pk1", "").unwrap();
assert!(item.is_some());
storage.drop_data_table("MyTable").unwrap();
}
#[test]
fn test_item_crud() {
let storage = test_storage();
storage.create_data_table("Items").unwrap();
let old = storage
.put_item(
"Items",
"user#1",
"profile",
r#"{"name":{"S":"Alice"}}"#,
20,
)
.unwrap();
assert!(old.is_none());
let item = storage.get_item("Items", "user#1", "profile").unwrap();
assert_eq!(item.unwrap(), r#"{"name":{"S":"Alice"}}"#);
let old = storage
.put_item("Items", "user#1", "profile", r#"{"name":{"S":"Bob"}}"#, 18)
.unwrap();
assert_eq!(old.unwrap(), r#"{"name":{"S":"Alice"}}"#);
let deleted = storage.delete_item("Items", "user#1", "profile").unwrap();
assert_eq!(deleted.unwrap(), r#"{"name":{"S":"Bob"}}"#);
assert!(
storage
.get_item("Items", "user#1", "profile")
.unwrap()
.is_none()
);
}
#[test]
fn test_query_items() {
let storage = test_storage();
storage.create_data_table("Orders").unwrap();
for i in 1..=5 {
let sk = format!("order#{i:03}");
let json = format!(r#"{{"id":{{"N":"{i}"}}}}"#);
storage
.put_item("Orders", "user#1", &sk, &json, 10)
.unwrap();
}
let results = storage
.query_items(
"Orders",
"user#1",
&QueryParams {
forward: true,
..Default::default()
},
)
.unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].1, "order#001");
let results = storage
.query_items(
"Orders",
"user#1",
&QueryParams {
forward: true,
limit: Some(2),
..Default::default()
},
)
.unwrap();
assert_eq!(results.len(), 2);
let results = storage
.query_items(
"Orders",
"user#1",
&QueryParams {
forward: false,
limit: Some(2),
..Default::default()
},
)
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].1, "order#005"); }
#[test]
fn test_scan_items() {
let storage = test_storage();
storage.create_data_table("ScanTest").unwrap();
storage.put_item("ScanTest", "a", "1", r#"{}"#, 2).unwrap();
storage.put_item("ScanTest", "b", "2", r#"{}"#, 2).unwrap();
storage.put_item("ScanTest", "c", "3", r#"{}"#, 2).unwrap();
let results = storage.scan_items("ScanTest", &Default::default()).unwrap();
assert_eq!(results.len(), 3);
let results = storage
.scan_items(
"ScanTest",
&ScanParams {
limit: Some(2),
..Default::default()
},
)
.unwrap();
assert_eq!(results.len(), 2);
let results = storage
.scan_items(
"ScanTest",
&ScanParams {
limit: Some(2),
exclusive_start_pk: Some("a"),
exclusive_start_sk: Some("1"),
..Default::default()
},
)
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "b"); }
#[test]
fn test_count_items() {
let storage = test_storage();
storage.create_data_table("CountTest").unwrap();
assert_eq!(storage.count_items("CountTest").unwrap(), 0);
storage.put_item("CountTest", "a", "", r#"{}"#, 2).unwrap();
storage.put_item("CountTest", "b", "", r#"{}"#, 2).unwrap();
assert_eq!(storage.count_items("CountTest").unwrap(), 2);
}
#[test]
fn test_gsi_table_lifecycle() {
let storage = test_storage();
storage.create_gsi_table("Orders", "ByDate").unwrap();
let gsi_name = "Orders::gsi::ByDate";
let sql = format!(
"INSERT INTO \"{}\" (gsi_pk, gsi_sk, table_pk, table_sk, item_json) VALUES (?1, ?2, ?3, ?4, ?5)",
gsi_name.replace('"', "\"\"")
);
storage
.conn()
.execute(
&sql,
params!["2024-01-01", "001", "user#1", "order#001", r#"{}"#],
)
.unwrap();
storage.drop_gsi_table("Orders", "ByDate").unwrap();
}
#[test]
fn test_nonexistent_table_metadata() {
let storage = test_storage();
assert!(storage.get_table_metadata("Nonexistent").unwrap().is_none());
assert!(!storage.delete_table_metadata("Nonexistent").unwrap());
}
#[test]
fn test_metadata_cache_hit() {
let storage = test_storage();
storage
.insert_table_metadata(&CreateTableMetadata {
table_name: "CacheTest",
key_schema: r#"[{"AttributeName":"pk","KeyType":"HASH"}]"#,
attribute_definitions: r#"[{"AttributeName":"pk","AttributeType":"S"}]"#,
created_at: 1000000,
..Default::default()
})
.unwrap();
let meta1 = storage.get_table_metadata("CacheTest").unwrap().unwrap();
assert_eq!(meta1.table_name, "CacheTest");
let meta2 = storage.get_table_metadata("CacheTest").unwrap().unwrap();
assert_eq!(meta2.table_name, "CacheTest");
assert_eq!(meta1.created_at, meta2.created_at);
assert!(storage.metadata_cache.borrow().contains_key("CacheTest"));
}
#[test]
fn test_metadata_cache_invalidated_on_delete() {
let storage = test_storage();
storage
.insert_table_metadata(&CreateTableMetadata {
table_name: "DelCache",
key_schema: r#"[{"AttributeName":"pk","KeyType":"HASH"}]"#,
attribute_definitions: r#"[{"AttributeName":"pk","AttributeType":"S"}]"#,
created_at: 1000000,
..Default::default()
})
.unwrap();
storage.get_table_metadata("DelCache").unwrap();
assert!(storage.metadata_cache.borrow().contains_key("DelCache"));
storage.delete_table_metadata("DelCache").unwrap();
assert!(!storage.metadata_cache.borrow().contains_key("DelCache"));
}
#[test]
fn test_metadata_cache_invalidated_on_stream_enable() {
let storage = test_storage();
storage
.insert_table_metadata(&CreateTableMetadata {
table_name: "StreamCache",
key_schema: r#"[{"AttributeName":"pk","KeyType":"HASH"}]"#,
attribute_definitions: r#"[{"AttributeName":"pk","AttributeType":"S"}]"#,
created_at: 1000000,
..Default::default()
})
.unwrap();
let meta = storage.get_table_metadata("StreamCache").unwrap().unwrap();
assert!(!meta.stream_enabled);
storage
.enable_stream("StreamCache", "NEW_AND_OLD_IMAGES", "2024-01-01T00:00:00")
.unwrap();
assert!(!storage.metadata_cache.borrow().contains_key("StreamCache"));
let meta = storage.get_table_metadata("StreamCache").unwrap().unwrap();
assert!(meta.stream_enabled);
}
#[test]
fn test_metadata_cache_invalidated_on_ttl_update() {
let storage = test_storage();
storage
.insert_table_metadata(&CreateTableMetadata {
table_name: "TtlCache",
key_schema: r#"[{"AttributeName":"pk","KeyType":"HASH"}]"#,
attribute_definitions: r#"[{"AttributeName":"pk","AttributeType":"S"}]"#,
created_at: 1000000,
..Default::default()
})
.unwrap();
let meta = storage.get_table_metadata("TtlCache").unwrap().unwrap();
assert!(!meta.ttl_enabled);
storage
.update_ttl_config("TtlCache", Some("expires_at"), true)
.unwrap();
assert!(!storage.metadata_cache.borrow().contains_key("TtlCache"));
let meta = storage.get_table_metadata("TtlCache").unwrap().unwrap();
assert!(meta.ttl_enabled);
assert_eq!(meta.ttl_attribute, Some("expires_at".to_string()));
}
#[test]
fn test_num_to_buffer_zero() {
assert_eq!(num_to_buffer("0"), vec![0x80]);
assert_eq!(num_to_buffer("-0"), vec![0x80]);
}
#[test]
fn test_hash_prefix_string_keys() {
let h1 = compute_hash_prefix(&AttributeValue::S("3635".into()));
let h2 = compute_hash_prefix(&AttributeValue::S("228".into()));
let h3 = compute_hash_prefix(&AttributeValue::S("1668".into()));
let h4 = compute_hash_prefix(&AttributeValue::S("3435".into()));
assert_eq!(
hash_bucket(&h1),
0,
"3635 should be bucket 0, got hash {h1}"
);
assert_eq!(hash_bucket(&h2), 0, "228 should be bucket 0, got hash {h2}");
assert_eq!(
hash_bucket(&h3),
1,
"1668 should be bucket 1, got hash {h3}"
);
assert_eq!(
hash_bucket(&h4),
4,
"3435 should be bucket 4, got hash {h4}"
);
}
#[test]
fn test_hash_prefix_number_keys() {
let h1 = compute_hash_prefix(&AttributeValue::N("251".into()));
assert_eq!(hash_bucket(&h1), 1, "251 should be bucket 1, got hash {h1}");
let h2 = compute_hash_prefix(&AttributeValue::N("2388".into()));
assert_eq!(
hash_bucket(&h2),
4095,
"2388 should be bucket 4095, got hash {h2}"
);
}
#[test]
fn test_hash_in_segment() {
assert!(hash_in_segment("000000", 0, 4096));
assert!(!hash_in_segment("000000", 1, 4096));
assert!(hash_in_segment("001000", 1, 4096));
assert!(!hash_in_segment("001000", 0, 4096));
assert!(hash_in_segment("fff000", 4095, 4096));
assert!(!hash_in_segment("fff000", 0, 4096));
assert!(hash_in_segment("000000", 0, 2));
assert!(hash_in_segment("7ff000", 0, 2));
assert!(hash_in_segment("800000", 1, 2));
assert!(hash_in_segment("fff000", 1, 2));
}
}