1use antimatter_api::models::tag_type_field::TagTypeField;
4use antimatter_api::models::{Tag, TagSetSpanTagsInner};
5use ciborium::de::from_reader;
6use serde::ser::{Error as SerdeError, Serializer};
7use serde::{Deserialize, Deserializer};
8use serde_repr::{Deserialize_repr, Serialize_repr};
9use serde_tuple::{Deserialize_tuple, Serialize_tuple};
10use std::collections::HashMap;
11use std::error::Error;
12use std::fmt;
13use std::io::Read;
14
15#[doc(hidden)]
16pub const VERSION_STRING: &str = "v0";
17#[doc(hidden)]
18pub const NONCE_SIZE: usize = 12; #[doc(hidden)]
20pub const NONCE_BLOCK_SIZE: usize = 6;
21#[doc(hidden)]
22pub const KEY_SIZE: usize = 32;
23#[doc(hidden)]
24pub const BUNDLE_MAGIC_BYTES: [u8; 8] = [249, 216, 132, 83, 144, 201, 2, 104];
25#[doc(hidden)]
26pub const BASE58_CHARSET: &str = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz";
27
28#[derive(Clone, Debug)]
31pub enum CapsuleError {
32 Generic(String),
35 DEKNotFound(String),
38 DEKUnexpectedType(String),
41 DEKWrongLength(String),
44 CBOREncodeFailed(String),
47 CBORDecodeFailed(String),
50 EncryptionFailure(String),
52 DecryptionFailure(String),
54 BadMagic(String),
57 UnsupportedVersion(String),
60 CapsuleAlreadySealed(String),
63 StreamWriteFailure(String),
66 StreamReadFailure(String),
69 FileIOError(String),
72 InsufficientPermissions(String),
75 DRDecryptError(String),
78 CapsuleOpenError(String),
81 CapsuleUpdateError(String),
84 EndOfRow,
85 EndOfCapsule,
86 CapsuleAccessDeniedByPolicy,
87 RowAccessDeniedByPolicy,
88}
89
90impl AsRef<str> for CapsuleError {
91 fn as_ref(&self) -> &str {
92 match self {
93 CapsuleError::Generic(msg) => msg,
94 CapsuleError::DEKNotFound(msg) => msg,
95 CapsuleError::DEKUnexpectedType(msg) => msg,
96 CapsuleError::DEKWrongLength(msg) => msg,
97 CapsuleError::CBOREncodeFailed(msg) => msg,
98 CapsuleError::CBORDecodeFailed(msg) => msg,
99 CapsuleError::EncryptionFailure(msg) => msg,
100 CapsuleError::DecryptionFailure(msg) => msg,
101 CapsuleError::BadMagic(msg) => msg,
102 CapsuleError::UnsupportedVersion(msg) => msg,
103 CapsuleError::CapsuleAlreadySealed(msg) => msg,
104 CapsuleError::StreamWriteFailure(msg) => msg,
105 CapsuleError::StreamReadFailure(msg) => msg,
106 CapsuleError::FileIOError(msg) => msg,
107 CapsuleError::InsufficientPermissions(msg) => msg,
108 CapsuleError::DRDecryptError(msg) => msg,
109 CapsuleError::CapsuleUpdateError(msg) => msg,
110 CapsuleError::CapsuleOpenError(msg) => msg,
111 CapsuleError::EndOfRow => "end of row",
112 CapsuleError::EndOfCapsule => "end of capsule",
113 CapsuleError::CapsuleAccessDeniedByPolicy => "capsule access denied by policy",
114 CapsuleError::RowAccessDeniedByPolicy => "row access denied by policy",
115 }
116 }
117}
118
119#[doc(hidden)]
120pub type PlaintextHeader = HashMap<String, Vec<u8>>;
121
122#[doc(hidden)]
123pub type EncryptedHeader = HashMap<String, Vec<u8>>;
124
125#[doc(hidden)]
126pub enum HeaderValue {
127 Str(String),
128 Bytes(Vec<u8>),
129}
130
131impl fmt::Display for CapsuleError {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 match self {
135 CapsuleError::Generic(msg) => {
136 write!(f, "{}", msg)
137 }
138 CapsuleError::DEKNotFound(msg) => {
139 write!(f, "DEK not found: {}", msg)
140 }
141 CapsuleError::DEKUnexpectedType(msg) => {
142 write!(f, "DEK has an unexpected type: {}", msg)
143 }
144 CapsuleError::DEKWrongLength(msg) => {
145 write!(f, "DEK has the wrong length: {}", msg)
146 }
147 CapsuleError::CBOREncodeFailed(msg) => {
148 write!(f, "failed to encode CBOR: {}", msg)
149 }
150 CapsuleError::CBORDecodeFailed(msg) => {
151 write!(f, "failed to decode CBOR: {}", msg)
152 }
153 CapsuleError::EncryptionFailure(msg) => {
154 write!(f, "failed to encrypt data: {}", msg)
155 }
156 CapsuleError::DecryptionFailure(msg) => {
157 write!(f, "failed to decrypt data: {}", msg)
158 }
159 CapsuleError::BadMagic(msg) => {
160 write!(f, "bad magic value detected: {}", msg)
161 }
162 CapsuleError::UnsupportedVersion(msg) => {
163 write!(f, "unsupported capsule version: {}", msg)
164 }
165 CapsuleError::CapsuleAlreadySealed(msg) => {
166 write!(f, "capsule is already sealed: {}", msg)
167 }
168 CapsuleError::StreamWriteFailure(msg) => {
169 write!(f, "failed to write to stream: {}", msg)
170 }
171 CapsuleError::StreamReadFailure(msg) => {
172 write!(f, "failed to read from stream: {}", msg)
173 }
174 CapsuleError::FileIOError(msg) => {
175 write!(f, "failed file IO operation: {}", msg)
176 }
177 CapsuleError::InsufficientPermissions(msg) => {
178 write!(f, "insufficient permissions: {}", msg)
179 }
180 CapsuleError::DRDecryptError(msg) => {
181 write!(f, "failed to decrypt the disaster recovery header: {}", msg)
182 }
183 CapsuleError::CapsuleOpenError(msg) => {
184 write!(f, "failed to open capsule: {}", msg)
185 }
186 CapsuleError::CapsuleUpdateError(msg) => {
187 write!(f, "failed to apply updates to the capsule: {}", msg)
188 }
189 CapsuleError::EndOfRow => {
190 write!(f, "end of row")
191 }
192 CapsuleError::EndOfCapsule => {
193 write!(f, "end of capsule")
194 }
195 CapsuleError::CapsuleAccessDeniedByPolicy => {
196 write!(f, "capsule access denied by policy")
197 }
198 CapsuleError::RowAccessDeniedByPolicy => {
199 write!(f, "row access denied by policy")
200 }
201 }
202 }
203}
204
205#[derive(Clone, Serialize_tuple, Deserialize_tuple, Debug, PartialEq)]
212pub struct Column {
213 pub name: String,
216 pub tags: Vec<CapsuleTag>,
218 pub skip_classification: bool,
220}
221
222#[doc(hidden)]
223#[derive(Clone, Serialize_tuple, Deserialize_tuple, Debug)]
224pub struct DataElement {
225 #[serde(with = "serde_bytes")]
226 pub data: Vec<u8>,
227 pub tags: Vec<SpanTag>,
228}
229
230pub struct CellReader {
235 pub data: Box<dyn Read + Send>,
238 pub tags: Vec<SpanTag>,
241}
242
243pub struct RowReader {
248 pub cells: Vec<CellReader>,
250 pub tags: Vec<CapsuleTag>,
252}
253
254impl CellReader {
255 pub fn new<R: Read + Send + 'static>(
265 tags: Vec<SpanTag>,
266 data: R,
267 ) -> Result<Self, CapsuleError> {
268 Ok(Self {
269 data: Box::new(data),
270 tags,
271 })
272 }
273
274 pub fn copy_data(&mut self) -> Result<Vec<u8>, CapsuleError> {
281 let mut result: Vec<u8> = Vec::new();
282 self.data
283 .read_to_end(&mut result)
284 .map_err(|e| CapsuleError::Generic(format!("reading cell data: {}", e)))?;
285 let _ = std::mem::replace(
286 &mut self.data,
287 Box::new(std::io::Cursor::new(result.clone())),
288 );
289 Ok(result)
290 }
291}
292
293impl Read for CellReader {
294 fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
295 self.data.read(&mut buf[..])
296 }
297}
298
299#[doc(hidden)]
300#[derive(Debug, Clone, Serialize_tuple, Deserialize_tuple)]
301pub struct FileHeader {
302 pub magic: [u8; BUNDLE_MAGIC_BYTES.len()],
303 pub version: u8,
304}
305
306impl FileHeader {
307 pub fn new(version: u8) -> Self {
308 FileHeader {
309 magic: BUNDLE_MAGIC_BYTES,
310 version,
311 }
312 }
313
314 pub fn from_reader<R: Read>(r: R) -> Result<Self, CapsuleError> {
315 from_reader::<FileHeader, R>(r)
316 .map_err(|e| CapsuleError::Generic(format!("parsing FileHeader: {}", e)))
317 }
318
319 pub fn is_capsule_bytes(content: &[u8]) -> bool {
320 let header = from_reader::<FileHeader, &[u8]>(content);
321 match header.is_ok() {
322 true => header.unwrap().magic == BUNDLE_MAGIC_BYTES,
323 false => false,
324 }
325 }
326
327 pub fn is_capsule<R: Read + 'static>(
334 mut r: R,
335 ) -> Result<(Box<dyn Read + 'static>, bool), CapsuleError> {
336 let len = 18;
340 let mut handle = r.by_ref().take(len as u64);
341 let mut header_bytes: Vec<u8> = Vec::new();
342
343 let n = handle
344 .read_to_end(&mut header_bytes)
345 .map_err(|e| CapsuleError::FileIOError(format!("reading capsule file: {}", e)))?;
346
347 if n < len {
348 return Ok((Box::new(std::io::Cursor::new(header_bytes)), false));
350 }
351
352 Ok((
353 Box::new(std::io::Cursor::new(header_bytes.clone()).chain(r)),
354 Self::is_capsule_bytes(&header_bytes),
355 ))
356 }
357}
358
359#[doc(hidden)]
360#[derive(Serialize_tuple, Deserialize_tuple, Clone)]
361pub struct BundleHeaderV2 {
362 #[serde(
364 serialize_with = "serialize_domain_id",
365 deserialize_with = "deserialize_domain_id"
366 )]
367 pub domain_id: String,
368 pub created: i64,
369 pub is_bundle: bool,
370}
371
372impl BundleHeaderV2 {
373 pub fn from_reader<R>(input: &mut R) -> Result<Self, CapsuleError>
374 where
375 R: Read,
376 {
377 ciborium::from_reader(input)
378 .map_err(|e| CapsuleError::Generic(format!("deserializing bundle header: {}", e)))
379 }
380}
381
382#[doc(hidden)]
383#[derive(Serialize_tuple, Deserialize_tuple, Clone)]
384pub struct BundleHeaderV3 {
385 #[serde(
387 serialize_with = "serialize_domain_id",
388 deserialize_with = "deserialize_domain_id"
389 )]
390 pub domain_id: String,
391 pub created: i64,
392 pub is_bundle: bool,
393}
394
395impl BundleHeaderV3 {
396 pub fn from_reader<R>(input: &mut R) -> Result<Self, CapsuleError>
397 where
398 R: Read,
399 {
400 ciborium::from_reader(input)
401 .map_err(|e| CapsuleError::Generic(format!("deserializing bundle header: {}", e)))
402 }
403}
404
405#[doc(hidden)]
406#[derive(Serialize_tuple, Deserialize_tuple, Clone)]
407pub struct CapsuleHeader {
408 #[serde(with = "serde_bytes")]
409 pub encrypted_dek: Vec<u8>,
410 pub key_id: u64,
411 #[serde(
412 serialize_with = "serialize_domain_id",
413 deserialize_with = "deserialize_domain_id"
414 )]
415 pub domain_id: String,
416 #[serde(
417 serialize_with = "serialize_capsule_id",
418 deserialize_with = "deserialize_capsule_id"
419 )]
420 pub capsule_id: String,
421 #[serde(skip_serializing_if = "Option::is_none", with = "serde_bytes", default)]
422 pub disaster_recovery_token: Option<Vec<u8>>,
423}
424
425impl CapsuleHeader {
426 pub fn from_reader<R>(input: &mut R) -> Result<Self, CapsuleError>
427 where
428 R: Read,
429 {
430 ciborium::from_reader(input)
431 .map_err(|e| CapsuleError::Generic(format!("deserializing capsule header: {}", e)))
432 }
433}
434
435#[doc(hidden)]
436#[derive(Serialize_tuple, Deserialize_tuple, Clone, PartialEq)]
437pub struct HookInfo {
438 pub name: String,
439 pub version: String,
440}
441
442#[derive(Eq, Hash, Clone, Serialize_repr, Deserialize_repr, Debug, PartialEq, PartialOrd)]
444#[repr(u8)]
445pub enum TagType {
446 Unary,
448 Str,
450 Number,
452 Boolean,
454 Date,
456}
457
458impl From<TagTypeField> for TagType {
460 fn from(tag_type: TagTypeField) -> Self {
461 match tag_type {
462 TagTypeField::String => TagType::Str,
463 TagTypeField::Number => TagType::Number,
464 TagTypeField::Boolean => TagType::Boolean,
465 TagTypeField::Date => TagType::Date,
466 TagTypeField::Unary => TagType::Unary,
467 }
468 }
469}
470
471impl From<TagType> for TagTypeField {
472 fn from(tag_type: TagType) -> Self {
473 match tag_type {
474 TagType::Str => TagTypeField::String,
475 TagType::Number => TagTypeField::Number,
476 TagType::Boolean => TagTypeField::Boolean,
477 TagType::Date => TagTypeField::Date,
478 TagType::Unary => TagTypeField::Unary,
479 }
480 }
481}
482
483#[derive(Clone, Serialize_tuple, Deserialize_tuple, Debug, Eq, Hash)]
486pub struct CapsuleTag {
487 pub name: String,
489 pub tag_type: TagType,
491 pub value: String,
493 pub source: String,
496 pub hook_version: (i32, i32, i32),
498}
499
500impl CapsuleTag {
501 pub fn from_tag(tag: &Tag) -> Result<CapsuleTag, CapsuleError> {
509 let tuple = convert_to_tuple(&tag.hook_version.clone().unwrap())?;
510 Ok(CapsuleTag {
511 name: tag.name.clone(),
512 tag_type: TagType::from(tag.r#type),
513 value: tag.value.clone(),
514 source: tag.source.clone(),
515 hook_version: tuple,
516 })
517 }
518}
519
520impl PartialEq for CapsuleTag {
521 fn eq(&self, other: &Self) -> bool {
522 self.name == other.name && self.tag_type == other.tag_type && self.value == other.value
523 }
524}
525
526impl From<CapsuleTag> for Tag {
527 fn from(capsule_tag: CapsuleTag) -> Self {
528 Self {
529 name: capsule_tag.name.clone(),
530 r#type: match capsule_tag.tag_type {
531 TagType::Str => TagTypeField::String,
532 TagType::Number => TagTypeField::Number,
533 TagType::Boolean => TagTypeField::Boolean,
534 TagType::Date => TagTypeField::Date,
535 TagType::Unary => TagTypeField::Unary,
536 },
537 value: capsule_tag.value.clone(),
538 source: capsule_tag.source.clone(),
539 hook_version: Some(format!(
540 "{}.{}.{}",
541 capsule_tag.hook_version.0, capsule_tag.hook_version.1, capsule_tag.hook_version.2
542 )),
543 }
544 }
545}
546
547#[derive(Clone, Serialize_tuple, Deserialize_tuple, Debug, PartialEq, Eq)]
549pub struct SpanTag {
550 pub tag: CapsuleTag,
552 pub start: usize,
554 pub end: usize,
556}
557
558impl SpanTag {
559 pub fn from_api_span_inner(inner: &TagSetSpanTagsInner) -> Result<Vec<SpanTag>, CapsuleError> {
567 let mut output: Vec<SpanTag> = Vec::new();
568 for tag in &inner.tags {
569 output.push(SpanTag {
570 tag: CapsuleTag::from_tag(tag)?,
571 start: inner.start as usize,
572 end: inner.end as usize,
573 });
574 }
575 Ok(output)
576 }
577}
578
579impl From<SpanTag> for TagSetSpanTagsInner {
580 fn from(span_tag: SpanTag) -> Self {
581 Self {
582 start: span_tag.start as i64,
583 end: span_tag.end as i64,
584 tags: vec![span_tag.tag.into()],
585 }
586 }
587}
588
589#[doc(hidden)]
591#[derive(PartialEq, Debug, Copy, Clone)]
592pub enum PolicyDecision {
593 Allow,
594 Redact,
595 Tokenize,
596 DenyRecord,
597 DenyCapsule,
598 NoMatch,
599}
600
601fn convert_to_tuple(input: &str) -> Result<(i32, i32, i32), CapsuleError> {
605 let parts: Vec<&str> = input.split('.').collect();
606
607 if parts.len() != 3 {
608 return Err(CapsuleError::Generic(
609 "Input string does not contain exactly three parts".to_string(),
610 ));
611 }
612
613 let part1 = parts[0].parse::<i32>();
614 let part2 = parts[1].parse::<i32>();
615 let part3 = parts[2].parse::<i32>();
616
617 match (part1, part2, part3) {
618 (Ok(p1), Ok(p2), Ok(p3)) => Ok((p1, p2, p3)),
619 _ => Err(CapsuleError::Generic(
620 "Failed to parse one or more parts into an integer".to_string(),
621 )),
622 }
623}
624
625fn base58_to_packed_bytes(input: &str) -> Result<Vec<u8>, Box<dyn Error>> {
626 let bits: Vec<u8> = input
627 .chars()
628 .map(|c| {
629 BASE58_CHARSET
630 .find(c)
631 .map(|idx| idx as u8)
632 .ok_or_else(|| "Invalid base58 character".into())
633 })
634 .collect::<Result<Vec<u8>, Box<dyn Error>>>()?;
635
636 let mut bytes = Vec::new();
637 let mut accumulator = 0u16; let mut bits_in_accumulator = 0;
639
640 for bit_value in bits {
641 accumulator <<= 6;
642 accumulator |= bit_value as u16;
643 bits_in_accumulator += 6;
644
645 if bits_in_accumulator >= 8 {
646 bits_in_accumulator -= 8;
647 bytes.push((accumulator >> bits_in_accumulator) as u8);
648 }
649 }
650
651 if bits_in_accumulator > 0 {
653 bytes.push((accumulator << (8 - bits_in_accumulator)) as u8);
654 }
655 Ok(bytes)
656}
657
658fn serialize_base58<S>(prefix: &str, input: &str, serializer: S) -> Result<S::Ok, S::Error>
659where
660 S: Serializer,
661{
662 let stripped = input.strip_prefix(prefix).ok_or_else(|| {
663 S::Error::custom(format!("invalid ID format (must begin with {})", prefix))
664 })?;
665 serializer.serialize_bytes(
666 &base58_to_packed_bytes(stripped)
667 .map_err(S::Error::custom)?
668 .to_vec(),
669 )
670}
671
672#[doc(hidden)]
673pub fn serialize_domain_id<S>(domain_id: &str, serializer: S) -> Result<S::Ok, S::Error>
674where
675 S: Serializer,
676{
677 serialize_base58("dm-", domain_id, serializer)
678}
679
680#[doc(hidden)]
681pub fn serialize_capsule_id<S>(capsule_id: &str, serializer: S) -> Result<S::Ok, S::Error>
682where
683 S: Serializer,
684{
685 serialize_base58("ca-", capsule_id, serializer)
686}
687
688fn unpack_base58_bytes(input: &[u8]) -> Result<String, Box<dyn Error>> {
689 let mut bits = Vec::new();
690 let mut accumulator = 0u16; let mut bits_in_accumulator = 0;
692
693 for &byte in input {
694 accumulator = (accumulator << 8) | (byte as u16);
695 bits_in_accumulator += 8;
696
697 while bits_in_accumulator >= 6 {
698 bits_in_accumulator -= 6;
699 let index = ((accumulator >> bits_in_accumulator) & 0x3F) as usize; bits.push(index);
701 }
702 }
703
704 if bits_in_accumulator > 0 {
705 let index = ((accumulator << (6 - bits_in_accumulator)) & 0x3F) as usize;
706 bits.push(index);
707 }
708
709 let result: String = bits
711 .iter()
712 .map(|&idx| BASE58_CHARSET.chars().nth(idx).ok_or("Invalid 6-bit value"))
713 .collect::<Result<String, &str>>()?;
714
715 Ok(result)
716}
717
718fn deserialize_base58<'de, D>(len: usize, prefix: &str, deserializer: D) -> Result<String, D::Error>
719where
720 D: Deserializer<'de>,
721{
722 let packed: Vec<u8> = Deserialize::deserialize(deserializer)?;
723 let suffix: String = unpack_base58_bytes(packed.as_slice())
724 .map_err(serde::de::Error::custom)?
725 .chars()
726 .take(len)
727 .collect();
728 Ok(format!("{}{}", prefix, suffix))
729}
730
731#[doc(hidden)]
732pub fn deserialize_domain_id<'de, D>(deserializer: D) -> Result<String, D::Error>
733where
734 D: Deserializer<'de>,
735{
736 deserialize_base58(11, "dm-", deserializer)
737}
738
739#[doc(hidden)]
740pub fn deserialize_capsule_id<'de, D>(deserializer: D) -> Result<String, D::Error>
741where
742 D: Deserializer<'de>,
743{
744 deserialize_base58(22, "ca-", deserializer)
745}