use std::io;
use std::time::Duration;
use num_bigint::BigUint;
use rand::RngExt;
use sha1::{Digest, Sha1};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::timeout;
use super::metainfo::Sha1Hash;
use crate::error::{EngineError, NetworkErrorKind, ProtocolErrorKind, Result};
pub const DH_PRIME: [u8; 96] = [
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xC9, 0x0F, 0xDA, 0xA2, 0x21, 0x68, 0xC2, 0x34,
0xC4, 0xC6, 0x62, 0x8B, 0x80, 0xDC, 0x1C, 0xD1, 0x29, 0x02, 0x4E, 0x08, 0x8A, 0x67, 0xCC, 0x74,
0x02, 0x0B, 0xBE, 0xA6, 0x3B, 0x13, 0x9B, 0x22, 0x51, 0x4A, 0x08, 0x79, 0x8E, 0x34, 0x04, 0xDD,
0xEF, 0x95, 0x19, 0xB3, 0xCD, 0x3A, 0x43, 0x1B, 0x30, 0x2B, 0x0A, 0x6D, 0xF2, 0x5F, 0x14, 0x37,
0x4F, 0xE1, 0x35, 0x6D, 0x6D, 0x51, 0xC2, 0x45, 0xE4, 0x85, 0xB5, 0x76, 0x62, 0x5E, 0x7E, 0xC6,
0xF4, 0x4C, 0x42, 0xE9, 0xA6, 0x37, 0xED, 0x6B, 0x0B, 0xFF, 0x5C, 0xB6, 0xF4, 0x06, 0xB7, 0xED,
];
pub const DH_GENERATOR: u64 = 2;
pub const VC: [u8; 8] = [0u8; 8];
pub const MAX_PADDING: usize = 512;
pub const RC4_DISCARD: usize = 1024;
pub const CRYPTO_PLAINTEXT: u32 = 0x01;
pub const CRYPTO_RC4: u32 = 0x02;
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EncryptionPolicy {
Disabled,
Allowed,
#[default]
Preferred,
Required,
}
#[derive(Debug, Clone)]
pub struct MseConfig {
pub policy: EncryptionPolicy,
pub allow_plaintext: bool,
pub allow_rc4: bool,
pub min_padding: usize,
pub max_padding: usize,
}
impl Default for MseConfig {
fn default() -> Self {
Self {
policy: EncryptionPolicy::Preferred,
allow_plaintext: true,
allow_rc4: true,
min_padding: 0,
max_padding: MAX_PADDING,
}
}
}
impl MseConfig {
pub fn crypto_provide(&self) -> u32 {
let mut provide = 0u32;
if self.allow_plaintext {
provide |= CRYPTO_PLAINTEXT;
}
if self.allow_rc4 {
provide |= CRYPTO_RC4;
}
provide
}
}
#[derive(Clone)]
pub struct Rc4Cipher {
state: [u8; 256],
i: u8,
j: u8,
}
impl Rc4Cipher {
pub fn new(key: &[u8]) -> Self {
let mut state = [0u8; 256];
for (i, byte) in state.iter_mut().enumerate() {
*byte = i as u8;
}
let mut j: u8 = 0;
for i in 0..256 {
j = j.wrapping_add(state[i]).wrapping_add(key[i % key.len()]);
state.swap(i, j as usize);
}
let mut cipher = Self { state, i: 0, j: 0 };
let mut discard = [0u8; RC4_DISCARD];
cipher.process(&mut discard);
cipher
}
#[cfg(test)]
pub fn new_no_discard(key: &[u8]) -> Self {
let mut state = [0u8; 256];
for (i, byte) in state.iter_mut().enumerate() {
*byte = i as u8;
}
let mut j: u8 = 0;
for i in 0..256 {
j = j.wrapping_add(state[i]).wrapping_add(key[i % key.len()]);
state.swap(i, j as usize);
}
Self { state, i: 0, j: 0 }
}
pub fn process(&mut self, data: &mut [u8]) {
for byte in data.iter_mut() {
self.i = self.i.wrapping_add(1);
self.j = self.j.wrapping_add(self.state[self.i as usize]);
self.state.swap(self.i as usize, self.j as usize);
let k = self.state
[(self.state[self.i as usize].wrapping_add(self.state[self.j as usize])) as usize];
*byte ^= k;
}
}
}
pub struct DhKeyPair {
private: BigUint,
public: [u8; 96],
}
impl DhKeyPair {
pub fn generate() -> Self {
let mut private_bytes = [0u8; 20];
rand::rng().fill(&mut private_bytes);
let private = BigUint::from_bytes_be(&private_bytes);
let g = BigUint::from(DH_GENERATOR);
let p = BigUint::from_bytes_be(&DH_PRIME);
let public_big = g.modpow(&private, &p);
let public_bytes = public_big.to_bytes_be();
let mut public = [0u8; 96];
let offset = 96 - public_bytes.len().min(96);
public[offset..].copy_from_slice(&public_bytes[..public_bytes.len().min(96)]);
Self { private, public }
}
pub fn compute_shared_secret(&self, peer_public: &[u8; 96]) -> [u8; 96] {
let peer = BigUint::from_bytes_be(peer_public);
let p = BigUint::from_bytes_be(&DH_PRIME);
let secret = peer.modpow(&self.private, &p);
let secret_bytes = secret.to_bytes_be();
let mut result = [0u8; 96];
let offset = 96 - secret_bytes.len().min(96);
result[offset..].copy_from_slice(&secret_bytes[..secret_bytes.len().min(96)]);
result
}
pub fn public_bytes(&self) -> &[u8; 96] {
&self.public
}
}
pub fn derive_rc4_keys(
shared_secret: &[u8; 96],
info_hash: &Sha1Hash,
is_initiator: bool,
) -> (Rc4Cipher, Rc4Cipher) {
let mut hasher_a = Sha1::new();
hasher_a.update(b"keyA");
hasher_a.update(shared_secret);
hasher_a.update(info_hash);
let key_a: [u8; 20] = hasher_a.finalize().into();
let mut hasher_b = Sha1::new();
hasher_b.update(b"keyB");
hasher_b.update(shared_secret);
hasher_b.update(info_hash);
let key_b: [u8; 20] = hasher_b.finalize().into();
if is_initiator {
(Rc4Cipher::new(&key_a), Rc4Cipher::new(&key_b))
} else {
(Rc4Cipher::new(&key_b), Rc4Cipher::new(&key_a))
}
}
pub struct EncryptedStream {
inner: TcpStream,
encrypt_cipher: Rc4Cipher,
decrypt_cipher: Rc4Cipher,
pub crypto_method: u32,
}
impl EncryptedStream {
pub fn new(
stream: TcpStream,
encrypt_cipher: Rc4Cipher,
decrypt_cipher: Rc4Cipher,
crypto_method: u32,
) -> Self {
Self {
inner: stream,
encrypt_cipher,
decrypt_cipher,
crypto_method,
}
}
pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.inner.read(buf).await?;
if n > 0 && self.crypto_method == CRYPTO_RC4 {
self.decrypt_cipher.process(&mut buf[..n]);
}
Ok(n)
}
pub async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
self.inner.read_exact(buf).await?;
if self.crypto_method == CRYPTO_RC4 {
self.decrypt_cipher.process(buf);
}
Ok(())
}
pub async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
if self.crypto_method == CRYPTO_RC4 {
let mut encrypted = buf.to_vec();
self.encrypt_cipher.process(&mut encrypted);
self.inner.write_all(&encrypted).await
} else {
self.inner.write_all(buf).await
}
}
pub async fn flush(&mut self) -> io::Result<()> {
self.inner.flush().await
}
pub fn peer_addr(&self) -> io::Result<std::net::SocketAddr> {
self.inner.peer_addr()
}
pub fn local_addr(&self) -> io::Result<std::net::SocketAddr> {
self.inner.local_addr()
}
pub async fn shutdown(&mut self) -> io::Result<()> {
self.inner.shutdown().await
}
}
pub enum MseHandshakeResult {
Encrypted(Box<EncryptedStream>),
Plaintext(TcpStream, Vec<u8>),
Failed(EngineError),
}
pub async fn mse_handshake_outgoing(
mut stream: TcpStream,
info_hash: Sha1Hash,
config: &MseConfig,
) -> MseHandshakeResult {
let key_pair = DhKeyPair::generate();
let padding_len = rand::rng().random_range(config.min_padding..=config.max_padding);
let mut padding = vec![0u8; padding_len];
rand::RngExt::fill(&mut rand::rng(), &mut padding[..]);
let mut send_buf = Vec::with_capacity(96 + padding_len);
send_buf.extend_from_slice(key_pair.public_bytes());
send_buf.extend_from_slice(&padding);
if let Err(e) = timeout(HANDSHAKE_TIMEOUT, stream.write_all(&send_buf)).await {
return match config.policy {
EncryptionPolicy::Required => MseHandshakeResult::Failed(EngineError::network(
NetworkErrorKind::Timeout,
format!("MSE handshake timeout: {}", e),
)),
_ => MseHandshakeResult::Plaintext(stream, vec![]),
};
}
let mut yb = [0u8; 96];
match timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut yb)).await {
Ok(Ok(_)) => {}
Ok(Err(e)) => {
return match config.policy {
EncryptionPolicy::Required => MseHandshakeResult::Failed(EngineError::network(
NetworkErrorKind::ConnectionReset,
format!("Failed to receive Yb: {}", e),
)),
_ => MseHandshakeResult::Plaintext(stream, vec![]),
};
}
Err(_) => {
return match config.policy {
EncryptionPolicy::Required => MseHandshakeResult::Failed(EngineError::network(
NetworkErrorKind::Timeout,
"Timeout receiving Yb",
)),
_ => MseHandshakeResult::Plaintext(stream, vec![]),
};
}
}
let shared_secret = key_pair.compute_shared_secret(&yb);
let mut hasher = Sha1::new();
hasher.update(b"req1");
hasher.update(shared_secret);
let req1_hash: [u8; 20] = hasher.finalize().into();
let mut hasher = Sha1::new();
hasher.update(b"req2");
hasher.update(info_hash);
let req2_hash: [u8; 20] = hasher.finalize().into();
let mut hasher = Sha1::new();
hasher.update(b"req3");
hasher.update(shared_secret);
let req3_hash: [u8; 20] = hasher.finalize().into();
let mut skey_hash = [0u8; 20];
for i in 0..20 {
skey_hash[i] = req2_hash[i] ^ req3_hash[i];
}
let mut crypto_provide: u32 = 0;
if config.allow_rc4 {
crypto_provide |= CRYPTO_RC4;
}
if config.allow_plaintext {
crypto_provide |= CRYPTO_PLAINTEXT;
}
let (mut encrypt_cipher, _) = derive_rc4_keys(&shared_secret, &info_hash, true);
let padc_len: u16 = rand::rng().random_range(0..512);
let mut padc = vec![0u8; padc_len as usize];
rand::RngExt::fill(&mut rand::rng(), &mut padc[..]);
let ia_len: u16 = 0;
let mut encrypted_part = Vec::new();
encrypted_part.extend_from_slice(&VC);
encrypted_part.extend_from_slice(&crypto_provide.to_be_bytes());
encrypted_part.extend_from_slice(&padc_len.to_be_bytes());
encrypted_part.extend_from_slice(&padc);
encrypted_part.extend_from_slice(&ia_len.to_be_bytes());
encrypt_cipher.process(&mut encrypted_part);
let mut send_buf = Vec::new();
send_buf.extend_from_slice(&req1_hash);
send_buf.extend_from_slice(&skey_hash);
send_buf.extend_from_slice(&encrypted_part);
if let Err(e) = timeout(HANDSHAKE_TIMEOUT, stream.write_all(&send_buf)).await {
return MseHandshakeResult::Failed(EngineError::network(
NetworkErrorKind::Timeout,
format!("Failed to send crypto handshake: {}", e),
));
}
let (_, decrypt_cipher) = derive_rc4_keys(&shared_secret, &info_hash, true);
match receive_crypto_response(&mut stream, decrypt_cipher, config).await {
Ok((crypto_method, final_decrypt)) => {
MseHandshakeResult::Encrypted(Box::new(EncryptedStream::new(
stream,
encrypt_cipher,
final_decrypt,
crypto_method,
)))
}
Err(e) => MseHandshakeResult::Failed(e),
}
}
async fn receive_crypto_response(
stream: &mut TcpStream,
mut decrypt_cipher: Rc4Cipher,
_config: &MseConfig,
) -> Result<(u32, Rc4Cipher)> {
let mut buf = [0u8; 14];
timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut buf))
.await
.map_err(|_| EngineError::network(NetworkErrorKind::Timeout, "Timeout reading response"))?
.map_err(|e| {
EngineError::network(
NetworkErrorKind::ConnectionReset,
format!("Failed to read response: {}", e),
)
})?;
decrypt_cipher.process(&mut buf);
if buf[..8] != VC {
return Err(EngineError::protocol(
ProtocolErrorKind::PeerProtocol,
"Invalid VC in response",
));
}
let crypto_select = u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]]);
let padd_len = u16::from_be_bytes([buf[12], buf[13]]) as usize;
if padd_len > 0 {
let mut padd = vec![0u8; padd_len];
timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut padd))
.await
.map_err(|_| EngineError::network(NetworkErrorKind::Timeout, "Timeout reading PadD"))?
.map_err(|e| {
EngineError::network(
NetworkErrorKind::ConnectionReset,
format!("Failed to read PadD: {}", e),
)
})?;
decrypt_cipher.process(&mut padd);
}
Ok((crypto_select, decrypt_cipher))
}
pub enum PeerStream {
Plain(TcpStream),
Encrypted(Box<EncryptedStream>),
Utp(super::utp::UtpSocket),
}
impl PeerStream {
pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Plain(s) => s.read(buf).await,
Self::Encrypted(s) => s.as_mut().read(buf).await,
Self::Utp(s) => s.read(buf).await,
}
}
pub async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
match self {
Self::Plain(s) => s.read_exact(buf).await.map(|_| ()),
Self::Encrypted(s) => s.as_mut().read_exact(buf).await,
Self::Utp(s) => s.read_exact(buf).await,
}
}
pub async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
match self {
Self::Plain(s) => s.write_all(buf).await,
Self::Encrypted(s) => s.as_mut().write_all(buf).await,
Self::Utp(s) => s.write_all(buf).await,
}
}
pub async fn flush(&mut self) -> io::Result<()> {
match self {
Self::Plain(s) => s.flush().await,
Self::Encrypted(s) => s.as_mut().flush().await,
Self::Utp(s) => s.flush().await,
}
}
pub fn is_encrypted(&self) -> bool {
matches!(self, Self::Encrypted(_))
}
pub fn peer_addr(&self) -> io::Result<std::net::SocketAddr> {
match self {
Self::Plain(s) => s.peer_addr(),
Self::Encrypted(s) => s.as_ref().peer_addr(),
Self::Utp(s) => s.peer_addr(),
}
}
pub async fn shutdown(&mut self) -> io::Result<()> {
match self {
Self::Plain(s) => s.shutdown().await,
Self::Encrypted(s) => s.as_mut().shutdown().await,
Self::Utp(s) => s.shutdown().await,
}
}
}
pub async fn connect_with_mse(
stream: TcpStream,
info_hash: Sha1Hash,
config: &MseConfig,
) -> Result<PeerStream> {
match config.policy {
EncryptionPolicy::Disabled => Ok(PeerStream::Plain(stream)),
EncryptionPolicy::Allowed => {
Ok(PeerStream::Plain(stream))
}
EncryptionPolicy::Preferred | EncryptionPolicy::Required => {
match mse_handshake_outgoing(stream, info_hash, config).await {
MseHandshakeResult::Encrypted(enc) => Ok(PeerStream::Encrypted(enc)),
MseHandshakeResult::Plaintext(stream, _) => {
if config.policy == EncryptionPolicy::Required {
Err(EngineError::protocol(
ProtocolErrorKind::PeerProtocol,
"Encryption required but peer doesn't support it",
))
} else {
Ok(PeerStream::Plain(stream))
}
}
MseHandshakeResult::Failed(e) => {
if config.policy == EncryptionPolicy::Required {
Err(e)
} else {
Err(e)
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rc4_cipher() {
let key = b"test_key_1234567";
let mut cipher1 = Rc4Cipher::new_no_discard(key);
let mut cipher2 = Rc4Cipher::new_no_discard(key);
let original = b"Hello, World! This is a test message.";
let mut data = original.to_vec();
cipher1.process(&mut data);
assert_ne!(&data[..], &original[..]);
cipher2.process(&mut data);
assert_eq!(&data[..], &original[..]);
}
#[test]
fn test_dh_key_exchange() {
let alice = DhKeyPair::generate();
let bob = DhKeyPair::generate();
let secret_a = alice.compute_shared_secret(bob.public_bytes());
let secret_b = bob.compute_shared_secret(alice.public_bytes());
assert_eq!(secret_a, secret_b, "Shared secrets should match");
}
#[test]
fn test_key_derivation() {
let shared_secret = [0x42u8; 96];
let info_hash = [0x12u8; 20];
let (enc_a, dec_a) = derive_rc4_keys(&shared_secret, &info_hash, true);
let (enc_b, dec_b) = derive_rc4_keys(&shared_secret, &info_hash, false);
let original = b"test data for encryption";
let mut data = original.to_vec();
let mut enc_a = enc_a;
let mut dec_b = dec_b;
enc_a.process(&mut data);
dec_b.process(&mut data);
assert_eq!(&data[..], &original[..]);
let mut data = original.to_vec();
let mut enc_b = enc_b;
let mut dec_a = dec_a;
enc_b.process(&mut data);
dec_a.process(&mut data);
assert_eq!(&data[..], &original[..]);
}
#[test]
fn test_mse_config_default() {
let config = MseConfig::default();
assert_eq!(config.policy, EncryptionPolicy::Preferred);
assert!(config.allow_plaintext);
assert!(config.allow_rc4);
}
#[test]
fn test_encryption_policy_disabled() {
let policy = EncryptionPolicy::Disabled;
assert_eq!(policy, EncryptionPolicy::Disabled);
}
#[test]
fn test_encryption_policy_default_is_preferred() {
let policy = EncryptionPolicy::default();
assert_eq!(policy, EncryptionPolicy::Preferred);
}
#[test]
fn test_mse_config_crypto_provide() {
let mut config = MseConfig::default();
config.allow_plaintext = true;
config.allow_rc4 = true;
let provide = config.crypto_provide();
assert_eq!(provide, CRYPTO_PLAINTEXT | CRYPTO_RC4);
config.allow_plaintext = false;
config.allow_rc4 = true;
let provide = config.crypto_provide();
assert_eq!(provide, CRYPTO_RC4);
config.allow_plaintext = true;
config.allow_rc4 = false;
let provide = config.crypto_provide();
assert_eq!(provide, CRYPTO_PLAINTEXT);
config.allow_plaintext = false;
config.allow_rc4 = false;
let provide = config.crypto_provide();
assert_eq!(provide, 0);
}
#[test]
fn test_crypto_constants() {
assert_eq!(CRYPTO_PLAINTEXT, 0x01);
assert_eq!(CRYPTO_RC4, 0x02);
assert_eq!(CRYPTO_PLAINTEXT & CRYPTO_RC4, 0);
}
#[test]
fn test_dh_prime_length() {
assert_eq!(DH_PRIME.len(), 96);
}
#[test]
fn test_rc4_discard_constant() {
assert_eq!(RC4_DISCARD, 1024);
}
#[test]
fn test_vc_constant() {
assert_eq!(VC.len(), 8);
assert!(VC.iter().all(|&b| b == 0));
}
#[test]
fn test_max_padding() {
assert_eq!(MAX_PADDING, 512);
}
#[test]
fn test_rc4_discard_security() {
let key = b"security_test_key";
let mut with_discard = Rc4Cipher::new(key);
let mut without_discard = Rc4Cipher::new_no_discard(key);
let mut data1 = vec![0u8; 32];
let mut data2 = vec![0u8; 32];
with_discard.process(&mut data1);
without_discard.process(&mut data2);
assert_ne!(
data1, data2,
"RC4 with discard should produce different output"
);
}
#[test]
fn test_dh_generates_unique_keys() {
let kp1 = DhKeyPair::generate();
let kp2 = DhKeyPair::generate();
assert_ne!(
kp1.public_bytes(),
kp2.public_bytes(),
"Generated key pairs should be unique"
);
}
#[test]
fn test_bidirectional_encryption() {
let shared_secret = [0xABu8; 96];
let info_hash = [0xCDu8; 20];
let (mut enc_a, mut dec_a) = derive_rc4_keys(&shared_secret, &info_hash, true);
let (mut enc_b, mut dec_b) = derive_rc4_keys(&shared_secret, &info_hash, false);
let msg_a_to_b = b"Hello from A";
let mut encrypted = msg_a_to_b.to_vec();
enc_a.process(&mut encrypted);
dec_b.process(&mut encrypted);
assert_eq!(&encrypted[..], msg_a_to_b);
let msg_b_to_a = b"Hello from B";
let mut encrypted = msg_b_to_a.to_vec();
enc_b.process(&mut encrypted);
dec_a.process(&mut encrypted);
assert_eq!(&encrypted[..], msg_b_to_a);
}
}