jaws/
base64data.rs

1//! Base64 Data tools, which interact well with serde and are useful
2//! for building data structures for JWTs.
3
4#[cfg(feature = "fmt")]
5use std::fmt::Write;
6
7use std::marker::PhantomData;
8
9use base64ct::Encoding;
10use bytes::Bytes;
11use serde::{
12    de::{self, DeserializeOwned},
13    ser, Serialize,
14};
15
16#[cfg(feature = "fmt")]
17use super::fmt::{self, IndentWriter};
18
19/// Error type for decoding base64 data in wrappers.
20#[derive(Debug, thiserror::Error)]
21pub enum DecodeError {
22    /// The data being decoded is not base64
23    #[error(transparent)]
24    Base64(#[from] base64ct::Error),
25
26    /// The data being decoded is not valid JSON
27    #[error(transparent)]
28    Json(#[from] serde_json::Error),
29
30    /// The data being decoded is not valid for another reason.
31    #[error("data is not valid: {0}")]
32    InvalidData(#[source] Box<dyn std::error::Error + Send + Sync>),
33}
34
35/// Wrapper type for types which implement AsRef<[u8]> to indicate that
36/// they should serialize as bytes with a Base64 URL-safe encoding.
37#[derive(Clone, PartialEq, Eq, Hash)]
38pub struct Base64Data<T>(pub T);
39
40impl<T> Base64Data<T>
41where
42    T: AsRef<[u8]>,
43{
44    pub(crate) fn serialized_value(&self) -> Result<String, serde_json::Error> {
45        Ok(base64ct::Base64UrlUnpadded::encode_string(self.0.as_ref()))
46    }
47}
48
49impl<T> std::fmt::Debug for Base64Data<T>
50where
51    T: AsRef<[u8]>,
52{
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_tuple("Base64Data")
55            .field(&self.serialized_value().unwrap())
56            .finish()
57    }
58}
59
60impl<T> From<T> for Base64Data<T> {
61    fn from(value: T) -> Self {
62        Base64Data(value)
63    }
64}
65
66impl<T> ser::Serialize for Base64Data<T>
67where
68    T: AsRef<[u8]>,
69{
70    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
71    where
72        S: serde::Serializer,
73    {
74        let target = self
75            .serialized_value()
76            .map_err(|err| unreachable!("serialization error: {}", err))?;
77        serializer.serialize_str(&target)
78    }
79}
80
81impl<T> AsRef<[u8]> for Base64Data<T>
82where
83    T: AsRef<[u8]>,
84{
85    fn as_ref(&self) -> &[u8] {
86        self.0.as_ref()
87    }
88}
89
90struct Base64DataVisitor<T>(PhantomData<T>);
91
92impl<'de, T> de::Visitor<'de> for Base64DataVisitor<T>
93where
94    T: for<'a> TryFrom<&'a [u8]>,
95{
96    type Value = Base64Data<T>;
97
98    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
99        formatter.write_str("base64url encoded data")
100    }
101
102    fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
103    where
104        E: de::Error,
105    {
106        let data = base64ct::Base64UrlUnpadded::decode_vec(v)
107            .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"invalid base64url encoding"))?;
108
109        let realized = T::try_from(data.as_ref())
110            .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"can't parse internal data"))?;
111        Ok(Base64Data(realized))
112    }
113}
114
115impl<'de, T> de::Deserialize<'de> for Base64Data<T>
116where
117    T: for<'a> TryFrom<&'a [u8]>,
118{
119    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
120    where
121        D: serde::Deserializer<'de>,
122    {
123        deserializer.deserialize_str(Base64DataVisitor(PhantomData))
124    }
125}
126
127#[cfg(feature = "fmt")]
128impl<T> fmt::JWTFormat for Base64Data<T>
129where
130    T: AsRef<[u8]>,
131{
132    fn fmt<W: fmt::Write>(&self, f: &mut IndentWriter<'_, W>) -> fmt::Result {
133        write!(
134            f,
135            "b64\"{}\"",
136            base64ct::Base64UrlUnpadded::encode_string(self.0.as_ref())
137        )
138    }
139}
140
141/// Wrapper type to indicate that the inner type should be serialized
142/// as bytes with a Base64 URL-safe encoding.
143#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
144pub struct Base64Signature<T>(pub T);
145
146impl<T> Base64Signature<T>
147where
148    T: signature::SignatureEncoding,
149{
150    pub(crate) fn serialized_value(&self) -> Result<String, serde_json::Error> {
151        Ok(base64ct::Base64UrlUnpadded::encode_string(
152            self.0.to_bytes().as_ref(),
153        ))
154    }
155}
156
157impl<T> std::fmt::Debug for Base64Signature<T>
158where
159    T: signature::SignatureEncoding,
160{
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        f.debug_tuple("Base64Signature")
163            .field(&self.serialized_value().unwrap())
164            .finish()
165    }
166}
167
168impl<T> Base64Signature<T>
169where
170    T: TryFrom<Vec<u8>>,
171    T::Error: std::error::Error + Send + Sync + 'static,
172{
173    pub(crate) fn parse(value: &str) -> Result<Self, DecodeError> {
174        let data = base64ct::Base64UrlUnpadded::decode_vec(value)?;
175        let data = T::try_from(data).map_err(|err| DecodeError::InvalidData(err.into()))?;
176        Ok(Base64Signature(data))
177    }
178}
179
180impl<T> From<T> for Base64Signature<T> {
181    fn from(value: T) -> Self {
182        Base64Signature(value)
183    }
184}
185
186impl<T> ser::Serialize for Base64Signature<T>
187where
188    T: signature::SignatureEncoding,
189{
190    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
191    where
192        S: serde::Serializer,
193    {
194        let target = self
195            .serialized_value()
196            .map_err(|err| unreachable!("serialization error: {}", err))?;
197        serializer.serialize_str(&target)
198    }
199}
200
201impl<T> AsRef<[u8]> for Base64Signature<T>
202where
203    T: AsRef<[u8]>,
204{
205    fn as_ref(&self) -> &[u8] {
206        self.0.as_ref()
207    }
208}
209
210struct Base64SignatureVisitor<T>(PhantomData<T>);
211
212impl<'de, T> de::Visitor<'de> for Base64SignatureVisitor<T>
213where
214    T: for<'a> TryFrom<&'a [u8]>,
215{
216    type Value = Base64Signature<T>;
217
218    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
219        formatter.write_str("base64url encoded data")
220    }
221
222    fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
223    where
224        E: de::Error,
225    {
226        let data = base64ct::Base64UrlUnpadded::decode_vec(v)
227            .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"invalid base64url encoding"))?;
228
229        let realized = T::try_from(data.as_ref())
230            .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"can't parse internal data"))?;
231        Ok(Base64Signature(realized))
232    }
233}
234
235impl<'de, T> de::Deserialize<'de> for Base64Signature<T>
236where
237    T: for<'a> TryFrom<&'a [u8]>,
238{
239    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
240    where
241        D: serde::Deserializer<'de>,
242    {
243        deserializer.deserialize_str(Base64SignatureVisitor(PhantomData))
244    }
245}
246
247#[cfg(feature = "fmt")]
248impl<T> fmt::JWTFormat for Base64Signature<T>
249where
250    T: AsRef<[u8]>,
251{
252    fn fmt<W: fmt::Write>(&self, f: &mut IndentWriter<'_, W>) -> fmt::Result {
253        write!(
254            f,
255            "b64\"{}\"",
256            base64ct::Base64UrlUnpadded::encode_string(self.0.as_ref())
257        )
258    }
259}
260
261/// Wrapper type to indicate that the inner type should be serialized
262/// as JSON and then Base64 URL-safe encoded and serialized as a string.
263#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
264pub struct Base64JSON<T>(pub T);
265
266impl<T> Base64JSON<T> {
267    /// Create a new Base64JSON wrapper.
268    pub fn new(value: T) -> Self {
269        Base64JSON(value)
270    }
271
272    /// Consume the wrapper and return the inner value.
273    pub fn into_inner(self) -> T {
274        self.0
275    }
276}
277
278impl<T> Base64JSON<T>
279where
280    T: Serialize,
281{
282    pub(crate) fn serialized_value(&self) -> Result<String, serde_json::Error> {
283        let inner = serde_json::to_vec(&self.0)?;
284        Ok(base64ct::Base64UrlUnpadded::encode_string(&inner))
285    }
286
287    pub(crate) fn serialized_bytes(&self) -> Result<Bytes, serde_json::Error> {
288        self.serialized_value().map(Bytes::from)
289    }
290}
291
292pub(crate) struct ParsedBase64JSON<T> {
293    pub(crate) data: T,
294    pub(crate) bytes: Bytes,
295}
296
297impl<T> Base64JSON<T>
298where
299    T: DeserializeOwned,
300{
301    pub(crate) fn parse(raw: &str) -> Result<ParsedBase64JSON<T>, DecodeError>
302    where
303        T: de::DeserializeOwned,
304    {
305        let data = base64ct::Base64UrlUnpadded::decode_vec(raw)?;
306        let value = serde_json::from_slice(&data)?;
307        Ok(ParsedBase64JSON {
308            data: value,
309            bytes: Bytes::from(raw.to_owned()),
310        })
311    }
312}
313
314impl<T> AsRef<T> for Base64JSON<T> {
315    fn as_ref(&self) -> &T {
316        &self.0
317    }
318}
319
320impl<T> From<T> for Base64JSON<T> {
321    fn from(value: T) -> Self {
322        Base64JSON(value)
323    }
324}
325
326#[cfg(feature = "fmt")]
327impl<T> fmt::JWTFormat for Base64JSON<T>
328where
329    T: Serialize,
330{
331    fn fmt<W: fmt::Write>(&self, f: &mut IndentWriter<'_, W>) -> fmt::Result {
332        write!(f, "base64url(")?;
333        f.write_json(&self.0)?;
334        f.write_str(")")
335    }
336}
337
338struct Base64JSONVisitor<T>(PhantomData<T>);
339
340impl<'de, T> de::Visitor<'de> for Base64JSONVisitor<T>
341where
342    T: de::DeserializeOwned,
343{
344    type Value = Base64JSON<T>;
345
346    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
347        formatter.write_str("a base64url encoded json document")
348    }
349
350    fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
351    where
352        E: de::Error,
353    {
354        let data = base64ct::Base64UrlUnpadded::decode_vec(v)
355            .map_err(|_| E::invalid_value(de::Unexpected::Str(v), &"invalid base64url encoding"))?;
356
357        let data = serde_json::from_slice(&data)
358            .map_err(|err| E::custom(format!("invalid JSON: {err}")))?;
359        Ok(Base64JSON(data))
360    }
361}
362
363impl<'de, T> de::Deserialize<'de> for Base64JSON<T>
364where
365    T: de::DeserializeOwned,
366{
367    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
368    where
369        D: serde::Deserializer<'de>,
370    {
371        deserializer.deserialize_str(Base64JSONVisitor(PhantomData))
372    }
373}
374
375impl<T> ser::Serialize for Base64JSON<T>
376where
377    T: ser::Serialize,
378{
379    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
380    where
381        S: serde::Serializer,
382    {
383        use serde::ser::Error;
384        let inner = self
385            .serialized_value()
386            .map_err(|err| S::Error::custom(format!("Error producing inner JSON: {err}")))?;
387        serializer.serialize_str(&inner)
388    }
389}
390
391#[cfg(test)]
392mod test {
393    use serde_json::{json, Value};
394
395    use super::*;
396    use crate::algorithms::SignatureBytes;
397
398    #[test]
399    fn test_base64_data() {
400        let data = Base64Data::from(vec![1, 2, 3, 4]);
401        let serialized = serde_json::to_string(&data).unwrap();
402        assert_eq!(serialized, r#""AQIDBA""#);
403        let deserialized: Base64Data<Vec<u8>> = serde_json::from_str(&serialized).unwrap();
404        assert_eq!(deserialized, data);
405    }
406
407    #[test]
408    fn test_base64_signature() {
409        let data = Base64Signature::from(SignatureBytes::from(vec![1, 2, 3, 4]));
410        let serialized = serde_json::to_string(&data).unwrap();
411        assert_eq!(serialized, r#""AQIDBA""#);
412        let deserialized: Base64Signature<SignatureBytes> =
413            serde_json::from_str(&serialized).unwrap();
414        assert_eq!(deserialized, data);
415    }
416
417    #[test]
418    fn test_base64_json() {
419        let data = Base64JSON::from(json!({"foo": "bar"}));
420        let serialized = serde_json::to_string(&data).unwrap();
421        assert_eq!(serialized, r#""eyJmb28iOiJiYXIifQ""#);
422        let deserialized: Base64JSON<Value> = serde_json::from_str(&serialized).unwrap();
423        assert_eq!(deserialized, data);
424    }
425}