use crate::error::BitcoinError;
use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
use chacha20poly1305::{
ChaCha20Poly1305, Nonce,
aead::{Aead, KeyInit},
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct V2TransportConfig {
pub opportunistic: bool,
pub max_message_size: usize,
pub enable_padding: bool,
pub padding_granularity: usize,
}
impl Default for V2TransportConfig {
fn default() -> Self {
Self {
opportunistic: true,
max_message_size: 4_000_000, enable_padding: true,
padding_granularity: 64,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Init,
HandshakeInProgress,
Established,
Closed,
}
pub struct SessionKeys {
send_cipher: ChaCha20Poly1305,
recv_cipher: ChaCha20Poly1305,
send_counter: u64,
recv_counter: u64,
}
impl std::fmt::Debug for SessionKeys {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionKeys")
.field("send_counter", &self.send_counter)
.field("recv_counter", &self.recv_counter)
.finish()
}
}
impl SessionKeys {
pub fn derive_from_ecdh(shared_secret: &[u8; 32], is_initiator: bool) -> Self {
use bitcoin::hashes::{Hash, HashEngine, sha256};
let mut key1_engine = sha256::Hash::engine();
key1_engine.input(shared_secret);
key1_engine.input(b"bip324_initiator_to_responder");
let key1_hash = sha256::Hash::from_engine(key1_engine);
let mut key2_engine = sha256::Hash::engine();
key2_engine.input(shared_secret);
key2_engine.input(b"bip324_responder_to_initiator");
let key2_hash = sha256::Hash::from_engine(key2_engine);
let (send_key, recv_key) = if is_initiator {
(key1_hash.to_byte_array(), key2_hash.to_byte_array())
} else {
(key2_hash.to_byte_array(), key1_hash.to_byte_array())
};
let send_cipher = ChaCha20Poly1305::new(&send_key.into());
let recv_cipher = ChaCha20Poly1305::new(&recv_key.into());
Self {
send_cipher,
recv_cipher,
send_counter: 0,
recv_counter: 0,
}
}
fn next_send_nonce(&mut self) -> [u8; 12] {
let mut nonce = [0u8; 12];
nonce[4..].copy_from_slice(&self.send_counter.to_le_bytes());
self.send_counter += 1;
nonce
}
fn next_recv_nonce(&mut self) -> [u8; 12] {
let mut nonce = [0u8; 12];
nonce[4..].copy_from_slice(&self.recv_counter.to_le_bytes());
self.recv_counter += 1;
nonce
}
}
#[derive(Debug)]
pub struct V2Transport {
config: V2TransportConfig,
local_privkey: SecretKey,
local_pubkey: PublicKey,
remote_pubkey: Option<PublicKey>,
session_keys: Option<SessionKeys>,
state: ConnectionState,
secp: Secp256k1<bitcoin::secp256k1::All>,
}
impl V2Transport {
pub fn new(config: V2TransportConfig) -> Result<Self, BitcoinError> {
use bitcoin::secp256k1::rand::rngs::OsRng;
let secp = Secp256k1::new();
let local_privkey = SecretKey::new(&mut OsRng);
let local_pubkey = PublicKey::from_secret_key(&secp, &local_privkey);
Ok(Self {
config,
local_privkey,
local_pubkey,
remote_pubkey: None,
session_keys: None,
state: ConnectionState::Init,
secp,
})
}
pub fn local_public_key(&self) -> PublicKey {
self.local_pubkey
}
pub fn state(&self) -> ConnectionState {
self.state
}
pub fn initiate_handshake(&mut self) -> Result<Vec<u8>, BitcoinError> {
if self.state != ConnectionState::Init {
return Err(BitcoinError::InvalidAddress(
"Handshake already in progress or completed".to_string(),
));
}
self.state = ConnectionState::HandshakeInProgress;
Ok(self.local_pubkey.serialize().to_vec())
}
pub fn respond_handshake(&mut self, initiator_pubkey: &[u8]) -> Result<Vec<u8>, BitcoinError> {
if self.state != ConnectionState::Init {
return Err(BitcoinError::InvalidAddress(
"Invalid state for handshake response".to_string(),
));
}
let remote_pk = PublicKey::from_slice(initiator_pubkey)
.map_err(|e| BitcoinError::InvalidAddress(format!("Invalid public key: {}", e)))?;
self.remote_pubkey = Some(remote_pk);
self.state = ConnectionState::HandshakeInProgress;
self.complete_handshake(false)?;
Ok(self.local_pubkey.serialize().to_vec())
}
pub fn finalize_handshake(&mut self, responder_pubkey: &[u8]) -> Result<(), BitcoinError> {
if self.state != ConnectionState::HandshakeInProgress {
return Err(BitcoinError::InvalidAddress(
"No handshake in progress".to_string(),
));
}
let remote_pk = PublicKey::from_slice(responder_pubkey)
.map_err(|e| BitcoinError::InvalidAddress(format!("Invalid public key: {}", e)))?;
self.remote_pubkey = Some(remote_pk);
self.complete_handshake(true)
}
fn complete_handshake(&mut self, is_initiator: bool) -> Result<(), BitcoinError> {
let remote_pk = self
.remote_pubkey
.ok_or_else(|| BitcoinError::InvalidAddress("Remote public key not set".to_string()))?;
let shared_point = remote_pk
.mul_tweak(&self.secp, &self.local_privkey.into())
.map_err(|e| BitcoinError::InvalidAddress(format!("ECDH failed: {}", e)))?;
let shared_secret = shared_point.serialize();
let mut secret_bytes = [0u8; 32];
secret_bytes.copy_from_slice(&shared_secret[1..33]);
self.session_keys = Some(SessionKeys::derive_from_ecdh(&secret_bytes, is_initiator));
self.state = ConnectionState::Established;
Ok(())
}
pub fn encrypt_message(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, BitcoinError> {
if self.state != ConnectionState::Established {
return Err(BitcoinError::InvalidAddress(
"Connection not established".to_string(),
));
}
if plaintext.len() > self.config.max_message_size {
return Err(BitcoinError::InvalidAddress(
"Message too large".to_string(),
));
}
let padding_len = if self.config.enable_padding {
self.calculate_padding(plaintext.len())
} else {
0
};
let keys = self.session_keys.as_mut().ok_or_else(|| {
BitcoinError::InvalidAddress("Session keys not available".to_string())
})?;
let nonce_bytes = keys.next_send_nonce();
let nonce = Nonce::from_slice(&nonce_bytes);
let mut padded_plaintext = plaintext.to_vec();
padded_plaintext.resize(plaintext.len() + padding_len, 0);
let ciphertext = keys
.send_cipher
.encrypt(nonce, padded_plaintext.as_ref())
.map_err(|e| BitcoinError::InvalidAddress(format!("Encryption failed: {}", e)))?;
let mut result = Vec::with_capacity(nonce_bytes.len() + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
pub fn decrypt_message(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, BitcoinError> {
if self.state != ConnectionState::Established {
return Err(BitcoinError::InvalidAddress(
"Connection not established".to_string(),
));
}
if ciphertext.len() < 28 {
return Err(BitcoinError::InvalidAddress(
"Ciphertext too short".to_string(),
));
}
let keys = self.session_keys.as_mut().ok_or_else(|| {
BitcoinError::InvalidAddress("Session keys not available".to_string())
})?;
let nonce_bytes = keys.next_recv_nonce();
let received_nonce = &ciphertext[..12];
if received_nonce != nonce_bytes {
return Err(BitcoinError::InvalidAddress(
"Nonce mismatch - possible replay attack".to_string(),
));
}
let nonce = Nonce::from_slice(received_nonce);
let plaintext = keys
.recv_cipher
.decrypt(nonce, &ciphertext[12..])
.map_err(|e| BitcoinError::InvalidAddress(format!("Decryption failed: {}", e)))?;
Ok(plaintext)
}
fn calculate_padding(&self, message_len: usize) -> usize {
if !self.config.enable_padding {
return 0;
}
let granularity = self.config.padding_granularity;
let remainder = message_len % granularity;
if remainder == 0 {
0
} else {
granularity - remainder
}
}
pub fn close(&mut self) {
self.state = ConnectionState::Closed;
self.session_keys = None;
self.remote_pubkey = None;
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct V2TransportStats {
pub messages_sent: u64,
pub messages_received: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub padding_bytes: u64,
pub handshake_attempts: u64,
pub successful_connections: u64,
pub v1_fallbacks: u64,
}
impl V2TransportStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_send(&mut self, message_size: usize, padding: usize) {
self.messages_sent += 1;
self.bytes_sent += message_size as u64;
self.padding_bytes += padding as u64;
}
pub fn record_receive(&mut self, message_size: usize) {
self.messages_received += 1;
self.bytes_received += message_size as u64;
}
pub fn record_handshake(&mut self, successful: bool) {
self.handshake_attempts += 1;
if successful {
self.successful_connections += 1;
}
}
pub fn record_fallback(&mut self) {
self.v1_fallbacks += 1;
}
pub fn encryption_overhead_ratio(&self) -> f64 {
if self.bytes_sent == 0 {
return 0.0;
}
self.padding_bytes as f64 / self.bytes_sent as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = V2TransportConfig::default();
assert!(config.opportunistic);
assert!(config.enable_padding);
assert_eq!(config.max_message_size, 4_000_000);
}
#[test]
fn test_transport_creation() {
let config = V2TransportConfig::default();
let transport = V2Transport::new(config).unwrap();
assert_eq!(transport.state(), ConnectionState::Init);
}
#[test]
fn test_session_keys_derivation() {
let secret = [42u8; 32];
let keys_initiator = SessionKeys::derive_from_ecdh(&secret, true);
let keys_responder = SessionKeys::derive_from_ecdh(&secret, false);
assert_eq!(keys_initiator.send_counter, 0);
assert_eq!(keys_initiator.recv_counter, 0);
assert_eq!(keys_responder.send_counter, 0);
assert_eq!(keys_responder.recv_counter, 0);
}
#[test]
fn test_handshake_flow() {
let config = V2TransportConfig::default();
let mut initiator = V2Transport::new(config.clone()).unwrap();
let mut responder = V2Transport::new(config).unwrap();
let init_msg = initiator.initiate_handshake().unwrap();
assert_eq!(initiator.state(), ConnectionState::HandshakeInProgress);
let resp_msg = responder.respond_handshake(&init_msg).unwrap();
assert_eq!(responder.state(), ConnectionState::Established);
initiator.finalize_handshake(&resp_msg).unwrap();
assert_eq!(initiator.state(), ConnectionState::Established);
}
#[test]
fn test_encrypt_decrypt() {
let config = V2TransportConfig {
enable_padding: false, ..Default::default()
};
let mut initiator = V2Transport::new(config.clone()).unwrap();
let mut responder = V2Transport::new(config).unwrap();
let init_msg = initiator.initiate_handshake().unwrap();
let resp_msg = responder.respond_handshake(&init_msg).unwrap();
initiator.finalize_handshake(&resp_msg).unwrap();
let plaintext = b"Hello, Bitcoin!";
let ciphertext = initiator.encrypt_message(plaintext).unwrap();
assert!(ciphertext.len() >= plaintext.len() + 28);
let decrypted = responder.decrypt_message(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_padding_calculation() {
let config = V2TransportConfig {
enable_padding: true,
padding_granularity: 64,
..Default::default()
};
let transport = V2Transport::new(config).unwrap();
assert_eq!(transport.calculate_padding(0), 0);
assert_eq!(transport.calculate_padding(64), 0);
assert_eq!(transport.calculate_padding(65), 63);
assert_eq!(transport.calculate_padding(100), 28);
}
#[test]
fn test_stats_tracking() {
let mut stats = V2TransportStats::new();
stats.record_send(100, 28);
stats.record_send(200, 56);
stats.record_receive(150);
assert_eq!(stats.messages_sent, 2);
assert_eq!(stats.messages_received, 1);
assert_eq!(stats.bytes_sent, 300);
assert_eq!(stats.padding_bytes, 84);
assert!(stats.encryption_overhead_ratio() > 0.0);
}
#[test]
fn test_connection_close() {
let config = V2TransportConfig::default();
let mut transport = V2Transport::new(config).unwrap();
transport.close();
assert_eq!(transport.state(), ConnectionState::Closed);
assert!(transport.session_keys.is_none());
}
#[test]
fn test_encrypt_before_handshake_fails() {
let config = V2TransportConfig::default();
let mut transport = V2Transport::new(config).unwrap();
let result = transport.encrypt_message(b"test");
assert!(result.is_err());
}
#[test]
fn test_message_too_large() {
let config = V2TransportConfig {
max_message_size: 100,
..Default::default()
};
let mut initiator = V2Transport::new(config.clone()).unwrap();
let mut responder = V2Transport::new(config).unwrap();
let init_msg = initiator.initiate_handshake().unwrap();
let resp_msg = responder.respond_handshake(&init_msg).unwrap();
initiator.finalize_handshake(&resp_msg).unwrap();
let large_msg = vec![0u8; 200];
let result = initiator.encrypt_message(&large_msg);
assert!(result.is_err());
}
}