Skip to main content

bsv/wallet/
key_deriver.rs

1//! KeyDeriver: Type-42 key derivation for the wallet module.
2//!
3//! Implements BRC-42 key derivation using a root private key,
4//! supporting derivation of private keys, public keys, symmetric keys,
5//! and key linkage revelation.
6
7use crate::primitives::hash::sha256_hmac;
8use crate::primitives::private_key::PrivateKey;
9use crate::primitives::public_key::PublicKey;
10use crate::primitives::symmetric_key::SymmetricKey;
11use crate::wallet::error::WalletError;
12use crate::wallet::types::{anyone_pubkey, Counterparty, CounterpartyType, Protocol};
13
14/// KeyDeriver derives various types of keys using a root private key.
15///
16/// Supports deriving public and private keys, symmetric keys, and
17/// revealing key linkages, all using BRC-42 Type-42 derivation.
18pub struct KeyDeriver {
19    root_key: PrivateKey,
20}
21
22impl KeyDeriver {
23    /// Create a new KeyDeriver from a root private key.
24    pub fn new(private_key: PrivateKey) -> Self {
25        KeyDeriver {
26            root_key: private_key,
27        }
28    }
29
30    /// Create a KeyDeriver using the special "anyone" key (PrivateKey(1)).
31    pub fn new_anyone() -> Self {
32        KeyDeriver {
33            root_key: crate::wallet::types::anyone_private_key(),
34        }
35    }
36
37    /// Returns a reference to the root private key.
38    pub fn root_key(&self) -> &PrivateKey {
39        &self.root_key
40    }
41
42    /// Returns the public key corresponding to the root private key.
43    pub fn identity_key(&self) -> PublicKey {
44        self.root_key.to_public_key()
45    }
46
47    /// Returns the identity key as a compressed DER hex string.
48    pub fn identity_key_hex(&self) -> String {
49        self.identity_key().to_der_hex()
50    }
51
52    /// Derive a private key for the given protocol, key ID, and counterparty.
53    pub fn derive_private_key(
54        &self,
55        protocol: &Protocol,
56        key_id: &str,
57        counterparty: &Counterparty,
58    ) -> Result<PrivateKey, WalletError> {
59        let counterparty_pubkey = self.normalize_counterparty(counterparty)?;
60        let invoice_number = Self::compute_invoice_number(protocol, key_id)?;
61        let child = self
62            .root_key
63            .derive_child(&counterparty_pubkey, &invoice_number)?;
64        Ok(child)
65    }
66
67    /// Derive a public key for the given protocol, key ID, and counterparty.
68    ///
69    /// If `for_self` is true, derives the private child key first and returns
70    /// its public key. If false, derives directly on the counterparty's public key.
71    pub fn derive_public_key(
72        &self,
73        protocol: &Protocol,
74        key_id: &str,
75        counterparty: &Counterparty,
76        for_self: bool,
77    ) -> Result<PublicKey, WalletError> {
78        let counterparty_pubkey = self.normalize_counterparty(counterparty)?;
79        let invoice_number = Self::compute_invoice_number(protocol, key_id)?;
80
81        if for_self {
82            let priv_child = self
83                .root_key
84                .derive_child(&counterparty_pubkey, &invoice_number)?;
85            Ok(priv_child.to_public_key())
86        } else {
87            let pub_child = counterparty_pubkey.derive_child(&self.root_key, &invoice_number)?;
88            Ok(pub_child)
89        }
90    }
91
92    /// Derive a symmetric key from the ECDH shared secret of the derived
93    /// private and public keys.
94    ///
95    /// The symmetric key is the x-coordinate of the shared secret point.
96    pub fn derive_symmetric_key(
97        &self,
98        protocol: &Protocol,
99        key_id: &str,
100        counterparty: &Counterparty,
101    ) -> Result<SymmetricKey, WalletError> {
102        // If counterparty is Anyone, treat as Other with anyone pubkey
103        let effective_counterparty = if counterparty.counterparty_type == CounterpartyType::Anyone {
104            Counterparty {
105                counterparty_type: CounterpartyType::Other,
106                public_key: Some(anyone_pubkey()),
107            }
108        } else {
109            counterparty.clone()
110        };
111
112        let derived_pub =
113            self.derive_public_key(protocol, key_id, &effective_counterparty, false)?;
114        let derived_priv = self.derive_private_key(protocol, key_id, &effective_counterparty)?;
115
116        let shared_secret = derived_priv.derive_shared_secret(&derived_pub)?;
117        let x_bytes = shared_secret
118            .x
119            .to_array(crate::primitives::big_number::Endian::Big, Some(32));
120        let sym_key = SymmetricKey::from_bytes(&x_bytes)?;
121        Ok(sym_key)
122    }
123
124    /// Reveal the counterparty shared secret as a public key point.
125    ///
126    /// Cannot be used for counterparty type "self".
127    pub fn reveal_counterparty_secret(
128        &self,
129        counterparty: &Counterparty,
130    ) -> Result<PublicKey, WalletError> {
131        if counterparty.counterparty_type == CounterpartyType::Self_ {
132            return Err(WalletError::InvalidParameter(
133                "counterparty secrets cannot be revealed for counterparty=self".to_string(),
134            ));
135        }
136
137        let counterparty_pubkey = self.normalize_counterparty(counterparty)?;
138
139        // Double-check: verify it is not actually self
140        let self_pub = self.root_key.to_public_key();
141        let key_derived_by_self = self.root_key.derive_child(&self_pub, "test")?;
142        let key_derived_by_counterparty =
143            self.root_key.derive_child(&counterparty_pubkey, "test")?;
144
145        if key_derived_by_self.to_bytes() == key_derived_by_counterparty.to_bytes() {
146            return Err(WalletError::InvalidParameter(
147                "counterparty secrets cannot be revealed if counterparty key is self".to_string(),
148            ));
149        }
150
151        let shared_secret = self.root_key.derive_shared_secret(&counterparty_pubkey)?;
152        Ok(PublicKey::from_point(shared_secret))
153    }
154
155    /// Reveal a specific secret for the given protocol and key ID.
156    ///
157    /// Computes HMAC-SHA256 of the shared secret (compressed) and the
158    /// invoice number string.
159    pub fn reveal_specific_secret(
160        &self,
161        counterparty: &Counterparty,
162        protocol: &Protocol,
163        key_id: &str,
164    ) -> Result<Vec<u8>, WalletError> {
165        let counterparty_pubkey = self.normalize_counterparty(counterparty)?;
166        let shared_secret = self.root_key.derive_shared_secret(&counterparty_pubkey)?;
167        let invoice_number = Self::compute_invoice_number(protocol, key_id)?;
168        let shared_secret_compressed = shared_secret.to_der(true);
169        let hmac = sha256_hmac(&shared_secret_compressed, invoice_number.as_bytes());
170        Ok(hmac.to_vec())
171    }
172
173    /// Normalize a Counterparty to a concrete PublicKey.
174    fn normalize_counterparty(
175        &self,
176        counterparty: &Counterparty,
177    ) -> Result<PublicKey, WalletError> {
178        match counterparty.counterparty_type {
179            CounterpartyType::Self_ => Ok(self.root_key.to_public_key()),
180            CounterpartyType::Anyone => Ok(anyone_pubkey()),
181            CounterpartyType::Other => counterparty.public_key.clone().ok_or_else(|| {
182                WalletError::InvalidParameter(
183                    "counterparty public key required for type Other".to_string(),
184                )
185            }),
186            CounterpartyType::Uninitialized => Err(WalletError::InvalidParameter(
187                "counterparty type is uninitialized".to_string(),
188            )),
189        }
190    }
191
192    /// Compute the invoice number string from protocol and key ID.
193    ///
194    /// Format: "{security_level}-{protocol_name}-{key_id}"
195    /// Validates security level (0-2), protocol name (5-400 chars, lowercase
196    /// alphanumeric + spaces, no consecutive spaces, must not end with " protocol"),
197    /// and key ID (1-800 chars).
198    fn compute_invoice_number(protocol: &Protocol, key_id: &str) -> Result<String, WalletError> {
199        // Validate security level
200        if protocol.security_level > 2 {
201            return Err(WalletError::InvalidParameter(
202                "protocol security level must be 0, 1, or 2".to_string(),
203            ));
204        }
205
206        // Validate key ID
207        if key_id.is_empty() {
208            return Err(WalletError::InvalidParameter(
209                "key IDs must be 1 character or more".to_string(),
210            ));
211        }
212        if key_id.len() > 800 {
213            return Err(WalletError::InvalidParameter(
214                "key IDs must be 800 characters or less".to_string(),
215            ));
216        }
217
218        // Validate protocol name
219        let protocol_name = protocol.protocol.trim().to_lowercase();
220        if protocol_name.len() < 5 {
221            return Err(WalletError::InvalidParameter(
222                "protocol names must be 5 characters or more".to_string(),
223            ));
224        }
225        if protocol_name.len() > 400 {
226            if protocol_name.starts_with("specific linkage revelation ") {
227                if protocol_name.len() > 430 {
228                    return Err(WalletError::InvalidParameter(
229                        "specific linkage revelation protocol names must be 430 characters or less"
230                            .to_string(),
231                    ));
232                }
233            } else {
234                return Err(WalletError::InvalidParameter(
235                    "protocol names must be 400 characters or less".to_string(),
236                ));
237            }
238        }
239        if protocol_name.contains("  ") {
240            return Err(WalletError::InvalidParameter(
241                "protocol names cannot contain multiple consecutive spaces".to_string(),
242            ));
243        }
244        if !protocol_name
245            .chars()
246            .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == ' ')
247        {
248            return Err(WalletError::InvalidParameter(
249                "protocol names can only contain letters, numbers and spaces".to_string(),
250            ));
251        }
252        if protocol_name.ends_with(" protocol") {
253            return Err(WalletError::InvalidParameter(
254                "no need to end your protocol name with \" protocol\"".to_string(),
255            ));
256        }
257
258        Ok(format!(
259            "{}-{}-{}",
260            protocol.security_level, protocol_name, key_id
261        ))
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::wallet::types::CounterpartyType;
269
270    #[test]
271    fn test_identity_key_known_vector() {
272        let priv_key = PrivateKey::from_hex("1").unwrap();
273        let kd = KeyDeriver::new(priv_key);
274        assert_eq!(
275            kd.identity_key_hex(),
276            "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
277        );
278    }
279
280    #[test]
281    fn test_anyone_deriver() {
282        let kd = KeyDeriver::new_anyone();
283        // Anyone key is PrivateKey(1) -> G point
284        assert_eq!(
285            kd.identity_key_hex(),
286            "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
287        );
288    }
289
290    #[test]
291    fn test_compute_invoice_number_valid() {
292        let protocol = Protocol {
293            security_level: 2,
294            protocol: "hello world".to_string(),
295        };
296        let result = KeyDeriver::compute_invoice_number(&protocol, "1");
297        assert_eq!(result.unwrap(), "2-hello world-1");
298    }
299
300    #[test]
301    fn test_compute_invoice_number_security_level_too_high() {
302        let protocol = Protocol {
303            security_level: 3,
304            protocol: "hello world".to_string(),
305        };
306        let result = KeyDeriver::compute_invoice_number(&protocol, "1");
307        assert!(result.is_err());
308    }
309
310    #[test]
311    fn test_compute_invoice_number_protocol_too_short() {
312        let protocol = Protocol {
313            security_level: 0,
314            protocol: "abcd".to_string(),
315        };
316        let result = KeyDeriver::compute_invoice_number(&protocol, "1");
317        assert!(result.is_err());
318    }
319
320    #[test]
321    fn test_compute_invoice_number_protocol_too_long() {
322        let protocol = Protocol {
323            security_level: 0,
324            protocol: "a".repeat(401),
325        };
326        let result = KeyDeriver::compute_invoice_number(&protocol, "1");
327        assert!(result.is_err());
328    }
329
330    #[test]
331    fn test_compute_invoice_number_consecutive_spaces() {
332        let protocol = Protocol {
333            security_level: 0,
334            protocol: "hello  world".to_string(),
335        };
336        let result = KeyDeriver::compute_invoice_number(&protocol, "1");
337        assert!(result.is_err());
338    }
339
340    #[test]
341    fn test_compute_invoice_number_ends_with_protocol() {
342        let protocol = Protocol {
343            security_level: 0,
344            protocol: "my cool protocol".to_string(),
345        };
346        let result = KeyDeriver::compute_invoice_number(&protocol, "1");
347        assert!(result.is_err());
348    }
349
350    #[test]
351    fn test_compute_invoice_number_invalid_chars() {
352        let protocol = Protocol {
353            security_level: 0,
354            protocol: "Hello World".to_string(), // uppercase
355        };
356        // After lowercasing, "hello world" is valid
357        let result = KeyDeriver::compute_invoice_number(&protocol, "1");
358        assert!(result.is_ok());
359    }
360
361    #[test]
362    fn test_compute_invoice_number_special_chars_rejected() {
363        let protocol = Protocol {
364            security_level: 0,
365            protocol: "hello-world".to_string(),
366        };
367        let result = KeyDeriver::compute_invoice_number(&protocol, "1");
368        assert!(result.is_err());
369    }
370
371    #[test]
372    fn test_compute_invoice_number_key_id_empty() {
373        let protocol = Protocol {
374            security_level: 0,
375            protocol: "hello world".to_string(),
376        };
377        let result = KeyDeriver::compute_invoice_number(&protocol, "");
378        assert!(result.is_err());
379    }
380
381    #[test]
382    fn test_compute_invoice_number_key_id_too_long() {
383        let protocol = Protocol {
384            security_level: 0,
385            protocol: "hello world".to_string(),
386        };
387        let result = KeyDeriver::compute_invoice_number(&protocol, &"x".repeat(801));
388        assert!(result.is_err());
389    }
390
391    #[test]
392    fn test_normalize_counterparty_self() {
393        let priv_key = PrivateKey::from_hex("ff").unwrap();
394        let kd = KeyDeriver::new(priv_key.clone());
395        let counterparty = Counterparty {
396            counterparty_type: CounterpartyType::Self_,
397            public_key: None,
398        };
399        let result = kd.normalize_counterparty(&counterparty).unwrap();
400        assert_eq!(result.to_der_hex(), priv_key.to_public_key().to_der_hex());
401    }
402
403    #[test]
404    fn test_normalize_counterparty_anyone() {
405        let priv_key = PrivateKey::from_hex("ff").unwrap();
406        let kd = KeyDeriver::new(priv_key);
407        let counterparty = Counterparty {
408            counterparty_type: CounterpartyType::Anyone,
409            public_key: None,
410        };
411        let result = kd.normalize_counterparty(&counterparty).unwrap();
412        // Anyone = PrivateKey(1).to_public_key() = G point
413        assert_eq!(
414            result.to_der_hex(),
415            "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
416        );
417    }
418
419    #[test]
420    fn test_normalize_counterparty_other_missing_key() {
421        let priv_key = PrivateKey::from_hex("ff").unwrap();
422        let kd = KeyDeriver::new(priv_key);
423        let counterparty = Counterparty {
424            counterparty_type: CounterpartyType::Other,
425            public_key: None,
426        };
427        let result = kd.normalize_counterparty(&counterparty);
428        assert!(result.is_err());
429    }
430
431    #[test]
432    fn test_derive_child_roundtrip() {
433        // Key property: priv.derive_child(counterparty_pub, inv).to_public_key()
434        //            == counterparty_pub.derive_child(priv, inv) for for_self=true
435        let priv_a = PrivateKey::from_hex("aa").unwrap();
436        let priv_b = PrivateKey::from_hex("bb").unwrap();
437        let pub_b = priv_b.to_public_key();
438
439        let protocol = Protocol {
440            security_level: 2,
441            protocol: "test derivation".to_string(),
442        };
443        let key_id = "42";
444
445        let kd_a = KeyDeriver::new(priv_a);
446        let counterparty_b = Counterparty {
447            counterparty_type: CounterpartyType::Other,
448            public_key: Some(pub_b),
449        };
450
451        // Derive for_self=true: use own private key to derive child, get pubkey
452        let pub_for_self = kd_a
453            .derive_public_key(&protocol, key_id, &counterparty_b, true)
454            .unwrap();
455
456        // Derive for_self=false: use counterparty's pubkey to derive child pubkey
457        let pub_for_other = kd_a
458            .derive_public_key(&protocol, key_id, &counterparty_b, false)
459            .unwrap();
460
461        // These should be different (for_self vs not for_self derive differently)
462        // But the key round-trip property is:
463        // KeyDeriver(A).derive_pub(B, for_self=true) ==
464        // KeyDeriver(B).derive_pub(A, for_self=false)
465        let kd_b = KeyDeriver::new(priv_b);
466        let pub_a = kd_a.identity_key();
467        let counterparty_a = Counterparty {
468            counterparty_type: CounterpartyType::Other,
469            public_key: Some(pub_a),
470        };
471        let pub_from_b = kd_b
472            .derive_public_key(&protocol, key_id, &counterparty_a, false)
473            .unwrap();
474
475        assert_eq!(
476            pub_for_self.to_der_hex(),
477            pub_from_b.to_der_hex(),
478            "A.derive_pub(B, for_self=true) should equal B.derive_pub(A, for_self=false)"
479        );
480
481        // Also verify the other direction
482        let pub_from_b_self = kd_b
483            .derive_public_key(&protocol, key_id, &counterparty_a, true)
484            .unwrap();
485        assert_eq!(
486            pub_for_other.to_der_hex(),
487            pub_from_b_self.to_der_hex(),
488            "A.derive_pub(B, for_self=false) should equal B.derive_pub(A, for_self=true)"
489        );
490    }
491
492    #[test]
493    fn test_derive_symmetric_key_deterministic() {
494        let priv_key = PrivateKey::from_hex("abcd").unwrap();
495        let kd = KeyDeriver::new(priv_key);
496        let protocol = Protocol {
497            security_level: 2,
498            protocol: "test symmetric".to_string(),
499        };
500        let counterparty = Counterparty {
501            counterparty_type: CounterpartyType::Self_,
502            public_key: None,
503        };
504        let key1 = kd
505            .derive_symmetric_key(&protocol, "1", &counterparty)
506            .unwrap();
507        let key2 = kd
508            .derive_symmetric_key(&protocol, "1", &counterparty)
509            .unwrap();
510        assert_eq!(key1.to_hex(), key2.to_hex());
511    }
512
513    #[test]
514    fn test_reveal_counterparty_secret_rejects_self() {
515        let priv_key = PrivateKey::from_hex("ff").unwrap();
516        let kd = KeyDeriver::new(priv_key);
517        let counterparty = Counterparty {
518            counterparty_type: CounterpartyType::Self_,
519            public_key: None,
520        };
521        let result = kd.reveal_counterparty_secret(&counterparty);
522        assert!(result.is_err());
523    }
524}