fire_crypto/token/
mod.rs

1#[cfg(feature = "b64")]
2use crate::error::DecodeError;
3use crate::error::TryFromError;
4
5use std::convert::{TryFrom, TryInto};
6use std::fmt;
7
8use rand::rngs::OsRng;
9use rand::RngCore;
10
11#[cfg(feature = "b64")]
12use base64::engine::general_purpose::URL_SAFE_NO_PAD;
13#[cfg(feature = "b64")]
14use base64::Engine;
15
16/// A random Token
17#[derive(Clone, PartialEq, Eq, Hash)]
18pub struct Token<const S: usize> {
19	bytes: [u8; S],
20}
21
22impl<const S: usize> Token<S> {
23	pub const LEN: usize = S;
24
25	pub const STR_LEN: usize = crate::calculate_b64_len(S);
26
27	/// Creates a new random Token
28	pub fn new() -> Self {
29		let mut bytes = [0u8; S];
30
31		OsRng.fill_bytes(&mut bytes);
32
33		Self { bytes }
34	}
35
36	/// ## Panics
37	/// if the slice is not `S` bytes long.
38	pub fn from_slice(slice: &[u8]) -> Self {
39		slice.try_into().unwrap()
40	}
41
42	pub fn to_bytes(&self) -> [u8; S] {
43		self.bytes
44	}
45}
46
47#[cfg(not(feature = "b64"))]
48impl<const S: usize> fmt::Debug for Token<S> {
49	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50		f.debug_tuple("Token").field(&self.as_ref()).finish()
51	}
52}
53
54#[cfg(feature = "b64")]
55impl<const S: usize> fmt::Debug for Token<S> {
56	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57		f.debug_tuple("Token").field(&self.to_string()).finish()
58	}
59}
60
61#[cfg(feature = "b64")]
62impl<const S: usize> fmt::Display for Token<S> {
63	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64		base64::display::Base64Display::new(self.as_ref(), &URL_SAFE_NO_PAD)
65			.fmt(f)
66	}
67}
68
69impl<const S: usize> From<[u8; S]> for Token<S> {
70	fn from(bytes: [u8; S]) -> Self {
71		Self { bytes }
72	}
73}
74
75impl<const S: usize> TryFrom<&[u8]> for Token<S> {
76	type Error = TryFromError;
77
78	fn try_from(v: &[u8]) -> Result<Self, Self::Error> {
79		<[u8; S]>::try_from(v)
80			.map_err(TryFromError::from_any)
81			.map(Self::from)
82	}
83}
84
85#[cfg(feature = "b64")]
86impl<const S: usize> crate::FromStr for Token<S> {
87	type Err = DecodeError;
88
89	fn from_str(s: &str) -> Result<Self, Self::Err> {
90		if s.len() != crate::calculate_b64_len(S) {
91			return Err(DecodeError::InvalidLength);
92		}
93
94		let mut bytes = [0u8; S];
95		URL_SAFE_NO_PAD
96			.decode_slice_unchecked(s, &mut bytes)
97			.map_err(DecodeError::inv_bytes)
98			.map(|_| Self::from(bytes))
99	}
100}
101
102impl<const S: usize> AsRef<[u8]> for Token<S> {
103	fn as_ref(&self) -> &[u8] {
104		&self.bytes
105	}
106}
107
108#[cfg(all(feature = "b64", feature = "serde"))]
109mod impl_serde {
110	use super::*;
111
112	use std::borrow::Cow;
113	use std::str::FromStr;
114
115	use _serde::de::Error;
116	use _serde::{Deserialize, Deserializer, Serialize, Serializer};
117
118	impl<const SI: usize> Serialize for Token<SI> {
119		fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
120		where
121			S: Serializer,
122		{
123			serializer.collect_str(&self)
124		}
125	}
126
127	impl<'de, const S: usize> Deserialize<'de> for Token<S> {
128		fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
129		where
130			D: Deserializer<'de>,
131		{
132			let s: Cow<'_, str> = Deserialize::deserialize(deserializer)?;
133			Self::from_str(s.as_ref()).map_err(D::Error::custom)
134		}
135	}
136}
137
138#[cfg(feature = "protobuf")]
139mod protobuf {
140	use super::*;
141
142	use fire_protobuf::{
143		bytes::BytesWrite,
144		decode::{DecodeError, DecodeMessage, FieldKind},
145		encode::{
146			EncodeError, EncodeMessage, FieldOpt, MessageEncoder, SizeBuilder,
147		},
148		WireType,
149	};
150
151	impl<const SI: usize> EncodeMessage for Token<SI> {
152		const WIRE_TYPE: WireType = WireType::Len;
153
154		fn is_default(&self) -> bool {
155			false
156		}
157
158		fn encoded_size(
159			&mut self,
160			field: Option<FieldOpt>,
161			builder: &mut SizeBuilder,
162		) -> Result<(), EncodeError> {
163			self.bytes.encoded_size(field, builder)
164		}
165
166		fn encode<B>(
167			&mut self,
168			field: Option<FieldOpt>,
169			encoder: &mut MessageEncoder<B>,
170		) -> Result<(), EncodeError>
171		where
172			B: BytesWrite,
173		{
174			self.bytes.encode(field, encoder)
175		}
176	}
177
178	impl<'m, const SI: usize> DecodeMessage<'m> for Token<SI> {
179		const WIRE_TYPE: WireType = WireType::Len;
180
181		fn decode_default() -> Self {
182			[0; SI].into()
183		}
184
185		fn merge(
186			&mut self,
187			kind: FieldKind<'m>,
188			is_field: bool,
189		) -> Result<(), DecodeError> {
190			self.bytes.merge(kind, is_field)
191		}
192	}
193}
194
195#[cfg(all(feature = "b64", feature = "postgres"))]
196mod impl_postgres {
197	use super::*;
198
199	use bytes::BytesMut;
200	use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
201
202	impl<const SI: usize> ToSql for Token<SI> {
203		fn to_sql(
204			&self,
205			ty: &Type,
206			out: &mut BytesMut,
207		) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
208		where
209			Self: Sized,
210		{
211			self.to_string().to_sql(ty, out)
212		}
213
214		fn accepts(ty: &Type) -> bool
215		where
216			Self: Sized,
217		{
218			<&str as ToSql>::accepts(ty)
219		}
220
221		to_sql_checked!();
222	}
223
224	impl<'r, const SI: usize> FromSql<'r> for Token<SI> {
225		fn from_sql(
226			ty: &Type,
227			raw: &'r [u8],
228		) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
229			let s = <&str as FromSql>::from_sql(ty, raw)?;
230			s.parse().map_err(Into::into)
231		}
232
233		fn accepts(ty: &Type) -> bool {
234			<&str as FromSql>::accepts(ty)
235		}
236	}
237}
238
239#[cfg(all(test, feature = "b64"))]
240mod tests {
241
242	use super::*;
243
244	use std::str::FromStr;
245
246	pub fn b64<const S: usize>() {
247		let tok = Token::<S>::new();
248
249		let b64 = tok.to_string();
250		let tok_2 = Token::<S>::from_str(&b64).unwrap();
251
252		assert_eq!(b64, tok_2.to_string());
253	}
254
255	#[test]
256	pub fn test_b64() {
257		b64::<1>();
258		b64::<2>();
259		b64::<3>();
260		b64::<13>();
261		b64::<24>();
262		b64::<200>();
263		b64::<213>();
264	}
265}