1use crate::{
2 db::store::{EntityName, StoreRegistry},
3 error::{ErrorClass, ErrorOrigin, InternalError},
4 prelude::*,
5 serialize::deserialize,
6 traits::Storable,
7};
8use canic_cdk::structures::{BTreeMap, DefaultMemoryImpl, memory::VirtualMemory, storable::Bound};
9use derive_more::{Deref, DerefMut};
10use std::{
11 borrow::Cow,
12 fmt::{self, Display},
13};
14use thiserror::Error as ThisError;
15
16#[derive(Deref, DerefMut)]
21pub struct DataStoreRegistry(StoreRegistry<DataStore>);
22
23impl DataStoreRegistry {
24 #[must_use]
25 #[allow(clippy::new_without_default)]
26 pub fn new() -> Self {
28 Self(StoreRegistry::new())
29 }
30}
31
32#[derive(Deref, DerefMut)]
37pub struct DataStore(BTreeMap<RawDataKey, RawRow, VirtualMemory<DefaultMemoryImpl>>);
38
39impl DataStore {
40 #[must_use]
41 pub fn init(memory: VirtualMemory<DefaultMemoryImpl>) -> Self {
43 Self(BTreeMap::init(memory))
44 }
45
46 pub fn memory_bytes(&self) -> u64 {
48 self.iter()
49 .map(|entry| u64::from(DataKey::STORED_SIZE) + entry.value().len() as u64)
50 .sum()
51 }
52}
53
54#[derive(Debug, ThisError)]
59pub enum RawRowError {
60 #[error("row exceeds max size: {len} bytes (limit {MAX_ROW_BYTES})")]
61 TooLarge { len: usize },
62}
63
64impl RawRowError {
65 #[must_use]
66 pub const fn class(&self) -> ErrorClass {
67 ErrorClass::Unsupported
68 }
69
70 #[must_use]
71 pub const fn origin(&self) -> ErrorOrigin {
72 ErrorOrigin::Store
73 }
74}
75
76impl From<RawRowError> for InternalError {
77 fn from(err: RawRowError) -> Self {
78 Self::new(err.class(), err.origin(), err.to_string())
79 }
80}
81
82#[derive(Debug, ThisError)]
87pub enum RowDecodeError {
88 #[error("row exceeds max size: {len} bytes (limit {MAX_ROW_BYTES})")]
89 TooLarge { len: usize },
90 #[error("row failed to deserialize")]
91 Deserialize,
92}
93
94pub const MAX_ROW_BYTES: u32 = 4 * 1024 * 1024;
100
101#[derive(Clone, Debug, Eq, PartialEq)]
102pub struct RawRow(Vec<u8>);
103
104impl RawRow {
105 pub fn try_new(bytes: Vec<u8>) -> Result<Self, RawRowError> {
106 if bytes.len() > MAX_ROW_BYTES as usize {
107 return Err(RawRowError::TooLarge { len: bytes.len() });
108 }
109 Ok(Self(bytes))
110 }
111
112 #[must_use]
113 pub fn as_bytes(&self) -> &[u8] {
114 &self.0
115 }
116
117 #[must_use]
118 pub const fn len(&self) -> usize {
119 self.0.len()
120 }
121
122 #[must_use]
123 pub const fn is_empty(&self) -> bool {
124 self.0.is_empty()
125 }
126
127 pub fn try_decode<E: EntityKind>(&self) -> Result<E, RowDecodeError> {
128 if self.0.len() > MAX_ROW_BYTES as usize {
129 return Err(RowDecodeError::TooLarge { len: self.0.len() });
130 }
131
132 deserialize::<E>(&self.0).map_err(|_| RowDecodeError::Deserialize)
133 }
134}
135
136impl Storable for RawRow {
137 fn to_bytes(&self) -> Cow<'_, [u8]> {
138 Cow::Borrowed(&self.0)
139 }
140
141 fn from_bytes(bytes: Cow<'_, [u8]>) -> Self {
142 Self(bytes.into_owned())
143 }
144
145 fn into_bytes(self) -> Vec<u8> {
146 self.0
147 }
148
149 const BOUND: Bound = Bound::Bounded {
150 max_size: MAX_ROW_BYTES,
151 is_fixed_size: false,
152 };
153}
154
155pub type DataRow = (DataKey, RawRow);
156
157#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
162pub struct DataKey {
163 entity: EntityName,
164 key: Key,
165}
166
167impl DataKey {
168 #[allow(clippy::cast_possible_truncation)]
169 pub const STORED_SIZE: u32 = EntityName::STORED_SIZE + Key::STORED_SIZE as u32;
170
171 #[must_use]
172 pub fn new<E: EntityKind>(key: impl Into<Key>) -> Self {
174 Self {
175 entity: EntityName::from_static(E::ENTITY_NAME),
176 key: key.into(),
177 }
178 }
179
180 #[must_use]
181 pub const fn lower_bound<E: EntityKind>() -> Self {
182 Self {
183 entity: EntityName::from_static(E::ENTITY_NAME),
184 key: Key::lower_bound(),
185 }
186 }
187
188 #[must_use]
189 pub const fn upper_bound<E: EntityKind>() -> Self {
190 Self {
191 entity: EntityName::from_static(E::ENTITY_NAME),
192 key: Key::upper_bound(),
193 }
194 }
195
196 #[must_use]
198 pub const fn key(&self) -> Key {
199 self.key
200 }
201
202 #[must_use]
204 pub const fn entity_name(&self) -> &EntityName {
205 &self.entity
206 }
207
208 #[must_use]
211 pub const fn entry_size_bytes(value_len: u64) -> u64 {
212 Self::STORED_SIZE as u64 + value_len
213 }
214
215 #[must_use]
216 pub fn max_storable() -> Self {
218 Self {
219 entity: EntityName::max_storable(),
220 key: Key::max_storable(),
221 }
222 }
223
224 #[must_use]
225 pub fn to_raw(&self) -> RawDataKey {
226 let mut buf = [0u8; Self::STORED_SIZE as usize];
227
228 buf[0] = self.entity.len;
229 let entity_end = EntityName::STORED_SIZE_USIZE;
230 buf[1..entity_end].copy_from_slice(&self.entity.bytes);
231
232 let key_bytes = self.key.to_bytes();
233 debug_assert_eq!(
234 key_bytes.len(),
235 Key::STORED_SIZE,
236 "Key serialization must be exactly fixed-size"
237 );
238 let key_offset = EntityName::STORED_SIZE_USIZE;
239 buf[key_offset..key_offset + Key::STORED_SIZE].copy_from_slice(&key_bytes);
240
241 RawDataKey(buf)
242 }
243
244 pub fn try_from_raw(raw: &RawDataKey) -> Result<Self, &'static str> {
245 let bytes = &raw.0;
246
247 let mut offset = 0;
248 let entity = EntityName::from_bytes(&bytes[offset..offset + EntityName::STORED_SIZE_USIZE])
249 .map_err(|_| "corrupted DataKey: invalid EntityName bytes")?;
250 offset += EntityName::STORED_SIZE_USIZE;
251
252 let key = Key::try_from_bytes(&bytes[offset..offset + Key::STORED_SIZE])
253 .map_err(|_| "corrupted DataKey: invalid Key bytes")?;
254
255 Ok(Self { entity, key })
256 }
257}
258
259impl Display for DataKey {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 write!(f, "#{} ({})", self.entity, self.key)
262 }
263}
264
265impl From<DataKey> for Key {
266 fn from(key: DataKey) -> Self {
267 key.key()
268 }
269}
270
271#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
276pub struct RawDataKey([u8; DataKey::STORED_SIZE as usize]);
277
278impl RawDataKey {
279 #[must_use]
280 pub const fn as_bytes(&self) -> &[u8; DataKey::STORED_SIZE as usize] {
281 &self.0
282 }
283}
284
285impl Storable for RawDataKey {
286 fn to_bytes(&self) -> Cow<'_, [u8]> {
287 Cow::Borrowed(&self.0)
288 }
289
290 fn from_bytes(bytes: Cow<'_, [u8]>) -> Self {
291 let mut out = [0u8; DataKey::STORED_SIZE as usize];
292 if bytes.len() == out.len() {
293 out.copy_from_slice(bytes.as_ref());
294 }
295 Self(out)
296 }
297
298 fn into_bytes(self) -> Vec<u8> {
299 self.0.to_vec()
300 }
301
302 const BOUND: Bound = Bound::Bounded {
303 max_size: DataKey::STORED_SIZE,
304 is_fixed_size: true,
305 };
306}
307
308#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::traits::Storable;
316 use crate::{
317 model::index::IndexModel,
318 serialize::serialize,
319 traits::{
320 CanisterKind, EntityKind, FieldValues, Path, SanitizeAuto, SanitizeCustom, StoreKind,
321 ValidateAuto, ValidateCustom, View, Visitable,
322 },
323 };
324 use serde::{Deserialize, Serialize};
325 use std::borrow::Cow;
326
327 #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
328 struct DummyEntity {
329 id: u64,
330 }
331
332 impl Path for DummyEntity {
333 const PATH: &'static str = "dummy_entity";
334 }
335
336 impl View for DummyEntity {
337 type ViewType = Self;
338
339 fn to_view(&self) -> Self::ViewType {
340 self.clone()
341 }
342
343 fn from_view(view: Self::ViewType) -> Self {
344 view
345 }
346 }
347
348 impl FieldValues for DummyEntity {
349 fn get_value(&self, field: &str) -> Option<Value> {
350 match field {
351 "id" => Some(Value::Uint(self.id)),
352 _ => None,
353 }
354 }
355 }
356
357 impl SanitizeAuto for DummyEntity {}
358 impl SanitizeCustom for DummyEntity {}
359 impl ValidateAuto for DummyEntity {}
360 impl ValidateCustom for DummyEntity {}
361 impl Visitable for DummyEntity {}
362
363 #[derive(Clone, Copy, Debug)]
364 struct DummyStore;
365
366 impl Path for DummyStore {
367 const PATH: &'static str = "dummy_store";
368 }
369
370 #[derive(Clone, Copy, Debug)]
371 struct DummyCanister;
372
373 impl Path for DummyCanister {
374 const PATH: &'static str = "dummy_canister";
375 }
376
377 impl CanisterKind for DummyCanister {}
378
379 impl StoreKind for DummyStore {
380 type Canister = DummyCanister;
381 }
382
383 impl EntityKind for DummyEntity {
384 type PrimaryKey = u64;
385 type Store = DummyStore;
386 type Canister = DummyCanister;
387
388 const ENTITY_NAME: &'static str = "dummy";
389 const PRIMARY_KEY: &'static str = "id";
390 const FIELDS: &'static [&'static str] = &["id"];
391 const INDEXES: &'static [&'static IndexModel] = &[];
392
393 fn key(&self) -> Key {
394 Key::Uint(self.id)
395 }
396
397 fn primary_key(&self) -> Self::PrimaryKey {
398 self.id
399 }
400
401 fn set_primary_key(&mut self, key: Self::PrimaryKey) {
402 self.id = key;
403 }
404 }
405
406 #[test]
407 fn data_key_is_exactly_fixed_size() {
408 let data_key = DataKey::max_storable();
409 let size = data_key.to_raw().as_bytes().len();
410
411 assert_eq!(
412 size,
413 DataKey::STORED_SIZE as usize,
414 "DataKey must serialize to exactly STORED_SIZE bytes"
415 );
416 }
417
418 #[test]
419 fn data_key_ordering_matches_bytes() {
420 let keys = vec![
421 DataKey {
422 entity: EntityName::from_static("a"),
423 key: Key::Int(0),
424 },
425 DataKey {
426 entity: EntityName::from_static("aa"),
427 key: Key::Int(0),
428 },
429 DataKey {
430 entity: EntityName::from_static("b"),
431 key: Key::Int(0),
432 },
433 DataKey {
434 entity: EntityName::from_static("a"),
435 key: Key::Uint(1),
436 },
437 ];
438
439 let mut sorted_by_ord = keys.clone();
440 sorted_by_ord.sort();
441
442 let mut sorted_by_bytes = keys;
443 sorted_by_bytes.sort_by(|a, b| a.to_raw().as_bytes().cmp(b.to_raw().as_bytes()));
444
445 assert_eq!(
446 sorted_by_ord, sorted_by_bytes,
447 "DataKey Ord and byte ordering diverged"
448 );
449 }
450
451 #[test]
452 fn data_key_rejects_undersized_bytes() {
453 let buf = vec![0u8; DataKey::STORED_SIZE as usize - 1];
454 let raw = RawDataKey::from_bytes(Cow::Borrowed(&buf));
455 let err = DataKey::try_from_raw(&raw).unwrap_err();
456 assert!(
457 err.contains("corrupted"),
458 "expected corruption error, got: {err}"
459 );
460 }
461
462 #[test]
463 fn data_key_rejects_oversized_bytes() {
464 let buf = vec![0u8; DataKey::STORED_SIZE as usize + 1];
465 let raw = RawDataKey::from_bytes(Cow::Borrowed(&buf));
466 let err = DataKey::try_from_raw(&raw).unwrap_err();
467 assert!(
468 err.contains("corrupted"),
469 "expected corruption error, got: {err}"
470 );
471 }
472
473 #[test]
474 fn data_key_rejects_invalid_entity_len() {
475 let mut raw = DataKey::max_storable().to_raw();
476 raw.0[0] = 0;
477 assert!(DataKey::try_from_raw(&raw).is_err());
478 }
479
480 #[test]
481 fn data_key_rejects_non_ascii_entity_bytes() {
482 let data_key = DataKey {
483 entity: EntityName::from_static("a"),
484 key: Key::Int(1),
485 };
486 let mut raw = data_key.to_raw();
487 raw.0[1] = 0xFF;
488 assert!(DataKey::try_from_raw(&raw).is_err());
489 }
490
491 #[test]
492 fn data_key_rejects_entity_padding() {
493 let data_key = DataKey {
494 entity: EntityName::from_static("user"),
495 key: Key::Int(1),
496 };
497 let mut raw = data_key.to_raw();
498 let padding_offset = 1 + data_key.entity.len();
499 raw.0[padding_offset] = b'x';
500 assert!(DataKey::try_from_raw(&raw).is_err());
501 }
502
503 #[test]
504 #[allow(clippy::cast_possible_truncation)]
505 fn data_key_decode_fuzz_roundtrip_is_canonical() {
506 const RUNS: u64 = 1_000;
507
508 let mut seed = 0xDEAD_BEEF_u64;
509 for _ in 0..RUNS {
510 let mut bytes = [0u8; DataKey::STORED_SIZE as usize];
511 for b in &mut bytes {
512 seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
513 *b = (seed >> 24) as u8;
514 }
515
516 let raw = RawDataKey(bytes);
517 if let Ok(decoded) = DataKey::try_from_raw(&raw) {
518 let re = decoded.to_raw();
519 assert_eq!(
520 raw.as_bytes(),
521 re.as_bytes(),
522 "decoded DataKey must be canonical"
523 );
524 }
525 }
526 }
527
528 #[test]
529 fn raw_data_key_roundtrip_via_bytes() {
530 let data_key = DataKey::new::<DummyEntity>(Key::Uint(7));
531 let raw = data_key.to_raw();
532 let bytes = Storable::to_bytes(&raw);
533 let raw = RawDataKey::from_bytes(bytes);
534 let decoded = DataKey::try_from_raw(&raw).expect("decode should succeed");
535
536 assert_eq!(decoded, data_key);
537 }
538
539 #[test]
540 fn raw_row_roundtrip_via_bytes() {
541 let entity = DummyEntity { id: 42 };
542 let bytes = serialize(&entity).expect("serialize");
543 let raw = RawRow::try_new(bytes).expect("raw row");
544
545 let encoded = Storable::to_bytes(&raw);
546 let raw = RawRow::from_bytes(encoded);
547 let decoded = raw.try_decode::<DummyEntity>().expect("decode");
548
549 assert_eq!(decoded, entity);
550 }
551
552 #[test]
553 fn raw_row_rejects_oversized_payload() {
554 let bytes = vec![0u8; MAX_ROW_BYTES as usize + 1];
555 let err = RawRow::try_new(bytes).unwrap_err();
556 assert!(matches!(err, RawRowError::TooLarge { .. }));
557 }
558
559 #[test]
560 fn raw_row_rejects_truncated_payload() {
561 let entity = DummyEntity { id: 7 };
562 let mut bytes = serialize(&entity).expect("serialize");
563 bytes.truncate(bytes.len().saturating_sub(1));
564 let raw = RawRow::try_new(bytes).expect("raw row");
565
566 let err = raw.try_decode::<DummyEntity>().unwrap_err();
567 assert!(matches!(err, RowDecodeError::Deserialize));
568 }
569}