houseflow_types/
common.rs

1use bytes::{Buf, BufMut};
2use std::{convert::TryInto, str::FromStr};
3use thiserror::Error;
4
5#[derive(Hash, Eq, PartialEq, PartialOrd, Ord, Clone)]
6pub struct Credential<const N: usize> {
7    inner: [u8; N],
8}
9
10impl<const N: usize> Credential<N> {
11    pub const SIZE: usize = N;
12
13    pub fn into_bytes(self) -> [u8; N] {
14        self.inner
15    }
16
17    pub fn from_bytes(bytes: [u8; N]) -> Self {
18        Self::from(bytes)
19    }
20
21    pub fn encode(&self, buf: &mut impl BufMut) {
22        buf.put_slice(&self.inner);
23    }
24
25    pub fn decode(buf: &mut impl Buf) -> Result<Self, CredentialError> {
26        if buf.remaining() < N {
27            return Err(CredentialError::InvalidSize {
28                expected: N,
29                received: buf.remaining(),
30            });
31        }
32
33        let mut inner = [0; N];
34        buf.copy_to_slice(&mut inner);
35        Ok(Self { inner })
36    }
37}
38
39impl<const N: usize> AsRef<[u8]> for Credential<N> {
40    fn as_ref(&self) -> &[u8] {
41        &self.inner
42    }
43}
44
45#[derive(Debug, Clone, Error, PartialEq, Eq)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47pub enum CredentialError {
48    #[error("Invalid size, expected: {expected}, received: {received}")]
49    InvalidSize { expected: usize, received: usize },
50
51    #[error("Invalid encoding: {0}")]
52    InvalidEncoding(String),
53}
54
55impl<const N: usize> From<[u8; N]> for Credential<N> {
56    fn from(v: [u8; N]) -> Self {
57        Self { inner: v }
58    }
59}
60
61impl<const N: usize> From<Credential<N>> for [u8; N] {
62    fn from(val: Credential<N>) -> Self {
63        val.inner
64    }
65}
66
67impl<const N: usize> Default for Credential<N> {
68    fn default() -> Self {
69        Self { inner: [0; N] }
70    }
71}
72
73impl<const N: usize> From<Credential<N>> for String {
74    fn from(val: Credential<N>) -> Self {
75        hex::encode(val.inner)
76    }
77}
78
79impl<const N: usize> FromStr for Credential<N> {
80    type Err = CredentialError;
81
82    fn from_str(v: &str) -> Result<Self, Self::Err> {
83        // N * 2 because encoding with hex doubles the size
84
85        if v.len() != N * 2 {
86            Err(CredentialError::InvalidSize {
87                expected: N * 2,
88                received: v.len(),
89            })
90        } else {
91            Ok(Self {
92                inner: hex::decode(v)
93                    .map_err(|err| CredentialError::InvalidEncoding(err.to_string()))?
94                    .try_into()
95                    .unwrap(),
96            })
97        }
98    }
99}
100
101impl<const N: usize> std::fmt::Display for Credential<N> {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
103        write!(f, "{}", hex::encode(self.inner))
104    }
105}
106
107impl<const N: usize> std::fmt::Debug for Credential<N> {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
109        write!(f, "Inner: `{}`", hex::encode(self.inner))
110    }
111}
112
113impl<const N: usize> rand::distributions::Distribution<Credential<N>>
114    for rand::distributions::Standard
115{
116    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Credential<N> {
117        Credential {
118            inner: (0..N)
119                .map(|_| rng.gen())
120                .collect::<Vec<u8>>()
121                .try_into()
122                .unwrap(),
123        }
124    }
125}
126
127#[cfg(feature = "postgres-types")]
128impl<const N: usize> postgres_types::ToSql for Credential<N> {
129    fn accepts(ty: &postgres_types::Type) -> bool {
130        *ty == postgres_types::Type::BPCHAR
131    }
132
133    fn to_sql(
134        &self,
135        _ty: &postgres_types::Type,
136        out: &mut bytes::BytesMut,
137    ) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>>
138    where
139        Self: Sized,
140    {
141        let string = self.to_string();
142        out.put_slice(string.as_bytes());
143        Ok(postgres_types::IsNull::No)
144    }
145
146    fn to_sql_checked(
147        &self,
148        _ty: &postgres_types::Type,
149        out: &mut bytes::BytesMut,
150    ) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
151        let string = self.to_string();
152        out.put_slice(string.as_bytes());
153        Ok(postgres_types::IsNull::No)
154    }
155}
156
157#[cfg(feature = "postgres-types")]
158impl<'a, const N: usize> postgres_types::FromSql<'a> for Credential<N> {
159    fn from_sql(
160        _ty: &postgres_types::Type,
161        raw: &'a [u8],
162    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
163        let str = std::str::from_utf8(raw)?;
164        let credential = Self::from_str(str)?;
165        Ok(credential)
166    }
167
168    fn accepts(ty: &postgres_types::Type) -> bool {
169        *ty == postgres_types::Type::BPCHAR
170    }
171}
172
173#[cfg(feature = "serde")]
174use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
175
176#[cfg(feature = "serde")]
177struct CredentialVisitor<const N: usize>;
178
179#[cfg(feature = "serde")]
180impl<'de, const N: usize> Visitor<'de> for CredentialVisitor<N> {
181    type Value = Credential<N>;
182
183    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
184        formatter.write_str(&format!("an array of length {}", N))
185    }
186
187    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
188    where
189        E: serde::de::Error,
190    {
191        Credential::from_str(v).map_err(|err| {
192            E::invalid_value(
193                serde::de::Unexpected::Other(err.to_string().as_str()),
194                &"hex encoded credential",
195            )
196        })
197    }
198}
199
200#[cfg(feature = "serde")]
201impl<'de, const N: usize> Deserialize<'de> for Credential<N> {
202    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203    where
204        D: Deserializer<'de>,
205    {
206        deserializer.deserialize_str(CredentialVisitor::<N>)
207    }
208}
209
210#[cfg(feature = "serde")]
211impl<const N: usize> Serialize for Credential<N> {
212    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
213    where
214        S: Serializer,
215    {
216        serializer.serialize_str(&self.to_string())
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use bytes::BytesMut;
224    const SIZE: usize = 32;
225
226    #[test]
227    fn test_buffer_parse() {
228        let mut buf = BytesMut::with_capacity(SIZE);
229        let credential: Credential<SIZE> = rand::random();
230        credential.encode(&mut buf);
231        let parsed_credential = Credential::<SIZE>::decode(&mut buf)
232            .expect("reading Credential from buffer returned Error");
233        assert_eq!(credential, parsed_credential);
234    }
235
236    #[test]
237    fn test_buffer_parse_underflow() {
238        let mut buf = BytesMut::with_capacity(SIZE);
239        let credential: Credential<SIZE> = rand::random();
240        credential.encode(&mut buf);
241        buf = buf[0..SIZE - 1].into(); // Malform some last bytes of Buf
242        Credential::<SIZE>::decode(&mut buf)
243            .expect_err("reading malformed Credential from buffer did not return Error");
244    }
245}