use core::fmt;
use std::str::FromStr;
use http::StatusCode;
use lazy_static::lazy_static;
use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
use strum_macros::AsRefStr;
use thiserror::Error;
use crate::{format::*, impl_into_status_code, state::InternedId};
#[derive(Debug, Error, AsRefStr)]
pub enum KeySourceError {
#[error("invalid key source string")]
InvalidKeySource,
#[error("invalid committee index: {0}")]
InvalidCommitteeIndex(#[source] std::num::ParseIntError),
}
impl_into_status_code!(KeySourceError, |value| match value {
InvalidKeySource => StatusCode::BAD_REQUEST,
InvalidCommitteeIndex(_) => StatusCode::BAD_REQUEST,
});
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum KeySource {
Local,
PrivateKeyLiteral(String),
PublicKeyLiteral(String),
ProgramLiteral(String),
Committee(Option<usize>),
Named(InternedId, Option<usize>),
}
impl<'de> Deserialize<'de> for KeySource {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct KeySourceVisitor;
impl<'de> Visitor<'de> for KeySourceVisitor {
type Value = KeySource;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str(
"a string that represents an aleo private/public key, or a file from storage",
)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
KeySource::from_str(v).map_err(E::custom)
}
}
deserializer.deserialize_str(KeySourceVisitor)
}
}
impl Serialize for KeySource {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.to_string())
}
}
lazy_static! {
pub static ref ACCOUNTS_KEY_ID: InternedId = InternedId::from_str("accounts").unwrap();
}
impl FromStr for KeySource {
type Err = KeySourceError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s == "local" {
return Ok(KeySource::Local);
} else if s.len() == 59 && s.starts_with("APrivateKey1") {
return Ok(KeySource::PrivateKeyLiteral(s.to_string()));
} else if s.len() == 63 && s.starts_with("aleo1") {
return Ok(KeySource::PublicKeyLiteral(s.to_string()));
} else if let Some(index) = s.strip_prefix("committee.") {
if index == "$" {
return Ok(KeySource::Committee(None));
}
let replica = index
.parse()
.map_err(KeySourceError::InvalidCommitteeIndex)?;
return Ok(KeySource::Committee(Some(replica)));
}
lazy_static! {
static ref NAMED_KEYSOURCE_REGEX: regex::Regex =
regex::Regex::new(r"^(?P<name>[A-Za-z0-9][A-Za-z0-9\-_.]{0,63})\.(?P<idx>\d+|\$)$")
.unwrap();
static ref NAMED_PROGRAM_REGEX: regex::Regex =
regex::Regex::new(r"^[A-Za-z0-9_]{1,256}\.aleo$").unwrap();
}
if NAMED_PROGRAM_REGEX.is_match(s) {
return Ok(KeySource::ProgramLiteral(s.to_string()));
}
let groups = NAMED_KEYSOURCE_REGEX
.captures(s)
.ok_or(KeySourceError::InvalidKeySource)?;
let name = InternedId::from_str(groups.name("name").unwrap().as_str())
.map_err(|_| KeySourceError::InvalidKeySource)?;
let idx = match groups.name("idx").unwrap().as_str() {
"$" => None,
idx => Some(idx.parse().map_err(KeySourceError::InvalidCommitteeIndex)?),
};
Ok(KeySource::Named(name, idx))
}
}
impl fmt::Display for KeySource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
KeySource::Local => "local".to_owned(),
KeySource::PrivateKeyLiteral(key) => key.to_owned(),
KeySource::ProgramLiteral(key) => key.to_owned(),
KeySource::PublicKeyLiteral(key) => key.to_owned(),
KeySource::Committee(None) => "committee.$".to_owned(),
KeySource::Committee(Some(idx)) => {
format!("committee.{}", idx)
}
KeySource::Named(name, None) => format!("{}.{}", name, "$"),
KeySource::Named(name, Some(idx)) => {
format!("{}.{}", name, idx)
}
}
)
}
}
impl DataFormat for KeySource {
type Header = u8;
const LATEST_HEADER: Self::Header = 1u8;
fn write_data<W: std::io::prelude::Write>(
&self,
writer: &mut W,
) -> Result<usize, DataWriteError> {
Ok(match self {
KeySource::Local => writer.write_data(&0u8)?,
KeySource::PrivateKeyLiteral(key) => {
writer.write_data(&1u8)? + writer.write_data(key)?
}
KeySource::Committee(None) => writer.write_data(&2u8)?,
KeySource::Committee(Some(idx)) => {
writer.write_data(&3u8)? + writer.write_data(idx)?
}
KeySource::Named(name, None) => writer.write_data(&4u8)? + writer.write_data(name)?,
KeySource::Named(name, Some(idx)) => {
writer.write_data(&5u8)? + writer.write_data(name)? + writer.write_data(idx)?
}
KeySource::PublicKeyLiteral(key) => {
writer.write_data(&6u8)? + writer.write_data(key)?
}
KeySource::ProgramLiteral(key) => writer.write_data(&7u8)? + writer.write_data(key)?,
})
}
fn read_data<R: std::io::prelude::Read>(
reader: &mut R,
header: &Self::Header,
) -> Result<Self, DataReadError> {
if *header != Self::LATEST_HEADER {
return Err(DataReadError::unsupported(
"KeySource",
Self::LATEST_HEADER,
*header,
));
}
match reader.read_data(&())? {
0u8 => Ok(KeySource::Local),
1u8 => Ok(KeySource::PrivateKeyLiteral(reader.read_data(&())?)),
2u8 => Ok(KeySource::Committee(None)),
3u8 => Ok(KeySource::Committee(Some(reader.read_data(&())?))),
4u8 => Ok(KeySource::Named(reader.read_data(&())?, None)),
5u8 => Ok(KeySource::Named(
reader.read_data(&())?,
Some(reader.read_data(&())?),
)),
6u8 => Ok(KeySource::PublicKeyLiteral(reader.read_data(&())?)),
7u8 => Ok(KeySource::ProgramLiteral(reader.read_data(&())?)),
n => Err(DataReadError::Custom(format!("invalid KeySource tag {n}"))),
}
}
}
impl KeySource {
pub fn with_index(&self, idx: usize) -> Self {
match self {
KeySource::Committee(None) => KeySource::Committee(Some(idx)),
KeySource::Named(name, None) => KeySource::Named(*name, Some(idx)),
_ => self.clone(),
}
}
}