use antimatter_api::models::tag_type_field::TagTypeField;
use antimatter_api::models::{Tag, TagSetSpanTagsInner};
use ciborium::de::from_reader;
use serde::ser::{Error as SerdeError, Serializer};
use serde::{Deserialize, Deserializer};
use serde_repr::{Deserialize_repr, Serialize_repr};
use serde_tuple::{Deserialize_tuple, Serialize_tuple};
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::io::Read;
#[doc(hidden)]
pub const VERSION_STRING: &str = "v0";
#[doc(hidden)]
pub const NONCE_SIZE: usize = 12; #[doc(hidden)]
pub const NONCE_BLOCK_SIZE: usize = 6;
#[doc(hidden)]
pub const KEY_SIZE: usize = 32;
#[doc(hidden)]
pub const BUNDLE_MAGIC_BYTES: [u8; 8] = [249, 216, 132, 83, 144, 201, 2, 104];
#[doc(hidden)]
pub const BASE58_CHARSET: &str = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz";
#[derive(Clone, Debug)]
pub enum CapsuleError {
Generic(String),
DEKNotFound(String),
DEKUnexpectedType(String),
DEKWrongLength(String),
CBOREncodeFailed(String),
CBORDecodeFailed(String),
EncryptionFailure(String),
DecryptionFailure(String),
BadMagic(String),
UnsupportedVersion(String),
CapsuleAlreadySealed(String),
StreamWriteFailure(String),
StreamReadFailure(String),
FileIOError(String),
InsufficientPermissions(String),
DRDecryptError(String),
CapsuleOpenError(String),
CapsuleUpdateError(String),
EndOfRow,
EndOfCapsule,
CapsuleAccessDeniedByPolicy,
RowAccessDeniedByPolicy,
}
impl AsRef<str> for CapsuleError {
fn as_ref(&self) -> &str {
match self {
CapsuleError::Generic(msg) => msg,
CapsuleError::DEKNotFound(msg) => msg,
CapsuleError::DEKUnexpectedType(msg) => msg,
CapsuleError::DEKWrongLength(msg) => msg,
CapsuleError::CBOREncodeFailed(msg) => msg,
CapsuleError::CBORDecodeFailed(msg) => msg,
CapsuleError::EncryptionFailure(msg) => msg,
CapsuleError::DecryptionFailure(msg) => msg,
CapsuleError::BadMagic(msg) => msg,
CapsuleError::UnsupportedVersion(msg) => msg,
CapsuleError::CapsuleAlreadySealed(msg) => msg,
CapsuleError::StreamWriteFailure(msg) => msg,
CapsuleError::StreamReadFailure(msg) => msg,
CapsuleError::FileIOError(msg) => msg,
CapsuleError::InsufficientPermissions(msg) => msg,
CapsuleError::DRDecryptError(msg) => msg,
CapsuleError::CapsuleUpdateError(msg) => msg,
CapsuleError::CapsuleOpenError(msg) => msg,
CapsuleError::EndOfRow => "end of row",
CapsuleError::EndOfCapsule => "end of capsule",
CapsuleError::CapsuleAccessDeniedByPolicy => "capsule access denied by policy",
CapsuleError::RowAccessDeniedByPolicy => "row access denied by policy",
}
}
}
#[doc(hidden)]
pub type PlaintextHeader = HashMap<String, Vec<u8>>;
#[doc(hidden)]
pub type EncryptedHeader = HashMap<String, Vec<u8>>;
#[doc(hidden)]
pub enum HeaderValue {
Str(String),
Bytes(Vec<u8>),
}
impl fmt::Display for CapsuleError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CapsuleError::Generic(msg) => {
write!(f, "{}", msg)
}
CapsuleError::DEKNotFound(msg) => {
write!(f, "DEK not found: {}", msg)
}
CapsuleError::DEKUnexpectedType(msg) => {
write!(f, "DEK has an unexpected type: {}", msg)
}
CapsuleError::DEKWrongLength(msg) => {
write!(f, "DEK has the wrong length: {}", msg)
}
CapsuleError::CBOREncodeFailed(msg) => {
write!(f, "failed to encode CBOR: {}", msg)
}
CapsuleError::CBORDecodeFailed(msg) => {
write!(f, "failed to decode CBOR: {}", msg)
}
CapsuleError::EncryptionFailure(msg) => {
write!(f, "failed to encrypt data: {}", msg)
}
CapsuleError::DecryptionFailure(msg) => {
write!(f, "failed to decrypt data: {}", msg)
}
CapsuleError::BadMagic(msg) => {
write!(f, "bad magic value detected: {}", msg)
}
CapsuleError::UnsupportedVersion(msg) => {
write!(f, "unsupported capsule version: {}", msg)
}
CapsuleError::CapsuleAlreadySealed(msg) => {
write!(f, "capsule is already sealed: {}", msg)
}
CapsuleError::StreamWriteFailure(msg) => {
write!(f, "failed to write to stream: {}", msg)
}
CapsuleError::StreamReadFailure(msg) => {
write!(f, "failed to read from stream: {}", msg)
}
CapsuleError::FileIOError(msg) => {
write!(f, "failed file IO operation: {}", msg)
}
CapsuleError::InsufficientPermissions(msg) => {
write!(f, "insufficient permissions: {}", msg)
}
CapsuleError::DRDecryptError(msg) => {
write!(f, "failed to decrypt the disaster recovery header: {}", msg)
}
CapsuleError::CapsuleOpenError(msg) => {
write!(f, "failed to open capsule: {}", msg)
}
CapsuleError::CapsuleUpdateError(msg) => {
write!(f, "failed to apply updates to the capsule: {}", msg)
}
CapsuleError::EndOfRow => {
write!(f, "end of row")
}
CapsuleError::EndOfCapsule => {
write!(f, "end of capsule")
}
CapsuleError::CapsuleAccessDeniedByPolicy => {
write!(f, "capsule access denied by policy")
}
CapsuleError::RowAccessDeniedByPolicy => {
write!(f, "row access denied by policy")
}
}
}
}
#[derive(Clone, Serialize_tuple, Deserialize_tuple, Debug, PartialEq)]
pub struct Column {
pub name: String,
pub tags: Vec<CapsuleTag>,
pub skip_classification: bool,
}
#[doc(hidden)]
#[derive(Clone, Serialize_tuple, Deserialize_tuple, Debug)]
pub struct DataElement {
#[serde(with = "serde_bytes")]
pub data: Vec<u8>,
pub tags: Vec<SpanTag>,
}
pub struct CellReader {
pub data: Box<dyn Read + Send>,
pub tags: Vec<SpanTag>,
}
pub struct RowReader {
pub cells: Vec<CellReader>,
pub tags: Vec<CapsuleTag>,
}
impl CellReader {
pub fn new<R: Read + Send + 'static>(
tags: Vec<SpanTag>,
data: R,
) -> Result<Self, CapsuleError> {
Ok(Self {
data: Box::new(data),
tags,
})
}
pub fn copy_data(&mut self) -> Result<Vec<u8>, CapsuleError> {
let mut result: Vec<u8> = Vec::new();
self.data
.read_to_end(&mut result)
.map_err(|e| CapsuleError::Generic(format!("reading cell data: {}", e)))?;
let _ = std::mem::replace(
&mut self.data,
Box::new(std::io::Cursor::new(result.clone())),
);
Ok(result)
}
}
impl Read for CellReader {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
self.data.read(&mut buf[..])
}
}
#[doc(hidden)]
#[derive(Debug, Clone, Serialize_tuple, Deserialize_tuple)]
pub struct FileHeader {
pub magic: [u8; BUNDLE_MAGIC_BYTES.len()],
pub version: u8,
}
impl FileHeader {
pub fn new(version: u8) -> Self {
FileHeader {
magic: BUNDLE_MAGIC_BYTES,
version,
}
}
pub fn from_reader<R: Read>(r: R) -> Result<Self, CapsuleError> {
from_reader::<FileHeader, R>(r)
.map_err(|e| CapsuleError::Generic(format!("parsing FileHeader: {}", e)))
}
pub fn is_capsule_bytes(content: &[u8]) -> bool {
let header = from_reader::<FileHeader, &[u8]>(content);
match header.is_ok() {
true => header.unwrap().magic == BUNDLE_MAGIC_BYTES,
false => false,
}
}
pub fn is_capsule<R: Read + 'static>(
mut r: R,
) -> Result<(Box<dyn Read + 'static>, bool), CapsuleError> {
let len = 18;
let mut handle = r.by_ref().take(len as u64);
let mut header_bytes: Vec<u8> = Vec::new();
let n = handle
.read_to_end(&mut header_bytes)
.map_err(|e| CapsuleError::FileIOError(format!("reading capsule file: {}", e)))?;
if n < len {
return Ok((Box::new(std::io::Cursor::new(header_bytes)), false));
}
Ok((
Box::new(std::io::Cursor::new(header_bytes.clone()).chain(r)),
Self::is_capsule_bytes(&header_bytes),
))
}
}
#[doc(hidden)]
#[derive(Serialize_tuple, Deserialize_tuple, Clone)]
pub struct BundleHeaderV2 {
#[serde(
serialize_with = "serialize_domain_id",
deserialize_with = "deserialize_domain_id"
)]
pub domain_id: String,
pub created: i64,
pub is_bundle: bool,
}
impl BundleHeaderV2 {
pub fn from_reader<R>(input: &mut R) -> Result<Self, CapsuleError>
where
R: Read,
{
ciborium::from_reader(input)
.map_err(|e| CapsuleError::Generic(format!("deserializing bundle header: {}", e)))
}
}
#[doc(hidden)]
#[derive(Serialize_tuple, Deserialize_tuple, Clone)]
pub struct BundleHeaderV3 {
#[serde(
serialize_with = "serialize_domain_id",
deserialize_with = "deserialize_domain_id"
)]
pub domain_id: String,
pub created: i64,
pub is_bundle: bool,
}
impl BundleHeaderV3 {
pub fn from_reader<R>(input: &mut R) -> Result<Self, CapsuleError>
where
R: Read,
{
ciborium::from_reader(input)
.map_err(|e| CapsuleError::Generic(format!("deserializing bundle header: {}", e)))
}
}
#[doc(hidden)]
#[derive(Serialize_tuple, Deserialize_tuple, Clone)]
pub struct CapsuleHeader {
#[serde(with = "serde_bytes")]
pub encrypted_dek: Vec<u8>,
pub key_id: u64,
#[serde(
serialize_with = "serialize_domain_id",
deserialize_with = "deserialize_domain_id"
)]
pub domain_id: String,
#[serde(
serialize_with = "serialize_capsule_id",
deserialize_with = "deserialize_capsule_id"
)]
pub capsule_id: String,
#[serde(skip_serializing_if = "Option::is_none", with = "serde_bytes", default)]
pub disaster_recovery_token: Option<Vec<u8>>,
}
impl CapsuleHeader {
pub fn from_reader<R>(input: &mut R) -> Result<Self, CapsuleError>
where
R: Read,
{
ciborium::from_reader(input)
.map_err(|e| CapsuleError::Generic(format!("deserializing capsule header: {}", e)))
}
}
#[doc(hidden)]
#[derive(Serialize_tuple, Deserialize_tuple, Clone, PartialEq)]
pub struct HookInfo {
pub name: String,
pub version: String,
}
#[derive(Eq, Hash, Clone, Serialize_repr, Deserialize_repr, Debug, PartialEq, PartialOrd)]
#[repr(u8)]
pub enum TagType {
Unary,
Str,
Number,
Boolean,
Date,
}
impl From<TagTypeField> for TagType {
fn from(tag_type: TagTypeField) -> Self {
match tag_type {
TagTypeField::String => TagType::Str,
TagTypeField::Number => TagType::Number,
TagTypeField::Boolean => TagType::Boolean,
TagTypeField::Date => TagType::Date,
TagTypeField::Unary => TagType::Unary,
}
}
}
impl From<TagType> for TagTypeField {
fn from(tag_type: TagType) -> Self {
match tag_type {
TagType::Str => TagTypeField::String,
TagType::Number => TagTypeField::Number,
TagType::Boolean => TagTypeField::Boolean,
TagType::Date => TagTypeField::Date,
TagType::Unary => TagTypeField::Unary,
}
}
}
#[derive(Clone, Serialize_tuple, Deserialize_tuple, Debug, Eq, Hash)]
pub struct CapsuleTag {
pub name: String,
pub tag_type: TagType,
pub value: String,
pub source: String,
pub hook_version: (i32, i32, i32),
}
impl CapsuleTag {
pub fn from_tag(tag: &Tag) -> Result<CapsuleTag, CapsuleError> {
let tuple = convert_to_tuple(&tag.hook_version.clone().unwrap())?;
Ok(CapsuleTag {
name: tag.name.clone(),
tag_type: TagType::from(tag.r#type),
value: tag.value.clone(),
source: tag.source.clone(),
hook_version: tuple,
})
}
}
impl PartialEq for CapsuleTag {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.tag_type == other.tag_type && self.value == other.value
}
}
impl From<CapsuleTag> for Tag {
fn from(capsule_tag: CapsuleTag) -> Self {
Self {
name: capsule_tag.name.clone(),
r#type: match capsule_tag.tag_type {
TagType::Str => TagTypeField::String,
TagType::Number => TagTypeField::Number,
TagType::Boolean => TagTypeField::Boolean,
TagType::Date => TagTypeField::Date,
TagType::Unary => TagTypeField::Unary,
},
value: capsule_tag.value.clone(),
source: capsule_tag.source.clone(),
hook_version: Some(format!(
"{}.{}.{}",
capsule_tag.hook_version.0, capsule_tag.hook_version.1, capsule_tag.hook_version.2
)),
}
}
}
#[derive(Clone, Serialize_tuple, Deserialize_tuple, Debug, PartialEq, Eq)]
pub struct SpanTag {
pub tag: CapsuleTag,
pub start: usize,
pub end: usize,
}
impl SpanTag {
pub fn from_api_span_inner(inner: &TagSetSpanTagsInner) -> Result<Vec<SpanTag>, CapsuleError> {
let mut output: Vec<SpanTag> = Vec::new();
for tag in &inner.tags {
output.push(SpanTag {
tag: CapsuleTag::from_tag(tag)?,
start: inner.start as usize,
end: inner.end as usize,
});
}
Ok(output)
}
}
impl From<SpanTag> for TagSetSpanTagsInner {
fn from(span_tag: SpanTag) -> Self {
Self {
start: span_tag.start as i64,
end: span_tag.end as i64,
tags: vec![span_tag.tag.into()],
}
}
}
#[doc(hidden)]
#[derive(PartialEq, Debug, Copy, Clone)]
pub enum PolicyDecision {
Allow,
Redact,
Tokenize,
DenyRecord,
DenyCapsule,
NoMatch,
}
fn convert_to_tuple(input: &str) -> Result<(i32, i32, i32), CapsuleError> {
let parts: Vec<&str> = input.split('.').collect();
if parts.len() != 3 {
return Err(CapsuleError::Generic(
"Input string does not contain exactly three parts".to_string(),
));
}
let part1 = parts[0].parse::<i32>();
let part2 = parts[1].parse::<i32>();
let part3 = parts[2].parse::<i32>();
match (part1, part2, part3) {
(Ok(p1), Ok(p2), Ok(p3)) => Ok((p1, p2, p3)),
_ => Err(CapsuleError::Generic(
"Failed to parse one or more parts into an integer".to_string(),
)),
}
}
fn base58_to_packed_bytes(input: &str) -> Result<Vec<u8>, Box<dyn Error>> {
let bits: Vec<u8> = input
.chars()
.map(|c| {
BASE58_CHARSET
.find(c)
.map(|idx| idx as u8)
.ok_or_else(|| "Invalid base58 character".into())
})
.collect::<Result<Vec<u8>, Box<dyn Error>>>()?;
let mut bytes = Vec::new();
let mut accumulator = 0u16; let mut bits_in_accumulator = 0;
for bit_value in bits {
accumulator <<= 6;
accumulator |= bit_value as u16;
bits_in_accumulator += 6;
if bits_in_accumulator >= 8 {
bits_in_accumulator -= 8;
bytes.push((accumulator >> bits_in_accumulator) as u8);
}
}
if bits_in_accumulator > 0 {
bytes.push((accumulator << (8 - bits_in_accumulator)) as u8);
}
Ok(bytes)
}
fn serialize_base58<S>(prefix: &str, input: &str, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let stripped = input.strip_prefix(prefix).ok_or_else(|| {
S::Error::custom(format!("invalid ID format (must begin with {})", prefix))
})?;
serializer.serialize_bytes(
&base58_to_packed_bytes(stripped)
.map_err(S::Error::custom)?
.to_vec(),
)
}
#[doc(hidden)]
pub fn serialize_domain_id<S>(domain_id: &str, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_base58("dm-", domain_id, serializer)
}
#[doc(hidden)]
pub fn serialize_capsule_id<S>(capsule_id: &str, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_base58("ca-", capsule_id, serializer)
}
fn unpack_base58_bytes(input: &[u8]) -> Result<String, Box<dyn Error>> {
let mut bits = Vec::new();
let mut accumulator = 0u16; let mut bits_in_accumulator = 0;
for &byte in input {
accumulator = (accumulator << 8) | (byte as u16);
bits_in_accumulator += 8;
while bits_in_accumulator >= 6 {
bits_in_accumulator -= 6;
let index = ((accumulator >> bits_in_accumulator) & 0x3F) as usize; bits.push(index);
}
}
if bits_in_accumulator > 0 {
let index = ((accumulator << (6 - bits_in_accumulator)) & 0x3F) as usize;
bits.push(index);
}
let result: String = bits
.iter()
.map(|&idx| BASE58_CHARSET.chars().nth(idx).ok_or("Invalid 6-bit value"))
.collect::<Result<String, &str>>()?;
Ok(result)
}
fn deserialize_base58<'de, D>(len: usize, prefix: &str, deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let packed: Vec<u8> = Deserialize::deserialize(deserializer)?;
let suffix: String = unpack_base58_bytes(packed.as_slice())
.map_err(serde::de::Error::custom)?
.chars()
.take(len)
.collect();
Ok(format!("{}{}", prefix, suffix))
}
#[doc(hidden)]
pub fn deserialize_domain_id<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
deserialize_base58(11, "dm-", deserializer)
}
#[doc(hidden)]
pub fn deserialize_capsule_id<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
deserialize_base58(22, "ca-", deserializer)
}