1use core::fmt::Write;
2
3use serde::{
4 ser::{Serialize, SerializeMap},
5 Serializer,
6};
7
8use super::{ops::KeyOpsSet, ToJwk};
9use crate::{
10 alg::KeyAlg,
11 buffer::{WriteBuffer, Writer},
12 error::Error,
13};
14
15fn write_hex_buffer(mut buffer: impl Write, value: &[u8]) -> Result<(), Error> {
16 write!(
17 buffer,
18 "{}",
19 base64::display::Base64Display::new(
20 value,
21 &base64::engine::general_purpose::URL_SAFE_NO_PAD
22 )
23 )
24 .map_err(|_| err_msg!(Unexpected, "Error writing to JWK buffer"))
25}
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum JwkEncoderMode {
30 PublicKey,
32 SecretKey,
34 Thumbprint,
36}
37
38pub trait JwkEncoder {
40 fn alg(&self) -> Option<KeyAlg>;
42
43 fn add_str(&mut self, key: &str, value: &str) -> Result<(), Error>;
45
46 fn add_as_base64(&mut self, key: &str, value: &[u8]) -> Result<(), Error>;
48
49 fn mode(&self) -> JwkEncoderMode;
51
52 fn is_public(&self) -> bool {
54 matches!(self.mode(), JwkEncoderMode::PublicKey)
55 }
56
57 fn is_secret(&self) -> bool {
59 matches!(self.mode(), JwkEncoderMode::SecretKey)
60 }
61
62 fn is_thumbprint(&self) -> bool {
64 matches!(self.mode(), JwkEncoderMode::Thumbprint)
65 }
66}
67
68#[derive(Debug)]
70pub struct JwkBufferEncoder<'b, B: WriteBuffer> {
71 mode: JwkEncoderMode,
72 buffer: &'b mut B,
73 empty: bool,
74 alg: Option<KeyAlg>,
75 key_ops: Option<KeyOpsSet>,
76 kid: Option<&'b str>,
77}
78
79impl<'b, B: WriteBuffer> JwkBufferEncoder<'b, B> {
80 pub fn new(buffer: &'b mut B, mode: JwkEncoderMode) -> Self {
82 Self {
83 mode,
84 buffer,
85 empty: true,
86 alg: None,
87 key_ops: None,
88 kid: None,
89 }
90 }
91
92 fn start_attr(&mut self, key: &str) -> Result<(), Error> {
93 let buffer = &mut *self.buffer;
94 if self.empty {
95 buffer.buffer_write(b"{\"")?;
96 self.empty = false;
97 } else {
98 buffer.buffer_write(b",\"")?;
99 }
100 buffer.buffer_write(key.as_bytes())?;
101 buffer.buffer_write(b"\":")?;
102 Ok(())
103 }
104
105 pub fn alg(self, alg: Option<KeyAlg>) -> Self {
107 Self { alg, ..self }
108 }
109
110 pub fn key_ops(self, key_ops: Option<KeyOpsSet>) -> Self {
112 Self { key_ops, ..self }
113 }
114
115 pub fn kid(self, kid: Option<&'b str>) -> Self {
117 Self { kid, ..self }
118 }
119
120 pub fn finalize(mut self) -> Result<(), Error> {
122 if let Some(ops) = self.key_ops {
123 self.start_attr("key_ops")?;
124 let buffer = &mut *self.buffer;
125 for (idx, op) in ops.into_iter().enumerate() {
126 if idx > 0 {
127 buffer.buffer_write(b",\"")?;
128 } else {
129 buffer.buffer_write(b"\"")?;
130 }
131 buffer.buffer_write(op.as_str().as_bytes())?;
132 buffer.buffer_write(b"\"")?;
133 }
134 buffer.buffer_write(b"]")?;
135 }
136 if let Some(kid) = self.kid {
137 self.add_str("kid", kid)?;
138 }
139 if !self.empty {
140 self.buffer.buffer_write(b"}")?;
141 }
142 Ok(())
143 }
144}
145
146impl<B: WriteBuffer> JwkEncoder for JwkBufferEncoder<'_, B> {
147 #[inline]
148 fn alg(&self) -> Option<KeyAlg> {
149 self.alg
150 }
151
152 fn add_str(&mut self, key: &str, value: &str) -> Result<(), Error> {
153 self.start_attr(key)?;
154 let buffer = &mut *self.buffer;
155 buffer.buffer_write(b"\"")?;
156 buffer.buffer_write(value.as_bytes())?;
157 buffer.buffer_write(b"\"")?;
158 Ok(())
159 }
160
161 fn add_as_base64(&mut self, key: &str, value: &[u8]) -> Result<(), Error> {
162 self.start_attr(key)?;
163 let buffer = &mut *self.buffer;
164 buffer.buffer_write(b"\"")?;
165 write_hex_buffer(Writer::from_buffer(&mut *buffer), value)?;
166 buffer.buffer_write(b"\"")?;
167 Ok(())
168 }
169
170 #[inline]
171 fn mode(&self) -> JwkEncoderMode {
172 self.mode
173 }
174}
175
176#[derive(Debug)]
178pub struct JwkSerialize<'s, K: ToJwk> {
179 mode: JwkEncoderMode,
180 key: &'s K,
181 alg: Option<KeyAlg>,
182 key_ops: Option<KeyOpsSet>,
183 kid: Option<&'s str>,
184}
185
186impl<'s, K: ToJwk> JwkSerialize<'s, K> {
187 pub fn new(key: &'s K, mode: JwkEncoderMode) -> Self {
189 Self {
190 alg: None,
191 mode,
192 key,
193 key_ops: None,
194 kid: None,
195 }
196 }
197
198 pub fn as_public(key: &'s K) -> Self {
200 Self {
201 mode: JwkEncoderMode::PublicKey,
202 key,
203 alg: None,
204 key_ops: None,
205 kid: None,
206 }
207 }
208
209 pub fn as_secret(key: &'s K) -> Self {
211 Self {
212 mode: JwkEncoderMode::SecretKey,
213 key,
214 alg: None,
215 key_ops: None,
216 kid: None,
217 }
218 }
219
220 pub fn as_thumbprint(key: &'s K) -> Self {
222 Self {
223 mode: JwkEncoderMode::Thumbprint,
224 key,
225 alg: None,
226 key_ops: None,
227 kid: None,
228 }
229 }
230
231 pub fn alg(self, alg: Option<KeyAlg>) -> Self {
233 Self { alg, ..self }
234 }
235
236 pub fn key_ops(self, key_ops: Option<KeyOpsSet>) -> Self {
238 Self { key_ops, ..self }
239 }
240
241 pub fn kid(self, kid: Option<&'s str>) -> Self {
243 Self { kid, ..self }
244 }
245}
246
247impl<K: ToJwk> Serialize for JwkSerialize<'_, K> {
248 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
249 where
250 S: Serializer,
251 {
252 struct Enc<'m, M> {
253 alg: Option<KeyAlg>,
254 mode: JwkEncoderMode,
255 map: &'m mut M,
256 }
257
258 impl<M: SerializeMap> JwkEncoder for Enc<'_, M> {
259 fn alg(&self) -> Option<KeyAlg> {
260 self.alg
261 }
262
263 fn add_str(&mut self, key: &str, value: &str) -> Result<(), Error> {
264 self.map
265 .serialize_entry(key, value)
266 .map_err(|_| err_msg!(Unexpected, "Error serializing JWK"))
267 }
268
269 fn add_as_base64(&mut self, key: &str, value: &[u8]) -> Result<(), Error> {
270 let mut buf = [0u8; 256];
273 let mut w = Writer::from_slice(&mut buf);
274 write_hex_buffer(&mut w, value)?;
275 self.map
276 .serialize_entry(key, core::str::from_utf8(w.as_ref()).unwrap())
277 .map_err(|_| err_msg!(Unexpected, "Error serializing JWK"))
278 }
279
280 fn mode(&self) -> JwkEncoderMode {
281 self.mode
282 }
283 }
284
285 let mut map = serializer.serialize_map(None)?;
286 let mut enc = Enc {
287 alg: self.alg,
288 mode: self.mode,
289 map: &mut map,
290 };
291 self.key
292 .encode_jwk(&mut enc)
293 .map_err(|err| <S::Error as serde::ser::Error>::custom(err.message()))?;
294 if let Some(ops) = self.key_ops {
295 map.serialize_entry("key_ops", &ops)?;
296 }
297 if let Some(kid) = self.kid {
298 map.serialize_entry("kid", kid)?;
299 }
300 map.end()
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 #[cfg(feature = "ed25519")]
307 #[test]
308 fn serialize_jwk() {
309 use super::JwkSerialize;
310 use crate::{
311 alg::ed25519::Ed25519KeyPair,
312 jwk::{JwkParts, KeyOps},
313 repr::KeySecretBytes,
314 };
315
316 let kp = Ed25519KeyPair::from_secret_bytes(&hex!(
317 "9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60"
318 ))
319 .unwrap();
320 let mut buf = [0u8; 512];
321 let len = serde_json_core::to_slice(
322 &JwkSerialize::as_secret(&kp)
323 .kid(Some("FdFYFzERwC2uCBB46pZQi4GG85LujR8obt-KWRBICVQ"))
324 .key_ops(Some(KeyOps::Sign | KeyOps::Verify)),
325 &mut buf,
326 )
327 .unwrap();
328 let parts = JwkParts::from_slice(&buf[..len]).unwrap();
329 assert_eq!(parts.kty, "OKP");
330 assert_eq!(
331 parts.kid,
332 Some("FdFYFzERwC2uCBB46pZQi4GG85LujR8obt-KWRBICVQ")
333 );
334 assert_eq!(parts.crv, Some("Ed25519"));
335 assert_eq!(parts.x, Some("11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo"));
336 assert_eq!(parts.y, None);
337 assert_eq!(parts.d, Some("nWGxne_9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A"));
338 assert_eq!(parts.k, None);
339 assert_eq!(parts.key_ops, Some(KeyOps::Sign | KeyOps::Verify));
340 }
341}