1#[cfg(feature = "b64")]
2use crate::error::DecodeError;
3use crate::error::TryFromError;
4
5use std::convert::{TryFrom, TryInto};
6use std::fmt;
7
8use rand::rngs::OsRng;
9use rand::RngCore;
10
11#[cfg(feature = "b64")]
12use base64::engine::general_purpose::URL_SAFE_NO_PAD;
13#[cfg(feature = "b64")]
14use base64::Engine;
15
16#[derive(Clone, PartialEq, Eq, Hash)]
18pub struct Token<const S: usize> {
19 bytes: [u8; S],
20}
21
22impl<const S: usize> Token<S> {
23 pub const LEN: usize = S;
24
25 pub const STR_LEN: usize = crate::calculate_b64_len(S);
26
27 pub fn new() -> Self {
29 let mut bytes = [0u8; S];
30
31 OsRng.fill_bytes(&mut bytes);
32
33 Self { bytes }
34 }
35
36 pub fn from_slice(slice: &[u8]) -> Self {
39 slice.try_into().unwrap()
40 }
41
42 pub fn to_bytes(&self) -> [u8; S] {
43 self.bytes
44 }
45}
46
47#[cfg(not(feature = "b64"))]
48impl<const S: usize> fmt::Debug for Token<S> {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_tuple("Token").field(&self.as_ref()).finish()
51 }
52}
53
54#[cfg(feature = "b64")]
55impl<const S: usize> fmt::Debug for Token<S> {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.debug_tuple("Token").field(&self.to_string()).finish()
58 }
59}
60
61#[cfg(feature = "b64")]
62impl<const S: usize> fmt::Display for Token<S> {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 base64::display::Base64Display::new(self.as_ref(), &URL_SAFE_NO_PAD)
65 .fmt(f)
66 }
67}
68
69impl<const S: usize> From<[u8; S]> for Token<S> {
70 fn from(bytes: [u8; S]) -> Self {
71 Self { bytes }
72 }
73}
74
75impl<const S: usize> TryFrom<&[u8]> for Token<S> {
76 type Error = TryFromError;
77
78 fn try_from(v: &[u8]) -> Result<Self, Self::Error> {
79 <[u8; S]>::try_from(v)
80 .map_err(TryFromError::from_any)
81 .map(Self::from)
82 }
83}
84
85#[cfg(feature = "b64")]
86impl<const S: usize> crate::FromStr for Token<S> {
87 type Err = DecodeError;
88
89 fn from_str(s: &str) -> Result<Self, Self::Err> {
90 if s.len() != crate::calculate_b64_len(S) {
91 return Err(DecodeError::InvalidLength);
92 }
93
94 let mut bytes = [0u8; S];
95 URL_SAFE_NO_PAD
96 .decode_slice_unchecked(s, &mut bytes)
97 .map_err(DecodeError::inv_bytes)
98 .map(|_| Self::from(bytes))
99 }
100}
101
102impl<const S: usize> AsRef<[u8]> for Token<S> {
103 fn as_ref(&self) -> &[u8] {
104 &self.bytes
105 }
106}
107
108#[cfg(all(feature = "b64", feature = "serde"))]
109mod impl_serde {
110 use super::*;
111
112 use std::borrow::Cow;
113 use std::str::FromStr;
114
115 use _serde::de::Error;
116 use _serde::{Deserialize, Deserializer, Serialize, Serializer};
117
118 impl<const SI: usize> Serialize for Token<SI> {
119 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
120 where
121 S: Serializer,
122 {
123 serializer.collect_str(&self)
124 }
125 }
126
127 impl<'de, const S: usize> Deserialize<'de> for Token<S> {
128 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
129 where
130 D: Deserializer<'de>,
131 {
132 let s: Cow<'_, str> = Deserialize::deserialize(deserializer)?;
133 Self::from_str(s.as_ref()).map_err(D::Error::custom)
134 }
135 }
136}
137
138#[cfg(feature = "protobuf")]
139mod protobuf {
140 use super::*;
141
142 use fire_protobuf::{
143 bytes::BytesWrite,
144 decode::{DecodeError, DecodeMessage, FieldKind},
145 encode::{
146 EncodeError, EncodeMessage, FieldOpt, MessageEncoder, SizeBuilder,
147 },
148 WireType,
149 };
150
151 impl<const SI: usize> EncodeMessage for Token<SI> {
152 const WIRE_TYPE: WireType = WireType::Len;
153
154 fn is_default(&self) -> bool {
155 false
156 }
157
158 fn encoded_size(
159 &mut self,
160 field: Option<FieldOpt>,
161 builder: &mut SizeBuilder,
162 ) -> Result<(), EncodeError> {
163 self.bytes.encoded_size(field, builder)
164 }
165
166 fn encode<B>(
167 &mut self,
168 field: Option<FieldOpt>,
169 encoder: &mut MessageEncoder<B>,
170 ) -> Result<(), EncodeError>
171 where
172 B: BytesWrite,
173 {
174 self.bytes.encode(field, encoder)
175 }
176 }
177
178 impl<'m, const SI: usize> DecodeMessage<'m> for Token<SI> {
179 const WIRE_TYPE: WireType = WireType::Len;
180
181 fn decode_default() -> Self {
182 [0; SI].into()
183 }
184
185 fn merge(
186 &mut self,
187 kind: FieldKind<'m>,
188 is_field: bool,
189 ) -> Result<(), DecodeError> {
190 self.bytes.merge(kind, is_field)
191 }
192 }
193}
194
195#[cfg(all(feature = "b64", feature = "postgres"))]
196mod impl_postgres {
197 use super::*;
198
199 use bytes::BytesMut;
200 use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
201
202 impl<const SI: usize> ToSql for Token<SI> {
203 fn to_sql(
204 &self,
205 ty: &Type,
206 out: &mut BytesMut,
207 ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
208 where
209 Self: Sized,
210 {
211 self.to_string().to_sql(ty, out)
212 }
213
214 fn accepts(ty: &Type) -> bool
215 where
216 Self: Sized,
217 {
218 <&str as ToSql>::accepts(ty)
219 }
220
221 to_sql_checked!();
222 }
223
224 impl<'r, const SI: usize> FromSql<'r> for Token<SI> {
225 fn from_sql(
226 ty: &Type,
227 raw: &'r [u8],
228 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
229 let s = <&str as FromSql>::from_sql(ty, raw)?;
230 s.parse().map_err(Into::into)
231 }
232
233 fn accepts(ty: &Type) -> bool {
234 <&str as FromSql>::accepts(ty)
235 }
236 }
237}
238
239#[cfg(all(test, feature = "b64"))]
240mod tests {
241
242 use super::*;
243
244 use std::str::FromStr;
245
246 pub fn b64<const S: usize>() {
247 let tok = Token::<S>::new();
248
249 let b64 = tok.to_string();
250 let tok_2 = Token::<S>::from_str(&b64).unwrap();
251
252 assert_eq!(b64, tok_2.to_string());
253 }
254
255 #[test]
256 pub fn test_b64() {
257 b64::<1>();
258 b64::<2>();
259 b64::<3>();
260 b64::<13>();
261 b64::<24>();
262 b64::<200>();
263 b64::<213>();
264 }
265}