passkey_types/utils/
bytes.rs1use std::ops::{Deref, DerefMut};
2
3use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
4use typeshare::typeshare;
5
6use super::encoding;
7
8#[typeshare(transparent)]
20#[derive(Debug, Default, PartialEq, Eq, Clone, Hash)]
21#[repr(transparent)]
22pub struct Bytes(Vec<u8>);
23
24impl Deref for Bytes {
25 type Target = Vec<u8>;
26
27 fn deref(&self) -> &Self::Target {
28 &self.0
29 }
30}
31
32impl DerefMut for Bytes {
33 fn deref_mut(&mut self) -> &mut Self::Target {
34 &mut self.0
35 }
36}
37
38impl From<Vec<u8>> for Bytes {
39 fn from(inner: Vec<u8>) -> Self {
40 Bytes(inner)
41 }
42}
43
44impl From<&[u8]> for Bytes {
45 fn from(value: &[u8]) -> Self {
46 Bytes(value.to_vec())
47 }
48}
49
50impl From<Bytes> for Vec<u8> {
51 fn from(src: Bytes) -> Self {
52 src.0
53 }
54}
55
56impl From<Bytes> for String {
57 fn from(src: Bytes) -> Self {
58 encoding::base64url(&src)
59 }
60}
61
62#[derive(Debug)]
64pub struct NotBase64Encoded;
65
66impl TryFrom<&str> for Bytes {
67 type Error = NotBase64Encoded;
68
69 fn try_from(value: &str) -> Result<Self, Self::Error> {
70 encoding::try_from_base64url(value)
71 .or_else(|| encoding::try_from_base64(value))
72 .ok_or(NotBase64Encoded)
73 .map(Self)
74 }
75}
76
77impl FromIterator<u8> for Bytes {
78 fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
79 Bytes(iter.into_iter().collect())
80 }
81}
82
83impl IntoIterator for Bytes {
84 type Item = u8;
85
86 type IntoIter = std::vec::IntoIter<u8>;
87
88 fn into_iter(self) -> Self::IntoIter {
89 self.0.into_iter()
90 }
91}
92
93impl<'a> IntoIterator for &'a Bytes {
94 type Item = &'a u8;
95
96 type IntoIter = std::slice::Iter<'a, u8>;
97
98 fn into_iter(self) -> Self::IntoIter {
99 self.0.iter()
100 }
101}
102
103impl Serialize for Bytes {
104 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105 where
106 S: serde::Serializer,
107 {
108 if cfg!(feature = "serialize_bytes_as_base64_string") {
109 serializer.serialize_str(&encoding::base64url(&self.0))
110 } else {
111 serializer.serialize_bytes(&self.0)
112 }
113 }
114}
115
116impl<'de> Deserialize<'de> for Bytes {
117 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118 where
119 D: Deserializer<'de>,
120 {
121 struct Base64Visitor;
122
123 impl<'de> Visitor<'de> for Base64Visitor {
124 type Value = Bytes;
125
126 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
127 write!(f, "A vector of bytes or a base46(url) encoded string")
128 }
129 fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
130 where
131 E: serde::de::Error,
132 {
133 self.visit_str(v)
134 }
135 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
136 where
137 E: serde::de::Error,
138 {
139 self.visit_str(&v)
140 }
141 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
142 where
143 E: serde::de::Error,
144 {
145 v.try_into().map_err(|_| {
146 E::invalid_value(
147 serde::de::Unexpected::Str(v),
148 &"A base64(url) encoded string",
149 )
150 })
151 }
152 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
153 where
154 A: serde::de::SeqAccess<'de>,
155 {
156 let mut buf = Vec::with_capacity(seq.size_hint().unwrap_or_default());
157 while let Some(byte) = seq.next_element()? {
158 buf.push(byte);
159 }
160 Ok(Bytes(buf))
161 }
162 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
163 where
164 E: serde::de::Error,
165 {
166 Ok(Bytes(v.to_vec()))
167 }
168 }
169 deserializer.deserialize_any(Base64Visitor)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use std::collections::HashMap;
177 #[test]
178 fn deserialize_many_formats_into_base64url_vec() {
179 let json = r#"{
180 "array": [101,195,212,161,191,112,75,189,152,52,121,17,62,113,114,164],
181 "base64url": "ZcPUob9wS72YNHkRPnFypA",
182 "base64": "ZcPUob9wS72YNHkRPnFypA=="
183 }"#;
184
185 let deserialized: HashMap<&str, Bytes> =
186 serde_json::from_str(json).expect("failed to deserialize");
187
188 assert_eq!(deserialized["array"], deserialized["base64url"]);
189 assert_eq!(deserialized["base64url"], deserialized["base64"]);
190 }
191
192 #[test]
193 fn deserialization_should_fail() {
194 let json = r#"{
195 "array": ["ZcPUob9wS72YNHkRPnFypA","ZcPUob9wS72YNHkRPnFypA=="],
196 }"#;
197
198 serde_json::from_str::<HashMap<&str, Bytes>>(json)
199 .expect_err("did not give an error as expected.");
200 }
201}