1use crate::primitives::hash::sha256_hmac;
8use crate::primitives::private_key::PrivateKey;
9use crate::primitives::public_key::PublicKey;
10use crate::primitives::symmetric_key::SymmetricKey;
11use crate::wallet::error::WalletError;
12use crate::wallet::types::{anyone_pubkey, Counterparty, CounterpartyType, Protocol};
13
14pub struct KeyDeriver {
19 root_key: PrivateKey,
20}
21
22impl KeyDeriver {
23 pub fn new(private_key: PrivateKey) -> Self {
25 KeyDeriver {
26 root_key: private_key,
27 }
28 }
29
30 pub fn new_anyone() -> Self {
32 KeyDeriver {
33 root_key: crate::wallet::types::anyone_private_key(),
34 }
35 }
36
37 pub fn root_key(&self) -> &PrivateKey {
39 &self.root_key
40 }
41
42 pub fn identity_key(&self) -> PublicKey {
44 self.root_key.to_public_key()
45 }
46
47 pub fn identity_key_hex(&self) -> String {
49 self.identity_key().to_der_hex()
50 }
51
52 pub fn derive_private_key(
54 &self,
55 protocol: &Protocol,
56 key_id: &str,
57 counterparty: &Counterparty,
58 ) -> Result<PrivateKey, WalletError> {
59 let counterparty_pubkey = self.normalize_counterparty(counterparty)?;
60 let invoice_number = Self::compute_invoice_number(protocol, key_id)?;
61 let child = self
62 .root_key
63 .derive_child(&counterparty_pubkey, &invoice_number)?;
64 Ok(child)
65 }
66
67 pub fn derive_public_key(
72 &self,
73 protocol: &Protocol,
74 key_id: &str,
75 counterparty: &Counterparty,
76 for_self: bool,
77 ) -> Result<PublicKey, WalletError> {
78 let counterparty_pubkey = self.normalize_counterparty(counterparty)?;
79 let invoice_number = Self::compute_invoice_number(protocol, key_id)?;
80
81 if for_self {
82 let priv_child = self
83 .root_key
84 .derive_child(&counterparty_pubkey, &invoice_number)?;
85 Ok(priv_child.to_public_key())
86 } else {
87 let pub_child = counterparty_pubkey.derive_child(&self.root_key, &invoice_number)?;
88 Ok(pub_child)
89 }
90 }
91
92 pub fn derive_symmetric_key(
97 &self,
98 protocol: &Protocol,
99 key_id: &str,
100 counterparty: &Counterparty,
101 ) -> Result<SymmetricKey, WalletError> {
102 let effective_counterparty = if counterparty.counterparty_type == CounterpartyType::Anyone {
104 Counterparty {
105 counterparty_type: CounterpartyType::Other,
106 public_key: Some(anyone_pubkey()),
107 }
108 } else {
109 counterparty.clone()
110 };
111
112 let derived_pub =
113 self.derive_public_key(protocol, key_id, &effective_counterparty, false)?;
114 let derived_priv = self.derive_private_key(protocol, key_id, &effective_counterparty)?;
115
116 let shared_secret = derived_priv.derive_shared_secret(&derived_pub)?;
117 let x_bytes = shared_secret
118 .x
119 .to_array(crate::primitives::big_number::Endian::Big, Some(32));
120 let sym_key = SymmetricKey::from_bytes(&x_bytes)?;
121 Ok(sym_key)
122 }
123
124 pub fn reveal_counterparty_secret(
128 &self,
129 counterparty: &Counterparty,
130 ) -> Result<PublicKey, WalletError> {
131 if counterparty.counterparty_type == CounterpartyType::Self_ {
132 return Err(WalletError::InvalidParameter(
133 "counterparty secrets cannot be revealed for counterparty=self".to_string(),
134 ));
135 }
136
137 let counterparty_pubkey = self.normalize_counterparty(counterparty)?;
138
139 let self_pub = self.root_key.to_public_key();
141 let key_derived_by_self = self.root_key.derive_child(&self_pub, "test")?;
142 let key_derived_by_counterparty =
143 self.root_key.derive_child(&counterparty_pubkey, "test")?;
144
145 if key_derived_by_self.to_bytes() == key_derived_by_counterparty.to_bytes() {
146 return Err(WalletError::InvalidParameter(
147 "counterparty secrets cannot be revealed if counterparty key is self".to_string(),
148 ));
149 }
150
151 let shared_secret = self.root_key.derive_shared_secret(&counterparty_pubkey)?;
152 Ok(PublicKey::from_point(shared_secret))
153 }
154
155 pub fn reveal_specific_secret(
160 &self,
161 counterparty: &Counterparty,
162 protocol: &Protocol,
163 key_id: &str,
164 ) -> Result<Vec<u8>, WalletError> {
165 let counterparty_pubkey = self.normalize_counterparty(counterparty)?;
166 let shared_secret = self.root_key.derive_shared_secret(&counterparty_pubkey)?;
167 let invoice_number = Self::compute_invoice_number(protocol, key_id)?;
168 let shared_secret_compressed = shared_secret.to_der(true);
169 let hmac = sha256_hmac(&shared_secret_compressed, invoice_number.as_bytes());
170 Ok(hmac.to_vec())
171 }
172
173 fn normalize_counterparty(
175 &self,
176 counterparty: &Counterparty,
177 ) -> Result<PublicKey, WalletError> {
178 match counterparty.counterparty_type {
179 CounterpartyType::Self_ => Ok(self.root_key.to_public_key()),
180 CounterpartyType::Anyone => Ok(anyone_pubkey()),
181 CounterpartyType::Other => counterparty.public_key.clone().ok_or_else(|| {
182 WalletError::InvalidParameter(
183 "counterparty public key required for type Other".to_string(),
184 )
185 }),
186 CounterpartyType::Uninitialized => Err(WalletError::InvalidParameter(
187 "counterparty type is uninitialized".to_string(),
188 )),
189 }
190 }
191
192 fn compute_invoice_number(protocol: &Protocol, key_id: &str) -> Result<String, WalletError> {
199 if protocol.security_level > 2 {
201 return Err(WalletError::InvalidParameter(
202 "protocol security level must be 0, 1, or 2".to_string(),
203 ));
204 }
205
206 if key_id.is_empty() {
208 return Err(WalletError::InvalidParameter(
209 "key IDs must be 1 character or more".to_string(),
210 ));
211 }
212 if key_id.len() > 800 {
213 return Err(WalletError::InvalidParameter(
214 "key IDs must be 800 characters or less".to_string(),
215 ));
216 }
217
218 let protocol_name = protocol.protocol.trim().to_lowercase();
220 if protocol_name.len() < 5 {
221 return Err(WalletError::InvalidParameter(
222 "protocol names must be 5 characters or more".to_string(),
223 ));
224 }
225 if protocol_name.len() > 400 {
226 if protocol_name.starts_with("specific linkage revelation ") {
227 if protocol_name.len() > 430 {
228 return Err(WalletError::InvalidParameter(
229 "specific linkage revelation protocol names must be 430 characters or less"
230 .to_string(),
231 ));
232 }
233 } else {
234 return Err(WalletError::InvalidParameter(
235 "protocol names must be 400 characters or less".to_string(),
236 ));
237 }
238 }
239 if protocol_name.contains(" ") {
240 return Err(WalletError::InvalidParameter(
241 "protocol names cannot contain multiple consecutive spaces".to_string(),
242 ));
243 }
244 if !protocol_name
245 .chars()
246 .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == ' ')
247 {
248 return Err(WalletError::InvalidParameter(
249 "protocol names can only contain letters, numbers and spaces".to_string(),
250 ));
251 }
252 if protocol_name.ends_with(" protocol") {
253 return Err(WalletError::InvalidParameter(
254 "no need to end your protocol name with \" protocol\"".to_string(),
255 ));
256 }
257
258 Ok(format!(
259 "{}-{}-{}",
260 protocol.security_level, protocol_name, key_id
261 ))
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use crate::wallet::types::CounterpartyType;
269
270 #[test]
271 fn test_identity_key_known_vector() {
272 let priv_key = PrivateKey::from_hex("1").unwrap();
273 let kd = KeyDeriver::new(priv_key);
274 assert_eq!(
275 kd.identity_key_hex(),
276 "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
277 );
278 }
279
280 #[test]
281 fn test_anyone_deriver() {
282 let kd = KeyDeriver::new_anyone();
283 assert_eq!(
285 kd.identity_key_hex(),
286 "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
287 );
288 }
289
290 #[test]
291 fn test_compute_invoice_number_valid() {
292 let protocol = Protocol {
293 security_level: 2,
294 protocol: "hello world".to_string(),
295 };
296 let result = KeyDeriver::compute_invoice_number(&protocol, "1");
297 assert_eq!(result.unwrap(), "2-hello world-1");
298 }
299
300 #[test]
301 fn test_compute_invoice_number_security_level_too_high() {
302 let protocol = Protocol {
303 security_level: 3,
304 protocol: "hello world".to_string(),
305 };
306 let result = KeyDeriver::compute_invoice_number(&protocol, "1");
307 assert!(result.is_err());
308 }
309
310 #[test]
311 fn test_compute_invoice_number_protocol_too_short() {
312 let protocol = Protocol {
313 security_level: 0,
314 protocol: "abcd".to_string(),
315 };
316 let result = KeyDeriver::compute_invoice_number(&protocol, "1");
317 assert!(result.is_err());
318 }
319
320 #[test]
321 fn test_compute_invoice_number_protocol_too_long() {
322 let protocol = Protocol {
323 security_level: 0,
324 protocol: "a".repeat(401),
325 };
326 let result = KeyDeriver::compute_invoice_number(&protocol, "1");
327 assert!(result.is_err());
328 }
329
330 #[test]
331 fn test_compute_invoice_number_consecutive_spaces() {
332 let protocol = Protocol {
333 security_level: 0,
334 protocol: "hello world".to_string(),
335 };
336 let result = KeyDeriver::compute_invoice_number(&protocol, "1");
337 assert!(result.is_err());
338 }
339
340 #[test]
341 fn test_compute_invoice_number_ends_with_protocol() {
342 let protocol = Protocol {
343 security_level: 0,
344 protocol: "my cool protocol".to_string(),
345 };
346 let result = KeyDeriver::compute_invoice_number(&protocol, "1");
347 assert!(result.is_err());
348 }
349
350 #[test]
351 fn test_compute_invoice_number_invalid_chars() {
352 let protocol = Protocol {
353 security_level: 0,
354 protocol: "Hello World".to_string(), };
356 let result = KeyDeriver::compute_invoice_number(&protocol, "1");
358 assert!(result.is_ok());
359 }
360
361 #[test]
362 fn test_compute_invoice_number_special_chars_rejected() {
363 let protocol = Protocol {
364 security_level: 0,
365 protocol: "hello-world".to_string(),
366 };
367 let result = KeyDeriver::compute_invoice_number(&protocol, "1");
368 assert!(result.is_err());
369 }
370
371 #[test]
372 fn test_compute_invoice_number_key_id_empty() {
373 let protocol = Protocol {
374 security_level: 0,
375 protocol: "hello world".to_string(),
376 };
377 let result = KeyDeriver::compute_invoice_number(&protocol, "");
378 assert!(result.is_err());
379 }
380
381 #[test]
382 fn test_compute_invoice_number_key_id_too_long() {
383 let protocol = Protocol {
384 security_level: 0,
385 protocol: "hello world".to_string(),
386 };
387 let result = KeyDeriver::compute_invoice_number(&protocol, &"x".repeat(801));
388 assert!(result.is_err());
389 }
390
391 #[test]
392 fn test_normalize_counterparty_self() {
393 let priv_key = PrivateKey::from_hex("ff").unwrap();
394 let kd = KeyDeriver::new(priv_key.clone());
395 let counterparty = Counterparty {
396 counterparty_type: CounterpartyType::Self_,
397 public_key: None,
398 };
399 let result = kd.normalize_counterparty(&counterparty).unwrap();
400 assert_eq!(result.to_der_hex(), priv_key.to_public_key().to_der_hex());
401 }
402
403 #[test]
404 fn test_normalize_counterparty_anyone() {
405 let priv_key = PrivateKey::from_hex("ff").unwrap();
406 let kd = KeyDeriver::new(priv_key);
407 let counterparty = Counterparty {
408 counterparty_type: CounterpartyType::Anyone,
409 public_key: None,
410 };
411 let result = kd.normalize_counterparty(&counterparty).unwrap();
412 assert_eq!(
414 result.to_der_hex(),
415 "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
416 );
417 }
418
419 #[test]
420 fn test_normalize_counterparty_other_missing_key() {
421 let priv_key = PrivateKey::from_hex("ff").unwrap();
422 let kd = KeyDeriver::new(priv_key);
423 let counterparty = Counterparty {
424 counterparty_type: CounterpartyType::Other,
425 public_key: None,
426 };
427 let result = kd.normalize_counterparty(&counterparty);
428 assert!(result.is_err());
429 }
430
431 #[test]
432 fn test_derive_child_roundtrip() {
433 let priv_a = PrivateKey::from_hex("aa").unwrap();
436 let priv_b = PrivateKey::from_hex("bb").unwrap();
437 let pub_b = priv_b.to_public_key();
438
439 let protocol = Protocol {
440 security_level: 2,
441 protocol: "test derivation".to_string(),
442 };
443 let key_id = "42";
444
445 let kd_a = KeyDeriver::new(priv_a);
446 let counterparty_b = Counterparty {
447 counterparty_type: CounterpartyType::Other,
448 public_key: Some(pub_b),
449 };
450
451 let pub_for_self = kd_a
453 .derive_public_key(&protocol, key_id, &counterparty_b, true)
454 .unwrap();
455
456 let pub_for_other = kd_a
458 .derive_public_key(&protocol, key_id, &counterparty_b, false)
459 .unwrap();
460
461 let kd_b = KeyDeriver::new(priv_b);
466 let pub_a = kd_a.identity_key();
467 let counterparty_a = Counterparty {
468 counterparty_type: CounterpartyType::Other,
469 public_key: Some(pub_a),
470 };
471 let pub_from_b = kd_b
472 .derive_public_key(&protocol, key_id, &counterparty_a, false)
473 .unwrap();
474
475 assert_eq!(
476 pub_for_self.to_der_hex(),
477 pub_from_b.to_der_hex(),
478 "A.derive_pub(B, for_self=true) should equal B.derive_pub(A, for_self=false)"
479 );
480
481 let pub_from_b_self = kd_b
483 .derive_public_key(&protocol, key_id, &counterparty_a, true)
484 .unwrap();
485 assert_eq!(
486 pub_for_other.to_der_hex(),
487 pub_from_b_self.to_der_hex(),
488 "A.derive_pub(B, for_self=false) should equal B.derive_pub(A, for_self=true)"
489 );
490 }
491
492 #[test]
493 fn test_derive_symmetric_key_deterministic() {
494 let priv_key = PrivateKey::from_hex("abcd").unwrap();
495 let kd = KeyDeriver::new(priv_key);
496 let protocol = Protocol {
497 security_level: 2,
498 protocol: "test symmetric".to_string(),
499 };
500 let counterparty = Counterparty {
501 counterparty_type: CounterpartyType::Self_,
502 public_key: None,
503 };
504 let key1 = kd
505 .derive_symmetric_key(&protocol, "1", &counterparty)
506 .unwrap();
507 let key2 = kd
508 .derive_symmetric_key(&protocol, "1", &counterparty)
509 .unwrap();
510 assert_eq!(key1.to_hex(), key2.to_hex());
511 }
512
513 #[test]
514 fn test_reveal_counterparty_secret_rejects_self() {
515 let priv_key = PrivateKey::from_hex("ff").unwrap();
516 let kd = KeyDeriver::new(priv_key);
517 let counterparty = Counterparty {
518 counterparty_type: CounterpartyType::Self_,
519 public_key: None,
520 };
521 let result = kd.reveal_counterparty_secret(&counterparty);
522 assert!(result.is_err());
523 }
524}