use crate::{fmt, str::FromStr, String, ToString};
use borsh::{io, BorshDeserialize, BorshSerialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PublicKey {
Ed25519([u8; 32]),
Secp256k1([u8; 64]),
}
impl PublicKey {
#[must_use]
pub fn key_data(&self) -> &[u8] {
match self {
Self::Ed25519(data) => &data[..],
Self::Secp256k1(data) => &data[..],
}
}
}
impl BorshSerialize for PublicKey {
fn serialize<W: io::Write>(&self, writer: &mut W) -> Result<(), io::Error> {
match self {
Self::Ed25519(public_key) => {
BorshSerialize::serialize(&0u8, writer)?;
writer.write_all(public_key)?;
}
Self::Secp256k1(public_key) => {
BorshSerialize::serialize(&1u8, writer)?;
writer.write_all(public_key)?;
}
}
Ok(())
}
}
impl BorshDeserialize for PublicKey {
fn deserialize_reader<R: io::Read>(rd: &mut R) -> io::Result<Self> {
let key_type = u8::deserialize_reader(rd).and_then(KeyType::try_from)?;
match key_type {
KeyType::Ed25519 => Ok(Self::Ed25519(BorshDeserialize::deserialize_reader(rd)?)),
KeyType::Secp256k1 => Ok(Self::Secp256k1(BorshDeserialize::deserialize_reader(rd)?)),
}
}
}
impl serde::Serialize for PublicKey {
fn serialize<S>(
&self,
serializer: S,
) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error>
where
S: serde::Serializer,
{
serializer.collect_str(self)
}
}
impl<'de> serde::Deserialize<'de> for PublicKey {
fn deserialize<D>(deserializer: D) -> Result<Self, <D as serde::Deserializer<'de>>::Error>
where
D: serde::Deserializer<'de>,
{
let s = <String as serde::Deserialize>::deserialize(deserializer)?;
s.parse()
.map_err(|_| serde::de::Error::custom("PublicKey decode error"))
}
}
impl FromStr for PublicKey {
type Err = DecodeBs58Error;
fn from_str(value: &str) -> Result<Self, Self::Err> {
let (key_type, key_data) = split_key_type_data(value)?;
Ok(match key_type {
KeyType::Ed25519 => Self::Ed25519(decode_bs58(key_data)?),
KeyType::Secp256k1 => Self::Secp256k1(decode_bs58(key_data)?),
})
}
}
impl fmt::Display for PublicKey {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
let (key_type, key_data) = match self {
Self::Ed25519(public_key) => (KeyType::Ed25519, &public_key[..]),
Self::Secp256k1(public_key) => (KeyType::Secp256k1, &public_key[..]),
};
write!(fmt, "{}:{}", key_type, Bs58(key_data))
}
}
#[derive(Debug, Copy, Clone, serde::Deserialize, serde::Serialize)]
pub enum KeyType {
Ed25519 = 0,
Secp256k1 = 1,
}
impl TryFrom<u8> for KeyType {
type Error = io::Error;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Ed25519),
1 => Ok(Self::Secp256k1),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Wrong key prefix",
)),
}
}
}
impl fmt::Display for KeyType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.write_str(match self {
Self::Ed25519 => "ed25519",
Self::Secp256k1 => "secp256k1",
})
}
}
impl FromStr for KeyType {
type Err = DecodeBs58Error;
fn from_str(value: &str) -> Result<Self, Self::Err> {
let lowercase_key_type = value.to_ascii_lowercase();
match lowercase_key_type.as_str() {
"ed25519" => Ok(Self::Ed25519),
"secp256k1" => Ok(Self::Secp256k1),
_ => Err(Self::Err::BadData(value.to_string())),
}
}
}
impl From<KeyType> for u8 {
fn from(key_type: KeyType) -> Self {
match key_type {
KeyType::Ed25519 => 0,
KeyType::Secp256k1 => 1,
}
}
}
fn split_key_type_data(value: &str) -> Result<(KeyType, &str), DecodeBs58Error> {
if let Some(idx) = value.find(':') {
let (prefix, key_data) = value.split_at(idx);
Ok((KeyType::from_str(prefix)?, &key_data[1..]))
} else {
Ok((KeyType::Ed25519, value))
}
}
struct Bs58<'a>(&'a [u8]);
impl fmt::Display for Bs58<'_> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
debug_assert!(self.0.len() <= 65);
let mut buf = [0u8; 96];
let len = bs58::encode(self.0).onto(&mut buf[..]).unwrap();
let output = &buf[..len];
fmt.write_str(unsafe { crate::str::from_utf8_unchecked(output) })
}
}
fn decode_bs58<const N: usize>(encoded: &str) -> Result<[u8; N], DecodeBs58Error> {
let mut buffer = [0u8; N];
decode_bs58_impl(&mut buffer[..], encoded)?;
Ok(buffer)
}
fn decode_bs58_impl(dst: &mut [u8], encoded: &str) -> Result<(), DecodeBs58Error> {
let expected = dst.len();
match bs58::decode(encoded).onto(dst) {
Ok(received) if received == expected => Ok(()),
Ok(received) => Err(DecodeBs58Error::BadLength { expected, received }),
Err(bs58::decode::Error::BufferTooSmall) => Err(DecodeBs58Error::BadLength {
expected,
received: expected.saturating_add(1),
}),
Err(err) => Err(DecodeBs58Error::BadData(err.to_string())),
}
}
#[derive(Debug)]
pub enum DecodeBs58Error {
BadLength { expected: usize, received: usize },
BadData(String),
}
#[cfg(feature = "std")]
impl fmt::Display for DecodeBs58Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::BadLength { expected, received } => {
write!(
f,
"Bad length of date: expected: {expected}, received: {received}"
)
}
Self::BadData(data) => write!(f, "Bad data: {data}"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for DecodeBs58Error {}