askar_crypto/jwk/
parts.rs

1use core::{
2    fmt::{self, Debug, Formatter},
3    marker::PhantomData,
4};
5
6#[cfg(feature = "arbitrary")]
7use arbitrary::Arbitrary;
8use base64::Engine;
9use serde::{
10    de::{Deserialize, Deserializer, MapAccess, Visitor},
11    ser::{Serialize, SerializeMap, Serializer},
12};
13
14use super::ops::{KeyOps, KeyOpsSet};
15use crate::error::Error;
16
17/// A parsed JWK
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
20pub struct JwkParts<'a> {
21    /// Key type
22    pub kty: &'a str,
23    /// Key ID
24    pub kid: OptAttr<'a>,
25    /// Key algorithm
26    pub alg: OptAttr<'a>,
27    /// Curve type
28    pub crv: OptAttr<'a>,
29    /// Curve key public x coordinate
30    pub x: OptAttr<'a>,
31    /// Curve key public y coordinate
32    pub y: OptAttr<'a>,
33    /// Curve key private key bytes
34    pub d: OptAttr<'a>,
35    /// Used by symmetric keys like AES
36    pub k: OptAttr<'a>,
37    /// Recognized key operations
38    pub key_ops: Option<KeyOpsSet>,
39}
40
41impl<'de> JwkParts<'de> {
42    /// Parse a JWK from a string reference
43    pub fn try_from_str(jwk: &'de str) -> Result<Self, Error> {
44        let (parts, _read) =
45            serde_json_core::from_str(jwk).map_err(err_map!(Invalid, "Error parsing JWK"))?;
46        Ok(parts)
47    }
48
49    /// Parse a JWK from a byte slice
50    pub fn from_slice(jwk: &'de [u8]) -> Result<Self, Error> {
51        let (parts, _read) =
52            serde_json_core::from_slice(jwk).map_err(err_map!(Invalid, "Error parsing JWK"))?;
53        Ok(parts)
54    }
55}
56
57#[derive(Copy, Clone, Default, PartialEq, Eq)]
58#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
59#[repr(transparent)]
60pub struct OptAttr<'a>(Option<&'a str>);
61
62impl OptAttr<'_> {
63    pub fn is_none(&self) -> bool {
64        self.0.is_none()
65    }
66
67    pub fn is_some(&self) -> bool {
68        self.0.is_some()
69    }
70
71    pub fn as_opt_str(&self) -> Option<&str> {
72        self.0
73    }
74
75    pub fn decode_base64(&self, output: &mut [u8]) -> Result<usize, Error> {
76        if let Some(s) = self.0 {
77            let max_input = (output.len() * 4 + 2) / 3; // ceil(4*n/3)
78            if s.len() > max_input {
79                Err(err_msg!(Invalid, "Base64 length exceeds max"))
80            } else {
81                base64::engine::general_purpose::URL_SAFE_NO_PAD
82                    .decode_slice_unchecked(s, output)
83                    .map_err(|_| err_msg!(Invalid, "Base64 decoding error"))
84            }
85        } else {
86            Err(err_msg!(Invalid, "Empty attribute"))
87        }
88    }
89}
90
91impl Debug for OptAttr<'_> {
92    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
93        match self.0 {
94            None => f.write_str("None"),
95            Some(s) => f.write_fmt(format_args!("{:?}", s)),
96        }
97    }
98}
99
100impl AsRef<str> for OptAttr<'_> {
101    fn as_ref(&self) -> &str {
102        self.0.unwrap_or_default()
103    }
104}
105
106impl<'o> From<&'o str> for OptAttr<'o> {
107    fn from(s: &'o str) -> Self {
108        Self(Some(s))
109    }
110}
111
112impl<'o> From<Option<&'o str>> for OptAttr<'o> {
113    fn from(s: Option<&'o str>) -> Self {
114        Self(s)
115    }
116}
117
118impl PartialEq<Option<&str>> for OptAttr<'_> {
119    fn eq(&self, other: &Option<&str>) -> bool {
120        self.0 == *other
121    }
122}
123
124impl PartialEq<&str> for OptAttr<'_> {
125    fn eq(&self, other: &&str) -> bool {
126        match self.0 {
127            None => false,
128            Some(s) => (*other) == s,
129        }
130    }
131}
132
133struct JwkMapVisitor<'de>(PhantomData<&'de ()>);
134
135impl<'de> Visitor<'de> for JwkMapVisitor<'de> {
136    type Value = JwkParts<'de>;
137
138    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
139        formatter.write_str("an object representing a JWK")
140    }
141
142    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
143    where
144        M: MapAccess<'de>,
145    {
146        let mut kty = None;
147        let mut kid = None;
148        let mut alg = None;
149        let mut crv = None;
150        let mut x = None;
151        let mut y = None;
152        let mut d = None;
153        let mut k = None;
154        let mut key_ops = None;
155
156        while let Some(key) = access.next_key::<&str>()? {
157            match key {
158                "kty" => kty = Some(access.next_value()?),
159                "kid" => kid = Some(access.next_value()?),
160                "alg" => alg = Some(access.next_value()?),
161                "crv" => crv = Some(access.next_value()?),
162                "x" => x = Some(access.next_value()?),
163                "y" => y = Some(access.next_value()?),
164                "d" => d = Some(access.next_value()?),
165                "k" => k = Some(access.next_value()?),
166                "use" => {
167                    let ops = match access.next_value()? {
168                        "enc" => {
169                            KeyOps::Encrypt | KeyOps::Decrypt | KeyOps::WrapKey | KeyOps::UnwrapKey
170                        }
171                        "sig" => KeyOps::Sign | KeyOps::Verify,
172                        _ => KeyOpsSet::new(),
173                    };
174                    if !ops.is_empty() {
175                        key_ops = Some(key_ops.unwrap_or_default() | ops);
176                    }
177                }
178                "key_ops" => key_ops = Some(access.next_value()?),
179                _ => (),
180            }
181        }
182
183        if let Some(kty) = kty {
184            Ok(JwkParts {
185                kty,
186                kid: kid.into(),
187                alg: alg.into(),
188                crv: crv.into(),
189                x: x.into(),
190                y: y.into(),
191                d: d.into(),
192                k: k.into(),
193                key_ops,
194            })
195        } else {
196            Err(serde::de::Error::missing_field("kty"))
197        }
198    }
199}
200
201impl<'de> Deserialize<'de> for JwkParts<'de> {
202    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203    where
204        D: Deserializer<'de>,
205    {
206        deserializer.deserialize_map(JwkMapVisitor(PhantomData))
207    }
208}
209
210impl Serialize for JwkParts<'_> {
211    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
212    where
213        S: Serializer,
214    {
215        let mut map = serializer.serialize_map(None)?;
216        if let Some(alg) = self.alg.as_opt_str() {
217            map.serialize_entry("alg", alg)?;
218        }
219        if let Some(crv) = self.crv.as_opt_str() {
220            map.serialize_entry("crv", crv)?;
221        }
222        if let Some(d) = self.d.as_opt_str() {
223            map.serialize_entry("d", d)?;
224        }
225        if let Some(k) = self.k.as_opt_str() {
226            map.serialize_entry("k", k)?;
227        }
228        if let Some(kid) = self.kid.as_opt_str() {
229            map.serialize_entry("kid", kid)?;
230        }
231        map.serialize_entry("kty", self.kty)?;
232        if let Some(x) = self.x.as_opt_str() {
233            map.serialize_entry("x", x)?;
234        }
235        if let Some(y) = self.y.as_opt_str() {
236            map.serialize_entry("y", y)?;
237        }
238        if let Some(ops) = self.key_ops {
239            map.serialize_entry("key_ops", &ops)?;
240        }
241        map.end()
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn parse_sample_okp() {
251        let jwk = r#"{
252            "kty": "OKP",
253            "crv": "Ed25519",
254            "x": "11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo",
255            "d": "nWGxne_9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A",
256            "key_ops": ["sign", "verify"],
257            "kid": "FdFYFzERwC2uCBB46pZQi4GG85LujR8obt-KWRBICVQ"
258        }"#;
259        let parts = JwkParts::try_from_str(jwk).unwrap();
260        assert_eq!(parts.kty, "OKP");
261        assert_eq!(
262            parts.kid,
263            Some("FdFYFzERwC2uCBB46pZQi4GG85LujR8obt-KWRBICVQ")
264        );
265        assert_eq!(parts.crv, Some("Ed25519"));
266        assert_eq!(parts.x, Some("11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo"));
267        assert_eq!(parts.y, None);
268        assert_eq!(parts.d, Some("nWGxne_9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A"));
269        assert_eq!(parts.k, None);
270        assert_eq!(parts.key_ops, Some(KeyOps::Sign | KeyOps::Verify));
271
272        // check serialization
273        let mut buf = [0u8; 512];
274        let len = serde_json_core::to_slice(&parts, &mut buf[..]).unwrap();
275        let parts_2 = JwkParts::from_slice(&buf[..len]).unwrap();
276        assert_eq!(parts_2, parts);
277    }
278}