use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use irontide_core::Id20;
use super::cipher::Rc4;
use super::crypto;
use super::dh::{DH_KEY_SIZE, DhKeypair};
use super::stream::MseStream;
use crate::error::{Error, Result};
const SCAN_CHUNK_SIZE: usize = 256;
async fn scan_for_marker<S: AsyncRead + Unpin>(
stream: &mut S,
marker: &[u8],
max_scan: usize,
) -> Result<Vec<u8>> {
let marker_len = marker.len();
let mut scan_buf = Vec::with_capacity(max_scan + SCAN_CHUNK_SIZE);
let mut total_read = 0usize;
while total_read < max_scan {
let want = SCAN_CHUNK_SIZE.min(max_scan - total_read + marker_len);
let old_len = scan_buf.len();
scan_buf.resize(old_len + want, 0);
let n = stream.read(&mut scan_buf[old_len..]).await?;
if n == 0 {
return Err(Error::EncryptionHandshakeFailed(
"stream closed during marker scan".into(),
));
}
scan_buf.truncate(old_len + n);
total_read += n;
if scan_buf.len() >= marker_len {
let search_start = if old_len >= marker_len {
old_len - marker_len + 1
} else {
0
};
if let Some(pos) = scan_buf[search_start..]
.windows(marker_len)
.rposition(|w| w == marker)
{
let match_end = search_start + pos + marker_len;
return Ok(scan_buf[match_end..].to_vec());
}
}
}
Err(Error::EncryptionHandshakeFailed(
"marker not found within scan limit".into(),
))
}
async fn read_from_overflow_then_stream<S: AsyncRead + Unpin>(
overflow: &mut Vec<u8>,
stream: &mut S,
buf: &mut [u8],
) -> Result<()> {
let from_overflow = overflow.len().min(buf.len());
buf[..from_overflow].copy_from_slice(&overflow[..from_overflow]);
overflow.drain(..from_overflow);
if from_overflow < buf.len() {
stream.read_exact(&mut buf[from_overflow..]).await?;
}
Ok(())
}
pub struct NegotiationResult<S> {
pub stream: MseStream<S>,
pub crypto_method: u32,
}
pub async fn negotiate_outbound<S>(
mut stream: S,
skey: &Id20,
crypto_provide: u32,
) -> Result<NegotiationResult<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let skey_bytes = skey.as_bytes();
let dh = DhKeypair::generate();
stream.write_all(&dh.public).await?;
stream.flush().await?;
let mut yb = [0u8; DH_KEY_SIZE];
stream.read_exact(&mut yb).await?;
let s = dh.shared_secret(&yb);
let ka = crypto::key_a(&s, skey_bytes);
let kb = crypto::key_b(&s, skey_bytes);
let mut encrypt_cipher = Rc4::new(&ka);
let sync = crypto::sync_marker(&s);
let proof = crypto::skey_proof(skey_bytes, &s);
let mut encrypted_part = Vec::new();
encrypted_part.extend_from_slice(&crypto::VC); encrypted_part.extend_from_slice(&crypto_provide.to_be_bytes()); encrypted_part.extend_from_slice(&0u16.to_be_bytes()); encrypted_part.extend_from_slice(&0u16.to_be_bytes()); encrypt_cipher.apply(&mut encrypted_part);
stream.write_all(&sync).await?;
stream.write_all(&proof).await?;
stream.write_all(&encrypted_part).await?;
stream.flush().await?;
let mut encrypted_vc = crypto::VC; let mut vc_cipher = Rc4::new(&kb);
vc_cipher.apply(&mut encrypted_vc);
let mut overflow = scan_for_marker(&mut stream, &encrypted_vc, 512 + 8).await?;
let mut scan_cipher = Rc4::new(&kb);
let mut skip = [0u8; 8];
scan_cipher.apply(&mut skip);
let mut select_buf = [0u8; 6];
let from_overflow = overflow.len().min(6);
select_buf[..from_overflow].copy_from_slice(&overflow[..from_overflow]);
overflow.drain(..from_overflow);
if from_overflow < 6 {
stream.read_exact(&mut select_buf[from_overflow..]).await?;
}
scan_cipher.apply(&mut select_buf);
let crypto_select =
u32::from_be_bytes([select_buf[0], select_buf[1], select_buf[2], select_buf[3]]);
let pad_len = u16::from_be_bytes([select_buf[4], select_buf[5]]) as usize;
if crypto_select & crypto_provide == 0 {
return Err(Error::UnsupportedCryptoMethod);
}
if pad_len > 0 {
let mut pad = vec![0u8; pad_len];
let from_overflow = overflow.len().min(pad_len);
pad[..from_overflow].copy_from_slice(&overflow[..from_overflow]);
overflow.drain(..from_overflow);
if from_overflow < pad_len {
stream.read_exact(&mut pad[from_overflow..]).await?;
}
scan_cipher.apply(&mut pad);
}
if !overflow.is_empty() {
scan_cipher.apply(&mut overflow);
}
let result_stream = if crypto_select & crypto::CRYPTO_RC4 != 0 {
MseStream::encrypted(stream, scan_cipher, encrypt_cipher, overflow)
} else {
MseStream::plaintext(stream)
};
Ok(NegotiationResult {
stream: result_stream,
crypto_method: crypto_select,
})
}
pub async fn negotiate_inbound<S>(
mut stream: S,
skey: &Id20,
prefer_rc4: bool,
) -> Result<NegotiationResult<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let skey_bytes = skey.as_bytes();
let mut ya = [0u8; DH_KEY_SIZE];
stream.read_exact(&mut ya).await?;
let dh = DhKeypair::generate();
stream.write_all(&dh.public).await?;
stream.flush().await?;
let s = dh.shared_secret(&ya);
let expected_sync = crypto::sync_marker(&s);
let mut overflow = scan_for_marker(&mut stream, &expected_sync, 512 + 20).await?;
let mut proof = [0u8; 20];
let from_overflow = overflow.len().min(20);
proof[..from_overflow].copy_from_slice(&overflow[..from_overflow]);
overflow.drain(..from_overflow);
if from_overflow < 20 {
stream.read_exact(&mut proof[from_overflow..]).await?;
}
let expected_proof = crypto::skey_proof(skey_bytes, &s);
if proof != expected_proof {
return Err(Error::EncryptionHandshakeFailed(
"SKEY proof mismatch".into(),
));
}
let ka = crypto::key_a(&s, skey_bytes);
let kb = crypto::key_b(&s, skey_bytes);
let mut decrypt_cipher = Rc4::new(&ka); let mut encrypt_cipher = Rc4::new(&kb);
let mut enc_header = [0u8; 14];
read_from_overflow_then_stream(&mut overflow, &mut stream, &mut enc_header).await?;
decrypt_cipher.apply(&mut enc_header);
if enc_header[..8] != crypto::VC {
return Err(Error::EncryptionHandshakeFailed("VC mismatch".into()));
}
let crypto_provide =
u32::from_be_bytes([enc_header[8], enc_header[9], enc_header[10], enc_header[11]]);
let pad_c_len = u16::from_be_bytes([enc_header[12], enc_header[13]]) as usize;
if pad_c_len > 0 {
let mut pad = vec![0u8; pad_c_len];
read_from_overflow_then_stream(&mut overflow, &mut stream, &mut pad).await?;
decrypt_cipher.apply(&mut pad);
}
let mut ia_len_buf = [0u8; 2];
read_from_overflow_then_stream(&mut overflow, &mut stream, &mut ia_len_buf).await?;
decrypt_cipher.apply(&mut ia_len_buf);
let ia_len = u16::from_be_bytes(ia_len_buf) as usize;
if ia_len > 0 {
let mut ia = vec![0u8; ia_len];
read_from_overflow_then_stream(&mut overflow, &mut stream, &mut ia).await?;
decrypt_cipher.apply(&mut ia);
}
let crypto_select = if prefer_rc4 && (crypto_provide & crypto::CRYPTO_RC4 != 0) {
crypto::CRYPTO_RC4
} else if crypto_provide & crypto::CRYPTO_PLAINTEXT != 0 {
crypto::CRYPTO_PLAINTEXT
} else if crypto_provide & crypto::CRYPTO_RC4 != 0 {
crypto::CRYPTO_RC4
} else {
return Err(Error::UnsupportedCryptoMethod);
};
let mut response = Vec::new();
response.extend_from_slice(&crypto::VC); response.extend_from_slice(&crypto_select.to_be_bytes()); response.extend_from_slice(&0u16.to_be_bytes()); encrypt_cipher.apply(&mut response);
stream.write_all(&response).await?;
stream.flush().await?;
let result_stream = if crypto_select & crypto::CRYPTO_RC4 != 0 {
MseStream::encrypted(stream, decrypt_cipher, encrypt_cipher, Vec::new())
} else {
MseStream::plaintext(stream)
};
Ok(NegotiationResult {
stream: result_stream,
crypto_method: crypto_select,
})
}
#[cfg(test)]
mod tests {
use super::super::crypto;
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn full_handshake_rc4() {
let info_hash = Id20::from([0xAA; 20]);
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client_handle = tokio::spawn(async move {
negotiate_outbound(client_stream, &info_hash, crypto::CRYPTO_RC4).await
});
let server_handle = tokio::spawn(async move {
negotiate_inbound(
server_stream,
&info_hash,
true, )
.await
});
let client_result = client_handle.await.unwrap().unwrap();
let server_result = server_handle.await.unwrap().unwrap();
assert_eq!(client_result.crypto_method, crypto::CRYPTO_RC4);
assert_eq!(server_result.crypto_method, crypto::CRYPTO_RC4);
let mut client = client_result.stream;
let mut server = server_result.stream;
client.write_all(b"ping").await.unwrap();
client.flush().await.unwrap();
let mut buf = [0u8; 4];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
server.write_all(b"pong").await.unwrap();
server.flush().await.unwrap();
let mut buf = [0u8; 4];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"pong");
}
#[tokio::test]
async fn full_handshake_plaintext() {
let info_hash = Id20::from([0xBB; 20]);
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client_handle = tokio::spawn(async move {
negotiate_outbound(client_stream, &info_hash, crypto::CRYPTO_PLAINTEXT).await
});
let server_handle = tokio::spawn(async move {
negotiate_inbound(
server_stream,
&info_hash,
false, )
.await
});
let client_result = client_handle.await.unwrap().unwrap();
let server_result = server_handle.await.unwrap().unwrap();
assert_eq!(client_result.crypto_method, crypto::CRYPTO_PLAINTEXT);
assert_eq!(server_result.crypto_method, crypto::CRYPTO_PLAINTEXT);
}
#[tokio::test]
async fn full_handshake_with_pad_b() {
let info_hash = Id20::from([0xEE; 20]);
let pad_b_len = 73;
let (client_stream, mut server_stream) = tokio::io::duplex(4096);
let client_handle = tokio::spawn(async move {
negotiate_outbound(
client_stream,
&info_hash,
crypto::CRYPTO_RC4 | crypto::CRYPTO_PLAINTEXT,
)
.await
});
let server_handle = tokio::spawn(async move {
let skey_bytes = info_hash.as_bytes();
let mut ya = [0u8; DH_KEY_SIZE];
server_stream.read_exact(&mut ya).await.unwrap();
let dh = DhKeypair::generate();
server_stream.write_all(&dh.public).await.unwrap();
let pad_b = vec![0xABu8; pad_b_len];
server_stream.write_all(&pad_b).await.unwrap();
server_stream.flush().await.unwrap();
let s = dh.shared_secret(&ya);
let expected_sync = crypto::sync_marker(&s);
let mut scan_buf = Vec::new();
for _ in 0..(512 + 20) {
let mut byte = [0u8; 1];
server_stream.read_exact(&mut byte).await.unwrap();
scan_buf.push(byte[0]);
if scan_buf.len() >= 20 && scan_buf[scan_buf.len() - 20..] == expected_sync {
break;
}
}
let mut proof = [0u8; 20];
server_stream.read_exact(&mut proof).await.unwrap();
let ka = crypto::key_a(&s, skey_bytes);
let kb = crypto::key_b(&s, skey_bytes);
let mut decrypt_cipher = Rc4::new(&ka);
let mut encrypt_cipher = Rc4::new(&kb);
let mut enc_header = [0u8; 14];
server_stream.read_exact(&mut enc_header).await.unwrap();
decrypt_cipher.apply(&mut enc_header);
assert_eq!(&enc_header[..8], &crypto::VC);
let mut ia_len_buf = [0u8; 2];
server_stream.read_exact(&mut ia_len_buf).await.unwrap();
decrypt_cipher.apply(&mut ia_len_buf);
let mut response = Vec::new();
response.extend_from_slice(&crypto::VC);
response.extend_from_slice(&crypto::CRYPTO_RC4.to_be_bytes());
response.extend_from_slice(&0u16.to_be_bytes()); encrypt_cipher.apply(&mut response);
server_stream.write_all(&response).await.unwrap();
server_stream.flush().await.unwrap();
MseStream::encrypted(server_stream, decrypt_cipher, encrypt_cipher, Vec::new())
});
let client_result = client_handle.await.unwrap().unwrap();
let mut server_stream = server_handle.await.unwrap();
assert_eq!(client_result.crypto_method, crypto::CRYPTO_RC4);
let mut client = client_result.stream;
client.write_all(b"hello pad").await.unwrap();
client.flush().await.unwrap();
let mut buf = [0u8; 9];
server_stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello pad");
server_stream.write_all(b"world pad").await.unwrap();
server_stream.flush().await.unwrap();
let mut buf = [0u8; 9];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"world pad");
}
#[tokio::test]
async fn handshake_skey_mismatch_fails() {
let client_hash = Id20::from([0xCC; 20]);
let server_hash = Id20::from([0xDD; 20]);
let (client_stream, server_stream) = tokio::io::duplex(4096);
let client_handle = tokio::spawn(async move {
negotiate_outbound(client_stream, &client_hash, crypto::CRYPTO_RC4).await
});
let server_handle =
tokio::spawn(async move { negotiate_inbound(server_stream, &server_hash, true).await });
let client_result = client_handle.await.unwrap();
let server_result = server_handle.await.unwrap();
assert!(client_result.is_err() || server_result.is_err());
}
#[tokio::test]
async fn scan_for_marker_finds_at_offset_zero() {
let marker = b"MARK";
let mut data = Vec::new();
data.extend_from_slice(marker);
data.extend_from_slice(b"trailing");
let mut cursor = std::io::Cursor::new(data);
let overflow = scan_for_marker(&mut cursor, marker, 100).await.unwrap();
assert_eq!(&overflow, b"trailing");
}
#[tokio::test]
async fn scan_for_marker_stream_closed() {
let marker = b"NOTHERE";
let data = b"short";
let mut cursor = std::io::Cursor::new(data.to_vec());
let result = scan_for_marker(&mut cursor, marker, 100).await;
assert!(result.is_err());
}
#[tokio::test]
async fn scan_for_marker_overflow_correctness() {
let marker = b"VC";
let mut data = Vec::new();
data.extend_from_slice(b"padding_padding_");
data.extend_from_slice(marker);
data.extend_from_slice(b"extra_bytes_after");
let mut cursor = std::io::Cursor::new(data);
let overflow = scan_for_marker(&mut cursor, marker, 200).await.unwrap();
assert_eq!(&overflow, b"extra_bytes_after");
}
#[tokio::test]
async fn scan_for_marker_chunk_boundary() {
let marker = [0xAA; 8];
let mut data = vec![0u8; 253];
data.extend_from_slice(&marker);
data.extend_from_slice(b"tail");
let mut cursor = std::io::Cursor::new(data);
let overflow = scan_for_marker(&mut cursor, &marker, 512).await.unwrap();
assert_eq!(&overflow, b"tail");
}
#[tokio::test]
async fn scan_for_marker_max_pad_512() {
let marker = [0xBB; 20];
let mut data = vec![0u8; 512];
data.extend_from_slice(&marker);
let mut cursor = std::io::Cursor::new(data);
let overflow = scan_for_marker(&mut cursor, &marker, 512 + 20)
.await
.unwrap();
assert!(overflow.is_empty());
}
#[tokio::test]
async fn scan_for_marker_exceeds_max() {
let marker = b"NEEDLE";
let mut data = vec![0u8; 100];
data.extend_from_slice(marker);
let mut cursor = std::io::Cursor::new(data);
let result = scan_for_marker(&mut cursor, marker, 50).await;
assert!(result.is_err());
}
}