rustywallet_electrum/
pinning.rs

1//! SSL certificate pinning for enhanced security.
2//!
3//! This module provides certificate pinning functionality to prevent
4//! man-in-the-middle attacks by verifying server certificates against
5//! known fingerprints.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use rustls::{
11    client::{ServerCertVerified, ServerCertVerifier},
12    Certificate, ServerName,
13};
14use sha2::{Digest, Sha256};
15
16use crate::error::{ElectrumError, Result};
17
18/// Certificate fingerprint (SHA-256 hash of DER-encoded certificate).
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub struct CertFingerprint([u8; 32]);
21
22impl CertFingerprint {
23    /// Create a fingerprint from raw bytes.
24    pub fn from_bytes(bytes: [u8; 32]) -> Self {
25        Self(bytes)
26    }
27
28    /// Create a fingerprint from hex string.
29    pub fn from_hex(hex: &str) -> Result<Self> {
30        let bytes = hex::decode(hex)
31            .map_err(|e| ElectrumError::TlsError(format!("Invalid hex fingerprint: {}", e)))?;
32        
33        if bytes.len() != 32 {
34            return Err(ElectrumError::TlsError(format!(
35                "Fingerprint must be 32 bytes, got {}",
36                bytes.len()
37            )));
38        }
39
40        let mut arr = [0u8; 32];
41        arr.copy_from_slice(&bytes);
42        Ok(Self(arr))
43    }
44
45    /// Calculate fingerprint from a DER-encoded certificate.
46    pub fn from_certificate(cert_der: &[u8]) -> Self {
47        let mut hasher = Sha256::new();
48        hasher.update(cert_der);
49        let result = hasher.finalize();
50        let mut arr = [0u8; 32];
51        arr.copy_from_slice(&result);
52        Self(arr)
53    }
54
55    /// Get the fingerprint as bytes.
56    pub fn as_bytes(&self) -> &[u8; 32] {
57        &self.0
58    }
59
60    /// Get the fingerprint as hex string.
61    pub fn to_hex(&self) -> String {
62        hex::encode(self.0)
63    }
64}
65
66impl std::fmt::Display for CertFingerprint {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(f, "{}", self.to_hex())
69    }
70}
71
72/// Certificate pin store for multiple servers.
73#[derive(Debug, Clone, Default)]
74pub struct CertPinStore {
75    pins: HashMap<String, Vec<CertFingerprint>>,
76}
77
78impl CertPinStore {
79    /// Create a new empty pin store.
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    /// Add a certificate pin for a server.
85    ///
86    /// Multiple pins can be added for the same server (for certificate rotation).
87    pub fn add_pin(&mut self, server: impl Into<String>, fingerprint: CertFingerprint) {
88        self.pins
89            .entry(server.into())
90            .or_default()
91            .push(fingerprint);
92    }
93
94    /// Add a certificate pin from hex string.
95    pub fn add_pin_hex(&mut self, server: impl Into<String>, hex: &str) -> Result<()> {
96        let fingerprint = CertFingerprint::from_hex(hex)?;
97        self.add_pin(server, fingerprint);
98        Ok(())
99    }
100
101    /// Check if a certificate is pinned for a server.
102    pub fn verify(&self, server: &str, cert_der: &[u8]) -> bool {
103        let fingerprint = CertFingerprint::from_certificate(cert_der);
104        
105        if let Some(pins) = self.pins.get(server) {
106            pins.contains(&fingerprint)
107        } else {
108            // No pins for this server - allow any certificate
109            true
110        }
111    }
112
113    /// Get all pins for a server.
114    pub fn get_pins(&self, server: &str) -> Option<&[CertFingerprint]> {
115        self.pins.get(server).map(|v| v.as_slice())
116    }
117
118    /// Check if any pins are configured for a server.
119    pub fn has_pins(&self, server: &str) -> bool {
120        self.pins.contains_key(server)
121    }
122
123    /// Remove all pins for a server.
124    pub fn remove_pins(&mut self, server: &str) {
125        self.pins.remove(server);
126    }
127
128    /// Get the number of servers with pins.
129    pub fn server_count(&self) -> usize {
130        self.pins.len()
131    }
132}
133
134/// Certificate verifier with pinning support.
135pub struct PinningVerifier {
136    pin_store: CertPinStore,
137    allow_unpinned: bool,
138}
139
140impl PinningVerifier {
141    /// Create a new pinning verifier.
142    ///
143    /// # Arguments
144    /// * `pin_store` - Store containing certificate pins
145    /// * `allow_unpinned` - If true, allow connections to servers without pins
146    pub fn new(pin_store: CertPinStore, allow_unpinned: bool) -> Self {
147        Self {
148            pin_store,
149            allow_unpinned,
150        }
151    }
152
153    /// Create a verifier that requires all servers to have pins.
154    pub fn strict(pin_store: CertPinStore) -> Self {
155        Self::new(pin_store, false)
156    }
157
158    /// Create a verifier that allows unpinned servers.
159    pub fn permissive(pin_store: CertPinStore) -> Self {
160        Self::new(pin_store, true)
161    }
162}
163
164impl ServerCertVerifier for PinningVerifier {
165    fn verify_server_cert(
166        &self,
167        end_entity: &Certificate,
168        _intermediates: &[Certificate],
169        server_name: &ServerName,
170        _scts: &mut dyn Iterator<Item = &[u8]>,
171        _ocsp_response: &[u8],
172        _now: std::time::SystemTime,
173    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
174        let server = match server_name {
175            ServerName::DnsName(name) => name.as_ref().to_string(),
176            _ => return Err(rustls::Error::General("Invalid server name".into())),
177        };
178
179        // Check if server has pins
180        if !self.pin_store.has_pins(&server) {
181            if self.allow_unpinned {
182                return Ok(ServerCertVerified::assertion());
183            } else {
184                return Err(rustls::Error::General(format!(
185                    "No certificate pins for server: {}",
186                    server
187                )));
188            }
189        }
190
191        // Verify against pins
192        if self.pin_store.verify(&server, &end_entity.0) {
193            Ok(ServerCertVerified::assertion())
194        } else {
195            Err(rustls::Error::General(format!(
196                "Certificate fingerprint mismatch for server: {}",
197                server
198            )))
199        }
200    }
201}
202
203/// Builder for creating TLS config with certificate pinning.
204pub struct PinningConfigBuilder {
205    pin_store: CertPinStore,
206    allow_unpinned: bool,
207}
208
209impl PinningConfigBuilder {
210    /// Create a new builder.
211    pub fn new() -> Self {
212        Self {
213            pin_store: CertPinStore::new(),
214            allow_unpinned: true,
215        }
216    }
217
218    /// Add a certificate pin.
219    pub fn pin(mut self, server: impl Into<String>, fingerprint: CertFingerprint) -> Self {
220        self.pin_store.add_pin(server, fingerprint);
221        self
222    }
223
224    /// Add a certificate pin from hex.
225    pub fn pin_hex(mut self, server: impl Into<String>, hex: &str) -> Result<Self> {
226        self.pin_store.add_pin_hex(server, hex)?;
227        Ok(self)
228    }
229
230    /// Set whether to allow connections to unpinned servers.
231    pub fn allow_unpinned(mut self, allow: bool) -> Self {
232        self.allow_unpinned = allow;
233        self
234    }
235
236    /// Build the TLS configuration.
237    pub fn build(self) -> rustls::ClientConfig {
238        let verifier = PinningVerifier::new(self.pin_store, self.allow_unpinned);
239        
240        rustls::ClientConfig::builder()
241            .with_safe_defaults()
242            .with_custom_certificate_verifier(Arc::new(verifier))
243            .with_no_client_auth()
244    }
245
246    /// Get the pin store.
247    pub fn pin_store(&self) -> &CertPinStore {
248        &self.pin_store
249    }
250}
251
252impl Default for PinningConfigBuilder {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258/// Known certificate fingerprints for popular Electrum servers.
259pub mod known_pins {
260    use super::CertFingerprint;
261
262    /// Get fingerprint for electrum.blockstream.info (if known).
263    /// Note: These fingerprints may change when certificates are rotated.
264    pub fn blockstream() -> Option<CertFingerprint> {
265        // Certificate fingerprints change over time
266        // Users should verify and update these periodically
267        None
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_fingerprint_from_hex() {
277        let hex = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
278        let fp = CertFingerprint::from_hex(hex).unwrap();
279        assert_eq!(fp.to_hex(), hex);
280    }
281
282    #[test]
283    fn test_fingerprint_from_certificate() {
284        let cert_der = b"test certificate data";
285        let fp = CertFingerprint::from_certificate(cert_der);
286        assert_eq!(fp.as_bytes().len(), 32);
287    }
288
289    #[test]
290    fn test_pin_store() {
291        let mut store = CertPinStore::new();
292        let fp = CertFingerprint::from_bytes([0u8; 32]);
293        
294        store.add_pin("server.example.com", fp.clone());
295        
296        assert!(store.has_pins("server.example.com"));
297        assert!(!store.has_pins("other.example.com"));
298        
299        let pins = store.get_pins("server.example.com").unwrap();
300        assert_eq!(pins.len(), 1);
301        assert_eq!(pins[0], fp);
302    }
303
304    #[test]
305    fn test_pin_store_verify() {
306        let mut store = CertPinStore::new();
307        let cert_der = b"test certificate";
308        let fp = CertFingerprint::from_certificate(cert_der);
309        
310        store.add_pin("server.example.com", fp);
311        
312        assert!(store.verify("server.example.com", cert_der));
313        assert!(!store.verify("server.example.com", b"wrong cert"));
314        // Unpinned server allows any cert
315        assert!(store.verify("unpinned.example.com", b"any cert"));
316    }
317}