1use crate::{
2 traits::FieldValue,
3 types::{Account, Principal, Subaccount, Timestamp, Ulid, Unit},
4 value::Value,
5};
6use candid::{CandidType, Principal as WrappedPrincipal};
7use derive_more::Display;
8use serde::{Deserialize, Serialize};
9use std::cmp::Ordering;
10
11#[derive(CandidType, Clone, Copy, Debug, Deserialize, Display, Eq, Hash, PartialEq, Serialize)]
20pub enum Key {
21 Account(Account),
22 Int(i64),
23 Principal(Principal),
24 Subaccount(Subaccount),
25 Timestamp(Timestamp),
26 Uint(u64),
27 Ulid(Ulid),
28 Unit,
29}
30
31impl Key {
32 const TAG_ACCOUNT: u8 = 0;
34 const TAG_INT: u8 = 1;
35 const TAG_PRINCIPAL: u8 = 2;
36 const TAG_SUBACCOUNT: u8 = 3;
37 const TAG_TIMESTAMP: u8 = 4;
38 const TAG_UINT: u8 = 5;
39 const TAG_ULID: u8 = 6;
40 const TAG_UNIT: u8 = 7;
41
42 pub const STORED_SIZE: usize = 64;
44
45 const TAG_SIZE: usize = 1;
47 const TAG_OFFSET: usize = 0;
48
49 const PAYLOAD_OFFSET: usize = Self::TAG_SIZE;
50 const PAYLOAD_SIZE: usize = Self::STORED_SIZE - Self::TAG_SIZE;
51
52 const INT_SIZE: usize = 8;
54 const UINT_SIZE: usize = 8;
55 const TIMESTAMP_SIZE: usize = 8;
56 const ULID_SIZE: usize = 16;
57 const SUBACCOUNT_SIZE: usize = 32;
58 const ACCOUNT_MAX_SIZE: usize = 62;
59
60 const fn tag(&self) -> u8 {
61 match self {
62 Self::Account(_) => Self::TAG_ACCOUNT,
63 Self::Int(_) => Self::TAG_INT,
64 Self::Principal(_) => Self::TAG_PRINCIPAL,
65 Self::Subaccount(_) => Self::TAG_SUBACCOUNT,
66 Self::Timestamp(_) => Self::TAG_TIMESTAMP,
67 Self::Uint(_) => Self::TAG_UINT,
68 Self::Ulid(_) => Self::TAG_ULID,
69 Self::Unit => Self::TAG_UNIT,
70 }
71 }
72
73 #[must_use]
74 pub fn max_storable() -> Self {
76 Self::Account(Account::max_storable())
77 }
78
79 #[must_use]
80 pub const fn lower_bound() -> Self {
81 Self::Int(i64::MIN)
82 }
83
84 #[must_use]
85 pub const fn upper_bound() -> Self {
86 Self::Unit
87 }
88
89 const fn variant_rank(&self) -> u8 {
90 self.tag()
91 }
92
93 #[must_use]
94 pub fn to_bytes(&self) -> [u8; Self::STORED_SIZE] {
95 let mut buf = [0u8; Self::STORED_SIZE];
96
97 buf[Self::TAG_OFFSET] = self.tag();
99 let payload = &mut buf[Self::PAYLOAD_OFFSET..];
100
101 debug_assert_eq!(payload.len(), Self::PAYLOAD_SIZE);
102
103 #[allow(clippy::cast_possible_truncation)]
105 match self {
106 Self::Account(v) => {
107 let bytes = v.to_bytes();
108 debug_assert_eq!(bytes.len(), Self::ACCOUNT_MAX_SIZE);
109 payload[..bytes.len()].copy_from_slice(&bytes);
110 }
111
112 Self::Int(v) => {
113 let biased = (*v).cast_unsigned() ^ (1u64 << 63);
115 payload[..Self::INT_SIZE].copy_from_slice(&biased.to_be_bytes());
116 }
117
118 Self::Uint(v) => {
119 payload[..Self::UINT_SIZE].copy_from_slice(&v.to_be_bytes());
120 }
121
122 Self::Timestamp(v) => {
123 payload[..Self::TIMESTAMP_SIZE].copy_from_slice(&v.get().to_be_bytes());
124 }
125
126 Self::Principal(v) => {
127 let bytes = v.to_bytes();
128 let len = bytes.len();
129 assert!(
130 (1..=Principal::MAX_LENGTH_IN_BYTES as usize).contains(&len),
131 "invalid Key principal length"
132 );
133 payload[0] = len as u8;
134 if len > 0 {
135 payload[1..=len].copy_from_slice(&bytes[..len]);
136 }
137 }
138
139 Self::Subaccount(v) => {
140 let bytes = v.to_array();
141 debug_assert_eq!(bytes.len(), Self::SUBACCOUNT_SIZE);
142 payload[..Self::SUBACCOUNT_SIZE].copy_from_slice(&bytes);
143 }
144
145 Self::Ulid(v) => {
146 payload[..Self::ULID_SIZE].copy_from_slice(&v.to_bytes());
147 }
148
149 Self::Unit => {}
150 }
151
152 buf
153 }
154
155 pub fn try_from_bytes(bytes: &[u8]) -> Result<Self, &'static str> {
156 if bytes.len() != Self::STORED_SIZE {
157 return Err("corrupted Key: invalid size");
158 }
159
160 let tag = bytes[Self::TAG_OFFSET];
161 let payload = &bytes[Self::PAYLOAD_OFFSET..];
162
163 let ensure_zero_padding = |used: usize, context: &str| {
164 if payload[used..].iter().all(|&b| b == 0) {
165 Ok(())
166 } else {
167 Err(match context {
168 "account" => "corrupted Key: non-zero account padding",
169 "int" => "corrupted Key: non-zero int padding",
170 "principal" => "corrupted Key: non-zero principal padding",
171 "subaccount" => "corrupted Key: non-zero subaccount padding",
172 "timestamp" => "corrupted Key: non-zero timestamp padding",
173 "uint" => "corrupted Key: non-zero uint padding",
174 "ulid" => "corrupted Key: non-zero ulid padding",
175 "unit" => "corrupted Key: non-zero unit padding",
176 _ => "corrupted Key: non-zero padding",
177 })
178 }
179 };
180
181 match tag {
182 Self::TAG_ACCOUNT => {
183 let end = Account::STORED_SIZE as usize;
184 ensure_zero_padding(end, "account")?;
185 let account = Account::try_from_bytes(&payload[..end])?;
186 Ok(Self::Account(account))
187 }
188
189 Self::TAG_INT => {
190 let mut buf = [0u8; Self::INT_SIZE];
191 buf.copy_from_slice(&payload[..Self::INT_SIZE]);
192 let biased = u64::from_be_bytes(buf);
193 ensure_zero_padding(Self::INT_SIZE, "int")?;
194 Ok(Self::Int((biased ^ (1u64 << 63)).cast_signed()))
195 }
196
197 Self::TAG_PRINCIPAL => {
198 let len = payload[0] as usize;
199 if !(1..=Principal::MAX_LENGTH_IN_BYTES as usize).contains(&len) {
200 return Err("corrupted Key: invalid principal length");
201 }
202 let end = 1 + len;
203 ensure_zero_padding(end, "principal")?;
204 Ok(Self::Principal(Principal::from_slice(&payload[1..end])))
205 }
206
207 Self::TAG_SUBACCOUNT => {
208 let mut buf = [0u8; Self::SUBACCOUNT_SIZE];
209 buf.copy_from_slice(&payload[..Self::SUBACCOUNT_SIZE]);
210 ensure_zero_padding(Self::SUBACCOUNT_SIZE, "subaccount")?;
211 Ok(Self::Subaccount(Subaccount::from_array(buf)))
212 }
213
214 Self::TAG_TIMESTAMP => {
215 let mut buf = [0u8; Self::TIMESTAMP_SIZE];
216 buf.copy_from_slice(&payload[..Self::TIMESTAMP_SIZE]);
217 ensure_zero_padding(Self::TIMESTAMP_SIZE, "timestamp")?;
218 Ok(Self::Timestamp(Timestamp::from(u64::from_be_bytes(buf))))
219 }
220
221 Self::TAG_UINT => {
222 let mut buf = [0u8; Self::UINT_SIZE];
223 buf.copy_from_slice(&payload[..Self::UINT_SIZE]);
224 ensure_zero_padding(Self::UINT_SIZE, "uint")?;
225 Ok(Self::Uint(u64::from_be_bytes(buf)))
226 }
227
228 Self::TAG_ULID => {
229 let mut buf = [0u8; Self::ULID_SIZE];
230 buf.copy_from_slice(&payload[..Self::ULID_SIZE]);
231 ensure_zero_padding(Self::ULID_SIZE, "ulid")?;
232 Ok(Self::Ulid(Ulid::from_bytes(buf)))
233 }
234
235 Self::TAG_UNIT => {
236 ensure_zero_padding(0, "unit")?;
237 Ok(Self::Unit)
238 }
239
240 _ => Err("corrupted Key: invalid tag"),
241 }
242 }
243}
244
245impl FieldValue for Key {
246 fn to_value(&self) -> Value {
247 match self {
248 Self::Account(v) => Value::Account(*v),
249 Self::Int(v) => Value::Int(*v),
250 Self::Principal(v) => Value::Principal(*v),
251 Self::Subaccount(v) => Value::Subaccount(*v),
252 Self::Timestamp(v) => Value::Timestamp(*v),
253 Self::Uint(v) => Value::Uint(*v),
254 Self::Ulid(v) => Value::Ulid(*v),
255 Self::Unit => Value::Unit,
256 }
257 }
258}
259
260impl From<()> for Key {
261 fn from((): ()) -> Self {
262 Self::Unit
263 }
264}
265
266impl From<Unit> for Key {
267 fn from(_: Unit) -> Self {
268 Self::Unit
269 }
270}
271
272impl PartialEq<()> for Key {
273 fn eq(&self, (): &()) -> bool {
274 matches!(self, Self::Unit)
275 }
276}
277
278impl PartialEq<Key> for () {
279 fn eq(&self, other: &Key) -> bool {
280 other == self
281 }
282}
283
284macro_rules! impl_from_key {
286 ( $( $ty:ty => $variant:ident ),* $(,)? ) => {
287 $(
288 impl From<$ty> for Key {
289 fn from(v: $ty) -> Self {
290 Self::$variant(v.into())
291 }
292 }
293 )*
294 }
295}
296
297macro_rules! impl_eq_key {
299 ( $( $ty:ty => $variant:ident ),* $(,)? ) => {
300 $(
301 impl PartialEq<$ty> for Key {
302 fn eq(&self, other: &$ty) -> bool {
303 matches!(self, Self::$variant(val) if val == other)
304 }
305 }
306
307 impl PartialEq<Key> for $ty {
308 fn eq(&self, other: &Key) -> bool {
309 other == self
310 }
311 }
312 )*
313 }
314}
315
316impl_from_key! {
317 Account => Account,
318 i8 => Int,
319 i16 => Int,
320 i32 => Int,
321 i64 => Int,
322 Principal => Principal,
323 WrappedPrincipal => Principal,
324 Subaccount => Subaccount,
325 Timestamp => Timestamp,
326 u8 => Uint,
327 u16 => Uint,
328 u32 => Uint,
329 u64 => Uint,
330 Ulid => Ulid,
331}
332
333impl_eq_key! {
334 Account => Account,
335 i64 => Int,
336 Principal => Principal,
337 Subaccount => Subaccount,
338 Timestamp => Timestamp,
339 u64 => Uint,
340 Ulid => Ulid,
341}
342
343impl Ord for Key {
344 fn cmp(&self, other: &Self) -> Ordering {
345 match (self, other) {
346 (Self::Account(a), Self::Account(b)) => Ord::cmp(a, b),
347 (Self::Int(a), Self::Int(b)) => Ord::cmp(a, b),
348 (Self::Principal(a), Self::Principal(b)) => Ord::cmp(a, b),
349 (Self::Uint(a), Self::Uint(b)) => Ord::cmp(a, b),
350 (Self::Ulid(a), Self::Ulid(b)) => Ord::cmp(a, b),
351 (Self::Subaccount(a), Self::Subaccount(b)) => Ord::cmp(a, b),
352 (Self::Timestamp(a), Self::Timestamp(b)) => Ord::cmp(a, b),
353
354 _ => Ord::cmp(&self.variant_rank(), &other.variant_rank()), }
356 }
357}
358
359impl PartialOrd for Key {
360 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
361 Some(Ord::cmp(self, other))
362 }
363}
364#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn key_max_size_is_bounded() {
374 let key = Key::max_storable();
375 let size = key.to_bytes().len();
376
377 assert!(
378 size <= Key::STORED_SIZE,
379 "serialized Key too large: got {size} bytes (limit {})",
380 Key::STORED_SIZE
381 );
382 }
383
384 #[test]
385 fn key_storable_round_trip() {
386 let keys = [
387 Key::Account(Account::dummy(1)),
388 Key::Int(-42),
389 Key::Principal(Principal::from_slice(&[1, 2, 3])),
390 Key::Subaccount(Subaccount::from_array([7; 32])),
391 Key::Timestamp(Timestamp::from_seconds(42)),
392 Key::Uint(42),
393 Key::Ulid(Ulid::from_bytes([9; 16])),
394 Key::Unit,
395 ];
396
397 for key in keys {
398 let bytes = key.to_bytes();
399 let decoded = Key::try_from_bytes(&bytes).unwrap();
400
401 assert_eq!(decoded, key, "Key round trip failed for {key:?}");
402 }
403 }
404
405 #[test]
406 fn key_is_exactly_fixed_size() {
407 let keys = [
408 Key::Account(Account::dummy(1)),
409 Key::Int(0),
410 Key::Principal(Principal::anonymous()),
411 Key::Subaccount(Subaccount::from_array([0; 32])),
412 Key::Timestamp(Timestamp::from_seconds(0)),
413 Key::Uint(0),
414 Key::Ulid(Ulid::from_bytes([0; 16])),
415 Key::Unit,
416 ];
417
418 for key in keys {
419 let len = key.to_bytes().len();
420 assert_eq!(
421 len,
422 Key::STORED_SIZE,
423 "Key serialized length must be exactly {}",
424 Key::STORED_SIZE
425 );
426 }
427 }
428
429 #[test]
430 fn key_ordering_is_total_and_stable() {
431 let keys = vec![
432 Key::Account(Account::new(
433 Principal::from_slice(&[1]),
434 None::<Subaccount>,
435 )),
436 Key::Account(Account::new(Principal::from_slice(&[1]), Some([0u8; 32]))),
437 Key::Int(-1),
438 Key::Int(0),
439 Key::Principal(Principal::from_slice(&[1])),
440 Key::Subaccount(Subaccount::from_array([1; 32])),
441 Key::Uint(0),
442 Key::Uint(1),
443 Key::Timestamp(Timestamp::from_seconds(1)),
444 Key::Ulid(Ulid::from_bytes([9; 16])),
445 Key::Unit,
446 ];
447
448 let mut sorted_by_ord = keys.clone();
449 sorted_by_ord.sort();
450
451 let mut sorted_by_bytes = keys;
452 sorted_by_bytes.sort_by_key(Key::to_bytes);
453
454 assert_eq!(
455 sorted_by_ord, sorted_by_bytes,
456 "Key Ord and byte ordering diverged"
457 );
458 }
459
460 #[test]
461 fn key_from_bytes_rejects_undersized() {
462 let bytes = vec![0u8; Key::STORED_SIZE - 1];
463 assert!(Key::try_from_bytes(&bytes).is_err());
464 }
465
466 #[test]
467 fn key_from_bytes_rejects_oversized() {
468 let bytes = vec![0u8; Key::STORED_SIZE + 1];
469 assert!(Key::try_from_bytes(&bytes).is_err());
470 }
471
472 #[test]
473 fn key_from_bytes_rejects_zero_principal_len() {
474 let mut bytes = Key::Principal(Principal::from_slice(&[1])).to_bytes();
475 bytes[Key::TAG_OFFSET] = Key::TAG_PRINCIPAL;
476 bytes[Key::PAYLOAD_OFFSET] = 0;
477 assert!(Key::try_from_bytes(&bytes).is_err());
478 }
479
480 #[test]
481 #[allow(clippy::cast_possible_truncation)]
482 fn key_from_bytes_rejects_oversized_principal_len() {
483 let mut bytes = Key::Principal(Principal::from_slice(&[1])).to_bytes();
484 bytes[Key::TAG_OFFSET] = Key::TAG_PRINCIPAL;
485 bytes[Key::PAYLOAD_OFFSET] = (Principal::MAX_LENGTH_IN_BYTES as u8) + 1;
486 assert!(Key::try_from_bytes(&bytes).is_err());
487 }
488
489 #[test]
490 fn key_from_bytes_rejects_principal_padding() {
491 let mut bytes = Key::Principal(Principal::from_slice(&[1])).to_bytes();
492 bytes[Key::TAG_OFFSET] = Key::TAG_PRINCIPAL;
493 bytes[Key::PAYLOAD_OFFSET] = 1;
494 bytes[Key::PAYLOAD_OFFSET + 2] = 1;
495 assert!(Key::try_from_bytes(&bytes).is_err());
496 }
497
498 #[test]
499 fn key_from_bytes_rejects_account_padding() {
500 let mut bytes = Key::Account(Account::new(
501 Principal::from_slice(&[1]),
502 None::<Subaccount>,
503 ))
504 .to_bytes();
505 bytes[Key::TAG_OFFSET] = Key::TAG_ACCOUNT;
506 bytes[Key::PAYLOAD_OFFSET + Account::STORED_SIZE as usize] = 1;
507 assert!(Key::try_from_bytes(&bytes).is_err());
508 }
509
510 #[test]
511 fn key_from_bytes_rejects_int_padding() {
512 let mut bytes = Key::Int(0).to_bytes();
513 bytes[Key::TAG_OFFSET] = Key::TAG_INT;
514 bytes[Key::PAYLOAD_OFFSET + Key::INT_SIZE] = 1;
515 assert!(Key::try_from_bytes(&bytes).is_err());
516 }
517
518 #[test]
519 fn key_from_bytes_rejects_uint_padding() {
520 let mut bytes = Key::Uint(0).to_bytes();
521 bytes[Key::TAG_OFFSET] = Key::TAG_UINT;
522 bytes[Key::PAYLOAD_OFFSET + Key::UINT_SIZE] = 1;
523 assert!(Key::try_from_bytes(&bytes).is_err());
524 }
525
526 #[test]
527 fn key_from_bytes_rejects_timestamp_padding() {
528 let mut bytes = Key::Timestamp(Timestamp::from_seconds(0)).to_bytes();
529 bytes[Key::TAG_OFFSET] = Key::TAG_TIMESTAMP;
530 bytes[Key::PAYLOAD_OFFSET + Key::TIMESTAMP_SIZE] = 1;
531 assert!(Key::try_from_bytes(&bytes).is_err());
532 }
533
534 #[test]
535 fn key_from_bytes_rejects_subaccount_padding() {
536 let mut bytes = Key::Subaccount(Subaccount::from_array([0; 32])).to_bytes();
537 bytes[Key::TAG_OFFSET] = Key::TAG_SUBACCOUNT;
538 bytes[Key::PAYLOAD_OFFSET + Key::SUBACCOUNT_SIZE] = 1;
539 assert!(Key::try_from_bytes(&bytes).is_err());
540 }
541
542 #[test]
543 fn key_from_bytes_rejects_ulid_padding() {
544 let mut bytes = Key::Ulid(Ulid::from_bytes([0; 16])).to_bytes();
545 bytes[Key::TAG_OFFSET] = Key::TAG_ULID;
546 bytes[Key::PAYLOAD_OFFSET + Key::ULID_SIZE] = 1;
547 assert!(Key::try_from_bytes(&bytes).is_err());
548 }
549
550 #[test]
551 fn key_from_bytes_rejects_unit_padding() {
552 let mut bytes = Key::Unit.to_bytes();
553 bytes[Key::TAG_OFFSET] = Key::TAG_UNIT;
554 bytes[Key::PAYLOAD_OFFSET] = 1;
555 assert!(Key::try_from_bytes(&bytes).is_err());
556 }
557
558 #[test]
559 fn principal_encoding_respects_max_size() {
560 let max = Principal::from_slice(&[0xFF; 29]);
561 let key = Key::Principal(max);
562
563 let bytes = key.to_bytes();
564 assert_eq!(bytes.len(), Key::STORED_SIZE);
565 }
566
567 #[test]
568 #[allow(clippy::cast_possible_truncation)]
569 fn key_decode_fuzz_roundtrip_is_canonical() {
570 const RUNS: u64 = 1_000;
571
572 let mut seed = 0x1234_5678_u64;
573 for _ in 0..RUNS {
574 let mut bytes = [0u8; Key::STORED_SIZE];
575 for b in &mut bytes {
576 seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
577 *b = (seed >> 24) as u8;
578 }
579
580 if let Ok(decoded) = Key::try_from_bytes(&bytes) {
581 let re = decoded.to_bytes();
582 assert_eq!(bytes, re, "decoded key must be canonical");
583 }
584 }
585 }
586}