1use core::{
2 fmt::{self, Debug, Formatter},
3 marker::PhantomData,
4};
5
6#[cfg(feature = "arbitrary")]
7use arbitrary::Arbitrary;
8use base64::Engine;
9use serde::{
10 de::{Deserialize, Deserializer, MapAccess, Visitor},
11 ser::{Serialize, SerializeMap, Serializer},
12};
13
14use super::ops::{KeyOps, KeyOpsSet};
15use crate::error::Error;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
20pub struct JwkParts<'a> {
21 pub kty: &'a str,
23 pub kid: OptAttr<'a>,
25 pub alg: OptAttr<'a>,
27 pub crv: OptAttr<'a>,
29 pub x: OptAttr<'a>,
31 pub y: OptAttr<'a>,
33 pub d: OptAttr<'a>,
35 pub k: OptAttr<'a>,
37 pub key_ops: Option<KeyOpsSet>,
39}
40
41impl<'de> JwkParts<'de> {
42 pub fn try_from_str(jwk: &'de str) -> Result<Self, Error> {
44 let (parts, _read) =
45 serde_json_core::from_str(jwk).map_err(err_map!(Invalid, "Error parsing JWK"))?;
46 Ok(parts)
47 }
48
49 pub fn from_slice(jwk: &'de [u8]) -> Result<Self, Error> {
51 let (parts, _read) =
52 serde_json_core::from_slice(jwk).map_err(err_map!(Invalid, "Error parsing JWK"))?;
53 Ok(parts)
54 }
55}
56
57#[derive(Copy, Clone, Default, PartialEq, Eq)]
58#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
59#[repr(transparent)]
60pub struct OptAttr<'a>(Option<&'a str>);
61
62impl OptAttr<'_> {
63 pub fn is_none(&self) -> bool {
64 self.0.is_none()
65 }
66
67 pub fn is_some(&self) -> bool {
68 self.0.is_some()
69 }
70
71 pub fn as_opt_str(&self) -> Option<&str> {
72 self.0
73 }
74
75 pub fn decode_base64(&self, output: &mut [u8]) -> Result<usize, Error> {
76 if let Some(s) = self.0 {
77 let max_input = (output.len() * 4 + 2) / 3; if s.len() > max_input {
79 Err(err_msg!(Invalid, "Base64 length exceeds max"))
80 } else {
81 base64::engine::general_purpose::URL_SAFE_NO_PAD
82 .decode_slice_unchecked(s, output)
83 .map_err(|_| err_msg!(Invalid, "Base64 decoding error"))
84 }
85 } else {
86 Err(err_msg!(Invalid, "Empty attribute"))
87 }
88 }
89}
90
91impl Debug for OptAttr<'_> {
92 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
93 match self.0 {
94 None => f.write_str("None"),
95 Some(s) => f.write_fmt(format_args!("{:?}", s)),
96 }
97 }
98}
99
100impl AsRef<str> for OptAttr<'_> {
101 fn as_ref(&self) -> &str {
102 self.0.unwrap_or_default()
103 }
104}
105
106impl<'o> From<&'o str> for OptAttr<'o> {
107 fn from(s: &'o str) -> Self {
108 Self(Some(s))
109 }
110}
111
112impl<'o> From<Option<&'o str>> for OptAttr<'o> {
113 fn from(s: Option<&'o str>) -> Self {
114 Self(s)
115 }
116}
117
118impl PartialEq<Option<&str>> for OptAttr<'_> {
119 fn eq(&self, other: &Option<&str>) -> bool {
120 self.0 == *other
121 }
122}
123
124impl PartialEq<&str> for OptAttr<'_> {
125 fn eq(&self, other: &&str) -> bool {
126 match self.0 {
127 None => false,
128 Some(s) => (*other) == s,
129 }
130 }
131}
132
133struct JwkMapVisitor<'de>(PhantomData<&'de ()>);
134
135impl<'de> Visitor<'de> for JwkMapVisitor<'de> {
136 type Value = JwkParts<'de>;
137
138 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
139 formatter.write_str("an object representing a JWK")
140 }
141
142 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
143 where
144 M: MapAccess<'de>,
145 {
146 let mut kty = None;
147 let mut kid = None;
148 let mut alg = None;
149 let mut crv = None;
150 let mut x = None;
151 let mut y = None;
152 let mut d = None;
153 let mut k = None;
154 let mut key_ops = None;
155
156 while let Some(key) = access.next_key::<&str>()? {
157 match key {
158 "kty" => kty = Some(access.next_value()?),
159 "kid" => kid = Some(access.next_value()?),
160 "alg" => alg = Some(access.next_value()?),
161 "crv" => crv = Some(access.next_value()?),
162 "x" => x = Some(access.next_value()?),
163 "y" => y = Some(access.next_value()?),
164 "d" => d = Some(access.next_value()?),
165 "k" => k = Some(access.next_value()?),
166 "use" => {
167 let ops = match access.next_value()? {
168 "enc" => {
169 KeyOps::Encrypt | KeyOps::Decrypt | KeyOps::WrapKey | KeyOps::UnwrapKey
170 }
171 "sig" => KeyOps::Sign | KeyOps::Verify,
172 _ => KeyOpsSet::new(),
173 };
174 if !ops.is_empty() {
175 key_ops = Some(key_ops.unwrap_or_default() | ops);
176 }
177 }
178 "key_ops" => key_ops = Some(access.next_value()?),
179 _ => (),
180 }
181 }
182
183 if let Some(kty) = kty {
184 Ok(JwkParts {
185 kty,
186 kid: kid.into(),
187 alg: alg.into(),
188 crv: crv.into(),
189 x: x.into(),
190 y: y.into(),
191 d: d.into(),
192 k: k.into(),
193 key_ops,
194 })
195 } else {
196 Err(serde::de::Error::missing_field("kty"))
197 }
198 }
199}
200
201impl<'de> Deserialize<'de> for JwkParts<'de> {
202 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203 where
204 D: Deserializer<'de>,
205 {
206 deserializer.deserialize_map(JwkMapVisitor(PhantomData))
207 }
208}
209
210impl Serialize for JwkParts<'_> {
211 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
212 where
213 S: Serializer,
214 {
215 let mut map = serializer.serialize_map(None)?;
216 if let Some(alg) = self.alg.as_opt_str() {
217 map.serialize_entry("alg", alg)?;
218 }
219 if let Some(crv) = self.crv.as_opt_str() {
220 map.serialize_entry("crv", crv)?;
221 }
222 if let Some(d) = self.d.as_opt_str() {
223 map.serialize_entry("d", d)?;
224 }
225 if let Some(k) = self.k.as_opt_str() {
226 map.serialize_entry("k", k)?;
227 }
228 if let Some(kid) = self.kid.as_opt_str() {
229 map.serialize_entry("kid", kid)?;
230 }
231 map.serialize_entry("kty", self.kty)?;
232 if let Some(x) = self.x.as_opt_str() {
233 map.serialize_entry("x", x)?;
234 }
235 if let Some(y) = self.y.as_opt_str() {
236 map.serialize_entry("y", y)?;
237 }
238 if let Some(ops) = self.key_ops {
239 map.serialize_entry("key_ops", &ops)?;
240 }
241 map.end()
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn parse_sample_okp() {
251 let jwk = r#"{
252 "kty": "OKP",
253 "crv": "Ed25519",
254 "x": "11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo",
255 "d": "nWGxne_9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A",
256 "key_ops": ["sign", "verify"],
257 "kid": "FdFYFzERwC2uCBB46pZQi4GG85LujR8obt-KWRBICVQ"
258 }"#;
259 let parts = JwkParts::try_from_str(jwk).unwrap();
260 assert_eq!(parts.kty, "OKP");
261 assert_eq!(
262 parts.kid,
263 Some("FdFYFzERwC2uCBB46pZQi4GG85LujR8obt-KWRBICVQ")
264 );
265 assert_eq!(parts.crv, Some("Ed25519"));
266 assert_eq!(parts.x, Some("11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo"));
267 assert_eq!(parts.y, None);
268 assert_eq!(parts.d, Some("nWGxne_9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A"));
269 assert_eq!(parts.k, None);
270 assert_eq!(parts.key_ops, Some(KeyOps::Sign | KeyOps::Verify));
271
272 let mut buf = [0u8; 512];
274 let len = serde_json_core::to_slice(&parts, &mut buf[..]).unwrap();
275 let parts_2 = JwkParts::from_slice(&buf[..len]).unwrap();
276 assert_eq!(parts_2, parts);
277 }
278}