1use crate::{ffi, mem};
23
24#[derive(Debug, Clone)]
26pub struct SparkEntry {
27 pub value: Vec<u8>,
29
30 pub ttl_secs: u32,
33}
34
35#[derive(Debug, thiserror::Error)]
40pub enum SparkError {
41 #[error("spark: invalid TTL")]
43 InvalidTtl,
44
45 #[error("spark: value too large")]
48 TooLarge,
49
50 #[error("spark: write limit exceeded")]
52 WriteLimit,
53
54 #[error("spark: disk quota exceeded")]
56 QuotaExceeded,
57
58 #[error("spark: not available")]
60 NotAvailable,
61
62 #[error("spark: internal error")]
64 Internal,
65
66 #[error("spark: read limit exceeded")]
69 ReadLimit,
70
71 #[error("spark: invalid key")]
74 BadKey,
75
76 #[error("spark: no capability")]
78 NoCapability,
79
80 #[error("spark: unknown error code {0}")]
82 Unknown(i32),
83}
84
85impl SparkError {
86 fn from_code(code: i32) -> Self {
87 match code {
88 1 => Self::InvalidTtl,
89 2 => Self::TooLarge,
90 3 => Self::WriteLimit,
91 4 => Self::QuotaExceeded,
92 5 => Self::NotAvailable,
93 6 => Self::Internal,
94 7 => Self::ReadLimit,
95 8 => Self::BadKey,
96 9 => Self::NoCapability,
97 other => Self::Unknown(other),
98 }
99 }
100}
101
102pub fn get(key: &str) -> Option<SparkEntry> {
108 let (key_ptr, key_len) = mem::host_arg_str(key);
109 let result = unsafe { ffi::spark_get(key_ptr, key_len) };
110 if result == 0 {
111 return None;
112 }
113 let (ptr, len) = mem::decode_ptr_len(result);
114 if len < 4 {
115 return None;
117 }
118 let bytes = unsafe { mem::read_bytes(ptr, len) };
120 let ttl_secs = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
121 let value = bytes[4..].to_vec();
122 Some(SparkEntry { value, ttl_secs })
123}
124
125pub fn get_string(key: &str) -> Option<String> {
130 let entry = get(key)?;
131 String::from_utf8(entry.value).ok()
132}
133
134pub fn set(key: &str, value: &[u8], ttl_secs: u32) -> Result<(), SparkError> {
138 let (key_ptr, key_len) = mem::host_arg_str(key);
139 let (val_ptr, val_len) = mem::host_arg_bytes(value);
140 let code = unsafe { ffi::spark_set(key_ptr, key_len, val_ptr, val_len, ttl_secs as i32) };
141 if code == 0 {
142 Ok(())
143 } else {
144 Err(SparkError::from_code(code))
145 }
146}
147
148pub fn delete(key: &str) {
153 let (key_ptr, key_len) = mem::host_arg_str(key);
154 unsafe { ffi::spark_delete(key_ptr, key_len) }
155}
156
157pub fn list() -> Vec<String> {
162 let result = unsafe { ffi::spark_list() };
163 let Some(json_bytes) = (unsafe { mem::read_packed_bytes(result) }) else {
165 return Vec::new();
166 };
167 serde_json::from_slice(&json_bytes).unwrap_or_default()
168}
169
170#[derive(Debug, thiserror::Error)]
172pub enum SparkPullError {
173 #[error("spark pull: not available")]
175 NotAvailable,
176
177 #[error("spark pull: internal error")]
179 Internal,
180
181 #[error("spark pull: no capability")]
183 NoCapability,
184
185 #[error("spark pull: invalid key or origin")]
187 BadKey,
188
189 #[error("spark pull: rate limited")]
192 WriteLimit,
193
194 #[error("spark pull: unknown error code {0}")]
196 Unknown(i32),
197}
198
199impl SparkPullError {
200 fn from_code(code: i32) -> Self {
206 match code {
207 3 => Self::WriteLimit,
208 5 => Self::NotAvailable,
209 6 => Self::Internal,
210 8 => Self::BadKey,
211 9 => Self::NoCapability,
212 other => Self::Unknown(other),
213 }
214 }
215}
216
217pub fn pull(origin_node: &str, keys: &[&str]) -> Result<u32, SparkPullError> {
233 let keys_json = serde_json::to_string(keys).unwrap_or_else(|_| String::from("[]"));
234 let (origin_ptr, origin_len) = mem::host_arg_str(origin_node);
235 let (keys_ptr, keys_len) = mem::host_arg_str(&keys_json);
236 let code = unsafe { ffi::spark_pull(origin_ptr, origin_len, keys_ptr, keys_len) };
237 if code >= 0 {
238 Ok(code as u32)
239 } else {
240 Err(SparkPullError::from_code(-code))
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::ffi::test_host;
248
249 #[test]
250 fn get_strips_ttl_prefix_and_returns_value() {
251 test_host::reset();
252 test_host::with_mock(|m| {
253 m.spark_store.insert("k".into(), (b"hello".to_vec(), 60));
254 });
255 let entry = get("k").expect("get should hit the store");
256 assert_eq!(entry.value, b"hello");
257 assert_eq!(entry.ttl_secs, 60);
258 }
259
260 #[test]
261 fn get_handles_zero_ttl_no_expiry() {
262 test_host::reset();
263 test_host::with_mock(|m| {
264 m.spark_store.insert("k".into(), (b"forever".to_vec(), 0));
265 });
266 let entry = get("k").unwrap();
267 assert_eq!(entry.ttl_secs, 0);
268 assert_eq!(entry.value, b"forever");
269 }
270
271 #[test]
272 fn get_returns_none_for_missing_key() {
273 test_host::reset();
274 assert!(get("missing").is_none());
275 }
276
277 #[test]
278 fn get_string_decodes_utf8() {
279 test_host::reset();
280 test_host::with_mock(|m| {
281 m.spark_store
282 .insert("k".into(), ("héllo".as_bytes().to_vec(), 30));
283 });
284 assert_eq!(get_string("k").as_deref(), Some("héllo"));
285 }
286
287 #[test]
288 fn get_string_returns_none_for_invalid_utf8() {
289 test_host::reset();
290 test_host::with_mock(|m| {
291 m.spark_store.insert("k".into(), (vec![0xff, 0xfe], 30));
292 });
293 assert!(get_string("k").is_none());
294 }
295
296 #[test]
297 fn set_writes_to_store() {
298 test_host::reset();
299 set("greeting", b"hi", 120).expect("set should succeed");
300 let stored = test_host::read_mock(|m| m.spark_store.get("greeting").cloned());
301 assert_eq!(stored, Some((b"hi".to_vec(), 120)));
302 }
303
304 #[test]
305 fn set_captures_args() {
306 test_host::reset();
307 set("k", b"v", 30).unwrap();
308 let captured = test_host::read_mock(|m| m.last_spark_set.clone());
309 assert_eq!(captured, Some(("k".into(), b"v".to_vec(), 30)));
310 }
311
312 #[test]
313 fn set_maps_error_codes() {
314 let cases = [
315 (1, SparkError::InvalidTtl),
316 (2, SparkError::TooLarge),
317 (3, SparkError::WriteLimit),
318 (4, SparkError::QuotaExceeded),
319 (5, SparkError::NotAvailable),
320 (6, SparkError::Internal),
321 (7, SparkError::ReadLimit),
322 (8, SparkError::BadKey),
323 (9, SparkError::NoCapability),
324 ];
325 for (code, expected) in cases {
326 test_host::reset();
327 test_host::with_mock(|m| m.spark_set_error = code);
328 let err = set("k", b"v", 30).unwrap_err();
329 assert!(
330 std::mem::discriminant(&err) == std::mem::discriminant(&expected),
331 "code {} should map to {:?}, got {:?}",
332 code,
333 expected,
334 err,
335 );
336 }
337 }
338
339 #[test]
340 fn set_unknown_error_code() {
341 test_host::reset();
342 test_host::with_mock(|m| m.spark_set_error = 99);
343 match set("k", b"v", 30).unwrap_err() {
344 SparkError::Unknown(99) => {}
345 other => panic!("expected Unknown(99), got {:?}", other),
346 }
347 }
348
349 #[test]
350 fn delete_removes_from_store() {
351 test_host::reset();
352 test_host::with_mock(|m| {
353 m.spark_store.insert("k".into(), (b"v".to_vec(), 60));
354 });
355 delete("k");
356 assert!(test_host::read_mock(|m| m.spark_store.is_empty()));
357 assert_eq!(test_host::read_mock(|m| m.spark_deletes.clone()), vec!["k"]);
358 }
359
360 #[test]
361 fn list_returns_keys() {
362 test_host::reset();
363 test_host::with_mock(|m| {
364 m.spark_store.insert("a".into(), (b"1".to_vec(), 10));
365 m.spark_store.insert("b".into(), (b"2".to_vec(), 20));
366 });
367 let mut keys = list();
368 keys.sort();
369 assert_eq!(keys, vec!["a".to_string(), "b".to_string()]);
370 }
371
372 #[test]
373 fn list_empty_when_no_keys() {
374 test_host::reset();
375 assert!(list().is_empty());
376 }
377
378 #[test]
379 fn pull_serializes_keys_as_json() {
380 test_host::reset();
381 test_host::with_mock(|m| m.spark_pull_result = 3);
382 let count = pull("origin-node", &["a", "b", "c"]).unwrap();
383 assert_eq!(count, 3);
384 let calls = test_host::read_mock(|m| m.spark_pull_calls.clone());
385 assert_eq!(calls.len(), 1);
386 assert_eq!(calls[0].0, "origin-node");
387 assert_eq!(calls[0].1, r#"["a","b","c"]"#);
388 }
389
390 #[test]
391 fn pull_zero_count_is_ok() {
392 test_host::reset();
393 test_host::with_mock(|m| m.spark_pull_result = 0);
394 assert_eq!(pull("o", &[]).unwrap(), 0);
395 }
396
397 #[test]
398 fn pull_error_from_code_mapping() {
399 match SparkPullError::from_code(3) {
400 SparkPullError::WriteLimit => {}
401 other => panic!("3 should map to WriteLimit, got {:?}", other),
402 }
403 match SparkPullError::from_code(5) {
404 SparkPullError::NotAvailable => {}
405 other => panic!("5 should map to NotAvailable, got {:?}", other),
406 }
407 match SparkPullError::from_code(6) {
408 SparkPullError::Internal => {}
409 other => panic!("6 should map to Internal, got {:?}", other),
410 }
411 match SparkPullError::from_code(8) {
412 SparkPullError::BadKey => {}
413 other => panic!("8 should map to BadKey, got {:?}", other),
414 }
415 match SparkPullError::from_code(9) {
416 SparkPullError::NoCapability => {}
417 other => panic!("9 should map to NoCapability, got {:?}", other),
418 }
419 match SparkPullError::from_code(99) {
420 SparkPullError::Unknown(99) => {}
421 other => panic!("99 should map to Unknown(99), got {:?}", other),
422 }
423 }
424
425 #[test]
426 fn pull_negative_code_maps_to_typed_error() {
427 let cases = [
430 (-3, SparkPullError::WriteLimit),
431 (-5, SparkPullError::NotAvailable),
432 (-6, SparkPullError::Internal),
433 (-8, SparkPullError::BadKey),
434 (-9, SparkPullError::NoCapability),
435 ];
436 for (host_code, expected) in cases {
437 test_host::reset();
438 test_host::with_mock(|m| m.spark_pull_result = host_code);
439 let err = pull("origin", &["k"]).unwrap_err();
440 assert!(
441 std::mem::discriminant(&err) == std::mem::discriminant(&expected),
442 "host code {} should map to {:?}, got {:?}",
443 host_code,
444 expected,
445 err,
446 );
447 }
448 }
449
450 #[test]
451 fn pull_unknown_negative_code_is_unknown() {
452 test_host::reset();
453 test_host::with_mock(|m| m.spark_pull_result = -42);
454 match pull("origin", &["k"]).unwrap_err() {
455 SparkPullError::Unknown(42) => {}
456 other => panic!("expected Unknown(42), got {:?}", other),
457 }
458 }
459
460 #[test]
461 fn pull_positive_count_is_success() {
462 test_host::reset();
463 test_host::with_mock(|m| m.spark_pull_result = 7);
464 assert_eq!(pull("origin", &["a", "b"]).unwrap(), 7);
465 }
466}