use crate::{ffi, mem};
#[derive(Debug, Clone)]
pub struct SparkEntry {
pub value: Vec<u8>,
pub ttl_secs: u32,
}
#[derive(Debug, thiserror::Error)]
pub enum SparkError {
#[error("spark: invalid TTL")]
InvalidTtl,
#[error("spark: value too large")]
TooLarge,
#[error("spark: write limit exceeded")]
WriteLimit,
#[error("spark: disk quota exceeded")]
QuotaExceeded,
#[error("spark: not available")]
NotAvailable,
#[error("spark: internal error")]
Internal,
#[error("spark: read limit exceeded")]
ReadLimit,
#[error("spark: invalid key")]
BadKey,
#[error("spark: no capability")]
NoCapability,
#[error("spark: unknown error code {0}")]
Unknown(i32),
}
impl SparkError {
fn from_code(code: i32) -> Self {
match code {
1 => Self::InvalidTtl,
2 => Self::TooLarge,
3 => Self::WriteLimit,
4 => Self::QuotaExceeded,
5 => Self::NotAvailable,
6 => Self::Internal,
7 => Self::ReadLimit,
8 => Self::BadKey,
9 => Self::NoCapability,
other => Self::Unknown(other),
}
}
}
pub fn get(key: &str) -> Option<SparkEntry> {
let (key_ptr, key_len) = mem::host_arg_str(key);
let result = unsafe { ffi::spark_get(key_ptr, key_len) };
if result == 0 {
return None;
}
let (ptr, len) = mem::decode_ptr_len(result);
if len < 4 {
return None;
}
let bytes = unsafe { mem::read_bytes(ptr, len) };
let ttl_secs = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
let value = bytes[4..].to_vec();
Some(SparkEntry { value, ttl_secs })
}
pub fn get_string(key: &str) -> Option<String> {
let entry = get(key)?;
String::from_utf8(entry.value).ok()
}
pub fn set(key: &str, value: &[u8], ttl_secs: u32) -> Result<(), SparkError> {
let (key_ptr, key_len) = mem::host_arg_str(key);
let (val_ptr, val_len) = mem::host_arg_bytes(value);
let code = unsafe { ffi::spark_set(key_ptr, key_len, val_ptr, val_len, ttl_secs as i32) };
if code == 0 {
Ok(())
} else {
Err(SparkError::from_code(code))
}
}
pub fn delete(key: &str) {
let (key_ptr, key_len) = mem::host_arg_str(key);
unsafe { ffi::spark_delete(key_ptr, key_len) }
}
pub fn list() -> Vec<String> {
let result = unsafe { ffi::spark_list() };
let Some(json_bytes) = (unsafe { mem::read_packed_bytes(result) }) else {
return Vec::new();
};
serde_json::from_slice(&json_bytes).unwrap_or_default()
}
#[derive(Debug, thiserror::Error)]
pub enum SparkPullError {
#[error("spark pull: not available")]
NotAvailable,
#[error("spark pull: internal error")]
Internal,
#[error("spark pull: no capability")]
NoCapability,
#[error("spark pull: invalid key or origin")]
BadKey,
#[error("spark pull: rate limited")]
WriteLimit,
#[error("spark pull: unknown error code {0}")]
Unknown(i32),
}
impl SparkPullError {
fn from_code(code: i32) -> Self {
match code {
3 => Self::WriteLimit,
5 => Self::NotAvailable,
6 => Self::Internal,
8 => Self::BadKey,
9 => Self::NoCapability,
other => Self::Unknown(other),
}
}
}
pub fn pull(origin_node: &str, keys: &[&str]) -> Result<u32, SparkPullError> {
let keys_json = serde_json::to_string(keys).unwrap_or_else(|_| String::from("[]"));
let (origin_ptr, origin_len) = mem::host_arg_str(origin_node);
let (keys_ptr, keys_len) = mem::host_arg_str(&keys_json);
let code = unsafe { ffi::spark_pull(origin_ptr, origin_len, keys_ptr, keys_len) };
if code >= 0 {
Ok(code as u32)
} else {
Err(SparkPullError::from_code(-code))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ffi::test_host;
#[test]
fn get_strips_ttl_prefix_and_returns_value() {
test_host::reset();
test_host::with_mock(|m| {
m.spark_store.insert("k".into(), (b"hello".to_vec(), 60));
});
let entry = get("k").expect("get should hit the store");
assert_eq!(entry.value, b"hello");
assert_eq!(entry.ttl_secs, 60);
}
#[test]
fn get_handles_zero_ttl_no_expiry() {
test_host::reset();
test_host::with_mock(|m| {
m.spark_store.insert("k".into(), (b"forever".to_vec(), 0));
});
let entry = get("k").unwrap();
assert_eq!(entry.ttl_secs, 0);
assert_eq!(entry.value, b"forever");
}
#[test]
fn get_returns_none_for_missing_key() {
test_host::reset();
assert!(get("missing").is_none());
}
#[test]
fn get_string_decodes_utf8() {
test_host::reset();
test_host::with_mock(|m| {
m.spark_store
.insert("k".into(), ("héllo".as_bytes().to_vec(), 30));
});
assert_eq!(get_string("k").as_deref(), Some("héllo"));
}
#[test]
fn get_string_returns_none_for_invalid_utf8() {
test_host::reset();
test_host::with_mock(|m| {
m.spark_store.insert("k".into(), (vec![0xff, 0xfe], 30));
});
assert!(get_string("k").is_none());
}
#[test]
fn set_writes_to_store() {
test_host::reset();
set("greeting", b"hi", 120).expect("set should succeed");
let stored = test_host::read_mock(|m| m.spark_store.get("greeting").cloned());
assert_eq!(stored, Some((b"hi".to_vec(), 120)));
}
#[test]
fn set_captures_args() {
test_host::reset();
set("k", b"v", 30).unwrap();
let captured = test_host::read_mock(|m| m.last_spark_set.clone());
assert_eq!(captured, Some(("k".into(), b"v".to_vec(), 30)));
}
#[test]
fn set_maps_error_codes() {
let cases = [
(1, SparkError::InvalidTtl),
(2, SparkError::TooLarge),
(3, SparkError::WriteLimit),
(4, SparkError::QuotaExceeded),
(5, SparkError::NotAvailable),
(6, SparkError::Internal),
(7, SparkError::ReadLimit),
(8, SparkError::BadKey),
(9, SparkError::NoCapability),
];
for (code, expected) in cases {
test_host::reset();
test_host::with_mock(|m| m.spark_set_error = code);
let err = set("k", b"v", 30).unwrap_err();
assert!(
std::mem::discriminant(&err) == std::mem::discriminant(&expected),
"code {} should map to {:?}, got {:?}",
code,
expected,
err,
);
}
}
#[test]
fn set_unknown_error_code() {
test_host::reset();
test_host::with_mock(|m| m.spark_set_error = 99);
match set("k", b"v", 30).unwrap_err() {
SparkError::Unknown(99) => {}
other => panic!("expected Unknown(99), got {:?}", other),
}
}
#[test]
fn delete_removes_from_store() {
test_host::reset();
test_host::with_mock(|m| {
m.spark_store.insert("k".into(), (b"v".to_vec(), 60));
});
delete("k");
assert!(test_host::read_mock(|m| m.spark_store.is_empty()));
assert_eq!(test_host::read_mock(|m| m.spark_deletes.clone()), vec!["k"]);
}
#[test]
fn list_returns_keys() {
test_host::reset();
test_host::with_mock(|m| {
m.spark_store.insert("a".into(), (b"1".to_vec(), 10));
m.spark_store.insert("b".into(), (b"2".to_vec(), 20));
});
let mut keys = list();
keys.sort();
assert_eq!(keys, vec!["a".to_string(), "b".to_string()]);
}
#[test]
fn list_empty_when_no_keys() {
test_host::reset();
assert!(list().is_empty());
}
#[test]
fn pull_serializes_keys_as_json() {
test_host::reset();
test_host::with_mock(|m| m.spark_pull_result = 3);
let count = pull("origin-node", &["a", "b", "c"]).unwrap();
assert_eq!(count, 3);
let calls = test_host::read_mock(|m| m.spark_pull_calls.clone());
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].0, "origin-node");
assert_eq!(calls[0].1, r#"["a","b","c"]"#);
}
#[test]
fn pull_zero_count_is_ok() {
test_host::reset();
test_host::with_mock(|m| m.spark_pull_result = 0);
assert_eq!(pull("o", &[]).unwrap(), 0);
}
#[test]
fn pull_error_from_code_mapping() {
match SparkPullError::from_code(3) {
SparkPullError::WriteLimit => {}
other => panic!("3 should map to WriteLimit, got {:?}", other),
}
match SparkPullError::from_code(5) {
SparkPullError::NotAvailable => {}
other => panic!("5 should map to NotAvailable, got {:?}", other),
}
match SparkPullError::from_code(6) {
SparkPullError::Internal => {}
other => panic!("6 should map to Internal, got {:?}", other),
}
match SparkPullError::from_code(8) {
SparkPullError::BadKey => {}
other => panic!("8 should map to BadKey, got {:?}", other),
}
match SparkPullError::from_code(9) {
SparkPullError::NoCapability => {}
other => panic!("9 should map to NoCapability, got {:?}", other),
}
match SparkPullError::from_code(99) {
SparkPullError::Unknown(99) => {}
other => panic!("99 should map to Unknown(99), got {:?}", other),
}
}
#[test]
fn pull_negative_code_maps_to_typed_error() {
let cases = [
(-3, SparkPullError::WriteLimit),
(-5, SparkPullError::NotAvailable),
(-6, SparkPullError::Internal),
(-8, SparkPullError::BadKey),
(-9, SparkPullError::NoCapability),
];
for (host_code, expected) in cases {
test_host::reset();
test_host::with_mock(|m| m.spark_pull_result = host_code);
let err = pull("origin", &["k"]).unwrap_err();
assert!(
std::mem::discriminant(&err) == std::mem::discriminant(&expected),
"host code {} should map to {:?}, got {:?}",
host_code,
expected,
err,
);
}
}
#[test]
fn pull_unknown_negative_code_is_unknown() {
test_host::reset();
test_host::with_mock(|m| m.spark_pull_result = -42);
match pull("origin", &["k"]).unwrap_err() {
SparkPullError::Unknown(42) => {}
other => panic!("expected Unknown(42), got {:?}", other),
}
}
#[test]
fn pull_positive_count_is_success() {
test_host::reset();
test_host::with_mock(|m| m.spark_pull_result = 7);
assert_eq!(pull("origin", &["a", "b"]).unwrap(), 7);
}
}