use sha2::{Digest, Sha224};
use std::convert::TryFrom;
use thiserror::Error;
#[derive(Error, Clone, Debug, Eq, PartialEq)]
pub enum PrincipalError {
#[error("Buffer is too long.")]
BufferTooLong(),
#[error(r#"Invalid textual format: expected "{0}""#)]
AbnormalTextualFormat(String),
#[error("Text must be a base 32 string.")]
InvalidTextualFormatNotBase32(),
#[error("Text cannot be converted to a Principal; too small.")]
TextTooSmall(),
#[error("A custom tool returned an error instead of a Principal: {0}")]
ExternalError(String),
}
const ID_ANONYMOUS_BYTES: &[u8] = &[PrincipalClass::Anonymous as u8];
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub(crate) enum PrincipalClass {
Unassigned = 0,
OpaqueId = 1,
SelfAuthenticating = 2,
DerivedId = 3,
Anonymous = 4,
}
impl TryFrom<u8> for PrincipalClass {
type Error = PrincipalError;
fn try_from(byte: u8) -> Result<Self, Self::Error> {
match byte {
1 => Ok(PrincipalClass::OpaqueId),
2 => Ok(PrincipalClass::SelfAuthenticating),
3 => Ok(PrincipalClass::DerivedId),
4 => Ok(PrincipalClass::Anonymous),
_ => Ok(PrincipalClass::Unassigned),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Principal(PrincipalInner);
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum PrincipalInner {
ManagementCanister,
OpaqueId(Vec<u8>),
SelfAuthenticating(Vec<u8>),
DerivedId(Vec<u8>),
Anonymous,
Unassigned(Vec<u8>),
}
impl Principal {
pub fn management_canister() -> Self {
Self(PrincipalInner::ManagementCanister)
}
pub fn self_authenticating<P: AsRef<[u8]>>(public_key: P) -> Self {
let mut bytes: Vec<u8> = Vec::with_capacity(Sha224::output_size() + 1);
let hash = Sha224::digest(public_key.as_ref());
bytes.extend(&hash);
bytes.push(PrincipalClass::SelfAuthenticating as u8);
Self(PrincipalInner::SelfAuthenticating(bytes))
}
pub fn anonymous() -> Self {
Self(PrincipalInner::Anonymous)
}
pub fn from_text<S: AsRef<str>>(text: S) -> Result<Self, PrincipalError> {
let mut s = text.as_ref().to_string();
s.make_ascii_lowercase();
s.retain(|c| c != '-');
match base32::decode(base32::Alphabet::RFC4648 { padding: false }, &s) {
Some(mut bytes) => {
if bytes.len() < 4 {
return Err(PrincipalError::TextTooSmall());
}
let result = Self::try_from(bytes.split_off(4))?;
let expected = format!("{}", result);
if text.as_ref() != expected {
return Err(PrincipalError::AbnormalTextualFormat(expected));
}
Ok(result)
}
None => Err(PrincipalError::InvalidTextualFormatNotBase32()),
}
}
pub fn to_text(&self) -> String {
format!("{}", self)
}
pub fn as_slice(&self) -> &[u8] {
self.as_ref()
}
}
impl std::fmt::Display for Principal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let blob: &[u8] = self.0.as_ref();
let mut hasher = crc32fast::Hasher::new();
hasher.update(blob);
let checksum = hasher.finalize();
let mut bytes = vec![];
bytes.extend_from_slice(&checksum.to_be_bytes());
bytes.extend_from_slice(blob);
let mut s = base32::encode(base32::Alphabet::RFC4648 { padding: false }, &bytes);
s.make_ascii_lowercase();
while s.len() > 5 {
let rest = s.split_off(5);
f.write_fmt(format_args!("{}-", s))?;
s = rest;
}
f.write_str(&s)
}
}
impl std::str::FromStr for Principal {
type Err = PrincipalError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Principal::from_text(s)
}
}
impl TryFrom<&str> for Principal {
type Error = PrincipalError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
Principal::from_text(s)
}
}
impl TryFrom<Vec<u8>> for Principal {
type Error = PrincipalError;
fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
if let Some(last_byte) = bytes.last() {
match PrincipalClass::try_from(*last_byte)? {
PrincipalClass::OpaqueId => Ok(Principal(PrincipalInner::OpaqueId(bytes))),
PrincipalClass::SelfAuthenticating => {
Ok(Principal(PrincipalInner::SelfAuthenticating(bytes)))
}
PrincipalClass::DerivedId => Ok(Principal(PrincipalInner::DerivedId(bytes))),
PrincipalClass::Anonymous => {
if bytes.len() == 1 {
Ok(Principal(PrincipalInner::Anonymous))
} else {
Err(PrincipalError::BufferTooLong())
}
}
PrincipalClass::Unassigned => Ok(Principal(PrincipalInner::Unassigned(bytes))),
}
} else {
Ok(Principal(PrincipalInner::ManagementCanister))
}
}
}
impl TryFrom<&Vec<u8>> for Principal {
type Error = PrincipalError;
fn try_from(bytes: &Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(bytes.as_slice())
}
}
impl TryFrom<&[u8]> for Principal {
type Error = PrincipalError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
Self::try_from(bytes.to_vec())
}
}
impl AsRef<[u8]> for Principal {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl AsRef<[u8]> for PrincipalInner {
fn as_ref(&self) -> &[u8] {
match self {
PrincipalInner::Unassigned(v) => v,
PrincipalInner::ManagementCanister => &[],
PrincipalInner::OpaqueId(v) => v,
PrincipalInner::SelfAuthenticating(v) => v,
PrincipalInner::DerivedId(v) => v,
PrincipalInner::Anonymous => ID_ANONYMOUS_BYTES,
}
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for Principal {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
if serializer.is_human_readable() {
self.to_text().serialize(serializer)
} else {
serializer.serialize_bytes(self.0.as_ref())
}
}
}
#[cfg(feature = "serde")]
mod deserialize {
use super::Principal;
use std::convert::TryFrom;
pub(super) struct PrincipalVisitor;
impl<'de> serde::de::Visitor<'de> for PrincipalVisitor {
type Value = super::Principal;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("bytes or string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Principal::from_text(v).map_err(E::custom)
}
fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Principal::try_from(value).map_err(E::custom)
}
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v.is_empty() || v[0] != 2u8 {
Err(E::custom("Not called by Candid"))
} else {
Principal::try_from(&v[1..]).map_err(E::custom)
}
}
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for Principal {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Principal, D::Error> {
use serde::de::Error;
deserializer
.deserialize_bytes(deserialize::PrincipalVisitor)
.map_err(D::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[cfg(feature = "serde")]
#[test]
fn serializes() {
let seed = [
0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22,
0x11, 0x00, 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, 0x77, 0x66, 0x55, 0x44,
0x33, 0x22, 0x11, 0x00,
];
let principal: Principal = Principal::self_authenticating(&seed);
assert_eq!(
serde_cbor::from_slice::<Principal>(
serde_cbor::to_vec(&principal)
.expect("Failed to serialize")
.as_slice()
)
.unwrap(),
principal
);
}
#[test]
fn parse_management_canister_ok() {
assert_eq!(
Principal::from_str("aaaaa-aa").unwrap(),
Principal(PrincipalInner::ManagementCanister)
);
}
#[test]
fn parse_management_canister_to_text_ok() {
assert_eq!(Principal::from_str("aaaaa-aa").unwrap().as_slice(), &[]);
}
#[test]
fn create_managment_cid_from_empty_blob_ok() {
assert_eq!(Principal::management_canister().to_text(), "aaaaa-aa");
}
#[test]
fn create_managment_cid_from_text_ok() {
assert_eq!(
Principal::from_str("aaaaa-aa").unwrap().to_text(),
"aaaaa-aa",
);
}
#[test]
fn display_canister_id() {
assert_eq!(
Principal::try_from(vec![0xef, 0xcd, 0xab, 0, 0, 0, 0, 0, 1])
.unwrap()
.to_text(),
"2chl6-4hpzw-vqaaa-aaaaa-c",
);
}
#[test]
fn display_canister_id_from_bytes_as_bytes() {
assert_eq!(
Principal::try_from(vec![0xef, 0xcd, 0xab, 0, 0, 0, 0, 0, 1])
.unwrap()
.as_slice(),
&[0xef, 0xcd, 0xab, 0, 0, 0, 0, 0, 1],
);
}
#[test]
fn display_canister_id_from_blob_as_bytes() {
assert_eq!(
Principal::try_from(vec![0xef, 0xcd, 0xab, 0, 0, 0, 0, 0, 1])
.unwrap()
.as_slice(),
&[0xef, 0xcd, 0xab, 0, 0, 0, 0, 0, 1],
);
}
#[test]
fn display_canister_id_from_text_as_bytes() {
assert_eq!(
Principal::from_str("2chl6-4hpzw-vqaaa-aaaaa-c")
.unwrap()
.as_slice(),
&[0xef, 0xcd, 0xab, 0, 0, 0, 0, 0, 1],
);
}
#[cfg(feature = "serde")]
#[test]
fn check_serialize_deserialize() {
let id = Principal::from_str("2chl6-4hpzw-vqaaa-aaaaa-c").unwrap();
let vec = serde_cbor::to_vec(&id).unwrap();
let value = serde_cbor::from_slice(vec.as_slice()).unwrap();
assert_eq!(id, value);
}
#[test]
fn text_form() {
let cid = Principal::try_from(vec![1, 8, 64, 255]).unwrap();
let text = cid.to_text();
let cid2 = Principal::from_str(&text).unwrap();
assert_eq!(cid, cid2);
assert_eq!(text, "jkies-sibbb-ap6");
}
}