snops_common/
key_source.rs

1use core::fmt;
2use std::str::FromStr;
3
4use http::StatusCode;
5use lazy_static::lazy_static;
6use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
7use strum_macros::AsRefStr;
8use thiserror::Error;
9
10use crate::{format::*, impl_into_status_code, state::InternedId};
11
12#[derive(Debug, Error, AsRefStr)]
13pub enum KeySourceError {
14    #[error("invalid key source string")]
15    InvalidKeySource,
16    #[error("invalid committee index: {0}")]
17    InvalidCommitteeIndex(#[source] std::num::ParseIntError),
18}
19
20impl_into_status_code!(KeySourceError, |value| match value {
21    InvalidKeySource => StatusCode::BAD_REQUEST,
22    InvalidCommitteeIndex(_) => StatusCode::BAD_REQUEST,
23});
24
25#[derive(Debug, Clone, Eq, PartialEq)]
26pub enum KeySource {
27    /// Private key owned by the agent
28    Local,
29    /// APrivateKey1zkp...
30    PrivateKeyLiteral(String),
31    /// aleo1...
32    PublicKeyLiteral(String),
33    /// program_name1.aleo
34    ProgramLiteral(String),
35    /// committee.0 or committee.$ (for replicas)
36    Committee(Option<usize>),
37    /// accounts.0 or accounts.$ (for replicas)
38    Named(InternedId, Option<usize>),
39}
40
41impl<'de> Deserialize<'de> for KeySource {
42    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
43    where
44        D: Deserializer<'de>,
45    {
46        struct KeySourceVisitor;
47
48        impl<'de> Visitor<'de> for KeySourceVisitor {
49            type Value = KeySource;
50
51            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
52                formatter.write_str(
53                    "a string that represents an aleo private/public key, or a file from storage",
54                )
55            }
56
57            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
58            where
59                E: serde::de::Error,
60            {
61                KeySource::from_str(v).map_err(E::custom)
62            }
63        }
64
65        deserializer.deserialize_str(KeySourceVisitor)
66    }
67}
68
69impl Serialize for KeySource {
70    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
71        serializer.serialize_str(&self.to_string())
72    }
73}
74
75lazy_static! {
76    pub static ref ACCOUNTS_KEY_ID: InternedId = InternedId::from_str("accounts").unwrap();
77}
78
79impl FromStr for KeySource {
80    type Err = KeySourceError;
81
82    fn from_str(s: &str) -> Result<Self, Self::Err> {
83        // use KeySource::Literal(String) when the string is 59 characters long and
84        // starts with "APrivateKey1zkp" use KeySource::Commitee(Option<usize>)
85        // when the string is "committee.0" or "committee.$"
86        // use KeySource::Named(String, Option<usize>) when the string is "\w+.0" or
87        // "\w+.$"
88
89        if s == "local" {
90            return Ok(KeySource::Local);
91        // aleo private key
92        } else if s.len() == 59 && s.starts_with("APrivateKey1") {
93            return Ok(KeySource::PrivateKeyLiteral(s.to_string()));
94        // aleo public key
95        } else if s.len() == 63 && s.starts_with("aleo1") {
96            return Ok(KeySource::PublicKeyLiteral(s.to_string()));
97
98        // committee key
99        } else if let Some(index) = s.strip_prefix("committee.") {
100            if index == "$" {
101                return Ok(KeySource::Committee(None));
102            }
103            let replica = index
104                .parse()
105                .map_err(KeySourceError::InvalidCommitteeIndex)?;
106            return Ok(KeySource::Committee(Some(replica)));
107        }
108
109        // named key (using regex with capture groups)
110        lazy_static! {
111            static ref NAMED_KEYSOURCE_REGEX: regex::Regex =
112                regex::Regex::new(r"^(?P<name>[A-Za-z0-9][A-Za-z0-9\-_.]{0,63})\.(?P<idx>\d+|\$)$")
113                    .unwrap();
114            static ref NAMED_PROGRAM_REGEX: regex::Regex =
115                regex::Regex::new(r"^[A-Za-z0-9_]{1,256}\.aleo$").unwrap();
116        }
117
118        if NAMED_PROGRAM_REGEX.is_match(s) {
119            return Ok(KeySource::ProgramLiteral(s.to_string()));
120        }
121
122        let groups = NAMED_KEYSOURCE_REGEX
123            .captures(s)
124            .ok_or(KeySourceError::InvalidKeySource)?;
125        let name = InternedId::from_str(groups.name("name").unwrap().as_str())
126            .map_err(|_| KeySourceError::InvalidKeySource)?;
127        let idx = match groups.name("idx").unwrap().as_str() {
128            "$" => None,
129            idx => Some(idx.parse().map_err(KeySourceError::InvalidCommitteeIndex)?),
130        };
131        Ok(KeySource::Named(name, idx))
132    }
133}
134
135impl fmt::Display for KeySource {
136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137        write!(
138            f,
139            "{}",
140            match self {
141                KeySource::Local => "local".to_owned(),
142                KeySource::PrivateKeyLiteral(key) => key.to_owned(),
143                KeySource::ProgramLiteral(key) => key.to_owned(),
144                KeySource::PublicKeyLiteral(key) => key.to_owned(),
145                KeySource::Committee(None) => "committee.$".to_owned(),
146                KeySource::Committee(Some(idx)) => {
147                    format!("committee.{}", idx)
148                }
149                KeySource::Named(name, None) => format!("{}.{}", name, "$"),
150                KeySource::Named(name, Some(idx)) => {
151                    format!("{}.{}", name, idx)
152                }
153            }
154        )
155    }
156}
157
158impl DataFormat for KeySource {
159    type Header = u8;
160    const LATEST_HEADER: Self::Header = 1u8;
161
162    fn write_data<W: std::io::prelude::Write>(
163        &self,
164        writer: &mut W,
165    ) -> Result<usize, DataWriteError> {
166        Ok(match self {
167            KeySource::Local => writer.write_data(&0u8)?,
168            KeySource::PrivateKeyLiteral(key) => {
169                writer.write_data(&1u8)? + writer.write_data(key)?
170            }
171            KeySource::Committee(None) => writer.write_data(&2u8)?,
172            KeySource::Committee(Some(idx)) => {
173                // save a byte by making this a separate case
174                writer.write_data(&3u8)? + writer.write_data(idx)?
175            }
176            KeySource::Named(name, None) => writer.write_data(&4u8)? + writer.write_data(name)?,
177            KeySource::Named(name, Some(idx)) => {
178                // save a byte by making this a separate case
179                writer.write_data(&5u8)? + writer.write_data(name)? + writer.write_data(idx)?
180            }
181            KeySource::PublicKeyLiteral(key) => {
182                writer.write_data(&6u8)? + writer.write_data(key)?
183            }
184            KeySource::ProgramLiteral(key) => writer.write_data(&7u8)? + writer.write_data(key)?,
185        })
186    }
187
188    fn read_data<R: std::io::prelude::Read>(
189        reader: &mut R,
190        header: &Self::Header,
191    ) -> Result<Self, DataReadError> {
192        if *header != Self::LATEST_HEADER {
193            return Err(DataReadError::unsupported(
194                "KeySource",
195                Self::LATEST_HEADER,
196                *header,
197            ));
198        }
199
200        match reader.read_data(&())? {
201            0u8 => Ok(KeySource::Local),
202            1u8 => Ok(KeySource::PrivateKeyLiteral(reader.read_data(&())?)),
203            2u8 => Ok(KeySource::Committee(None)),
204            3u8 => Ok(KeySource::Committee(Some(reader.read_data(&())?))),
205            4u8 => Ok(KeySource::Named(reader.read_data(&())?, None)),
206            5u8 => Ok(KeySource::Named(
207                reader.read_data(&())?,
208                Some(reader.read_data(&())?),
209            )),
210            6u8 => Ok(KeySource::PublicKeyLiteral(reader.read_data(&())?)),
211            7u8 => Ok(KeySource::ProgramLiteral(reader.read_data(&())?)),
212            n => Err(DataReadError::Custom(format!("invalid KeySource tag {n}"))),
213        }
214    }
215}
216
217impl KeySource {
218    /// Add an index to a key source only if it did not have an index before
219    pub fn with_index(&self, idx: usize) -> Self {
220        match self {
221            KeySource::Committee(None) => KeySource::Committee(Some(idx)),
222            KeySource::Named(name, None) => KeySource::Named(*name, Some(idx)),
223            _ => self.clone(),
224        }
225    }
226}