passkey_types/utils/
bytes.rs

1use std::ops::{Deref, DerefMut};
2
3use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
4use typeshare::typeshare;
5
6use super::encoding;
7
8/// A newtype around `Vec<u8>` which serializes using the transport format's byte representation.
9///
10/// When feature `serialize_bytes_as_base64_string` is set, this type will be serialized into a
11/// `base64url` representation instead. Note that this type should not be used externally when this
12/// feature is set, such as in Kotlin, to avoid a serialization errors. In the future, this feature
13/// flag can be removed when typeshare supports target/language specific serialization:
14/// <https://github.com/1Password/typeshare/issues/63>
15///
16/// This will use an array of numbers for JSON, and a byte string in CBOR for example.
17///
18/// It also supports deserializing from `base64` and `base64url` formatted strings.
19#[typeshare(transparent)]
20#[derive(Debug, Default, PartialEq, Eq, Clone, Hash)]
21#[repr(transparent)]
22pub struct Bytes(Vec<u8>);
23
24impl Deref for Bytes {
25    type Target = Vec<u8>;
26
27    fn deref(&self) -> &Self::Target {
28        &self.0
29    }
30}
31
32impl DerefMut for Bytes {
33    fn deref_mut(&mut self) -> &mut Self::Target {
34        &mut self.0
35    }
36}
37
38impl From<Vec<u8>> for Bytes {
39    fn from(inner: Vec<u8>) -> Self {
40        Bytes(inner)
41    }
42}
43
44impl From<&[u8]> for Bytes {
45    fn from(value: &[u8]) -> Self {
46        Bytes(value.to_vec())
47    }
48}
49
50impl From<Bytes> for Vec<u8> {
51    fn from(src: Bytes) -> Self {
52        src.0
53    }
54}
55
56impl From<Bytes> for String {
57    fn from(src: Bytes) -> Self {
58        encoding::base64url(&src)
59    }
60}
61
62/// The string given for decoding is not `base64url` nor `base64` encoded data.
63#[derive(Debug)]
64pub struct NotBase64Encoded;
65
66impl TryFrom<&str> for Bytes {
67    type Error = NotBase64Encoded;
68
69    fn try_from(value: &str) -> Result<Self, Self::Error> {
70        encoding::try_from_base64url(value)
71            .or_else(|| encoding::try_from_base64(value))
72            .ok_or(NotBase64Encoded)
73            .map(Self)
74    }
75}
76
77impl FromIterator<u8> for Bytes {
78    fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
79        Bytes(iter.into_iter().collect())
80    }
81}
82
83impl IntoIterator for Bytes {
84    type Item = u8;
85
86    type IntoIter = std::vec::IntoIter<u8>;
87
88    fn into_iter(self) -> Self::IntoIter {
89        self.0.into_iter()
90    }
91}
92
93impl<'a> IntoIterator for &'a Bytes {
94    type Item = &'a u8;
95
96    type IntoIter = std::slice::Iter<'a, u8>;
97
98    fn into_iter(self) -> Self::IntoIter {
99        self.0.iter()
100    }
101}
102
103impl Serialize for Bytes {
104    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105    where
106        S: serde::Serializer,
107    {
108        if cfg!(feature = "serialize_bytes_as_base64_string") {
109            serializer.serialize_str(&encoding::base64url(&self.0))
110        } else {
111            serializer.serialize_bytes(&self.0)
112        }
113    }
114}
115
116impl<'de> Deserialize<'de> for Bytes {
117    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118    where
119        D: Deserializer<'de>,
120    {
121        struct Base64Visitor;
122
123        impl<'de> Visitor<'de> for Base64Visitor {
124            type Value = Bytes;
125
126            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
127                write!(f, "A vector of bytes or a base46(url) encoded string")
128            }
129            fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
130            where
131                E: serde::de::Error,
132            {
133                self.visit_str(v)
134            }
135            fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
136            where
137                E: serde::de::Error,
138            {
139                self.visit_str(&v)
140            }
141            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
142            where
143                E: serde::de::Error,
144            {
145                v.try_into().map_err(|_| {
146                    E::invalid_value(
147                        serde::de::Unexpected::Str(v),
148                        &"A base64(url) encoded string",
149                    )
150                })
151            }
152            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
153            where
154                A: serde::de::SeqAccess<'de>,
155            {
156                let mut buf = Vec::with_capacity(seq.size_hint().unwrap_or_default());
157                while let Some(byte) = seq.next_element()? {
158                    buf.push(byte);
159                }
160                Ok(Bytes(buf))
161            }
162            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
163            where
164                E: serde::de::Error,
165            {
166                Ok(Bytes(v.to_vec()))
167            }
168        }
169        deserializer.deserialize_any(Base64Visitor)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use std::collections::HashMap;
177    #[test]
178    fn deserialize_many_formats_into_base64url_vec() {
179        let json = r#"{
180            "array": [101,195,212,161,191,112,75,189,152,52,121,17,62,113,114,164],
181            "base64url": "ZcPUob9wS72YNHkRPnFypA",
182            "base64": "ZcPUob9wS72YNHkRPnFypA=="
183        }"#;
184
185        let deserialized: HashMap<&str, Bytes> =
186            serde_json::from_str(json).expect("failed to deserialize");
187
188        assert_eq!(deserialized["array"], deserialized["base64url"]);
189        assert_eq!(deserialized["base64url"], deserialized["base64"]);
190    }
191
192    #[test]
193    fn deserialization_should_fail() {
194        let json = r#"{
195            "array": ["ZcPUob9wS72YNHkRPnFypA","ZcPUob9wS72YNHkRPnFypA=="],
196        }"#;
197
198        serde_json::from_str::<HashMap<&str, Bytes>>(json)
199            .expect_err("did not give an error as expected.");
200    }
201}