use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use irontide_core::Id20;
use super::cipher::Rc4;
use super::crypto;
use super::dh::{DH_KEY_SIZE, DhKeypair};
use super::stream::MseStream;
use crate::error::{Error, Result};
pub struct NegotiationResult<S> {
pub stream: MseStream<S>,
pub crypto_method: u32,
}
pub async fn negotiate_outbound<S>(
mut stream: S,
skey: &Id20,
crypto_provide: u32,
) -> Result<NegotiationResult<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let skey_bytes = skey.as_bytes();
let dh = DhKeypair::generate();
stream.write_all(&dh.public).await?;
stream.flush().await?;
let mut yb = [0u8; DH_KEY_SIZE];
stream.read_exact(&mut yb).await?;
let s = dh.shared_secret(&yb);
let ka = crypto::key_a(&s, skey_bytes);
let kb = crypto::key_b(&s, skey_bytes);
let mut encrypt_cipher = Rc4::new(&ka);
let sync = crypto::sync_marker(&s);
let proof = crypto::skey_proof(skey_bytes, &s);
let mut encrypted_part = Vec::new();
encrypted_part.extend_from_slice(&crypto::VC); encrypted_part.extend_from_slice(&crypto_provide.to_be_bytes()); encrypted_part.extend_from_slice(&0u16.to_be_bytes()); encrypted_part.extend_from_slice(&0u16.to_be_bytes()); encrypt_cipher.apply(&mut encrypted_part);
stream.write_all(&sync).await?;
stream.write_all(&proof).await?;
stream.write_all(&encrypted_part).await?;
stream.flush().await?;
let mut encrypted_vc = crypto::VC; let mut vc_cipher = Rc4::new(&kb);
vc_cipher.apply(&mut encrypted_vc);
let mut raw_buf = Vec::new();
let max_scan = 512 + 8; let mut found_vc = false;
for _ in 0..max_scan {
let mut byte = [0u8; 1];
stream.read_exact(&mut byte).await?;
raw_buf.push(byte[0]);
if raw_buf.len() >= 8 && raw_buf[raw_buf.len() - 8..] == encrypted_vc {
found_vc = true;
break;
}
}
if !found_vc {
return Err(Error::EncryptionHandshakeFailed(
"VC not found in responder reply".into(),
));
}
let mut scan_cipher = Rc4::new(&kb);
let mut skip = [0u8; 8];
scan_cipher.apply(&mut skip);
let mut select_buf = [0u8; 6];
stream.read_exact(&mut select_buf).await?;
scan_cipher.apply(&mut select_buf);
let crypto_select =
u32::from_be_bytes([select_buf[0], select_buf[1], select_buf[2], select_buf[3]]);
let pad_len = u16::from_be_bytes([select_buf[4], select_buf[5]]) as usize;
if crypto_select & crypto_provide == 0 {
return Err(Error::UnsupportedCryptoMethod);
}
if pad_len > 0 {
let mut pad = vec![0u8; pad_len];
stream.read_exact(&mut pad).await?;
scan_cipher.apply(&mut pad);
}
let result_stream = if crypto_select & crypto::CRYPTO_RC4 != 0 {
MseStream::encrypted(stream, scan_cipher, encrypt_cipher)
} else {
MseStream::plaintext(stream)
};
Ok(NegotiationResult {
stream: result_stream,
crypto_method: crypto_select,
})
}
pub async fn negotiate_inbound<S>(
mut stream: S,
skey: &Id20,
prefer_rc4: bool,
) -> Result<NegotiationResult<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let skey_bytes = skey.as_bytes();
let mut ya = [0u8; DH_KEY_SIZE];
stream.read_exact(&mut ya).await?;
let dh = DhKeypair::generate();
stream.write_all(&dh.public).await?;
stream.flush().await?;
let s = dh.shared_secret(&ya);
let expected_sync = crypto::sync_marker(&s);
let mut scan_buf = Vec::new();
let max_scan = 512 + 20;
let mut found_sync = false;
for _ in 0..max_scan {
let mut byte = [0u8; 1];
stream.read_exact(&mut byte).await?;
scan_buf.push(byte[0]);
if scan_buf.len() >= 20 && scan_buf[scan_buf.len() - 20..] == expected_sync {
found_sync = true;
break;
}
}
if !found_sync {
return Err(Error::EncryptionHandshakeFailed(
"sync marker not found".into(),
));
}
let mut proof = [0u8; 20];
stream.read_exact(&mut proof).await?;
let expected_proof = crypto::skey_proof(skey_bytes, &s);
if proof != expected_proof {
return Err(Error::EncryptionHandshakeFailed(
"SKEY proof mismatch".into(),
));
}
let ka = crypto::key_a(&s, skey_bytes);
let kb = crypto::key_b(&s, skey_bytes);
let mut decrypt_cipher = Rc4::new(&ka); let mut encrypt_cipher = Rc4::new(&kb);
let mut enc_header = [0u8; 14];
stream.read_exact(&mut enc_header).await?;
decrypt_cipher.apply(&mut enc_header);
if enc_header[..8] != crypto::VC {
return Err(Error::EncryptionHandshakeFailed("VC mismatch".into()));
}
let crypto_provide =
u32::from_be_bytes([enc_header[8], enc_header[9], enc_header[10], enc_header[11]]);
let pad_c_len = u16::from_be_bytes([enc_header[12], enc_header[13]]) as usize;
if pad_c_len > 0 {
let mut pad = vec![0u8; pad_c_len];
stream.read_exact(&mut pad).await?;
decrypt_cipher.apply(&mut pad);
}
let mut ia_len_buf = [0u8; 2];
stream.read_exact(&mut ia_len_buf).await?;
decrypt_cipher.apply(&mut ia_len_buf);
let ia_len = u16::from_be_bytes(ia_len_buf) as usize;
if ia_len > 0 {
let mut ia = vec![0u8; ia_len];
stream.read_exact(&mut ia).await?;
decrypt_cipher.apply(&mut ia);
}
let crypto_select = if prefer_rc4 && (crypto_provide & crypto::CRYPTO_RC4 != 0) {
crypto::CRYPTO_RC4
} else if crypto_provide & crypto::CRYPTO_PLAINTEXT != 0 {
crypto::CRYPTO_PLAINTEXT
} else if crypto_provide & crypto::CRYPTO_RC4 != 0 {
crypto::CRYPTO_RC4
} else {
return Err(Error::UnsupportedCryptoMethod);
};
let mut response = Vec::new();
response.extend_from_slice(&crypto::VC); response.extend_from_slice(&crypto_select.to_be_bytes()); response.extend_from_slice(&0u16.to_be_bytes()); encrypt_cipher.apply(&mut response);
stream.write_all(&response).await?;
stream.flush().await?;
let result_stream = if crypto_select & crypto::CRYPTO_RC4 != 0 {
MseStream::encrypted(stream, decrypt_cipher, encrypt_cipher)
} else {
MseStream::plaintext(stream)
};
Ok(NegotiationResult {
stream: result_stream,
crypto_method: crypto_select,
})
}
#[cfg(test)]
mod tests {
use super::super::crypto;
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn full_handshake_rc4() {
let info_hash = Id20::from([0xAA; 20]);
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client_handle = tokio::spawn(async move {
negotiate_outbound(client_stream, &info_hash, crypto::CRYPTO_RC4).await
});
let server_handle = tokio::spawn(async move {
negotiate_inbound(
server_stream,
&info_hash,
true, )
.await
});
let client_result = client_handle.await.unwrap().unwrap();
let server_result = server_handle.await.unwrap().unwrap();
assert_eq!(client_result.crypto_method, crypto::CRYPTO_RC4);
assert_eq!(server_result.crypto_method, crypto::CRYPTO_RC4);
let mut client = client_result.stream;
let mut server = server_result.stream;
client.write_all(b"ping").await.unwrap();
client.flush().await.unwrap();
let mut buf = [0u8; 4];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
server.write_all(b"pong").await.unwrap();
server.flush().await.unwrap();
let mut buf = [0u8; 4];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"pong");
}
#[tokio::test]
async fn full_handshake_plaintext() {
let info_hash = Id20::from([0xBB; 20]);
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client_handle = tokio::spawn(async move {
negotiate_outbound(client_stream, &info_hash, crypto::CRYPTO_PLAINTEXT).await
});
let server_handle = tokio::spawn(async move {
negotiate_inbound(
server_stream,
&info_hash,
false, )
.await
});
let client_result = client_handle.await.unwrap().unwrap();
let server_result = server_handle.await.unwrap().unwrap();
assert_eq!(client_result.crypto_method, crypto::CRYPTO_PLAINTEXT);
assert_eq!(server_result.crypto_method, crypto::CRYPTO_PLAINTEXT);
}
#[tokio::test]
async fn full_handshake_with_pad_b() {
let info_hash = Id20::from([0xEE; 20]);
let pad_b_len = 73;
let (client_stream, mut server_stream) = tokio::io::duplex(4096);
let client_handle = tokio::spawn(async move {
negotiate_outbound(
client_stream,
&info_hash,
crypto::CRYPTO_RC4 | crypto::CRYPTO_PLAINTEXT,
)
.await
});
let server_handle = tokio::spawn(async move {
let skey_bytes = info_hash.as_bytes();
let mut ya = [0u8; DH_KEY_SIZE];
server_stream.read_exact(&mut ya).await.unwrap();
let dh = DhKeypair::generate();
server_stream.write_all(&dh.public).await.unwrap();
let pad_b = vec![0xABu8; pad_b_len];
server_stream.write_all(&pad_b).await.unwrap();
server_stream.flush().await.unwrap();
let s = dh.shared_secret(&ya);
let expected_sync = crypto::sync_marker(&s);
let mut scan_buf = Vec::new();
for _ in 0..(512 + 20) {
let mut byte = [0u8; 1];
server_stream.read_exact(&mut byte).await.unwrap();
scan_buf.push(byte[0]);
if scan_buf.len() >= 20 && scan_buf[scan_buf.len() - 20..] == expected_sync {
break;
}
}
let mut proof = [0u8; 20];
server_stream.read_exact(&mut proof).await.unwrap();
let ka = crypto::key_a(&s, skey_bytes);
let kb = crypto::key_b(&s, skey_bytes);
let mut decrypt_cipher = Rc4::new(&ka);
let mut encrypt_cipher = Rc4::new(&kb);
let mut enc_header = [0u8; 14];
server_stream.read_exact(&mut enc_header).await.unwrap();
decrypt_cipher.apply(&mut enc_header);
assert_eq!(&enc_header[..8], &crypto::VC);
let mut ia_len_buf = [0u8; 2];
server_stream.read_exact(&mut ia_len_buf).await.unwrap();
decrypt_cipher.apply(&mut ia_len_buf);
let mut response = Vec::new();
response.extend_from_slice(&crypto::VC);
response.extend_from_slice(&crypto::CRYPTO_RC4.to_be_bytes());
response.extend_from_slice(&0u16.to_be_bytes()); encrypt_cipher.apply(&mut response);
server_stream.write_all(&response).await.unwrap();
server_stream.flush().await.unwrap();
MseStream::encrypted(server_stream, decrypt_cipher, encrypt_cipher)
});
let client_result = client_handle.await.unwrap().unwrap();
let mut server_stream = server_handle.await.unwrap();
assert_eq!(client_result.crypto_method, crypto::CRYPTO_RC4);
let mut client = client_result.stream;
client.write_all(b"hello pad").await.unwrap();
client.flush().await.unwrap();
let mut buf = [0u8; 9];
server_stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello pad");
server_stream.write_all(b"world pad").await.unwrap();
server_stream.flush().await.unwrap();
let mut buf = [0u8; 9];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"world pad");
}
#[tokio::test]
async fn handshake_skey_mismatch_fails() {
let client_hash = Id20::from([0xCC; 20]);
let server_hash = Id20::from([0xDD; 20]);
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client_handle = tokio::spawn(async move {
negotiate_outbound(client_stream, &client_hash, crypto::CRYPTO_RC4).await
});
let server_handle =
tokio::spawn(async move { negotiate_inbound(server_stream, &server_hash, true).await });
let client_result = client_handle.await.unwrap();
let server_result = server_handle.await.unwrap();
assert!(client_result.is_err() || server_result.is_err());
}
}