1use std::{
2 collections::HashMap,
3 fs::File,
4 io::{Read, Write},
5 path::{Path, PathBuf},
6 str::FromStr,
7};
8
9use base64ct::Encoding;
10use ecdsa::SigningKey;
11use elliptic_curve::{rand_core::OsRng, JwkEcKey};
12use p521::ecdsa;
13use serde::Serialize;
14use sha2::Digest;
15
16#[derive(Debug)]
17pub struct TangyLib {
18 keys: std::collections::HashMap<String, MyJwkEcKey>,
19 signing_keys: Vec<MyJwkEcKey>,
20 default_adv: String,
21}
22
23#[derive(PartialEq)]
24pub enum KeySource<'a> {
25 LocalDir(&'a Path),
26 Vector(&'a Vec<&'a str>),
27}
28
29impl TangyLib {
30 pub fn init(source: KeySource) -> Result<Self, std::io::Error> {
31 let mut loaded_keys = match source {
32 KeySource::LocalDir(dir) => load_keys_from_dir(dir)?,
33 KeySource::Vector(keys) => load_keys_from_vec(keys)?,
34 };
35
36 let ecmr_exists = loaded_keys
37 .iter()
38 .any(|(_, v)| v.alg.is_some() && v.alg.as_ref().unwrap() == "ECMR");
39 let es512_exists = loaded_keys
40 .iter()
41 .any(|(_, v)| v.alg.is_some() && v.alg.as_ref().unwrap() == "ES512");
42
43 if (!ecmr_exists && es512_exists) || (ecmr_exists && !es512_exists) {
44 return Err(std::io::Error::new(
45 std::io::ErrorKind::NotFound,
46 "Key loading file error",
47 ));
48 }
49
50 if !ecmr_exists && !es512_exists {
51 match source {
52 KeySource::LocalDir(dir) => {
53 let keys = create_new_key_set();
54 for k in keys.iter() {
55 let jwk: MyJwkEcKey = serde_json::from_str(k).map_err(|e| {
56 std::io::Error::new(
57 std::io::ErrorKind::InvalidData,
58 format!("Unable to create new JWK: {e}"),
59 )
60 })?;
61 let thumbprint = jwk.thumbprint();
62 if let Ok(mut file) =
63 std::fs::File::create(dir.join(format!("{}.jwk", thumbprint)))
64 {
65 file.write_all(k.as_bytes())?;
66 set_file_permissions(&file)?;
67 }
68 loaded_keys.insert(thumbprint, jwk);
69 }
70 }
71 KeySource::Vector(_) => {
72 return Err(std::io::Error::new(
73 std::io::ErrorKind::NotFound,
74 "ES512 and ECMR keys not present in input vector",
75 ));
76 }
77 }
78 }
79
80 let signing_keys: Vec<MyJwkEcKey> = loaded_keys
84 .iter()
85 .filter_map(|(_, v)| {
86 if let Some(alg) = v.alg.as_ref() {
87 if alg == "ES512" {
88 return Some(v.clone());
89 }
90 }
91 None
92 })
93 .collect();
94
95 if signing_keys.is_empty() {
96 return Err(std::io::Error::new(
97 std::io::ErrorKind::NotFound,
98 "Signing key not found",
99 ));
100 }
101
102 let mut tangy = Self {
103 keys: loaded_keys,
104 signing_keys,
105 default_adv: "".into(),
106 };
107
108 tangy.default_adv = tangy.adv_internal(None)?;
109
110 Ok(tangy)
111 }
112
113 pub fn adv(&self, skid: Option<&str>) -> Result<String, std::io::Error> {
114 if skid.is_none() {
115 return Ok(self.default_adv.to_owned());
116 }
117 self.adv_internal(skid)
118 }
119
120 pub fn adv_internal(&self, skid: Option<&str>) -> Result<String, std::io::Error> {
121 #[derive(serde::Serialize)]
122 struct Siguature {
123 protected: String,
124 signature: String,
125 }
126 #[derive(serde::Serialize)]
127 struct Advertise {
128 payload: String,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 protected: Option<String>,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 signature: Option<String>,
133 #[serde(skip_serializing_if = "Option::is_none")]
134 signatures: Option<Vec<Siguature>>,
135 }
136
137 #[derive(serde::Serialize)]
138 struct Payload {
139 keys: Vec<MyJwkEcKey>,
140 }
141
142 let keys: Vec<&MyJwkEcKey> = self.keys.values().collect();
143
144 let signing_keys = if let Some(kid) = skid {
145 let key = self.keys.get(kid);
146 if key.is_none() {
147 return Err(std::io::Error::new(
148 std::io::ErrorKind::NotFound,
149 format!("Requested signing key {} not found", kid),
150 ));
151 }
152 if key.unwrap().key_ops.is_none() {
153 return Err(std::io::Error::new(
154 std::io::ErrorKind::NotFound,
155 format!("Requested signing key {} cannot be used for signing", kid),
156 ));
157 }
158 if !key
159 .as_ref()
160 .unwrap()
161 .key_ops
162 .as_ref()
163 .unwrap()
164 .contains(&"sign".to_string())
165 {
166 return Err(std::io::Error::new(
167 std::io::ErrorKind::NotFound,
168 format!("Requested signing key {} cannot be used for signing", kid),
169 ));
170 }
171 vec![key.unwrap().clone()]
172 } else {
173 self.signing_keys.to_vec()
174 };
175
176 let payload = base64ct::Base64Url::encode_string(
177 serde_json::to_string(&Payload {
178 keys: keys
179 .iter()
180 .map(|v| {
181 let mut k = v.to_public_key();
182 let mut ops = k.key_ops.take();
183 if let Some(ops) = &mut ops {
184 ops.retain(|v| *v != "sign");
185 }
186 k.key_ops = ops;
187 k
188 })
189 .collect(),
190 })
191 .map_err(|e| {
192 std::io::Error::new(
193 std::io::ErrorKind::InvalidData,
194 format!("Unable to encode adv message: {e}"),
195 )
196 })?
197 .as_bytes(),
198 );
199
200 let protected = base64ct::Base64Url::encode_string(
202 r#"{"alg":"ES512","cty":"jwk-set+json"}"#.as_bytes(),
203 );
204
205 let mut signing_keys: Vec<SigningKey> = signing_keys
206 .iter()
207 .map(|k| {
208 SigningKey::from_bytes(
209 &k.to_jwk_ec_key(false)
210 .to_secret_key::<p521::NistP521>()
211 .unwrap()
212 .to_bytes(),
213 )
214 .unwrap()
215 })
216 .collect();
217
218 let to_sign = format!("{}.{}", &protected, &payload);
221
222 let signatures: Vec<_> = signing_keys
223 .iter_mut()
224 .map(|k| ecdsa::signature::SignerMut::sign(k, to_sign.as_bytes()))
225 .collect();
226
227 let mut buf = [0; 1024];
228
229 if signatures.len() == 1 {
230 Ok(serde_json::to_string(&Advertise {
231 payload,
232 protected: Some(protected),
233 signature: Some(
234 base64ct::Base64Url::encode(&signatures[0].to_bytes(), &mut buf)
235 .map_err(|e| {
236 std::io::Error::new(
237 std::io::ErrorKind::InvalidData,
238 format!("Unable to encode string to base64: {e}"),
239 )
240 })?
241 .to_string(),
242 ),
243 signatures: None,
244 })
245 .unwrap())
246 } else {
247 Ok(serde_json::to_string(&Advertise {
248 payload,
249 protected: None,
250 signature: None,
251 signatures: Some(
252 signatures
253 .iter()
254 .map(|s| Siguature {
255 protected: protected.to_owned(),
256 signature: base64ct::Base64Url::encode(&s.to_bytes(), &mut buf)
257 .unwrap()
258 .to_string(),
259 })
260 .collect(),
261 ),
262 })
263 .unwrap())
264 }
265 }
266
267 pub fn rec(&self, kid: &str, request: &str) -> Result<String, std::io::Error> {
268 let key = self.keys.iter().find_map(|(k, v)| {
269 if k == kid {
270 return Some(v);
271 }
272 None
273 });
274
275 if key.is_none() {
276 return Err(std::io::Error::new(
277 std::io::ErrorKind::NotFound,
278 "Requested key not found".to_string(),
279 ));
280 }
281
282 let request_key: MyJwkEcKey = serde_json::from_str(request).unwrap();
283 let request_key = request_key
284 .to_jwk_ec_key(true)
285 .to_public_key::<p521::NistP521>()
286 .unwrap();
287
288 let p = diffie_hellman_public_key(
289 &key.as_ref()
290 .unwrap()
291 .to_jwk_ec_key(false)
292 .to_secret_key::<p521::NistP521>()
293 .unwrap()
294 .to_nonzero_scalar(),
295 request_key.as_affine(),
296 );
297
298 Ok(serde_json::to_string(&p).unwrap())
299 }
300}
301
302pub fn create_new_key_set() -> Vec<String> {
303 let es512_jwk = create_new_jwk("ES512", &["sign", "verify"]);
304 let ecmr_jwk = create_new_jwk("ECMR", &["deriveKey"]);
305 vec![es512_jwk, ecmr_jwk]
306}
307
308#[cfg(target_os = "linux")]
309fn set_file_permissions(file: &File) -> Result<(), std::io::Error> {
310 use std::os::unix::fs::PermissionsExt;
311 let mut perms = file.metadata()?.permissions();
312 perms.set_mode(0o440); file.set_permissions(perms)
314}
315
316#[cfg(not(target_os = "linux"))]
317fn set_file_permissions(file: &File) -> Result<(), std::io::Error> {
318 Ok(())
319}
320
321fn load_keys_from_dir(db_path: &Path) -> Result<HashMap<String, MyJwkEcKey>, std::io::Error> {
322 if !db_path.exists() {
323 return Err(std::io::Error::new(
324 std::io::ErrorKind::NotFound,
325 format!(
326 "Key database \"{}\" does not exist",
327 db_path.to_string_lossy()
328 ),
329 ));
330 }
331
332 if !db_path.is_dir() {
333 return Err(std::io::Error::new(
334 std::io::ErrorKind::Unsupported,
335 format!(
336 "Key database \"{}\" is not a directory",
337 db_path.to_string_lossy()
338 ),
339 ));
340 }
341
342 let jwk_files: Vec<PathBuf> = db_path
343 .read_dir()?
344 .filter_map(|f| f.ok())
345 .map(|e| e.path())
346 .filter(|f| f.extension() == Some(std::ffi::OsStr::new("jwk")))
347 .collect();
348
349 let keys: Vec<String> = jwk_files
350 .iter()
351 .filter_map(|j| {
352 let mut file_content = String::new();
353 match std::fs::File::open(j) {
354 Ok(mut f) => {
355 if f.read_to_string(&mut file_content).is_err() {
356 return None;
357 }
358 }
359 Err(_) => return None,
360 };
361 Some(file_content)
362 })
363 .collect();
364
365 load_keys_from_vec(&keys)
366}
367
368fn load_keys_from_vec<T: AsRef<str>>(
369 keys: &[T],
370) -> Result<HashMap<String, MyJwkEcKey>, std::io::Error> {
371 Ok(keys
372 .iter()
373 .filter_map(|key| {
374 let jwk: MyJwkEcKey = if let Ok(jwk) = serde_json::from_str(key.as_ref()) {
375 jwk
376 } else {
377 return None;
378 };
379
380 let thumbprint = thumprint(&jwk.crv, &jwk.kty, &jwk.x, &jwk.y);
381
382 Some((thumbprint, jwk))
383 })
384 .collect())
385}
386
387fn create_new_jwk(alg: &str, key_ops: &[&str]) -> String {
409 let priv_key = elliptic_curve::SecretKey::<p521::NistP521>::random(&mut OsRng);
410 let jwk = priv_key.to_jwk();
411 let encoded_point = jwk.to_encoded_point::<p521::NistP521>().unwrap();
412 let mut buf = [0; 1000];
413
414 let x = base64ct::Base64Url::encode(encoded_point.x().unwrap(), &mut buf)
415 .unwrap()
416 .to_string();
417
418 let y = base64ct::Base64Url::encode(encoded_point.y().unwrap(), &mut buf)
419 .unwrap()
420 .to_string();
421
422 serde_json::to_string(&MyJwkEcKey {
423 alg: Some(alg.into()),
424 kty: "EC".into(),
425 crv: jwk.crv().into(),
426 x: x.to_owned(),
427 y: y.to_owned(),
428 d: Some(
429 base64ct::Base64Url::encode(priv_key.to_bytes().as_slice(), &mut buf)
430 .unwrap()
431 .to_string(),
432 ),
433 key_ops: Some(key_ops.iter().map(|k| k.to_string()).collect()),
434 use_: None,
435 kid: None,
436 x5u: None,
437 x5c: None,
438 x5t: None,
439 x5t_s256: None,
440 })
441 .unwrap()
442}
443
444fn thumprint(crv: &str, kty: &str, x: &str, y: &str) -> String {
445 #[derive(Serialize)]
446 struct Required {
447 crv: String,
448 kty: String,
449 x: String,
450 y: String,
451 }
452
453 let required_fields = Required {
454 crv: crv.to_owned(),
455 kty: kty.to_owned(),
456 x: x.to_owned(),
457 y: y.to_owned(),
458 };
459
460 let mut hasher = sha2::Sha256::new();
461 hasher.update(serde_json::to_string(&required_fields).unwrap().as_bytes());
462 base64ct::Base64UrlUnpadded::encode_string(&hasher.finalize())
463}
464
465fn diffie_hellman_public_key(
466 secret_key: &elliptic_curve::NonZeroScalar<p521::NistP521>,
467 public_key: &elliptic_curve::AffinePoint<p521::NistP521>,
468) -> MyJwkEcKey {
469 let public_point = elliptic_curve::ProjectivePoint::<p521::NistP521>::from(*public_key);
470 let secret_point = (public_point * secret_key.as_ref()).to_affine();
471 let generated_public_key = p521::PublicKey::from_affine(secret_point).unwrap();
472
473 let mut formatted_public_key: MyJwkEcKey =
474 serde_json::from_str(generated_public_key.to_jwk_string().as_str()).unwrap();
475
476 formatted_public_key.alg = Some("ECMR".into());
477 formatted_public_key.crv = "P-521".into();
478 formatted_public_key.kty = "EC".into();
479 formatted_public_key.key_ops = Some(vec!["deriveKey".into()]);
480
481 formatted_public_key
482}
483
484#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
485pub struct MyJwkEcKey {
486 pub crv: String,
490
491 pub x: String,
495
496 pub y: String,
500
501 #[serde(skip_serializing_if = "Option::is_none")]
510 pub d: Option<String>,
511
512 pub kty: String,
519
520 #[serde(skip_serializing_if = "Option::is_none", rename = "use")]
525 pub use_: Option<String>,
526
527 #[serde(skip_serializing_if = "Option::is_none")]
532 pub key_ops: Option<Vec<String>>,
533
534 #[serde(skip_serializing_if = "Option::is_none")]
539 pub alg: Option<String>,
540
541 #[serde(skip_serializing_if = "Option::is_none")]
546 pub kid: Option<String>,
547
548 #[serde(skip_serializing_if = "Option::is_none")]
553 pub x5u: Option<String>,
554
555 #[serde(skip_serializing_if = "Option::is_none")]
560 pub x5c: Option<String>,
561
562 #[serde(skip_serializing_if = "Option::is_none")]
567 pub x5t: Option<String>,
568
569 #[serde(skip_serializing_if = "Option::is_none")]
574 pub x5t_s256: Option<String>,
575}
576
577impl MyJwkEcKey {
578 pub fn to_jwk_ec_key(&self, public: bool) -> JwkEcKey {
579 let mut content: HashMap<String, String> = HashMap::new();
580 content.insert("crv".into(), self.crv.to_owned());
581 content.insert("kty".into(), self.kty.to_owned());
582 content.insert("x".into(), self.x.to_owned());
583 content.insert("y".into(), self.y.to_owned());
584 if self.d.is_some() && !public {
585 content.insert("d".into(), self.d.as_ref().unwrap().to_owned());
586 }
587 JwkEcKey::from_str(&serde_json::to_string(&content).unwrap()).unwrap()
588 }
589
590 pub fn to_public_key(&self) -> MyJwkEcKey {
591 let mut ret = self.clone();
592 ret.d = None;
593 ret
594 }
595
596 pub fn thumbprint(&self) -> String {
597 #[derive(Serialize)]
600 struct Required {
601 crv: String,
602 kty: String,
603 x: String,
604 y: String,
605 }
606
607 let required_fields = Required {
608 crv: self.crv.to_owned(),
609 kty: self.kty.to_owned(),
610 x: self.x.to_owned(),
611 y: self.y.to_owned(),
612 };
613
614 let mut hasher = sha2::Sha256::new();
615 hasher.update(serde_json::to_string(&required_fields).unwrap().as_bytes());
616 base64ct::Base64UrlUnpadded::encode_string(&hasher.finalize())
617 }
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623 use serde::Deserialize;
624
625 const JWK_ES512: &str = r#"
626 {
627 "kty": "EC",
628 "alg": "ES512",
629 "crv": "P-521",
630 "x": "AX5mUTAH1qr3YSSwuMV_HV0yupJhMIAqwly710a7qLbXR6up3flnaPsJbaSVATrIF6QcXc9PPyFW1IQHmDOWGSPj",
631 "y": "ADT1K8Q-O1Q5lyU3StXnPMQwgnYWS8hnTRGjjcFssitZy_tUWSuhUPFhzaUJKhXRNbcyELeDX-kPCMbBKX1vb8Lq",
632 "d": "AbDO5xCtQHUbHld-Fq61sSCvyjr9EpNj3_sklNmo54xmKeYu_cW_s7fzQxm6SsqFwrTmiiFz2OaD1ODsXI-DdoKt",
633 "key_ops": ["sign","verify"]
634 }
635 "#;
636
637 const JWK_ES512_THUMBPRINT: &str = "tpUdnaei02Z6bSS3_rKEU0BDPl8tyZFy16CKCTWNlbA";
638
639 const JWK_ECMR: &str = r#"
640 {
641 "kty": "EC",
642 "alg": "ECMR",
643 "crv": "P-521",
644 "x": "ASa1DOpfB9-Qe1zkbG6HAZ_DC2FNUBeR6e3kgLgHF8xC8JZM1EsiGjkvTRk0paH_Oat8OSGSRPD0-PsXFAvNuXCd",
645 "y": "AaO_WH8pzC__37gCuCJdgtIbO6IK4XLfyjAjuJovvfksoMigvFwpyLKwWhIfE8lQqPR7CMxG2LRLXJIubFjSDMDH",
646 "d": "AQTm4JamDPZufHlRCC12Ssjh6xTwu630neCLr7EUtUuZoFHk9zga-kzwaGajH1MQb8ffc3CeV-7InHKmR8HvytTE",
647 "key_ops":["deriveKey"]
648 }
649 "#;
650
651 const JWK_ECMR_THUMBPRINT: &str = "UFgqx9-PLx_h6h4hd6sysNHMC6cDyjBQOYZHFvObLbo";
652
653 #[test]
654 fn source_local_dir() {
655 let tmp_dir = tempdir::TempDir::new("local_dir_test").unwrap();
656 let t = TangyLib::init(KeySource::LocalDir(&tmp_dir.path()));
657 assert!(t.is_ok());
658 }
659
660 #[test]
661 fn source_vector() {
662 let v = vec![JWK_ES512, JWK_ECMR];
663 let t = TangyLib::init(KeySource::Vector(&v));
664 assert!(t.is_ok());
665 }
666
667 #[test]
668 fn adv() {
669 let v = vec![JWK_ES512, JWK_ECMR];
670 let mut t = TangyLib::init(KeySource::Vector(&v)).unwrap();
671 let advertisment = t.adv(None).unwrap();
672
673 #[derive(Deserialize)]
674 struct Adv {
675 payload: String,
676 }
677
678 #[derive(Deserialize)]
679 struct Key {
680 kty: String,
681 crv: String,
682 x: String,
683 y: String,
684 }
685
686 #[derive(Deserialize)]
687 struct Payload {
688 keys: Vec<Key>,
689 }
690
691 let actual_adv: Adv = serde_json::from_str(&advertisment).unwrap();
692 let payload_json = base64ct::Base64Unpadded::decode_vec(&actual_adv.payload).unwrap();
693 let payload: Payload = serde_json::from_slice(&payload_json).unwrap();
694 assert_eq!(payload.keys.len(), 2);
695 }
696
697 #[test]
698 fn adv_skid() {
699 let v = vec![JWK_ES512, JWK_ECMR];
700 let mut t = TangyLib::init(KeySource::Vector(&v)).unwrap();
701 let advertisment = t.adv(Some(JWK_ES512_THUMBPRINT.into())).unwrap();
702
703 #[derive(Deserialize)]
704 struct Adv {
705 payload: String,
706 }
707
708 #[derive(Deserialize)]
709 struct Key {
710 kty: String,
711 crv: String,
712 x: String,
713 y: String,
714 }
715
716 #[derive(Deserialize)]
717 struct Payload {
718 keys: Vec<Key>,
719 }
720
721 let actual_adv: Adv = serde_json::from_str(&advertisment).unwrap();
722 let payload_json = base64ct::Base64Unpadded::decode_vec(&actual_adv.payload).unwrap();
723 let payload: Payload = serde_json::from_slice(&payload_json).unwrap();
724 assert_eq!(payload.keys.len(), 2);
725 }
726}