1use crate::{
2 traits::{FieldValue, Storable},
3 types::{Account, Principal, Subaccount, Timestamp, Ulid, Unit},
4 value::Value,
5};
6use candid::{CandidType, Principal as WrappedPrincipal};
7use canic_cdk::structures::storable::Bound;
8use derive_more::Display;
9use serde::{Deserialize, Serialize};
10use std::{borrow::Cow, cmp::Ordering};
11
12#[derive(CandidType, Clone, Copy, Debug, Deserialize, Display, Eq, Hash, PartialEq, Serialize)]
21pub enum Key {
22 Account(Account),
23 Int(i64),
24 Principal(Principal),
25 Subaccount(Subaccount),
26 Timestamp(Timestamp),
27 Uint(u64),
28 Ulid(Ulid),
29 Unit,
30}
31
32impl Key {
33 const TAG_ACCOUNT: u8 = 0;
35 const TAG_INT: u8 = 1;
36 const TAG_PRINCIPAL: u8 = 2;
37 const TAG_SUBACCOUNT: u8 = 3;
38 const TAG_TIMESTAMP: u8 = 4;
39 const TAG_UINT: u8 = 5;
40 const TAG_ULID: u8 = 6;
41 const TAG_UNIT: u8 = 7;
42
43 pub const STORED_SIZE: usize = 64;
45
46 const TAG_SIZE: usize = 1;
48 const TAG_OFFSET: usize = 0;
49
50 const PAYLOAD_OFFSET: usize = Self::TAG_SIZE;
51 const PAYLOAD_SIZE: usize = Self::STORED_SIZE - Self::TAG_SIZE;
52
53 const INT_SIZE: usize = 8;
55 const UINT_SIZE: usize = 8;
56 const TIMESTAMP_SIZE: usize = 8;
57 const ULID_SIZE: usize = 16;
58 const SUBACCOUNT_SIZE: usize = 32;
59 const ACCOUNT_MAX_SIZE: usize = 62;
60
61 const fn tag(&self) -> u8 {
62 match self {
63 Self::Account(_) => Self::TAG_ACCOUNT,
64 Self::Int(_) => Self::TAG_INT,
65 Self::Principal(_) => Self::TAG_PRINCIPAL,
66 Self::Subaccount(_) => Self::TAG_SUBACCOUNT,
67 Self::Timestamp(_) => Self::TAG_TIMESTAMP,
68 Self::Uint(_) => Self::TAG_UINT,
69 Self::Ulid(_) => Self::TAG_ULID,
70 Self::Unit => Self::TAG_UNIT,
71 }
72 }
73
74 #[must_use]
75 pub fn max_storable() -> Self {
77 Self::Account(Account::max_storable())
78 }
79
80 #[must_use]
81 pub const fn lower_bound() -> Self {
82 Self::Int(i64::MIN)
83 }
84
85 #[must_use]
86 pub const fn upper_bound() -> Self {
87 Self::Unit
88 }
89
90 const fn variant_rank(&self) -> u8 {
91 self.tag()
92 }
93}
94
95impl FieldValue for Key {
96 fn to_value(&self) -> Value {
97 match self {
98 Self::Account(v) => Value::Account(*v),
99 Self::Int(v) => Value::Int(*v),
100 Self::Principal(v) => Value::Principal(*v),
101 Self::Subaccount(v) => Value::Subaccount(*v),
102 Self::Timestamp(v) => Value::Timestamp(*v),
103 Self::Uint(v) => Value::Uint(*v),
104 Self::Ulid(v) => Value::Ulid(*v),
105 Self::Unit => Value::Unit,
106 }
107 }
108}
109
110impl From<()> for Key {
111 fn from((): ()) -> Self {
112 Self::Unit
113 }
114}
115
116impl From<Unit> for Key {
117 fn from(_: Unit) -> Self {
118 Self::Unit
119 }
120}
121
122impl PartialEq<()> for Key {
123 fn eq(&self, (): &()) -> bool {
124 matches!(self, Self::Unit)
125 }
126}
127
128impl PartialEq<Key> for () {
129 fn eq(&self, other: &Key) -> bool {
130 other == self
131 }
132}
133
134macro_rules! impl_from_key {
136 ( $( $ty:ty => $variant:ident ),* $(,)? ) => {
137 $(
138 impl From<$ty> for Key {
139 fn from(v: $ty) -> Self {
140 Self::$variant(v.into())
141 }
142 }
143 )*
144 }
145}
146
147macro_rules! impl_eq_key {
149 ( $( $ty:ty => $variant:ident ),* $(,)? ) => {
150 $(
151 impl PartialEq<$ty> for Key {
152 fn eq(&self, other: &$ty) -> bool {
153 matches!(self, Self::$variant(val) if val == other)
154 }
155 }
156
157 impl PartialEq<Key> for $ty {
158 fn eq(&self, other: &Key) -> bool {
159 other == self
160 }
161 }
162 )*
163 }
164}
165
166impl_from_key! {
167 Account => Account,
168 i8 => Int,
169 i16 => Int,
170 i32 => Int,
171 i64 => Int,
172 Principal => Principal,
173 WrappedPrincipal => Principal,
174 Subaccount => Subaccount,
175 Timestamp => Timestamp,
176 u8 => Uint,
177 u16 => Uint,
178 u32 => Uint,
179 u64 => Uint,
180 Ulid => Ulid,
181}
182
183impl_eq_key! {
184 Account => Account,
185 i64 => Int,
186 Principal => Principal,
187 Subaccount => Subaccount,
188 Timestamp => Timestamp,
189 u64 => Uint,
190 Ulid => Ulid,
191}
192
193impl Ord for Key {
194 fn cmp(&self, other: &Self) -> Ordering {
195 match (self, other) {
196 (Self::Account(a), Self::Account(b)) => Ord::cmp(a, b),
197 (Self::Int(a), Self::Int(b)) => Ord::cmp(a, b),
198 (Self::Principal(a), Self::Principal(b)) => Ord::cmp(a, b),
199 (Self::Uint(a), Self::Uint(b)) => Ord::cmp(a, b),
200 (Self::Ulid(a), Self::Ulid(b)) => Ord::cmp(a, b),
201 (Self::Subaccount(a), Self::Subaccount(b)) => Ord::cmp(a, b),
202 (Self::Timestamp(a), Self::Timestamp(b)) => Ord::cmp(a, b),
203
204 _ => Ord::cmp(&self.variant_rank(), &other.variant_rank()), }
206 }
207}
208
209impl PartialOrd for Key {
210 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
211 Some(Ord::cmp(self, other))
212 }
213}
214
215#[allow(clippy::cast_possible_truncation)]
216impl Storable for Key {
217 fn to_bytes(&self) -> Cow<'_, [u8]> {
218 let mut buf = [0u8; Self::STORED_SIZE];
219
220 buf[Self::TAG_OFFSET] = self.tag();
222 let payload = &mut buf[Self::PAYLOAD_OFFSET..];
223
224 debug_assert_eq!(payload.len(), Self::PAYLOAD_SIZE);
225
226 match self {
228 Self::Account(v) => {
229 let bytes = v.to_bytes();
230 debug_assert_eq!(bytes.len(), Self::ACCOUNT_MAX_SIZE);
231 payload[..bytes.len()].copy_from_slice(&bytes);
232 }
233
234 Self::Int(v) => {
235 let biased = (*v).cast_unsigned() ^ (1u64 << 63);
237 payload[..Self::INT_SIZE].copy_from_slice(&biased.to_be_bytes());
238 }
239
240 Self::Uint(v) => {
241 payload[..Self::UINT_SIZE].copy_from_slice(&v.to_be_bytes());
242 }
243
244 Self::Timestamp(v) => {
245 payload[..Self::TIMESTAMP_SIZE].copy_from_slice(&v.get().to_be_bytes());
246 }
247
248 Self::Principal(v) => {
249 let bytes = v.to_bytes();
250 let len = bytes.len();
251 assert!(
252 (1..=Principal::MAX_LENGTH_IN_BYTES as usize).contains(&len),
253 "invalid Key principal length"
254 );
255 payload[0] = len as u8;
256 if len > 0 {
257 payload[1..=len].copy_from_slice(&bytes[..len]);
258 }
259 }
260
261 Self::Subaccount(v) => {
262 let bytes = v.to_array();
263 debug_assert_eq!(bytes.len(), Self::SUBACCOUNT_SIZE);
264 payload[..Self::SUBACCOUNT_SIZE].copy_from_slice(&bytes);
265 }
266
267 Self::Ulid(v) => {
268 payload[..Self::ULID_SIZE].copy_from_slice(&v.to_bytes());
269 }
270
271 Self::Unit => {}
272 }
273
274 Cow::Owned(buf.to_vec())
275 }
276
277 fn from_bytes(bytes: Cow<'_, [u8]>) -> Self {
278 let bytes = bytes.as_ref();
279 assert_eq!(
280 bytes.len(),
281 Self::STORED_SIZE,
282 "corrupted Key: invalid size"
283 );
284
285 let tag = bytes[Self::TAG_OFFSET];
286 let payload = &bytes[Self::PAYLOAD_OFFSET..];
287
288 let assert_zero_padding = |used: usize, context: &str| {
289 assert!(
290 payload[used..].iter().all(|&b| b == 0),
291 "corrupted Key: {context}"
292 );
293 };
294
295 match tag {
296 Self::TAG_ACCOUNT => {
297 let end = Account::STORED_SIZE as usize;
298 assert_zero_padding(end, "account padding");
299 Self::Account(Account::from_bytes(payload[..end].into()))
300 }
301
302 Self::TAG_INT => {
303 let mut buf = [0u8; Self::INT_SIZE];
304 buf.copy_from_slice(&payload[..Self::INT_SIZE]);
305 let biased = u64::from_be_bytes(buf);
306 assert_zero_padding(Self::INT_SIZE, "int padding");
307 Self::Int((biased ^ (1u64 << 63)).cast_signed())
308 }
309
310 Self::TAG_PRINCIPAL => {
311 let len = payload[0] as usize;
312 assert!(
313 (1..=Principal::MAX_LENGTH_IN_BYTES as usize).contains(&len),
314 "corrupted Key: principal length out of bounds"
315 );
316 let end = 1 + len;
317 assert_zero_padding(end, "principal padding");
318 Self::Principal(Principal::from_slice(&payload[1..end]))
319 }
320
321 Self::TAG_SUBACCOUNT => {
322 let mut buf = [0u8; Self::SUBACCOUNT_SIZE];
323 buf.copy_from_slice(&payload[..Self::SUBACCOUNT_SIZE]);
324 assert_zero_padding(Self::SUBACCOUNT_SIZE, "subaccount padding");
325 Self::Subaccount(Subaccount::from_array(buf))
326 }
327
328 Self::TAG_TIMESTAMP => {
329 let mut buf = [0u8; Self::TIMESTAMP_SIZE];
330 buf.copy_from_slice(&payload[..Self::TIMESTAMP_SIZE]);
331 assert_zero_padding(Self::TIMESTAMP_SIZE, "timestamp padding");
332 Self::Timestamp(Timestamp::from(u64::from_be_bytes(buf)))
333 }
334
335 Self::TAG_UINT => {
336 let mut buf = [0u8; Self::UINT_SIZE];
337 buf.copy_from_slice(&payload[..Self::UINT_SIZE]);
338 assert_zero_padding(Self::UINT_SIZE, "uint padding");
339 Self::Uint(u64::from_be_bytes(buf))
340 }
341
342 Self::TAG_ULID => {
343 let mut buf = [0u8; Self::ULID_SIZE];
344 buf.copy_from_slice(&payload[..Self::ULID_SIZE]);
345 assert_zero_padding(Self::ULID_SIZE, "ulid padding");
346 Self::Ulid(Ulid::from_bytes(buf))
347 }
348
349 Self::TAG_UNIT => {
350 assert_zero_padding(0, "unit padding");
351 Self::Unit
352 }
353
354 _ => panic!("corrupted Key: invalid tag"),
355 }
356 }
357
358 fn into_bytes(self) -> Vec<u8> {
359 self.to_bytes().into_owned()
360 }
361
362 const BOUND: Bound = Bound::Bounded {
363 max_size: Self::STORED_SIZE as u32,
364 is_fixed_size: true,
365 };
366}
367#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::traits::Storable;
375
376 #[test]
377 fn key_max_size_is_bounded() {
378 let key = Key::max_storable();
379 let size = Storable::to_bytes(&key).len();
380
381 assert!(
382 size <= Key::STORED_SIZE,
383 "serialized Key too large: got {size} bytes (limit {})",
384 Key::STORED_SIZE
385 );
386 }
387
388 #[test]
389 fn key_storable_round_trip() {
390 let keys = [
391 Key::Account(Account::dummy(1)),
392 Key::Int(-42),
393 Key::Principal(Principal::from_slice(&[1, 2, 3])),
394 Key::Subaccount(Subaccount::from_array([7; 32])),
395 Key::Timestamp(Timestamp::from_seconds(42)),
396 Key::Uint(42),
397 Key::Ulid(Ulid::from_bytes([9; 16])),
398 Key::Unit,
399 ];
400
401 for key in keys {
402 let bytes = Storable::to_bytes(&key);
403 let decoded = Key::from_bytes(bytes);
404
405 assert_eq!(decoded, key, "Key round trip failed for {key:?}");
406 }
407 }
408
409 #[test]
410 fn key_is_exactly_fixed_size() {
411 let keys = [
412 Key::Account(Account::dummy(1)),
413 Key::Int(0),
414 Key::Principal(Principal::anonymous()),
415 Key::Subaccount(Subaccount::from_array([0; 32])),
416 Key::Timestamp(Timestamp::from_seconds(0)),
417 Key::Uint(0),
418 Key::Ulid(Ulid::from_bytes([0; 16])),
419 Key::Unit,
420 ];
421
422 for key in keys {
423 let len = key.to_bytes().len();
424 assert_eq!(
425 len,
426 Key::STORED_SIZE,
427 "Key serialized length must be exactly {}",
428 Key::STORED_SIZE
429 );
430 }
431 }
432
433 #[test]
434 fn key_ordering_is_total_and_stable() {
435 let keys = vec![
436 Key::Account(Account::new(
437 Principal::from_slice(&[1]),
438 None::<Subaccount>,
439 )),
440 Key::Account(Account::new(Principal::from_slice(&[1]), Some([0u8; 32]))),
441 Key::Int(-1),
442 Key::Int(0),
443 Key::Principal(Principal::from_slice(&[1])),
444 Key::Subaccount(Subaccount::from_array([1; 32])),
445 Key::Uint(0),
446 Key::Uint(1),
447 Key::Timestamp(Timestamp::from_seconds(1)),
448 Key::Ulid(Ulid::from_bytes([9; 16])),
449 Key::Unit,
450 ];
451
452 let mut sorted_by_ord = keys.clone();
453 sorted_by_ord.sort();
454
455 let mut sorted_by_bytes = keys;
456 sorted_by_bytes.sort_by(|a, b| a.to_bytes().cmp(&b.to_bytes()));
457
458 assert_eq!(
459 sorted_by_ord, sorted_by_bytes,
460 "Key Ord and byte ordering diverged"
461 );
462 }
463
464 #[test]
465 #[should_panic(expected = "corrupted Key: invalid size")]
466 fn key_from_bytes_rejects_undersized() {
467 let bytes = vec![0u8; Key::STORED_SIZE - 1];
468 let _ = Key::from_bytes(bytes.into());
469 }
470
471 #[test]
472 #[should_panic(expected = "corrupted Key: invalid size")]
473 fn key_from_bytes_rejects_oversized() {
474 let bytes = vec![0u8; Key::STORED_SIZE + 1];
475 let _ = Key::from_bytes(bytes.into());
476 }
477
478 #[test]
479 #[should_panic(expected = "corrupted Key: principal length out of bounds")]
480 fn key_from_bytes_rejects_zero_principal_len() {
481 let mut bytes = Key::Principal(Principal::from_slice(&[1]))
482 .to_bytes()
483 .into_owned();
484 bytes[Key::TAG_OFFSET] = Key::TAG_PRINCIPAL;
485 bytes[Key::PAYLOAD_OFFSET] = 0;
486 let _ = Key::from_bytes(bytes.into());
487 }
488
489 #[test]
490 #[allow(clippy::cast_possible_truncation)]
491 #[should_panic(expected = "corrupted Key: principal length out of bounds")]
492 fn key_from_bytes_rejects_oversized_principal_len() {
493 let mut bytes = Key::Principal(Principal::from_slice(&[1]))
494 .to_bytes()
495 .into_owned();
496 bytes[Key::TAG_OFFSET] = Key::TAG_PRINCIPAL;
497 bytes[Key::PAYLOAD_OFFSET] = (Principal::MAX_LENGTH_IN_BYTES as u8) + 1;
498 let _ = Key::from_bytes(bytes.into());
499 }
500
501 #[test]
502 #[should_panic(expected = "corrupted Key: principal padding")]
503 fn key_from_bytes_rejects_principal_padding() {
504 let mut bytes = Key::Principal(Principal::from_slice(&[1]))
505 .to_bytes()
506 .into_owned();
507 bytes[Key::TAG_OFFSET] = Key::TAG_PRINCIPAL;
508 bytes[Key::PAYLOAD_OFFSET] = 1;
509 bytes[Key::PAYLOAD_OFFSET + 2] = 1;
510 let _ = Key::from_bytes(bytes.into());
511 }
512
513 #[test]
514 #[should_panic(expected = "corrupted Key: account padding")]
515 fn key_from_bytes_rejects_account_padding() {
516 let mut bytes = Key::Account(Account::new(
517 Principal::from_slice(&[1]),
518 None::<Subaccount>,
519 ))
520 .to_bytes()
521 .into_owned();
522 bytes[Key::TAG_OFFSET] = Key::TAG_ACCOUNT;
523 bytes[Key::PAYLOAD_OFFSET + Account::STORED_SIZE as usize] = 1;
524 let _ = Key::from_bytes(bytes.into());
525 }
526
527 #[test]
528 #[should_panic(expected = "corrupted Key: int padding")]
529 fn key_from_bytes_rejects_int_padding() {
530 let mut bytes = Key::Int(0).to_bytes().into_owned();
531 bytes[Key::TAG_OFFSET] = Key::TAG_INT;
532 bytes[Key::PAYLOAD_OFFSET + Key::INT_SIZE] = 1;
533 let _ = Key::from_bytes(bytes.into());
534 }
535
536 #[test]
537 #[should_panic(expected = "corrupted Key: uint padding")]
538 fn key_from_bytes_rejects_uint_padding() {
539 let mut bytes = Key::Uint(0).to_bytes().into_owned();
540 bytes[Key::TAG_OFFSET] = Key::TAG_UINT;
541 bytes[Key::PAYLOAD_OFFSET + Key::UINT_SIZE] = 1;
542 let _ = Key::from_bytes(bytes.into());
543 }
544
545 #[test]
546 #[should_panic(expected = "corrupted Key: timestamp padding")]
547 fn key_from_bytes_rejects_timestamp_padding() {
548 let mut bytes = Key::Timestamp(Timestamp::from_seconds(0))
549 .to_bytes()
550 .into_owned();
551 bytes[Key::TAG_OFFSET] = Key::TAG_TIMESTAMP;
552 bytes[Key::PAYLOAD_OFFSET + Key::TIMESTAMP_SIZE] = 1;
553 let _ = Key::from_bytes(bytes.into());
554 }
555
556 #[test]
557 #[should_panic(expected = "corrupted Key: subaccount padding")]
558 fn key_from_bytes_rejects_subaccount_padding() {
559 let mut bytes = Key::Subaccount(Subaccount::from_array([0; 32]))
560 .to_bytes()
561 .into_owned();
562 bytes[Key::TAG_OFFSET] = Key::TAG_SUBACCOUNT;
563 bytes[Key::PAYLOAD_OFFSET + Key::SUBACCOUNT_SIZE] = 1;
564 let _ = Key::from_bytes(bytes.into());
565 }
566
567 #[test]
568 #[should_panic(expected = "corrupted Key: ulid padding")]
569 fn key_from_bytes_rejects_ulid_padding() {
570 let mut bytes = Key::Ulid(Ulid::from_bytes([0; 16])).to_bytes().into_owned();
571 bytes[Key::TAG_OFFSET] = Key::TAG_ULID;
572 bytes[Key::PAYLOAD_OFFSET + Key::ULID_SIZE] = 1;
573 let _ = Key::from_bytes(bytes.into());
574 }
575
576 #[test]
577 #[should_panic(expected = "corrupted Key: unit padding")]
578 fn key_from_bytes_rejects_unit_padding() {
579 let mut bytes = Key::Unit.to_bytes().into_owned();
580 bytes[Key::TAG_OFFSET] = Key::TAG_UNIT;
581 bytes[Key::PAYLOAD_OFFSET] = 1;
582 let _ = Key::from_bytes(bytes.into());
583 }
584
585 #[test]
586 fn principal_encoding_respects_max_size() {
587 let max = Principal::from_slice(&[0xFF; 29]);
588 let key = Key::Principal(max);
589
590 let bytes = key.to_bytes();
591 assert_eq!(bytes.len(), Key::STORED_SIZE);
592 }
593}