1#[cfg(feature = "fmt")]
5use std::fmt::Write;
6
7use std::marker::PhantomData;
8
9use base64ct::Encoding;
10use bytes::Bytes;
11use serde::{
12 de::{self, DeserializeOwned},
13 ser, Serialize,
14};
15
16#[cfg(feature = "fmt")]
17use super::fmt::{self, IndentWriter};
18
19#[derive(Debug, thiserror::Error)]
21pub enum DecodeError {
22 #[error(transparent)]
24 Base64(#[from] base64ct::Error),
25
26 #[error(transparent)]
28 Json(#[from] serde_json::Error),
29
30 #[error("data is not valid: {0}")]
32 InvalidData(#[source] Box<dyn std::error::Error + Send + Sync>),
33}
34
35#[derive(Clone, PartialEq, Eq, Hash)]
38pub struct Base64Data<T>(pub T);
39
40impl<T> Base64Data<T>
41where
42 T: AsRef<[u8]>,
43{
44 pub(crate) fn serialized_value(&self) -> Result<String, serde_json::Error> {
45 Ok(base64ct::Base64UrlUnpadded::encode_string(self.0.as_ref()))
46 }
47}
48
49impl<T> std::fmt::Debug for Base64Data<T>
50where
51 T: AsRef<[u8]>,
52{
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_tuple("Base64Data")
55 .field(&self.serialized_value().unwrap())
56 .finish()
57 }
58}
59
60impl<T> From<T> for Base64Data<T> {
61 fn from(value: T) -> Self {
62 Base64Data(value)
63 }
64}
65
66impl<T> ser::Serialize for Base64Data<T>
67where
68 T: AsRef<[u8]>,
69{
70 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
71 where
72 S: serde::Serializer,
73 {
74 let target = self
75 .serialized_value()
76 .map_err(|err| unreachable!("serialization error: {}", err))?;
77 serializer.serialize_str(&target)
78 }
79}
80
81impl<T> AsRef<[u8]> for Base64Data<T>
82where
83 T: AsRef<[u8]>,
84{
85 fn as_ref(&self) -> &[u8] {
86 self.0.as_ref()
87 }
88}
89
90struct Base64DataVisitor<T>(PhantomData<T>);
91
92impl<'de, T> de::Visitor<'de> for Base64DataVisitor<T>
93where
94 T: for<'a> TryFrom<&'a [u8]>,
95{
96 type Value = Base64Data<T>;
97
98 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
99 formatter.write_str("base64url encoded data")
100 }
101
102 fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
103 where
104 E: de::Error,
105 {
106 let data = base64ct::Base64UrlUnpadded::decode_vec(v)
107 .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"invalid base64url encoding"))?;
108
109 let realized = T::try_from(data.as_ref())
110 .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"can't parse internal data"))?;
111 Ok(Base64Data(realized))
112 }
113}
114
115impl<'de, T> de::Deserialize<'de> for Base64Data<T>
116where
117 T: for<'a> TryFrom<&'a [u8]>,
118{
119 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
120 where
121 D: serde::Deserializer<'de>,
122 {
123 deserializer.deserialize_str(Base64DataVisitor(PhantomData))
124 }
125}
126
127#[cfg(feature = "fmt")]
128impl<T> fmt::JWTFormat for Base64Data<T>
129where
130 T: AsRef<[u8]>,
131{
132 fn fmt<W: fmt::Write>(&self, f: &mut IndentWriter<'_, W>) -> fmt::Result {
133 write!(
134 f,
135 "b64\"{}\"",
136 base64ct::Base64UrlUnpadded::encode_string(self.0.as_ref())
137 )
138 }
139}
140
141#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
144pub struct Base64Signature<T>(pub T);
145
146impl<T> Base64Signature<T>
147where
148 T: signature::SignatureEncoding,
149{
150 pub(crate) fn serialized_value(&self) -> Result<String, serde_json::Error> {
151 Ok(base64ct::Base64UrlUnpadded::encode_string(
152 self.0.to_bytes().as_ref(),
153 ))
154 }
155}
156
157impl<T> std::fmt::Debug for Base64Signature<T>
158where
159 T: signature::SignatureEncoding,
160{
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 f.debug_tuple("Base64Signature")
163 .field(&self.serialized_value().unwrap())
164 .finish()
165 }
166}
167
168impl<T> Base64Signature<T>
169where
170 T: TryFrom<Vec<u8>>,
171 T::Error: std::error::Error + Send + Sync + 'static,
172{
173 pub(crate) fn parse(value: &str) -> Result<Self, DecodeError> {
174 let data = base64ct::Base64UrlUnpadded::decode_vec(value)?;
175 let data = T::try_from(data).map_err(|err| DecodeError::InvalidData(err.into()))?;
176 Ok(Base64Signature(data))
177 }
178}
179
180impl<T> From<T> for Base64Signature<T> {
181 fn from(value: T) -> Self {
182 Base64Signature(value)
183 }
184}
185
186impl<T> ser::Serialize for Base64Signature<T>
187where
188 T: signature::SignatureEncoding,
189{
190 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
191 where
192 S: serde::Serializer,
193 {
194 let target = self
195 .serialized_value()
196 .map_err(|err| unreachable!("serialization error: {}", err))?;
197 serializer.serialize_str(&target)
198 }
199}
200
201impl<T> AsRef<[u8]> for Base64Signature<T>
202where
203 T: AsRef<[u8]>,
204{
205 fn as_ref(&self) -> &[u8] {
206 self.0.as_ref()
207 }
208}
209
210struct Base64SignatureVisitor<T>(PhantomData<T>);
211
212impl<'de, T> de::Visitor<'de> for Base64SignatureVisitor<T>
213where
214 T: for<'a> TryFrom<&'a [u8]>,
215{
216 type Value = Base64Signature<T>;
217
218 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
219 formatter.write_str("base64url encoded data")
220 }
221
222 fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
223 where
224 E: de::Error,
225 {
226 let data = base64ct::Base64UrlUnpadded::decode_vec(v)
227 .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"invalid base64url encoding"))?;
228
229 let realized = T::try_from(data.as_ref())
230 .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"can't parse internal data"))?;
231 Ok(Base64Signature(realized))
232 }
233}
234
235impl<'de, T> de::Deserialize<'de> for Base64Signature<T>
236where
237 T: for<'a> TryFrom<&'a [u8]>,
238{
239 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
240 where
241 D: serde::Deserializer<'de>,
242 {
243 deserializer.deserialize_str(Base64SignatureVisitor(PhantomData))
244 }
245}
246
247#[cfg(feature = "fmt")]
248impl<T> fmt::JWTFormat for Base64Signature<T>
249where
250 T: AsRef<[u8]>,
251{
252 fn fmt<W: fmt::Write>(&self, f: &mut IndentWriter<'_, W>) -> fmt::Result {
253 write!(
254 f,
255 "b64\"{}\"",
256 base64ct::Base64UrlUnpadded::encode_string(self.0.as_ref())
257 )
258 }
259}
260
261#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
264pub struct Base64JSON<T>(pub T);
265
266impl<T> Base64JSON<T> {
267 pub fn new(value: T) -> Self {
269 Base64JSON(value)
270 }
271
272 pub fn into_inner(self) -> T {
274 self.0
275 }
276}
277
278impl<T> Base64JSON<T>
279where
280 T: Serialize,
281{
282 pub(crate) fn serialized_value(&self) -> Result<String, serde_json::Error> {
283 let inner = serde_json::to_vec(&self.0)?;
284 Ok(base64ct::Base64UrlUnpadded::encode_string(&inner))
285 }
286
287 pub(crate) fn serialized_bytes(&self) -> Result<Bytes, serde_json::Error> {
288 self.serialized_value().map(Bytes::from)
289 }
290}
291
292pub(crate) struct ParsedBase64JSON<T> {
293 pub(crate) data: T,
294 pub(crate) bytes: Bytes,
295}
296
297impl<T> Base64JSON<T>
298where
299 T: DeserializeOwned,
300{
301 pub(crate) fn parse(raw: &str) -> Result<ParsedBase64JSON<T>, DecodeError>
302 where
303 T: de::DeserializeOwned,
304 {
305 let data = base64ct::Base64UrlUnpadded::decode_vec(raw)?;
306 let value = serde_json::from_slice(&data)?;
307 Ok(ParsedBase64JSON {
308 data: value,
309 bytes: Bytes::from(raw.to_owned()),
310 })
311 }
312}
313
314impl<T> AsRef<T> for Base64JSON<T> {
315 fn as_ref(&self) -> &T {
316 &self.0
317 }
318}
319
320impl<T> From<T> for Base64JSON<T> {
321 fn from(value: T) -> Self {
322 Base64JSON(value)
323 }
324}
325
326#[cfg(feature = "fmt")]
327impl<T> fmt::JWTFormat for Base64JSON<T>
328where
329 T: Serialize,
330{
331 fn fmt<W: fmt::Write>(&self, f: &mut IndentWriter<'_, W>) -> fmt::Result {
332 write!(f, "base64url(")?;
333 f.write_json(&self.0)?;
334 f.write_str(")")
335 }
336}
337
338struct Base64JSONVisitor<T>(PhantomData<T>);
339
340impl<'de, T> de::Visitor<'de> for Base64JSONVisitor<T>
341where
342 T: de::DeserializeOwned,
343{
344 type Value = Base64JSON<T>;
345
346 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
347 formatter.write_str("a base64url encoded json document")
348 }
349
350 fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
351 where
352 E: de::Error,
353 {
354 let data = base64ct::Base64UrlUnpadded::decode_vec(v)
355 .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"invalid base64url encoding"))?;
356
357 let data = serde_json::from_slice(&data)
358 .map_err(|err| E::custom(format!("invalid JSON: {err}")))?;
359 Ok(Base64JSON(data))
360 }
361}
362
363impl<'de, T> de::Deserialize<'de> for Base64JSON<T>
364where
365 T: de::DeserializeOwned,
366{
367 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
368 where
369 D: serde::Deserializer<'de>,
370 {
371 deserializer.deserialize_str(Base64JSONVisitor(PhantomData))
372 }
373}
374
375impl<T> ser::Serialize for Base64JSON<T>
376where
377 T: ser::Serialize,
378{
379 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
380 where
381 S: serde::Serializer,
382 {
383 use serde::ser::Error;
384 let inner = self
385 .serialized_value()
386 .map_err(|err| S::Error::custom(format!("Error producing inner JSON: {err}")))?;
387 serializer.serialize_str(&inner)
388 }
389}
390
391#[cfg(test)]
392mod test {
393 use serde_json::{json, Value};
394
395 use super::*;
396 use crate::algorithms::SignatureBytes;
397
398 #[test]
399 fn test_base64_data() {
400 let data = Base64Data::from(vec![1, 2, 3, 4]);
401 let serialized = serde_json::to_string(&data).unwrap();
402 assert_eq!(serialized, r#""AQIDBA""#);
403 let deserialized: Base64Data<Vec<u8>> = serde_json::from_str(&serialized).unwrap();
404 assert_eq!(deserialized, data);
405 }
406
407 #[test]
408 fn test_base64_signature() {
409 let data = Base64Signature::from(SignatureBytes::from(vec![1, 2, 3, 4]));
410 let serialized = serde_json::to_string(&data).unwrap();
411 assert_eq!(serialized, r#""AQIDBA""#);
412 let deserialized: Base64Signature<SignatureBytes> =
413 serde_json::from_str(&serialized).unwrap();
414 assert_eq!(deserialized, data);
415 }
416
417 #[test]
418 fn test_base64_json() {
419 let data = Base64JSON::from(json!({"foo": "bar"}));
420 let serialized = serde_json::to_string(&data).unwrap();
421 assert_eq!(serialized, r#""eyJmb28iOiJiYXIifQ""#);
422 let deserialized: Base64JSON<Value> = serde_json::from_str(&serialized).unwrap();
423 assert_eq!(deserialized, data);
424 }
425}