houseflow_types/
common.rs1use bytes::{Buf, BufMut};
2use std::{convert::TryInto, str::FromStr};
3use thiserror::Error;
4
5#[derive(Hash, Eq, PartialEq, PartialOrd, Ord, Clone)]
6pub struct Credential<const N: usize> {
7 inner: [u8; N],
8}
9
10impl<const N: usize> Credential<N> {
11 pub const SIZE: usize = N;
12
13 pub fn into_bytes(self) -> [u8; N] {
14 self.inner
15 }
16
17 pub fn from_bytes(bytes: [u8; N]) -> Self {
18 Self::from(bytes)
19 }
20
21 pub fn encode(&self, buf: &mut impl BufMut) {
22 buf.put_slice(&self.inner);
23 }
24
25 pub fn decode(buf: &mut impl Buf) -> Result<Self, CredentialError> {
26 if buf.remaining() < N {
27 return Err(CredentialError::InvalidSize {
28 expected: N,
29 received: buf.remaining(),
30 });
31 }
32
33 let mut inner = [0; N];
34 buf.copy_to_slice(&mut inner);
35 Ok(Self { inner })
36 }
37}
38
39impl<const N: usize> AsRef<[u8]> for Credential<N> {
40 fn as_ref(&self) -> &[u8] {
41 &self.inner
42 }
43}
44
45#[derive(Debug, Clone, Error, PartialEq, Eq)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47pub enum CredentialError {
48 #[error("Invalid size, expected: {expected}, received: {received}")]
49 InvalidSize { expected: usize, received: usize },
50
51 #[error("Invalid encoding: {0}")]
52 InvalidEncoding(String),
53}
54
55impl<const N: usize> From<[u8; N]> for Credential<N> {
56 fn from(v: [u8; N]) -> Self {
57 Self { inner: v }
58 }
59}
60
61impl<const N: usize> From<Credential<N>> for [u8; N] {
62 fn from(val: Credential<N>) -> Self {
63 val.inner
64 }
65}
66
67impl<const N: usize> Default for Credential<N> {
68 fn default() -> Self {
69 Self { inner: [0; N] }
70 }
71}
72
73impl<const N: usize> From<Credential<N>> for String {
74 fn from(val: Credential<N>) -> Self {
75 hex::encode(val.inner)
76 }
77}
78
79impl<const N: usize> FromStr for Credential<N> {
80 type Err = CredentialError;
81
82 fn from_str(v: &str) -> Result<Self, Self::Err> {
83 if v.len() != N * 2 {
86 Err(CredentialError::InvalidSize {
87 expected: N * 2,
88 received: v.len(),
89 })
90 } else {
91 Ok(Self {
92 inner: hex::decode(v)
93 .map_err(|err| CredentialError::InvalidEncoding(err.to_string()))?
94 .try_into()
95 .unwrap(),
96 })
97 }
98 }
99}
100
101impl<const N: usize> std::fmt::Display for Credential<N> {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
103 write!(f, "{}", hex::encode(self.inner))
104 }
105}
106
107impl<const N: usize> std::fmt::Debug for Credential<N> {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
109 write!(f, "Inner: `{}`", hex::encode(self.inner))
110 }
111}
112
113impl<const N: usize> rand::distributions::Distribution<Credential<N>>
114 for rand::distributions::Standard
115{
116 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Credential<N> {
117 Credential {
118 inner: (0..N)
119 .map(|_| rng.gen())
120 .collect::<Vec<u8>>()
121 .try_into()
122 .unwrap(),
123 }
124 }
125}
126
127#[cfg(feature = "postgres-types")]
128impl<const N: usize> postgres_types::ToSql for Credential<N> {
129 fn accepts(ty: &postgres_types::Type) -> bool {
130 *ty == postgres_types::Type::BPCHAR
131 }
132
133 fn to_sql(
134 &self,
135 _ty: &postgres_types::Type,
136 out: &mut bytes::BytesMut,
137 ) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>>
138 where
139 Self: Sized,
140 {
141 let string = self.to_string();
142 out.put_slice(string.as_bytes());
143 Ok(postgres_types::IsNull::No)
144 }
145
146 fn to_sql_checked(
147 &self,
148 _ty: &postgres_types::Type,
149 out: &mut bytes::BytesMut,
150 ) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
151 let string = self.to_string();
152 out.put_slice(string.as_bytes());
153 Ok(postgres_types::IsNull::No)
154 }
155}
156
157#[cfg(feature = "postgres-types")]
158impl<'a, const N: usize> postgres_types::FromSql<'a> for Credential<N> {
159 fn from_sql(
160 _ty: &postgres_types::Type,
161 raw: &'a [u8],
162 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
163 let str = std::str::from_utf8(raw)?;
164 let credential = Self::from_str(str)?;
165 Ok(credential)
166 }
167
168 fn accepts(ty: &postgres_types::Type) -> bool {
169 *ty == postgres_types::Type::BPCHAR
170 }
171}
172
173#[cfg(feature = "serde")]
174use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
175
176#[cfg(feature = "serde")]
177struct CredentialVisitor<const N: usize>;
178
179#[cfg(feature = "serde")]
180impl<'de, const N: usize> Visitor<'de> for CredentialVisitor<N> {
181 type Value = Credential<N>;
182
183 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
184 formatter.write_str(&format!("an array of length {}", N))
185 }
186
187 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
188 where
189 E: serde::de::Error,
190 {
191 Credential::from_str(v).map_err(|err| {
192 E::invalid_value(
193 serde::de::Unexpected::Other(err.to_string().as_str()),
194 &"hex encoded credential",
195 )
196 })
197 }
198}
199
200#[cfg(feature = "serde")]
201impl<'de, const N: usize> Deserialize<'de> for Credential<N> {
202 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203 where
204 D: Deserializer<'de>,
205 {
206 deserializer.deserialize_str(CredentialVisitor::<N>)
207 }
208}
209
210#[cfg(feature = "serde")]
211impl<const N: usize> Serialize for Credential<N> {
212 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
213 where
214 S: Serializer,
215 {
216 serializer.serialize_str(&self.to_string())
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use bytes::BytesMut;
224 const SIZE: usize = 32;
225
226 #[test]
227 fn test_buffer_parse() {
228 let mut buf = BytesMut::with_capacity(SIZE);
229 let credential: Credential<SIZE> = rand::random();
230 credential.encode(&mut buf);
231 let parsed_credential = Credential::<SIZE>::decode(&mut buf)
232 .expect("reading Credential from buffer returned Error");
233 assert_eq!(credential, parsed_credential);
234 }
235
236 #[test]
237 fn test_buffer_parse_underflow() {
238 let mut buf = BytesMut::with_capacity(SIZE);
239 let credential: Credential<SIZE> = rand::random();
240 credential.encode(&mut buf);
241 buf = buf[0..SIZE - 1].into(); Credential::<SIZE>::decode(&mut buf)
243 .expect_err("reading malformed Credential from buffer did not return Error");
244 }
245}