diff --git a/src/crypto_connection/custom_tls_stream.rs b/src/crypto_connection/custom_tls_stream.rs
index 55247dd..7e36595 100644
@@ -325,6 +325,27 @@ where
}
}
+// Implement AsyncPing for CustomTlsStream
+impl<IO> crate::async_stream::AsyncPing for CustomTlsStream<IO>
+where
+ IO: AsyncRead + AsyncWrite + Unpin,
+{
+ fn supports_ping(&self) -> bool {
+ false // TLS doesn't have a built-in ping mechanism
+ }
+
+ fn poll_write_ping(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<bool>> {
+ Poll::Ready(Ok(false))
+ }
+}
+
+// Implement AsyncStream blanket trait
+impl<IO> crate::async_stream::AsyncStream for CustomTlsStream<IO>
+where
+ IO: AsyncRead + AsyncWrite + Unpin + Send,
+{
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/src/reality/mod.rs b/src/reality/mod.rs
index 1187fbb..786f78c 100644
@@ -9,7 +9,6 @@
mod reality_crypto;
mod reality_util;
-mod reality_client_handshake;
mod reality_client_handler;
mod reality_server_handler;
mod reality_destination;
@@ -17,8 +16,6 @@ mod reality_certificate;
mod reality_tls13_keys;
mod reality_tls13_messages;
mod reality_tls13_crypto;
-mod reality_tls13_stream;
-mod reality_server_handshake;
mod reality_server_connection;
mod reality_client_connection;
diff --git a/src/reality/reality_client_connection.rs b/src/reality/reality_client_connection.rs
index 89b6b3d..2e0c1c5 100644
@@ -10,7 +10,7 @@ use crate::crypto_connection::reader_writer::{CryptoReader, CryptoWriter, IoStat
use super::reality_crypto::{derive_auth_key, perform_ecdh, encrypt_session_id};
use super::reality_util::{extract_client_public_key, extract_server_public_key};
-use super::reality_client_handshake::construct_client_hello;
+
use super::reality_tls13_messages::*;
use super::reality_tls13_keys::{derive_handshake_keys, derive_traffic_keys, compute_finished_verify_data, derive_application_secrets};
use super::reality_tls13_crypto::{encrypt_handshake_message, decrypt_handshake_message, encrypt_tls13_record, decrypt_tls13_record};
diff --git a/src/reality/reality_client_handler.rs b/src/reality/reality_client_handler.rs
index 9ca98b6..af686e3 100644
@@ -1,16 +1,18 @@
use std::sync::Arc;
use async_trait::async_trait;
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::address::NetLocation;
use crate::async_stream::AsyncStream;
+use crate::crypto_connection::{Connection, CustomTlsStream};
+use crate::reality::{RealityClientConnection, RealityClientConfig};
use crate::tcp_handler::{TcpClientHandler, TcpClientSetupResult};
-/// REALITY client handler
+/// REALITY client handler using buffered Connection API
///
-/// This handler establishes REALITY-obfuscated TLS connections.
-/// It modifies the ClientHello to include encrypted REALITY metadata
-/// and verifies the server certificate using HMAC-SHA512.
+/// This handler establishes REALITY-obfuscated TLS connections using
+/// the buffered sans-I/O pattern with RealityClientConnection.
#[derive(Debug)]
pub struct RealityClientHandler {
public_key: [u8; 32],
@@ -43,10 +45,10 @@ impl TcpClientHandler for RealityClientHandler {
async fn setup_client_stream(
&self,
server_stream: &mut Box<dyn AsyncStream>,
- client_stream: Box<dyn AsyncStream>,
+ mut client_stream: Box<dyn AsyncStream>,
remote_location: NetLocation,
) -> std::io::Result<TcpClientSetupResult> {
- // Perform REALITY TLS 1.3 handshake
+ // Extract server name as string
let server_name_str = match &self.server_name {
rustls::pki_types::ServerName::DnsName(name) => name.as_ref(),
rustls::pki_types::ServerName::IpAddress(ip) => {
@@ -63,17 +65,74 @@ impl TcpClientHandler for RealityClientHandler {
}
};
- let tls_stream = crate::reality::reality_client_handshake::perform_reality_client_handshake(
- client_stream,
- &self.public_key,
- &self.short_id,
- server_name_str,
- )
- .await?;
+ // Step 1: Create buffered REALITY client connection
+ log::debug!("REALITY CLIENT: Creating buffered RealityClientConnection");
+ let reality_config = RealityClientConfig {
+ public_key: self.public_key,
+ short_id: self.short_id,
+ server_name: server_name_str.to_string(),
+ };
+
+ let mut reality_conn = RealityClientConnection::new(reality_config)?;
+
+ // Step 2: Write ClientHello to server
+ log::debug!("REALITY CLIENT: Writing ClientHello");
+ {
+ use std::io::Write;
+ let mut write_buf = Vec::new();
+ while reality_conn.wants_write() {
+ reality_conn.write_tls(&mut write_buf)?;
+ }
+ if !write_buf.is_empty() {
+ client_stream.write_all(&write_buf).await?;
+ client_stream.flush().await?;
+ }
+ }
+
+ // Step 3: Read server's handshake messages
+ log::debug!("REALITY CLIENT: Reading server handshake");
+ {
+ use std::io::Read;
+ let mut buf = vec![0u8; 16384]; // Large enough for server handshake
+ let n = client_stream.read(&mut buf).await?;
+ if n == 0 {
+ return Err(std::io::Error::new(
+ std::io::ErrorKind::UnexpectedEof,
+ "EOF while waiting for server handshake",
+ ));
+ }
+ let mut cursor = std::io::Cursor::new(&buf[..n]);
+ reality_conn.read_tls(&mut cursor)?;
+ }
+
+ // Step 4: Process server handshake
+ log::debug!("REALITY CLIENT: Processing server handshake");
+ reality_conn.process_new_packets()?;
+
+ // Step 5: Write client Finished message
+ log::debug!("REALITY CLIENT: Writing client Finished");
+ {
+ use std::io::Write;
+ let mut write_buf = Vec::new();
+ while reality_conn.wants_write() {
+ reality_conn.write_tls(&mut write_buf)?;
+ }
+ if !write_buf.is_empty() {
+ client_stream.write_all(&write_buf).await?;
+ client_stream.flush().await?;
+ }
+ }
+
+ // Step 6: Wrap in Connection enum and CustomTlsStream
+ log::debug!("REALITY CLIENT: Wrapping in CustomTlsStream");
+ let connection = Connection::new_reality_client(reality_conn);
+ let tls_stream = CustomTlsStream::new(client_stream, connection);
+
+ log::debug!("REALITY CLIENT: Handshake completed successfully");
// Pass the TLS-wrapped stream to the inner protocol handler
self.handler
- .setup_client_stream(server_stream, tls_stream, remote_location)
+ .setup_client_stream(server_stream, Box::new(tls_stream), remote_location)
.await
}
}
diff --git a/src/reality/reality_client_handshake.rs b/src/reality/reality_client_handshake.rs
deleted file mode 100644
index 84f8997..0000000
@@ -1,954 +0,0 @@
-// REALITY TLS 1.3 Client Handshake
-//
-// Performs complete TLS 1.3 client handshake with REALITY protocol:
-// 1. Generates ClientHello with encrypted REALITY SessionId
-// 2. Receives and validates ServerHello with HMAC-signed certificate
-// 3. Completes handshake and derives application secrets
-// 4. Returns encrypted Tls13Stream
-
-use crate::address::NetLocation;
-use crate::async_stream::AsyncStream;
-use super::reality_crypto::{
- create_session_id, decrypt_session_id, derive_auth_key, encrypt_session_id,
- get_current_timestamp, perform_ecdh,
-};
-use super::reality_tls13_crypto::{decrypt_handshake_message, encrypt_tls13_record};
-use super::reality_tls13_keys::{
- compute_finished_verify_data, derive_handshake_keys, derive_application_secrets,
- derive_traffic_keys,
-};
-use super::reality_tls13_messages::*;
-use super::reality_tls13_stream::Tls13Stream;
-use super::REALITY_PROTOCOL_VERSION;
-use aws_lc_rs::digest;
-use std::io::{Error, ErrorKind, Result};
-use tokio::io::{AsyncReadExt, AsyncWriteExt};
-use aws_lc_rs::agreement;
-
-/// Perform REALITY TLS 1.3 client handshake
-///
-/// # Arguments
-/// * `stream` - TCP connection to REALITY server
-/// * `server_public_key` - Server's X25519 public key (from config)
-/// * `short_id` - Client's short ID (from config)
-/// * `server_name` - SNI hostname for TLS handshake
-///
-/// # Returns
-/// Encrypted Tls13Stream ready for sending application data
-pub async fn perform_reality_client_handshake(
- mut stream: Box<dyn AsyncStream>,
- server_public_key: &[u8; 32],
- short_id: &[u8; 8],
- server_name: &str,
-) -> Result<Box<dyn AsyncStream>> {
- log::info!("========== REALITY CLIENT HANDSHAKE CALLED ==========");
- log::info!("REALITY CLIENT: Starting TLS 1.3 handshake to {}", server_name);
-
- // 1. Generate client's ephemeral X25519 keypair
- let client_private_key = agreement::PrivateKey::generate(&agreement::X25519)
- .map_err(|_| Error::new(ErrorKind::Other, "Failed to generate X25519 key"))?;
- let client_public_key = client_private_key
- .compute_public_key()
- .map_err(|_| Error::new(ErrorKind::Other, "Failed to compute public key"))?;
-
- log::debug!(" Generated client X25519 keypair");
-
- // 2. Perform ECDH with server's public key
- let mut shared_secret = [0u8; 32];
- agreement::agree(
- &client_private_key,
- &agreement::UnparsedPublicKey::new(&agreement::X25519, server_public_key),
- Error::new(ErrorKind::Other, "ECDH failed"),
- |key_material| {
- shared_secret.copy_from_slice(key_material);
- Ok(())
- },
- )?;
-
- log::debug!(" Client public key (full 32 bytes): {:02x?}", client_public_key.as_ref());
- log::debug!(" ECDH shared_secret (full 32 bytes): {:?}", &shared_secret);
-
- // 3. Generate ClientHello random
- use aws_lc_rs::rand::SecureRandom;
- let rng = aws_lc_rs::rand::SystemRandom::new();
- let mut client_random = [0u8; 32];
- rng.fill(&mut client_random)
- .map_err(|_| Error::new(ErrorKind::Other, "Failed to generate random"))?;
-
- log::debug!(" Generated client random");
-
- // 4. Derive AuthKey using HKDF-SHA256
- let salt: [u8; 20] = client_random[0..20]
- .try_into()
- .map_err(|_| Error::new(ErrorKind::InvalidData, "Salt extraction failed"))?;
- let auth_key = derive_auth_key(&shared_secret, &salt, b"REALITY")?;
-
- log::debug!(" Derived auth_key: {:?}", &auth_key[..8]);
-
- // 5. Create REALITY SessionId
- let timestamp = get_current_timestamp();
- let session_id = create_session_id(REALITY_PROTOCOL_VERSION, timestamp, short_id);
-
- log::info!(" Created SessionId: version={}.{}.{}, timestamp={}, short_id={:02x?}",
- REALITY_PROTOCOL_VERSION[0],
- REALITY_PROTOCOL_VERSION[1],
- REALITY_PROTOCOL_VERSION[2],
- timestamp,
- &short_id[..]);
- log::debug!(" SessionId (plaintext first 16 bytes): {:02x?}", &session_id[0..16]);
-
- // 6. Construct ClientHello
- let client_hello = construct_client_hello(
- &client_random,
- &session_id, // Plaintext SessionId for now, will encrypt in-place
- client_public_key.as_ref(),
- server_name,
- )?;
-
- log::debug!(" Constructed ClientHello ({} bytes)", client_hello.len());
-
- // 7. Encrypt SessionId using ClientHello with ZERO SessionId as AAD (matches Xray-core behavior)
- let mut modified_client_hello = client_hello;
-
- let nonce: [u8; 12] = client_random[20..32]
- .try_into()
- .map_err(|_| Error::new(ErrorKind::InvalidData, "Nonce extraction failed"))?;
-
- let plaintext: [u8; 16] = session_id[0..16]
- .try_into()
- .map_err(|_| Error::new(ErrorKind::InvalidData, "SessionId plaintext extraction failed"))?;
-
- // Zero out the SessionId in the ClientHello to create the AAD (this matches what the server will use)
- // SessionId location: handshake_type(1) + length(3) + version(2) + random(32) + session_id_length(1) = offset 39
- modified_client_hello[39..71].copy_from_slice(&[0u8; 32]);
-
- log::debug!(" ENCRYPTION: auth_key={:02x?}", &auth_key[..8]);
- log::debug!(" ENCRYPTION: nonce={:02x?}", &nonce);
- log::debug!(" ENCRYPTION: plaintext={:02x?}", &plaintext);
- log::debug!(" ENCRYPTION: aad_len={} (ClientHello with zero SessionId)", modified_client_hello.len());
-
- // Encrypt with AAD = ClientHello with zero SessionId
- let encrypted_session_id = encrypt_session_id(&plaintext, &auth_key, &nonce, &modified_client_hello)?;
-
- // Now replace the zero SessionId with the encrypted one
- modified_client_hello[39..71].copy_from_slice(&encrypted_session_id);
-
- log::debug!(" Encrypted SessionId and inserted into ClientHello");
-
- // 9. Wrap ClientHello in TLS record and send
- let mut client_hello_frame = write_record_header(0x16, modified_client_hello.len() as u16);
- client_hello_frame.extend_from_slice(&modified_client_hello);
-
- stream.write_all(&client_hello_frame).await?;
- stream.flush().await?;
-
- log::debug!(" Sent ClientHello ({} bytes total)", client_hello_frame.len());
-
- // 10. Compute ClientHello hash for transcript
- let mut ch_transcript = digest::Context::new(&digest::SHA256);
- ch_transcript.update(&modified_client_hello);
- let client_hello_hash = ch_transcript.finish();
-
- // 11. Read ServerHello
- let mut header = [0u8; 5];
- stream.read_exact(&mut header).await?;
-
- let record_type = header[0];
- let record_length = u16::from_be_bytes([header[3], header[4]]) as usize;
-
- if record_type != 0x16 {
- return Err(Error::new(
- ErrorKind::InvalidData,
- format!("Expected Handshake record, got 0x{:02x}", record_type),
- ));
- }
-
- let mut server_hello_data = vec![0u8; record_length];
- stream.read_exact(&mut server_hello_data).await?;
-
- log::debug!(" Received ServerHello ({} bytes)", server_hello_data.len());
-
- // 12. Parse ServerHello
- let (server_random, cipher_suite, server_key_share) =
- parse_server_hello(&server_hello_data)?;
-
- log::debug!(" Parsed ServerHello: cipher=0x{:04x}, key_share_len={}",
- cipher_suite,
- server_key_share.len());
-
- // 13. Compute transcript hash through ServerHello
- let mut ch_sh_transcript = digest::Context::new(&digest::SHA256);
- ch_sh_transcript.update(&modified_client_hello);
- ch_sh_transcript.update(&server_hello_data);
- let mut handshake_transcript = ch_sh_transcript.clone();
- let server_hello_hash = ch_sh_transcript.finish();
-
- // 14. Perform ECDH with server's ephemeral key
- let mut tls_shared_secret = [0u8; 32];
- agreement::agree(
- &client_private_key,
- &agreement::UnparsedPublicKey::new(&agreement::X25519, &server_key_share),
- Error::new(ErrorKind::Other, "TLS ECDH failed"),
- |key_material| {
- tls_shared_secret.copy_from_slice(key_material);
- Ok(())
- },
- )?;
-
- log::debug!(" TLS shared_secret: {:?}", &tls_shared_secret[..8]);
-
- // 15. Derive handshake keys
- let hs_keys = derive_handshake_keys(
- &tls_shared_secret,
- client_hello_hash.as_ref(),
- server_hello_hash.as_ref(),
- )?;
-
- let (server_hs_key, server_hs_iv) = derive_traffic_keys(
- &hs_keys.server_handshake_traffic_secret,
- cipher_suite,
- )?;
-
- let (client_hs_key, client_hs_iv) = derive_traffic_keys(
- &hs_keys.client_handshake_traffic_secret,
- cipher_suite,
- )?;
-
- log::debug!(" Derived handshake keys");
-
- // 16. Read ChangeCipherSpec (compatibility message, optional)
- stream.read_exact(&mut header).await?;
- let mut record_type = header[0];
- let mut record_length = u16::from_be_bytes([header[3], header[4]]) as usize;
- log::debug!(" First record after ServerHello: type=0x{:02x}, length={}", record_type, record_length);
-
- if record_type == 0x14 {
- // ChangeCipherSpec
- let mut ccs_data = vec![0u8; record_length];
- stream.read_exact(&mut ccs_data).await?;
- log::warn!(" Skipped ChangeCipherSpec");
-
- // Read next record
- stream.read_exact(&mut header).await?;
- record_type = header[0];
- record_length = u16::from_be_bytes([header[3], header[4]]) as usize;
- log::debug!(" After CCS, next record: type=0x{:02x}, length={}", record_type, record_length);
- }
-
- let mut server_seq = 0u64;
- let mut server_finished = false;
-
- // 17. Read encrypted handshake messages
- loop {
- if record_type != 0x17 {
- return Err(Error::new(
- ErrorKind::InvalidData,
- format!("Expected ApplicationData, got 0x{:02x}", record_type),
- ));
- }
-
- let mut ciphertext = vec![0u8; record_length];
- stream.read_exact(&mut ciphertext).await?;
-
- // Decrypt
- let plaintext = decrypt_handshake_message(
- &server_hs_key,
- &server_hs_iv,
- server_seq,
- &ciphertext,
- record_length as u16,
- )?;
-
- server_seq += 1;
-
- // Process multiple handshake messages in the plaintext
- let mut offset = 0;
- while offset < plaintext.len() {
- if offset + 4 > plaintext.len() {
- break; // Not enough bytes for header
- }
-
- let msg_type = plaintext[offset];
- let msg_len = u32::from_be_bytes([0, plaintext[offset+1], plaintext[offset+2], plaintext[offset+3]]) as usize;
-
- if offset + 4 + msg_len > plaintext.len() {
- return Err(Error::new(ErrorKind::InvalidData, "Invalid handshake message length"));
- }
-
- let message = &plaintext[offset..offset+4+msg_len];
- log::debug!(" Received encrypted message type: {} (length: {})", msg_type, msg_len);
-
- match msg_type {
- 8 => {
- // EncryptedExtensions
- log::debug!(" Processing EncryptedExtensions");
- handshake_transcript.update(message);
- }
- 11 => {
- // Certificate
- log::debug!(" Processing Certificate");
- handshake_transcript.update(message);
-
- // Verify HMAC signature (REALITY-specific verification)
- verify_reality_certificate(message, &auth_key)?;
- log::info!(" ✓ REALITY certificate verified via HMAC");
- }
- 15 => {
- // CertificateVerify
- log::debug!(" Processing CertificateVerify");
- handshake_transcript.update(message);
- }
- 20 => {
- // Finished
- log::debug!(" Processing server Finished");
-
- // Verify server Finished
- let handshake_hash_before_finished = handshake_transcript.clone().finish();
- let expected_verify_data = compute_finished_verify_data(
- &hs_keys.server_handshake_traffic_secret,
- handshake_hash_before_finished.as_ref(),
- )?;
-
- let verify_data = &message[4..];
- if verify_data != &expected_verify_data[..] {
- return Err(Error::new(
- ErrorKind::InvalidData,
- "Server Finished verify_data mismatch",
- ));
- }
-
- log::info!(" ✓ Server Finished verified");
- handshake_transcript.update(message);
- server_finished = true;
- }
- _ => {
- return Err(Error::new(
- ErrorKind::InvalidData,
- format!("Unexpected handshake message type: {}", msg_type),
- ));
- }
- }
-
- // Move to next message
- offset += 4 + msg_len;
- }
-
- // Check if we received the server Finished message
- if server_finished {
- break;
- }
-
- // Read next record
- stream.read_exact(&mut header).await?;
- record_type = header[0];
- record_length = u16::from_be_bytes([header[3], header[4]]) as usize;
- }
-
- // 18. Derive application traffic secrets BEFORE sending client Finished
- // Application secrets use hash up to server Finished, not including client Finished
- let handshake_hash_for_app_secrets = handshake_transcript.clone().finish();
- let (client_app_secret, server_app_secret) = derive_application_secrets(
- &hs_keys.master_secret,
- handshake_hash_for_app_secrets.as_ref(),
- )?;
-
- log::debug!(" Application secrets derived");
-
- // 19. Send ChangeCipherSpec (compatibility)
- log::debug!(" Sending ChangeCipherSpec");
- let ccs = write_record_header(0x14, 1);
- stream.write_all(&ccs).await?;
- stream.write_all(&[0x01]).await?;
- stream.flush().await?; // Ensure CCS is sent before Finished
-
- // 20. Compute and send client Finished
- let handshake_hash_for_client_finished = handshake_transcript.clone().finish();
- let client_finished_verify = compute_finished_verify_data(
- &hs_keys.client_handshake_traffic_secret,
- handshake_hash_for_client_finished.as_ref(),
- )?;
-
- let finished_msg = construct_finished(&client_finished_verify)?;
- // Note: We update transcript AFTER deriving app secrets
- handshake_transcript.update(&finished_msg);
-
- log::debug!(" Sending client Finished");
-
- // Encrypt Finished message with handshake keys (not app keys!)
- let mut buffer = finished_msg.clone();
- buffer.push(0x16); // ContentType: Handshake
-
- // Create TLS record header for AAD
- let record_header = write_record_header(0x17, (buffer.len() + 16) as u16); // +16 for GCM tag
-
- let finished_ciphertext = encrypt_tls13_record(&client_hs_key, &client_hs_iv, 0, &buffer, &record_header)?;
-
- // Send encrypted Finished
- stream.write_all(&record_header).await?;
- stream.write_all(&finished_ciphertext).await?;
- stream.flush().await?; // Ensure Finished is sent before creating TLS stream
-
- log::debug!(" Client Finished sent");
-
- // 21. Create TLS 1.3 stream wrapper
- let tls_stream = Tls13Stream::new(
- stream,
- &client_app_secret,
- &server_app_secret,
- cipher_suite,
- true, // is_client = true
- )?;
-
- log::info!("REALITY CLIENT: Handshake complete!");
-
- Ok(Box::new(tls_stream))
-}
-
-/// Verify REALITY server certificate using HMAC
-fn verify_reality_certificate(certificate_msg: &[u8], auth_key: &[u8; 32]) -> Result<()> {
- // Certificate message format:
- // - type(1) = 11
- // - length(3)
- // - certificate_request_context_length(1) = 0
- // - certificates_length(3)
- // - certificate_list...
-
- if certificate_msg.len() < 8 {
- return Err(Error::new(ErrorKind::InvalidData, "Certificate message too short"));
- }
-
- if certificate_msg[0] != 11 {
- return Err(Error::new(ErrorKind::InvalidData, "Not a Certificate message"));
- }
-
- // Skip: type(1) + length(3) + context_len(1)
- let cert_list_len = u32::from_be_bytes([0, certificate_msg[5], certificate_msg[6], certificate_msg[7]]) as usize;
-
- if certificate_msg.len() < 8 + cert_list_len {
- return Err(Error::new(ErrorKind::InvalidData, "Certificate list truncated"));
- }
-
- // Parse first certificate entry
- // Format: cert_data_length(3) + cert_data + extensions_length(2)
- if certificate_msg.len() < 11 {
- return Err(Error::new(ErrorKind::InvalidData, "No certificate data"));
- }
-
- let cert_data_len = u32::from_be_bytes([0, certificate_msg[8], certificate_msg[9], certificate_msg[10]]) as usize;
-
- if certificate_msg.len() < 11 + cert_data_len {
- return Err(Error::new(ErrorKind::InvalidData, "Certificate data truncated"));
- }
-
- let cert_der = &certificate_msg[11..11 + cert_data_len];
-
- log::debug!(" Certificate DER length: {}", cert_der.len());
- log::debug!(" Certificate first 32 bytes: {:02x?}", &cert_der[..32.min(cert_der.len())]);
- log::debug!(" Certificate last 32 bytes: {:02x?}", &cert_der[cert_der.len().saturating_sub(32)..]);
-
- // Extract Ed25519 public key and signature from certificate
- // For REALITY, the certificate uses HMAC-SHA512 as the signature
- let (public_key, signature) = extract_ed25519_cert_data(cert_der)?;
-
- log::debug!(" Extracted public key: {:?}", &public_key[..16]);
- log::debug!(" Extracted signature: {:?}", &signature[..16]);
-
- // Verify: signature == HMAC-SHA512(auth_key, public_key)
- use aws_lc_rs::hmac;
-
- let key = hmac::Key::new(hmac::HMAC_SHA512, auth_key);
- let expected_signature = hmac::sign(&key, &public_key);
-
- if signature != expected_signature.as_ref() {
- return Err(Error::new(
- ErrorKind::InvalidData,
- "REALITY certificate HMAC verification failed",
- ));
- }
-
- Ok(())
-}
-
-/// Extract Ed25519 public key and signature from DER certificate
-fn extract_ed25519_cert_data(cert_der: &[u8]) -> Result<([u8; 32], Vec<u8>)> {
- log::debug!(" Looking for Ed25519 key and signature in {} byte certificate", cert_der.len());
-
- // Debug: Find all BIT STRING tags
- for i in 0..cert_der.len().saturating_sub(2) {
- if cert_der[i] == 0x03 {
- let length = cert_der[i + 1];
- log::debug!(" Found BIT STRING at offset {}: length={}", i, length);
- }
- }
-
- // Find Ed25519 public key (32 bytes in BIT STRING, length=33)
- let mut public_key = [0u8; 32];
- let mut found_pk = false;
-
- for i in 0..cert_der.len().saturating_sub(35) {
- if cert_der[i] == 0x03 && cert_der[i + 1] == 33 && cert_der[i + 2] == 0x00 {
- // BIT STRING, length 33, unused bits 0
- public_key.copy_from_slice(&cert_der[i + 3..i + 35]);
- found_pk = true;
- log::debug!(" Found Ed25519 public key at offset {}", i);
- break;
- }
- }
-
- if !found_pk {
- return Err(Error::new(
- ErrorKind::InvalidData,
- "Could not find Ed25519 public key",
- ));
- }
-
- // Find signature (last BIT STRING with length 65: unused_bits(1) + signature(64))
- let mut signature = Vec::new();
- // Need to check up to and including the position where a 67-byte sequence could start
- for i in (0..=cert_der.len().saturating_sub(67)).rev() {
- if i + 67 <= cert_der.len() && cert_der[i] == 0x03 && cert_der[i + 1] == 0x41 && cert_der[i + 2] == 0x00 {
- // BIT STRING, length 65 (0x41), unused bits 0
- signature = cert_der[i + 3..i + 67].to_vec();
- log::debug!(" Found signature at offset {}", i);
- break;
- }
- }
-
- if signature.is_empty() {
- return Err(Error::new(
- ErrorKind::InvalidData,
- "Could not find certificate signature",
- ));
- }
-
- Ok((public_key, signature))
-}
-
-/// Parse ServerHello to extract server random, cipher suite, and key share
-fn parse_server_hello(data: &[u8]) -> Result<([u8; 32], u16, Vec<u8>)> {
- // ServerHello format:
- // - type(1) = 2
- // - length(3)
- // - version(2) = 0x0303
- // - random(32)
- // - session_id_length(1)
- // - session_id(variable)
- // - cipher_suite(2)
- // - compression(1) = 0
- // - extensions_length(2)
- // - extensions
-
- if data.len() < 38 {
- return Err(Error::new(ErrorKind::InvalidData, "ServerHello too short"));
- }
-
- if data[0] != 2 {
- return Err(Error::new(ErrorKind::InvalidData, "Not a ServerHello"));
- }
-
- // Extract random
- let mut server_random = [0u8; 32];
- server_random.copy_from_slice(&data[6..38]);
-
- // Parse session_id
- let session_id_len = data[38] as usize;
- if data.len() < 39 + session_id_len + 3 {
- return Err(Error::new(ErrorKind::InvalidData, "ServerHello truncated"));
- }
-
- let offset = 39 + session_id_len;
- let cipher_suite = u16::from_be_bytes([data[offset], data[offset + 1]]);
-
- // Parse extensions to find key_share
- let ext_offset = offset + 3;
- if data.len() < ext_offset + 2 {
- return Err(Error::new(ErrorKind::InvalidData, "No extensions"));
- }
-
- let extensions_length = u16::from_be_bytes([data[ext_offset], data[ext_offset + 1]]) as usize;
- let mut ext_data = &data[ext_offset + 2..ext_offset + 2 + extensions_length];
-
- let mut key_share_data = Vec::new();
-
- while ext_data.len() >= 4 {
- let ext_type = u16::from_be_bytes([ext_data[0], ext_data[1]]);
- let ext_len = u16::from_be_bytes([ext_data[2], ext_data[3]]) as usize;
-
- if ext_data.len() < 4 + ext_len {
- break;
- }
-
- // key_share extension type = 51
- if ext_type == 51 && ext_len >= 4 {
- let key_len = u16::from_be_bytes([ext_data[6], ext_data[7]]) as usize;
- if ext_len >= 4 + key_len {
- key_share_data = ext_data[8..8 + key_len].to_vec();
- }
- }
-
- ext_data = &ext_data[4 + ext_len..];
- }
-
- if key_share_data.is_empty() {
- return Err(Error::new(ErrorKind::InvalidData, "No key_share found"));
- }
-
- Ok((server_random, cipher_suite, key_share_data))
-}
-
-/// Construct TLS 1.3 ClientHello
-///
-/// Returns handshake message bytes (without record header)
-pub fn construct_client_hello(
- client_random: &[u8; 32],
- session_id: &[u8; 32],
- client_public_key: &[u8],
- server_name: &str,
-) -> Result<Vec<u8>> {
- let mut hello = Vec::with_capacity(512);
-
- // Handshake message type: ClientHello (0x01)
- hello.push(0x01);
-
- // Placeholder for handshake message length (3 bytes)
- let length_offset = hello.len();
- hello.extend_from_slice(&[0u8; 3]);
-
- // TLS version: 3.3 (TLS 1.2 for compatibility)
- hello.extend_from_slice(&[0x03, 0x03]);
-
- // Client random (32 bytes)
- hello.extend_from_slice(client_random);
-
- // Session ID length (1 byte) + Session ID (32 bytes)
- hello.push(32);
- hello.extend_from_slice(session_id);
-
- // Cipher suites
- // Support only TLS_AES_128_GCM_SHA256 (0x1301)
- hello.extend_from_slice(&[0x00, 0x02]); // Cipher suites length: 2 bytes
- hello.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256
-
- // Compression methods (1 method: null)
- hello.extend_from_slice(&[0x01, 0x00]);
-
- // Extensions
- let extensions_offset = hello.len();
- hello.extend_from_slice(&[0u8; 2]); // Placeholder for extensions length
-
- let mut extensions = Vec::new();
-
- // 1. server_name extension (type 0)
- {
- let server_name_bytes = server_name.as_bytes();
- let server_name_len = server_name_bytes.len();
-
- extensions.extend_from_slice(&[0x00, 0x00]); // Extension type: server_name
- let ext_len = 5 + server_name_len;
- extensions.extend_from_slice(&(ext_len as u16).to_be_bytes()); // Extension length
- extensions.extend_from_slice(&((server_name_len + 3) as u16).to_be_bytes()); // Server name list length
- extensions.push(0x00); // Name type: host_name
- extensions.extend_from_slice(&(server_name_len as u16).to_be_bytes()); // Name length
- extensions.extend_from_slice(server_name_bytes); // Server name
- }
-
- // 2. supported_versions extension (type 43)
- {
- extensions.extend_from_slice(&[0x00, 0x2b]); // Extension type: supported_versions
- extensions.extend_from_slice(&[0x00, 0x03]); // Extension length: 3
- extensions.push(0x02); // Supported versions length: 2
- extensions.extend_from_slice(&[0x03, 0x04]); // TLS 1.3
- }
-
- // 3. supported_groups extension (type 10)
- {
- extensions.extend_from_slice(&[0x00, 0x0a]); // Extension type: supported_groups
- extensions.extend_from_slice(&[0x00, 0x04]); // Extension length: 4
- extensions.extend_from_slice(&[0x00, 0x02]); // Supported groups length: 2
- extensions.extend_from_slice(&[0x00, 0x1d]); // x25519
- }
-
- // 4. key_share extension (type 51)
- {
- let key_share_start = extensions.len();
- extensions.extend_from_slice(&[0x00, 0x33]); // Extension type: key_share
- let key_share_len = 2 + 4 + client_public_key.len();
- extensions.extend_from_slice(&(key_share_len as u16).to_be_bytes()); // Extension length
- let key_share_list_len = 4 + client_public_key.len();
- extensions.extend_from_slice(&(key_share_list_len as u16).to_be_bytes()); // Key share list length
- extensions.extend_from_slice(&[0x00, 0x1d]); // Group: x25519
- extensions.extend_from_slice(&(client_public_key.len() as u16).to_be_bytes()); // Key length
- extensions.extend_from_slice(client_public_key); // Public key
- let key_share_end = extensions.len();
- log::debug!(" DEBUG: key_share extension bytes: {:02x?}", &extensions[key_share_start..key_share_end]);
- }
-
- // 5. signature_algorithms extension (type 13)
- {
- extensions.extend_from_slice(&[0x00, 0x0d]); // Extension type: signature_algorithms
- extensions.extend_from_slice(&[0x00, 0x04]); // Extension length: 4
- extensions.extend_from_slice(&[0x00, 0x02]); // Signature algorithms length: 2
- extensions.extend_from_slice(&[0x08, 0x07]); // ed25519
- }
-
- // Write extensions length
- let extensions_length = extensions.len();
- hello[extensions_offset..extensions_offset + 2]
- .copy_from_slice(&(extensions_length as u16).to_be_bytes());
-
- // Append extensions
- hello.extend_from_slice(&extensions);
-
- // Write handshake message length
- let message_length = hello.len() - 4; // Exclude type (1) and length (3)
- hello[length_offset..length_offset + 3]
- .copy_from_slice(&(message_length as u32).to_be_bytes()[1..]);
-
- Ok(hello)
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::io::{Error, ErrorKind, Result};
-
- #[test]
- fn test_parse_server_hello_valid() {
- // Construct a valid ServerHello message
- let mut data = Vec::new();
-
- // Handshake type (ServerHello = 0x02)
- data.push(0x02);
- // Length (3 bytes)
- data.extend_from_slice(&[0x00, 0x00, 0x4C]); // 76 bytes
- // Version (TLS 1.2 = 0x0303)
- data.extend_from_slice(&[0x03, 0x03]);
- // Server random (32 bytes)
- let server_random = [0xaa; 32];
- data.extend_from_slice(&server_random);
- // Session ID length
- data.push(32);
- // Session ID (32 bytes)
- let session_id = [0xbb; 32];
- data.extend_from_slice(&session_id);
- // Cipher suite (2 bytes)
- data.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256
- // Compression method
- data.push(0x00);
- // Extensions length
- data.extend_from_slice(&[0x00, 0x04]);
- // Extension type and data
- data.extend_from_slice(&[0x00, 0x2b, 0x00, 0x02, 0x03, 0x04]);
-
- let result = parse_server_hello(&data);
- assert!(result.is_ok());
-
- let (parsed_random, cipher_suite, extensions) = result.unwrap();
- assert_eq!(parsed_random, server_random);
- assert_eq!(cipher_suite, 0x1301);
- assert!(!extensions.is_empty());
- }
-
- #[test]
- fn test_parse_server_hello_invalid_type() {
- let mut data = Vec::new();
- // Wrong handshake type
- data.push(0x01); // ClientHello instead of ServerHello
- data.extend_from_slice(&[0x00, 0x00, 0x4C]);
-
- let result = parse_server_hello(&data);
- assert!(result.is_err());
- assert_eq!(result.unwrap_err().kind(), ErrorKind::InvalidData);
- }
-
- #[test]
- fn test_parse_server_hello_too_short() {
- let data = vec![0x02, 0x00, 0x00, 0x10]; // Too short
- let result = parse_server_hello(&data);
- assert!(result.is_err());
- }
-
- #[test]
- fn test_construct_client_hello() {
- let client_random = [0x11; 32];
- let session_id = [0x22; 32];
- let public_key_bytes = [0x33; 32];
- let server_name = "example.com";
-
- let result = construct_client_hello(
- &client_random,
- &session_id,
- &public_key_bytes,
- server_name,
- );
-
- assert!(result.is_ok());
- let client_hello = result.unwrap();
-
- // Verify basic structure
- assert_eq!(client_hello[0], 0x01); // ClientHello type
-
- // Check that length fields are present
- let length = ((client_hello[1] as usize) << 16)
- | ((client_hello[2] as usize) << 8)
- | (client_hello[3] as usize);
- assert!(length > 100); // Should be reasonably sized
-
- // Check version (TLS 1.2 = 0x0303)
- assert_eq!(client_hello[4], 0x03);
- assert_eq!(client_hello[5], 0x03);
-
- // Check client random is included
- assert_eq!(&client_hello[6..38], &client_random);
- }
-
- #[test]
- fn test_construct_client_hello_empty_server_name() {
- let client_random = [0x11; 32];
- let session_id = [0x22; 32];
- let public_key_bytes = [0x33; 32];
-
- let result = construct_client_hello(
- &client_random,
- &session_id,
- &public_key_bytes,
- "", // Empty server name
- );
-
- assert!(result.is_ok());
- let client_hello = result.unwrap();
-
- // Should still create a valid ClientHello
- assert_eq!(client_hello[0], 0x01);
- }
-
- #[test]
- fn test_construct_client_hello_long_server_name() {
- let client_random = [0x11; 32];
- let session_id = [0x22; 32];
- let public_key_bytes = [0x33; 32];
- let server_name = "a".repeat(255); // Max length server name
-
- let result = construct_client_hello(
- &client_random,
- &session_id,
- &public_key_bytes,
- &server_name,
- );
-
- assert!(result.is_ok());
- let client_hello = result.unwrap();
-
- // Should handle long server names
- assert_eq!(client_hello[0], 0x01);
- assert!(client_hello.len() > 300); // Should be larger with long SNI
- }
-
- #[test]
- fn test_extract_ed25519_cert_data_no_ed25519() {
- // Create a certificate without Ed25519 OID
- let cert_der = vec![0x30, 0x82, 0x01, 0x00]; // Just a SEQUENCE header
-
- let result = extract_ed25519_cert_data(&cert_der);
- assert!(result.is_err());
- }
-
- #[test]
- fn test_verify_reality_certificate_invalid_type() {
- // Create a message with wrong handshake type
- let mut cert_msg = Vec::new();
- cert_msg.push(0x01); // ClientHello instead of Certificate
- cert_msg.extend_from_slice(&[0x00, 0x00, 0x10]);
- cert_msg.extend_from_slice(&[0xaa; 16]);
-
- let auth_key = [0x42; 32];
- let result = verify_reality_certificate(&cert_msg, &auth_key);
-
- assert!(result.is_err());
- assert_eq!(result.unwrap_err().kind(), ErrorKind::InvalidData);
- }
-
- #[test]
- fn test_verify_reality_certificate_too_short() {
- // Create a message that's too short
- let cert_msg = vec![0x0b, 0x00, 0x00]; // Missing data
-
- let auth_key = [0x42; 32];
- let result = verify_reality_certificate(&cert_msg, &auth_key);
-
- assert!(result.is_err());
- }
-
- // Mock stream for testing
- struct MockStream {
- read_data: Vec<u8>,
- read_pos: usize,
- written_data: Vec<u8>,
- }
-
- impl MockStream {
- fn new() -> Self {
- MockStream {
- read_data: Vec::new(),
- read_pos: 0,
- written_data: Vec::new(),
- }
- }
-
- fn set_read_data(&mut self, data: Vec<u8>) {
- self.read_data = data;
- self.read_pos = 0;
- }
- }
-
- use std::pin::Pin;
- use std::task::{Context, Poll};
- use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
-
- impl AsyncRead for MockStream {
- fn poll_read(
- mut self: Pin<&mut Self>,
- _cx: &mut Context<'_>,
- buf: &mut ReadBuf<'_>,
- ) -> Poll<std::io::Result<()>> {
- let remaining = &self.read_data[self.read_pos..];
- let to_read = std::cmp::min(remaining.len(), buf.remaining());
-
- buf.put_slice(&remaining[..to_read]);
- self.read_pos += to_read;
-
- Poll::Ready(Ok(()))
- }
- }
-
- impl AsyncWrite for MockStream {
- fn poll_write(
- mut self: Pin<&mut Self>,
- _cx: &mut Context<'_>,
- buf: &[u8],
- ) -> Poll<std::io::Result<usize>> {
- self.written_data.extend_from_slice(buf);
- Poll::Ready(Ok(buf.len()))
- }
-
- fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
- Poll::Ready(Ok(()))
- }
-
- fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
- Poll::Ready(Ok(()))
- }
- }
-
- impl crate::async_stream::AsyncPing for MockStream {
- fn supports_ping(&self) -> bool {
- false
- }
-
- fn poll_write_ping(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<bool>> {
- Poll::Ready(Ok(false))
- }
- }
-
- impl AsyncStream for MockStream {}
-}
diff --git a/src/reality/reality_connection_wrapper.rs b/src/reality/reality_connection_wrapper.rs
deleted file mode 100644
index a26ff13..0000000
@@ -1,389 +0,0 @@
-/// REALITY Server Connection Wrapper
-///
-/// This module implements manual TLS handshake control using rustls::ServerConnection
-/// to inject an HMAC-signed certificate instead of the valid one.
-
-use crate::async_stream::AsyncStream;
-use std::io::{Cursor, Read, Write};
-use std::sync::Arc;
-use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
-
-pub struct RealityServerConnection;
-
-impl RealityServerConnection {
- /// Performs TLS handshake with HMAC certificate injection
- ///
- /// # Arguments
- /// * `stream` - The underlying TCP stream
- /// * `server_config` - rustls ServerConfig with VALID certificate (for internal use)
- /// * `client_hello_frame` - The ClientHello TLS frame
- /// * `hmac_cert_der` - Certificate with HMAC signature to send to client
- ///
- /// # Returns
- /// Box<dyn AsyncStream> wrapping the completed TLS connection
- pub async fn perform_handshake(
- mut stream: Box<dyn AsyncStream>,
- server_config: Arc<rustls::ServerConfig>,
- client_hello_frame: Vec<u8>,
- hmac_cert_der: Vec<u8>,
- ) -> std::io::Result<Box<dyn AsyncStream>> {
- log::info!("REALITY: Creating ServerConnection");
-
- // Create ServerConnection
- let mut server_conn = rustls::ServerConnection::new(server_config)
- .map_err(|e| std::io::Error::new(
- std::io::ErrorKind::InvalidData,
- format!("Failed to create ServerConnection: {}", e)
- ))?;
-
- // Feed the ClientHello to ServerConnection
- log::info!("REALITY: Feeding ClientHello to ServerConnection");
- let mut cursor = Cursor::new(&client_hello_frame);
- while cursor.position() < client_hello_frame.len() as u64 {
- let n = server_conn.read_tls(&mut cursor)
- .map_err(|e| std::io::Error::new(
- std::io::ErrorKind::InvalidData,
- format!("Failed to feed ClientHello: {}", e)
- ))?;
- if n == 0 {
- break;
- }
- }
-
- // Process the ClientHello
- log::debug!("REALITY: Processing ClientHello");
- server_conn.process_new_packets()
- .map_err(|e| std::io::Error::new(
- std::io::ErrorKind::InvalidData,
- format!("Failed to process ClientHello: {}", e)
- ))?;
-
- // Get the handshake response (ServerHello, Certificate, etc.)
- log::info!("REALITY: Generating handshake response");
- let mut response_buf = Vec::new();
- server_conn.write_tls(&mut response_buf)
- .map_err(|e| std::io::Error::new(
- std::io::ErrorKind::InvalidData,
- format!("Failed to write TLS response: {}", e)
- ))?;
-
- log::debug!("REALITY: Response buffer size: {}", response_buf.len());
-
- // Replace the certificate in the response with HMAC certificate
- log::debug!("REALITY: Replacing certificate with HMAC version");
- let modified_response = replace_certificate_in_handshake(&response_buf, &hmac_cert_der)?;
-
- // Send the modified response to the client
- log::debug!("REALITY: Sending modified response to client");
- stream.write_all(&modified_response).await?;
- stream.flush().await?;
-
- // Continue with the rest of the handshake
- log::info!("REALITY: Completing handshake");
- loop {
- // Read from client
- let mut client_buf = vec![0u8; 16384];
- let n = stream.read(&mut client_buf).await?;
- if n == 0 {
- break;
- }
- client_buf.truncate(n);
-
- // Feed to ServerConnection
- let mut cursor = Cursor::new(&client_buf);
- server_conn.read_tls(&mut cursor)
- .map_err(|e| std::io::Error::new(
- std::io::ErrorKind::InvalidData,
- format!("Failed to read client TLS data: {}", e)
- ))?;
-
- // Process packets
- server_conn.process_new_packets()
- .map_err(|e| std::io::Error::new(
- std::io::ErrorKind::InvalidData,
- format!("Failed to process TLS packets: {}", e)
- ))?;
-
- // Write response
- let mut response_buf = Vec::new();
- server_conn.write_tls(&mut response_buf)
- .map_err(|e| std::io::Error::new(
- std::io::ErrorKind::InvalidData,
- format!("Failed to write TLS response: {}", e)
- ))?;
-
- if !response_buf.is_empty() {
- stream.write_all(&response_buf).await?;
- stream.flush().await?;
- }
-
- // Check if handshake is complete
- if !server_conn.is_handshaking() {
- log::info!("REALITY: Handshake complete!");
- break;
- }
- }
-
- // Wrap the stream with our custom wrapper
- let tls_stream = RealityTlsStream {
- inner: stream,
- conn: server_conn,
- };
-
- Ok(Box::new(tls_stream))
- }
-}
-
-/// Replaces the certificate in a TLS handshake message
-///
-/// Searches for the Certificate handshake message and replaces the certificate DER bytes
-fn replace_certificate_in_handshake(
- handshake_buf: &[u8],
- new_cert_der: &[u8],
-) -> std::io::Result<Vec<u8>> {
- // TLS 1.3 handshake message format:
- // - TLS Record: [ContentType(1), Version(2), Length(2), Fragment]
- // - Handshake: [HandshakeType(1), Length(3), Body]
- // - Certificate: [CertificateRequestContext(1+), CertificateList]
-
- let mut result = Vec::new();
- let mut i = 0;
-
- while i < handshake_buf.len() {
- // Check for TLS record
- if i + 5 > handshake_buf.len() {
- // Not enough data for TLS record header
- result.extend_from_slice(&handshake_buf[i..]);
- break;
- }
-
- let content_type = handshake_buf[i];
- let record_len = u16::from_be_bytes([handshake_buf[i + 3], handshake_buf[i + 4]]) as usize;
-
- if content_type == 0x16 { // Handshake record
- // Check if this contains a Certificate message (type 0x0B)
- if i + 5 + 1 <= handshake_buf.len() {
- let handshake_type = handshake_buf[i + 5];
-
- if handshake_type == 0x0B { // Certificate message
- log::debug!("REALITY: Found Certificate message at offset {}", i);
- log::debug!("REALITY: Original record length: {}", record_len);
- log::debug!("REALITY: HMAC cert DER length: {}", new_cert_der.len());
-
- // Replace this entire handshake message with our HMAC certificate
- let replacement = build_certificate_handshake(new_cert_der)?;
- log::debug!("REALITY: Replacement record length: {}", replacement.len());
- log::debug!("REALITY: First 20 bytes of HMAC cert: {:?}", &new_cert_der[..20.min(new_cert_der.len())]);
-
- result.extend_from_slice(&replacement);
-
- // Skip the original certificate message
- i += 5 + record_len;
- continue;
- }
- }
- }
-
- // Copy this record as-is
- if i + 5 + record_len <= handshake_buf.len() {
- result.extend_from_slice(&handshake_buf[i..i + 5 + record_len]);
- i += 5 + record_len;
- } else {
- // Incomplete record, copy rest
- result.extend_from_slice(&handshake_buf[i..]);
- break;
- }
- }
-
- Ok(result)
-}
-
-/// Builds a TLS 1.3 Certificate handshake message
-fn build_certificate_handshake(cert_der: &[u8]) -> std::io::Result<Vec<u8>> {
- // TLS 1.3 Certificate message structure:
- // - CertificateRequestContext (1 byte length + context)
- // - CertificateList (3 bytes length + entries)
- // - CertificateEntry:
- // - cert_data (3 bytes length + DER)
- // - extensions (2 bytes length + extensions)
-
- let mut cert_msg = Vec::new();
-
- // CertificateRequestContext (empty)
- cert_msg.push(0);
-
- // CertificateList
- let cert_entry_len = 3 + cert_der.len() + 2; // cert_data + extensions
- let cert_list_len = 3 + cert_entry_len;
-
- // CertificateList length (3 bytes)
- cert_msg.push(((cert_list_len >> 16) & 0xFF) as u8);
- cert_msg.push(((cert_list_len >> 8) & 0xFF) as u8);
- cert_msg.push((cert_list_len & 0xFF) as u8);
-
- // CertificateEntry: cert_data length (3 bytes)
- cert_msg.push(((cert_der.len() >> 16) & 0xFF) as u8);
- cert_msg.push(((cert_der.len() >> 8) & 0xFF) as u8);
- cert_msg.push((cert_der.len() & 0xFF) as u8);
-
- // CertificateEntry: cert_data
- cert_msg.extend_from_slice(cert_der);
-
- // CertificateEntry: extensions (empty)
- cert_msg.push(0);
- cert_msg.push(0);
-
- // Wrap in Handshake message
- let mut handshake = Vec::new();
- handshake.push(0x0B); // Certificate handshake type
-
- // Handshake length (3 bytes)
- handshake.push(((cert_msg.len() >> 16) & 0xFF) as u8);
- handshake.push(((cert_msg.len() >> 8) & 0xFF) as u8);
- handshake.push((cert_msg.len() & 0xFF) as u8);
-
- handshake.extend_from_slice(&cert_msg);
-
- // Wrap in TLS Record
- let mut record = Vec::new();
- record.push(0x16); // ContentType: Handshake
- record.push(0x03); // Version: TLS 1.2 (for compatibility)
- record.push(0x03);
-
- // Record length (2 bytes)
- record.push(((handshake.len() >> 8) & 0xFF) as u8);
- record.push((handshake.len() & 0xFF) as u8);
-
- record.extend_from_slice(&handshake);
-
- Ok(record)
-}
-
-/// Custom TLS stream wrapper for REALITY
-struct RealityTlsStream {
- inner: Box<dyn AsyncStream>,
- conn: rustls::ServerConnection,
-}
-
-impl AsyncRead for RealityTlsStream {
- fn poll_read(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- buf: &mut tokio::io::ReadBuf<'_>,
- ) -> std::task::Poll<std::io::Result<()>> {
- use std::pin::Pin;
-
- // Try to read decrypted data from the connection
- if let Ok(n) = self.conn.reader().read(buf.initialize_unfilled()) {
- if n > 0 {
- buf.advance(n);
- return std::task::Poll::Ready(Ok(()));
- }
- }
-
- // Need more TLS data from the network
- let mut tls_buf = vec![0u8; 16384];
- let mut read_buf = tokio::io::ReadBuf::new(&mut tls_buf);
-
- match Pin::new(&mut self.inner).poll_read(cx, &mut read_buf) {
- std::task::Poll::Ready(Ok(())) => {
- let n = read_buf.filled().len();
- if n == 0 {
- return std::task::Poll::Ready(Ok(()));
- }
-
- // Feed TLS data to the connection
- let mut cursor = Cursor::new(&tls_buf[..n]);
- if let Err(e) = self.conn.read_tls(&mut cursor) {
- return std::task::Poll::Ready(Err(std::io::Error::new(
- std::io::ErrorKind::Other,
- format!("TLS read error: {}", e)
- )));
- }
-
- // Process TLS packets
- if let Err(e) = self.conn.process_new_packets() {
- return std::task::Poll::Ready(Err(std::io::Error::new(
- std::io::ErrorKind::Other,
- format!("TLS process error: {}", e)
- )));
- }
-
- // Try reading again
- if let Ok(n) = self.conn.reader().read(buf.initialize_unfilled()) {
- buf.advance(n);
- }
-
- std::task::Poll::Ready(Ok(()))
- }
- std::task::Poll::Ready(Err(e)) => std::task::Poll::Ready(Err(e)),
- std::task::Poll::Pending => std::task::Poll::Pending,
- }
- }
-}
-
-impl AsyncWrite for RealityTlsStream {
- fn poll_write(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- buf: &[u8],
- ) -> std::task::Poll<Result<usize, std::io::Error>> {
- use std::pin::Pin;
-
- // Write plaintext data to the connection
- let n = self.conn.writer().write(buf)
- .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("TLS write error: {}", e)))?;
-
- // Send encrypted data to the network
- let mut tls_buf = Vec::new();
- self.conn.write_tls(&mut tls_buf)
- .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("TLS write_tls error: {}", e)))?;
-
- if !tls_buf.is_empty() {
- // We need to write this to the network
- match Pin::new(&mut self.inner).poll_write(cx, &tls_buf) {
- std::task::Poll::Ready(Ok(_)) => {
- // Flush to ensure it's sent
- let _ = Pin::new(&mut self.inner).poll_flush(cx);
- }
- std::task::Poll::Ready(Err(e)) => return std::task::Poll::Ready(Err(e)),
- std::task::Poll::Pending => return std::task::Poll::Pending,
- }
- }
-
- std::task::Poll::Ready(Ok(n))
- }
-
- fn poll_flush(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Result<(), std::io::Error>> {
- use std::pin::Pin;
- Pin::new(&mut self.inner).poll_flush(cx)
- }
-
- fn poll_shutdown(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Result<(), std::io::Error>> {
- use std::pin::Pin;
- self.conn.send_close_notify();
- Pin::new(&mut self.inner).poll_shutdown(cx)
- }
-}
-
-impl crate::async_stream::AsyncPing for RealityTlsStream {
- fn supports_ping(&self) -> bool {
- self.inner.supports_ping()
- }
-
- fn poll_write_ping(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<std::io::Result<bool>> {
- use std::pin::Pin;
- Pin::new(&mut self.inner).poll_write_ping(cx)
- }
-}
-
-impl crate::async_stream::AsyncStream for RealityTlsStream {}
diff --git a/src/reality/reality_server_handler.rs b/src/reality/reality_server_handler.rs
index 7546448..a2ac9e3 100644
@@ -1,20 +1,17 @@
-use std::pin::Pin;
use std::sync::Arc;
-use std::task::{Context, Poll};
use std::time::{SystemTime, UNIX_EPOCH};
-use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::address::NetLocation;
use crate::async_stream::AsyncStream;
use crate::client_proxy_selector::ClientProxySelector;
+use crate::crypto_connection::{Connection, CustomTlsStream};
use crate::option_util::NoneOrOne;
-use crate::reality::reality_crypto::{
- decrypt_session_id, derive_auth_key, generate_reality_certificate, perform_ecdh,
-};
-use crate::reality::reality_server_handshake::perform_reality_server_handshake;
+use crate::reality::reality_crypto::{decrypt_session_id, derive_auth_key, perform_ecdh};
use crate::reality::reality_util::{
extract_client_public_key, extract_client_random, extract_session_id,
};
+use crate::reality::{RealityServerConnection, RealityServerConfig};
use crate::shadow_tls::ParsedClientHello;
use crate::tcp_client_connector::TcpClientConnector;
use crate::tcp_handler::TcpServerHandler;
@@ -32,32 +29,27 @@ pub struct RealityServerTarget {
pub override_proxy_provider: NoneOrOne<Arc<ClientProxySelector<TcpClientConnector>>>,
}
-/// Sets up a REALITY server stream
+/// Sets up a REALITY server stream using buffered Connection API
///
/// This function:
-/// 1. Extracts client's public key from ClientHello
-/// 2. Performs ECDH with server's private key
-/// 3. Derives AuthKey using HKDF
-/// 4. Decrypts SessionId
-/// 5. Validates metadata (version, timestamp, short_id)
-/// 6. Generates certificate with HMAC signature
-/// 7. Completes TLS handshake
-/// 8. Returns TLS stream
+/// 1. Validates the ClientHello (ECDH, SessionId decryption, metadata)
+/// 2. Creates a RealityServerConnection with buffered handshake
+/// 3. Feeds the ClientHello to the connection
+/// 4. Wraps in CustomTlsStream for async I/O
+/// 5. Returns the encrypted stream
pub async fn setup_reality_server_stream(
- server_stream: Box<dyn AsyncStream>,
+ mut server_stream: Box<dyn AsyncStream>,
parsed_client_hello: ParsedClientHello,
config: &RealityServerTarget,
) -> std::io::Result<Box<dyn AsyncStream>> {
log::debug!("REALITY DEBUG: setup_reality_server_stream called");
let client_hello_frame = &parsed_client_hello.client_hello_frame;
- log::debug!("REALITY DEBUG: ClientHello frame length: {}",
- client_hello_frame.len());
+ log::debug!("REALITY DEBUG: ClientHello frame length: {}", client_hello_frame.len());
// Step 1: Extract client's X25519 public key from KeyShare extension
log::debug!("REALITY DEBUG: Extracting client public key");
let client_public_key = extract_client_public_key(client_hello_frame)?;
- log::debug!("REALITY DEBUG: Client public key extracted: {:?}",
- &client_public_key[..16]);
+ log::debug!("REALITY DEBUG: Client public key extracted: {:?}", &client_public_key[..16]);
// Step 2: Extract ClientRandom
let client_random = extract_client_random(client_hello_frame)?;
@@ -67,54 +59,35 @@ pub async fn setup_reality_server_stream(
if session_id_vec.len() != 32 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
- format!(
- "Invalid SessionId length: {} (expected 32)",
- session_id_vec.len()
- ),
+ format!("Invalid SessionId length: {} (expected 32)", session_id_vec.len()),
));
}
let encrypted_session_id: [u8; 32] = session_id_vec.as_slice().try_into().map_err(|_| {
- std::io::Error::new(
- std::io::ErrorKind::InvalidData,
- "SessionId conversion failed",
- )
+ std::io::Error::new(std::io::ErrorKind::InvalidData, "SessionId conversion failed")
})?;
// Step 4: Perform ECDH with server's private key
let auth_key = perform_ecdh(&config.private_key, &client_public_key)?;
- log::debug!("REALITY DEBUG: ECDH shared secret (first 16 bytes): {:?}",
- &auth_key[..16]);
+ log::debug!("REALITY DEBUG: ECDH shared secret (first 16 bytes): {:?}", &auth_key[..16]);
// Step 5: Derive AuthKey with HKDF
let salt: [u8; 20] = client_random[0..20].try_into().map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "Salt extraction failed")
})?;
- log::debug!("REALITY DEBUG: HKDF salt (first 20 bytes of random): {:?}",
- &salt);
+ log::debug!("REALITY DEBUG: HKDF salt (first 20 bytes of random): {:?}", &salt);
let auth_key = derive_auth_key(&auth_key, &salt, b"REALITY")?;
- log::debug!("REALITY DEBUG: Derived auth_key (first 16 bytes): {:?}",
- &auth_key[..16]);
+ log::debug!("REALITY DEBUG: Derived auth_key (first 16 bytes): {:?}", &auth_key[..16]);
// Step 6: Decrypt SessionId
let nonce: [u8; 12] = client_random[20..32].try_into().map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "Nonce extraction failed")
})?;
log::debug!("REALITY DEBUG: Nonce (bytes 20-31 of random): {:?}", &nonce);
- log::debug!("REALITY DEBUG: Encrypted SessionId: {:?}",
- &encrypted_session_id);
+ log::debug!("REALITY DEBUG: Encrypted SessionId: {:?}", &encrypted_session_id);
- // Find the ClientHello handshake message in the frame (skip TLS header)
+ // CRITICAL: The AAD used during encryption contained ZEROS at the SessionId location
const TLS_HEADER_LEN: usize = 5;
let client_hello_handshake = &client_hello_frame[TLS_HEADER_LEN..];
- log::debug!("REALITY DEBUG: AAD length (ClientHello handshake): {}",
- client_hello_handshake.len());
- log::debug!("REALITY DEBUG: AAD first 32 bytes: {:?}",
- &client_hello_handshake[..32.min(client_hello_handshake.len())]
- );
-
- // CRITICAL: The AAD used during encryption contained ZEROS at the SessionId location (bytes 39-70),
- // not the encrypted SessionId. We need to reconstruct this for decryption.
- // SessionId is at fixed offset 39 in the ClientHello handshake message (after type + length + version + random + sessionid_length)
let mut aad_for_decryption = client_hello_handshake.to_vec();
if aad_for_decryption.len() >= 39 + 32 {
// Replace encrypted SessionId with zeros
@@ -190,21 +163,70 @@ pub async fn setup_reality_server_stream(
));
}
- // Step 8-9: Use raw key bytes (no need to convert to dalek types)
- // The private_key and client_public_key are already [u8; 32]
+ // Step 8: Create buffered REALITY connection
+ log::debug!("REALITY DEBUG: Creating buffered RealityServerConnection");
+ let reality_config = RealityServerConfig {
+ private_key: config.private_key,
+ short_ids: config.short_ids.clone(),
+ dest: config.dest.clone(),
+ max_time_diff: config.max_time_diff,
+ min_client_version: config.min_client_version,
+ max_client_version: config.max_client_version,
+ };
+
+ let mut reality_conn = RealityServerConnection::new(reality_config)?;
+
+ // Step 9: Feed the ClientHello to the connection
+ log::debug!("REALITY DEBUG: Feeding ClientHello to connection via read_tls");
+ {
+ use std::io::Read;
+ let mut cursor = std::io::Cursor::new(client_hello_frame);
+ reality_conn.read_tls(&mut cursor)?;
+ }
- // Step 10: Perform manual TLS 1.3 handshake with HMAC-signed certificate
- log::debug!("REALITY DEBUG: Starting manual TLS 1.3 handshake");
- let tls_stream = perform_reality_server_handshake(
- server_stream,
- client_hello_frame,
- &client_public_key,
- &config.private_key,
- &auth_key,
- &config.dest,
- )
- .await?;
+ // Step 10: Process the ClientHello to advance handshake
+ log::debug!("REALITY DEBUG: Processing ClientHello via process_new_packets");
+ reality_conn.process_new_packets()?;
+
+ // Step 11: Write the server's handshake response to the stream
+ log::debug!("REALITY DEBUG: Writing server handshake response");
+ {
+ use std::io::Write;
+ let mut write_buf = Vec::new();
+ while reality_conn.wants_write() {
+ reality_conn.write_tls(&mut write_buf)?;
+ }
+ if !write_buf.is_empty() {
+ server_stream.write_all(&write_buf).await?;
+ server_stream.flush().await?;
+ }
+ }
+
+ // Step 12: Read client's Finished message
+ log::debug!("REALITY DEBUG: Reading client Finished");
+ {
+ use std::io::Read;
+ let mut buf = vec![0u8; 4096];
+ let n = server_stream.read(&mut buf).await?;
+ if n == 0 {
+ return Err(std::io::Error::new(
+ std::io::ErrorKind::UnexpectedEof,
+ "EOF while waiting for client Finished",
+ ));
+ }
+ let mut cursor = std::io::Cursor::new(&buf[..n]);
+ reality_conn.read_tls(&mut cursor)?;
+ }
+
+ // Step 13: Process client Finished to complete handshake
+ log::debug!("REALITY DEBUG: Processing client Finished");
+ reality_conn.process_new_packets()?;
+
+ // Step 14: Wrap in Connection enum and CustomTlsStream
+ log::debug!("REALITY DEBUG: Wrapping in CustomTlsStream");
+ let connection = Connection::new_reality_server(reality_conn);
+ let tls_stream = CustomTlsStream::new(server_stream, connection);
log::debug!("REALITY DEBUG: TLS 1.3 handshake completed successfully");
- Ok(tls_stream)
+ Ok(Box::new(tls_stream))
}
diff --git a/src/reality/reality_server_handshake.rs b/src/reality/reality_server_handshake.rs
deleted file mode 100644
index 777c7d3..0000000
@@ -1,575 +0,0 @@
-// REALITY TLS 1.3 Handshake Orchestration
-//
-// Performs complete TLS 1.3 handshake with HMAC-signed certificate
-
-use crate::address::NetLocation;
-use crate::async_stream::AsyncStream;
-use super::reality_certificate::generate_hmac_certificate;
-use super::reality_destination::probe_destination_handshake;
-use super::reality_tls13_keys::{derive_handshake_keys, derive_application_secrets, derive_traffic_keys, compute_finished_verify_data};
-use super::reality_tls13_messages::*;
-use super::reality_tls13_crypto::encrypt_handshake_message;
-use super::reality_tls13_stream::Tls13Stream;
-use aws_lc_rs::digest;
-use std::io::{Error, ErrorKind, Result};
-use tokio::io::{AsyncReadExt, AsyncWriteExt};
-use aws_lc_rs::agreement;
-use aws_lc_rs::rand::{SecureRandom, SystemRandom};
-use crate::reality::reality_crypto::perform_ecdh;
-
-/// Perform REALITY TLS 1.3 server handshake
-///
-/// # Arguments
-/// * `stream` - TCP stream with authenticated client (SessionId validated)
-/// * `client_hello_frame` - Complete ClientHello TLS record
-/// * `client_public_key` - Client's X25519 public key (from ClientHello)
-/// * `server_private_key` - Server's X25519 private key
-/// * `auth_key` - Derived authentication key (from REALITY ECDH)
-/// * `destination` - Real destination to proxy certificate parameters from
-///
-/// # Returns
-/// Encrypted Tls13Stream ready for proxying application data
-pub async fn perform_reality_server_handshake(
- mut stream: Box<dyn AsyncStream>,
- client_hello_frame: &[u8],
- client_public_key: &[u8; 32],
- server_private_key: &[u8; 32],
- auth_key: &[u8; 32],
- destination: &NetLocation,
-) -> Result<Box<dyn AsyncStream>> {
- log::info!("REALITY HANDSHAKE: Starting TLS 1.3 handshake");
-
- // 1. Perform ECDH with client's public key
- let shared_secret = perform_ecdh(server_private_key, client_public_key)
- .map_err(|e| Error::new(ErrorKind::InvalidData, format!("ECDH failed: {:?}", e)))?;
- log::debug!(" ECDH shared_secret: {:?}", &shared_secret[..8]);
-
- // 2. Generate our own X25519 keypair for key_share
- let rng = SystemRandom::new();
- let mut our_private_bytes = [0u8; 32];
- rng.fill(&mut our_private_bytes)
- .map_err(|_| Error::new(ErrorKind::Other, "Failed to generate random bytes"))?;
-
- let our_private_key = agreement::PrivateKey::from_private_key(&agreement::X25519, &our_private_bytes)
- .map_err(|_| Error::new(ErrorKind::Other, "Failed to create X25519 key"))?;
- let our_public_key_bytes = our_private_key.compute_public_key()
- .map_err(|_| Error::new(ErrorKind::Other, "Failed to compute public key"))?;
-
- // 3. Connect to destination and get ServerHello parameters
- log::debug!(" Probing destination: {}", destination);
- let dest_handshake = probe_destination_handshake(destination, client_hello_frame).await?;
-
- log::debug!(" Got destination parameters:");
- log::debug!(" cipher_suite: 0x{:04x}", dest_handshake.cipher_suite);
- log::debug!(" key_share_group: 0x{:04x}", dest_handshake.key_share_group);
-
- // 4. Generate HMAC-signed Ed25519 certificate with destination hostname
- let dest_hostname = match destination.address() {
- crate::address::Address::Hostname(h) => h.as_str(),
- crate::address::Address::Ipv4(ip) => return Err(Error::new(
- ErrorKind::InvalidInput,
- format!("REALITY requires a hostname, not an IP address: {}", ip),
- )),
- crate::address::Address::Ipv6(ip) => return Err(Error::new(
- ErrorKind::InvalidInput,
- format!("REALITY requires a hostname, not an IP address: {}", ip),
- )),
- };
-
- log::debug!(" Generating HMAC certificate for {}", dest_hostname);
- let (cert_der, signing_key) = generate_hmac_certificate(auth_key, dest_hostname)?;
-
- // 5. Extract session_id from ClientHello for ServerHello
- // ClientHello format: type(1) + length(3) + version(2) + random(32) + session_id_length(1) + ...
- if client_hello_frame.len() < 44 {
- return Err(Error::new(ErrorKind::InvalidData, "ClientHello too short"));
- }
-
- let ch_handshake_start = 5; // Skip TLS record header
- let session_id_length = client_hello_frame[ch_handshake_start + 38] as usize;
-
- if client_hello_frame.len() < ch_handshake_start + 39 + session_id_length {
- return Err(Error::new(ErrorKind::InvalidData, "Invalid session_id"));
- }
-
- let session_id = &client_hello_frame[ch_handshake_start + 39..ch_handshake_start + 39 + session_id_length];
-
- // 6. Construct ServerHello
- log::debug!(" Constructing ServerHello");
- let server_hello = construct_server_hello(
- &dest_handshake.server_random,
- session_id,
- dest_handshake.cipher_suite,
- our_public_key_bytes.as_ref(),
- )?;
-
- // 7. Compute handshake transcript hashes
- let client_hello_handshake = &client_hello_frame[5..]; // Skip TLS record header
-
- let mut ch_transcript = digest::Context::new(&digest::SHA256);
- ch_transcript.update(client_hello_handshake);
- let client_hello_hash = ch_transcript.finish();
-
- let mut ch_sh_transcript = digest::Context::new(&digest::SHA256);
- ch_sh_transcript.update(client_hello_handshake);
- ch_sh_transcript.update(&server_hello);
-
- // Clone before finalizing
- let mut handshake_transcript = ch_sh_transcript.clone();
- let server_hello_hash = ch_sh_transcript.finish();
-
- // 8. Derive TLS 1.3 keys
- log::debug!(" Deriving TLS 1.3 keys");
-
- // Perform ECDH for TLS 1.3 key derivation
- let peer_public_key = agreement::UnparsedPublicKey::new(&agreement::X25519, client_public_key);
- let mut keys_for_ecdh_bytes = [0u8; 32];
- agreement::agree(
- &our_private_key,
- &peer_public_key,
- Error::new(ErrorKind::Other, "ECDH failed"),
- |key_material| {
- keys_for_ecdh_bytes.copy_from_slice(key_material);
- Ok(())
- },
- )?;
-
- // Construct EncryptedExtensions
- let encrypted_extensions = construct_encrypted_extensions()?;
- handshake_transcript.update(&encrypted_extensions);
-
- // Construct Certificate
- let certificate = construct_certificate(&cert_der)?;
- handshake_transcript.update(&certificate);
-
- // Construct CertificateVerify
- let cert_verify_hash = handshake_transcript.clone().finish();
- let certificate_verify = construct_certificate_verify(&signing_key, cert_verify_hash.as_ref())?;
- handshake_transcript.update(&certificate_verify);
-
- // Compute handshake hash before Finished
- let handshake_hash_before_finished = handshake_transcript.clone().finish();
-
- // Phase 1: Derive handshake keys and master secret (BEFORE Finished)
- // Per RFC 8446, application secrets must be derived AFTER server Finished
- log::info!(" Phase 1: Deriving handshake keys");
- let hs_keys = derive_handshake_keys(
- &keys_for_ecdh_bytes,
- client_hello_hash.as_ref(),
- server_hello_hash.as_ref(),
- )?;
-
- // Derive handshake traffic keys for encrypting handshake messages
- let (server_hs_key, server_hs_iv) = derive_traffic_keys(
- &hs_keys.server_handshake_traffic_secret,
- dest_handshake.cipher_suite,
- )?;
-
- // 9. Send ServerHello (plaintext)
- log::debug!(" Sending ServerHello");
- let mut sh_record = write_record_header(0x16, server_hello.len() as u16);
- sh_record.extend_from_slice(&server_hello);
- stream.write_all(&sh_record).await?;
-
- // 10. Send ChangeCipherSpec (for compatibility)
- log::debug!(" Sending ChangeCipherSpec");
- let ccs = write_record_header(0x14, 1);
- stream.write_all(&ccs).await?;
- stream.write_all(&[0x01]).await?;
- stream.flush().await?; // Ensure CCS is sent before encrypted messages
-
- let mut write_seq = 0u64;
-
- // 11. Send EncryptedExtensions (encrypted)
- log::debug!(" Sending EncryptedExtensions");
- let (ee_header, ee_ciphertext) = encrypt_handshake_message(
- &server_hs_key,
- &server_hs_iv,
- write_seq,
- &encrypted_extensions,
- )?;
- write_seq += 1;
-
- stream.write_all(&ee_header).await?;
- stream.write_all(&ee_ciphertext).await?;
-
- // 12. Send Certificate (encrypted)
- log::debug!(" Sending Certificate");
- let (cert_header, cert_ciphertext) = encrypt_handshake_message(
- &server_hs_key,
- &server_hs_iv,
- write_seq,
- &certificate,
- )?;
- write_seq += 1;
-
- stream.write_all(&cert_header).await?;
- stream.write_all(&cert_ciphertext).await?;
-
- // 13. Send CertificateVerify (encrypted)
- log::debug!(" Sending CertificateVerify");
- let (cv_header, cv_ciphertext) = encrypt_handshake_message(
- &server_hs_key,
- &server_hs_iv,
- write_seq,
- &certificate_verify,
- )?;
- write_seq += 1;
-
- stream.write_all(&cv_header).await?;
- stream.write_all(&cv_ciphertext).await?;
-
- // 14. Compute and send Finished
- log::debug!(" Computing server Finished");
- let server_finished_verify = compute_finished_verify_data(
- &hs_keys.server_handshake_traffic_secret,
- handshake_hash_before_finished.as_ref(),
- )?;
-
- let finished_msg = construct_finished(&server_finished_verify)?;
- handshake_transcript.update(&finished_msg);
-
- log::debug!(" Sending server Finished");
- let (fin_header, fin_ciphertext) = encrypt_handshake_message(
- &server_hs_key,
- &server_hs_iv,
- write_seq,
- &finished_msg,
- )?;
-
- stream.write_all(&fin_header).await?;
- stream.write_all(&fin_ciphertext).await?;
- stream.flush().await?; // Ensure server Finished is sent before waiting for client
-
- // Phase 2: Derive application traffic secrets (AFTER Finished)
- // Per RFC 8446 Section 7.1, application secrets MUST use transcript including Finished
- log::debug!(" Phase 2: Deriving application traffic secrets");
- let handshake_hash_with_finished = handshake_transcript.clone().finish();
- let (client_app_secret, server_app_secret) = derive_application_secrets(
- &hs_keys.master_secret,
- handshake_hash_with_finished.as_ref(),
- )?;
-
- log::debug!(" Application secrets derived with correct transcript!");
-
- // 15. Read and verify client Finished
- log::debug!(" Reading client Finished");
- let (client_hs_key, client_hs_iv) = derive_traffic_keys(
- &hs_keys.client_handshake_traffic_secret,
- dest_handshake.cipher_suite,
- )?;
-
- // Read TLS record header
- let mut header = [0u8; 5];
- stream.read_exact(&mut header).await?;
-
- let mut record_type = header[0];
- let mut record_length = u16::from_be_bytes([header[3], header[4]]);
-
- // Skip ChangeCipherSpec if present (TLS 1.3 compatibility message)
- if record_type == 0x14 {
- log::warn!(" Skipping ChangeCipherSpec (compatibility message)");
- // Read and discard the ChangeCipherSpec payload (always 1 byte: 0x01)
- let mut ccs_payload = vec![0u8; record_length as usize];
- stream.read_exact(&mut ccs_payload).await?;
-
- // Read the next record header (should be the encrypted Finished)
- stream.read_exact(&mut header).await?;
- record_type = header[0];
- record_length = u16::from_be_bytes([header[3], header[4]]);
- }
-
- if record_type != 0x17 {
- return Err(Error::new(
- ErrorKind::InvalidData,
- format!("Expected ApplicationData after ChangeCipherSpec, got 0x{:02x}", record_type),
- ));
- }
-
- // Read encrypted Finished
- let mut ciphertext = vec![0u8; record_length as usize];
- stream.read_exact(&mut ciphertext).await?;
-
- // Decrypt
- use super::reality_tls13_crypto::decrypt_handshake_message;
- let client_finished_msg = decrypt_handshake_message(
- &client_hs_key,
- &client_hs_iv,
- 0, // Client's first encrypted message
- &ciphertext,
- record_length,
- )?;
-
- // Verify it's a Finished message
- if client_finished_msg.len() < 4 || client_finished_msg[0] != 20 {
- return Err(Error::new(
- ErrorKind::InvalidData,
- "Expected Finished message",
- ));
- }
-
- // Extract verify_data
- let verify_data = &client_finished_msg[4..];
-
- // Compute expected verify_data
- let handshake_hash_for_client_finished = handshake_transcript.clone().finish();
- let expected_verify_data = compute_finished_verify_data(
- &hs_keys.client_handshake_traffic_secret,
- handshake_hash_for_client_finished.as_ref(),
- )?;
-
- if verify_data != &expected_verify_data[..] {
- return Err(Error::new(
- ErrorKind::InvalidData,
- "Client Finished verify_data mismatch",
- ));
- }
-
- log::info!(" Client Finished verified!");
-
- // 16. Handshake complete! Create Tls13Stream for application data
- log::debug!(" Creating TLS 1.3 encrypted stream with correctly-derived application secrets");
-
- let tls_stream = Tls13Stream::new(
- stream,
- &client_app_secret,
- &server_app_secret,
- dest_handshake.cipher_suite,
- false, // is_client = false (this is server)
- )?;
-
- log::info!("REALITY HANDSHAKE: Complete!");
-
- Ok(Box::new(tls_stream))
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use tokio::io::{AsyncReadExt, AsyncWriteExt};
-
- // Mock AsyncStream for testing
- #[derive(Debug)]
- struct MockStream {
- read_data: Vec<u8>,
- write_data: Vec<u8>,
- read_pos: usize,
- }
-
- impl MockStream {
- fn new(read_data: Vec<u8>) -> Self {
- MockStream {
- read_data,
- write_data: Vec::new(),
- read_pos: 0,
- }
- }
- }
-
- use std::pin::Pin;
- use std::task::{Context, Poll};
- use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
-
- impl AsyncRead for MockStream {
- fn poll_read(
- mut self: Pin<&mut Self>,
- _cx: &mut Context<'_>,
- buf: &mut ReadBuf<'_>,
- ) -> Poll<std::io::Result<()>> {
- let available = self.read_data.len() - self.read_pos;
- if available == 0 {
- return Poll::Ready(Ok(()));
- }
- let to_read = std::cmp::min(available, buf.remaining());
- buf.put_slice(&self.read_data[self.read_pos..self.read_pos + to_read]);
- self.read_pos += to_read;
- Poll::Ready(Ok(()))
- }
- }
-
- impl AsyncWrite for MockStream {
- fn poll_write(
- mut self: Pin<&mut Self>,
- _cx: &mut Context<'_>,
- buf: &[u8],
- ) -> Poll<std::io::Result<usize>> {
- self.write_data.extend_from_slice(buf);
- Poll::Ready(Ok(buf.len()))
- }
-
- fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
- Poll::Ready(Ok(()))
- }
-
- fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
- Poll::Ready(Ok(()))
- }
- }
-
- use crate::async_stream::AsyncPing;
-
- impl AsyncPing for MockStream {
- fn supports_ping(&self) -> bool {
- false
- }
-
- fn poll_write_ping(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<bool>> {
- Poll::Ready(Ok(false))
- }
-
- fn poll_read_pong(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<bool>> {
- Poll::Ready(Ok(false))
- }
- }
-
- impl AsyncStream for MockStream {}
-
- #[test]
- fn test_session_id_extraction() {
- // Create a mock ClientHello with a known session_id
- let mut client_hello = vec![0x16, 0x03, 0x03, 0x00, 0x50]; // TLS record header
- client_hello.extend_from_slice(&[0x01, 0x00, 0x00, 0x4c]); // Handshake header
- client_hello.extend_from_slice(&[0x03, 0x03]); // Version
- client_hello.extend_from_slice(&[0xaa; 32]); // Random
- client_hello.push(0x20); // Session ID length (32)
- let session_id = vec![0xbb; 32];
- client_hello.extend_from_slice(&session_id);
-
- // Extract session_id should work correctly
- let extracted_start = 5 + 38; // Skip TLS header + handshake header + version + random + length byte
- let extracted_end = extracted_start + 32;
- assert_eq!(&client_hello[extracted_start..extracted_end], &session_id[..]);
- }
-
- #[test]
- fn test_client_hello_parsing_boundaries() {
- // Test with minimum size ClientHello
- let mut client_hello = vec![0x16, 0x03, 0x03, 0x00, 0x27]; // TLS record header
- client_hello.extend_from_slice(&[0x01, 0x00, 0x00, 0x23]); // Handshake header
- client_hello.extend_from_slice(&[0x03, 0x03]); // Version
- client_hello.extend_from_slice(&[0xaa; 32]); // Random
- client_hello.push(0x00); // No session ID
-
- // Should be able to handle empty session_id
- let session_id_len = client_hello[5 + 38] as usize;
- assert_eq!(session_id_len, 0);
- }
-
- #[test]
- fn test_server_random_generation() {
- // Server random should be 32 bytes
- use aws_lc_rs::rand::SecureRandom;
- let rng = aws_lc_rs::rand::SystemRandom::new();
- let mut server_random = [0u8; 32];
- rng.fill(&mut server_random).unwrap();
- assert_eq!(server_random.len(), 32);
- // Should not be all zeros
- assert!(!server_random.iter().all(|&b| b == 0));
- }
-
- #[test]
- fn test_ecdh_key_generation() {
- use aws_lc_rs::agreement;
-
- // Generate a key pair
- let rng = aws_lc_rs::rand::SystemRandom::new();
- let mut private_bytes = [0u8; 32];
- rng.fill(&mut private_bytes).unwrap();
-
- let private_key = agreement::PrivateKey::from_private_key(
- &agreement::X25519,
- &private_bytes
- ).unwrap();
-
- let public_key_bytes = private_key.compute_public_key().unwrap();
-
- // Public key should be 32 bytes for X25519
- assert_eq!(public_key_bytes.as_ref().len(), 32);
- }
-
- #[test]
- fn test_cipher_suite_selection() {
- // Common cipher suites from destination probe
- let valid_cipher_suites = vec![
- 0x1301, // TLS_AES_128_GCM_SHA256
- 0x1302, // TLS_AES_256_GCM_SHA384
- 0x1303, // TLS_CHACHA20_POLY1305_SHA256
- ];
-
- for &suite in &valid_cipher_suites {
- // All these should be valid TLS 1.3 cipher suites
- assert!(suite >= 0x1301 && suite <= 0x1303);
- }
- }
-
- #[test]
- fn test_certificate_context_validation() {
- // Certificate context should be empty for server certificates
- let cert_context: Vec<u8> = vec![];
- assert_eq!(cert_context.len(), 0);
- }
-
- #[test]
- fn test_handshake_message_type_values() {
- // Verify handshake message type constants
- const CLIENT_HELLO: u8 = 1;
- const SERVER_HELLO: u8 = 2;
- const ENCRYPTED_EXTENSIONS: u8 = 8;
- const CERTIFICATE: u8 = 11;
- const CERTIFICATE_VERIFY: u8 = 15;
- const FINISHED: u8 = 20;
-
- assert_eq!(CLIENT_HELLO, 1);
- assert_eq!(SERVER_HELLO, 2);
- assert_eq!(ENCRYPTED_EXTENSIONS, 8);
- assert_eq!(CERTIFICATE, 11);
- assert_eq!(CERTIFICATE_VERIFY, 15);
- assert_eq!(FINISHED, 20);
- }
-
- #[test]
- fn test_tls_version_constants() {
- // TLS versions
- const TLS_1_2: u16 = 0x0303;
- const TLS_1_3: u16 = 0x0304;
-
- // In TLS 1.3, we still use 0x0303 in record headers for compatibility
- let record_version = TLS_1_2;
- assert_eq!(record_version, 0x0303);
-
- // But use 0x0304 in the handshake messages
- let handshake_version = TLS_1_3;
- assert_eq!(handshake_version, 0x0304);
- }
-
- #[test]
- fn test_extension_types() {
- // Common TLS extensions
- const KEY_SHARE: u16 = 51;
- const SUPPORTED_VERSIONS: u16 = 43;
- const SIGNATURE_ALGORITHMS: u16 = 13;
-
- assert_eq!(KEY_SHARE, 51);
- assert_eq!(SUPPORTED_VERSIONS, 43);
- assert_eq!(SIGNATURE_ALGORITHMS, 13);
- }
-
- #[test]
- fn test_record_header_construction() {
- // Test TLS record header creation
- fn write_record_header(content_type: u8, length: u16) -> Vec<u8> {
- vec![
- content_type,
- 0x03, 0x03, // TLS 1.2 for compatibility
- (length >> 8) as u8,
- (length & 0xff) as u8,
- ]
- }
-
- let header = write_record_header(0x16, 0x1234);
- assert_eq!(header, vec![0x16, 0x03, 0x03, 0x12, 0x34]);
-
- let header = write_record_header(0x17, 0x0001);
- assert_eq!(header, vec![0x17, 0x03, 0x03, 0x00, 0x01]);
- }
-}
diff --git a/src/reality/reality_tls13_messages.rs b/src/reality/reality_tls13_messages.rs
index 4154de0..bb60386 100644
@@ -267,6 +267,114 @@ pub fn construct_finished(verify_data: &[u8]) -> Result<Vec<u8>> {
/// Write TLS record header
///
+/// Construct TLS 1.3 ClientHello message
+///
+/// Returns handshake message bytes (without record header)
+pub fn construct_client_hello(
+ client_random: &[u8; 32],
+ session_id: &[u8; 32],
+ client_public_key: &[u8],
+ server_name: &str,
+) -> Result<Vec<u8>> {
+ let mut hello = Vec::with_capacity(512);
+
+ // Handshake message type: ClientHello (0x01)
+ hello.push(0x01);
+
+ // Placeholder for handshake message length (3 bytes)
+ let length_offset = hello.len();
+ hello.extend_from_slice(&[0u8; 3]);
+
+ // TLS version: 3.3 (TLS 1.2 for compatibility)
+ hello.extend_from_slice(&[0x03, 0x03]);
+
+ // Client random (32 bytes)
+ hello.extend_from_slice(client_random);
+
+ // Session ID length (1 byte) + Session ID (32 bytes)
+ hello.push(32);
+ hello.extend_from_slice(session_id);
+
+ // Cipher suites
+ // Support only TLS_AES_128_GCM_SHA256 (0x1301)
+ hello.extend_from_slice(&[0x00, 0x02]); // Cipher suites length: 2 bytes
+ hello.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256
+
+ // Compression methods (1 method: null)
+ hello.extend_from_slice(&[0x01, 0x00]);
+
+ // Extensions
+ let extensions_offset = hello.len();
+ hello.extend_from_slice(&[0u8; 2]); // Placeholder for extensions length
+
+ let mut extensions = Vec::new();
+
+ // 1. server_name extension (type 0)
+ {
+ let server_name_bytes = server_name.as_bytes();
+ let server_name_len = server_name_bytes.len();
+
+ extensions.extend_from_slice(&[0x00, 0x00]); // Extension type: server_name
+ let ext_len = 5 + server_name_len;
+ extensions.extend_from_slice(&(ext_len as u16).to_be_bytes()); // Extension length
+ extensions.extend_from_slice(&((server_name_len + 3) as u16).to_be_bytes()); // Server name list length
+ extensions.push(0x00); // Name type: host_name
+ extensions.extend_from_slice(&(server_name_len as u16).to_be_bytes()); // Name length
+ extensions.extend_from_slice(server_name_bytes); // Server name
+ }
+
+ // 2. supported_versions extension (type 43)
+ {
+ extensions.extend_from_slice(&[0x00, 0x2b]); // Extension type: supported_versions
+ extensions.extend_from_slice(&[0x00, 0x03]); // Extension length: 3
+ extensions.push(0x02); // Supported versions length: 2
+ extensions.extend_from_slice(&[0x03, 0x04]); // TLS 1.3
+ }
+
+ // 3. supported_groups extension (type 10)
+ {
+ extensions.extend_from_slice(&[0x00, 0x0a]); // Extension type: supported_groups
+ extensions.extend_from_slice(&[0x00, 0x04]); // Extension length: 4
+ extensions.extend_from_slice(&[0x00, 0x02]); // Supported groups length: 2
+ extensions.extend_from_slice(&[0x00, 0x1d]); // x25519
+ }
+
+ // 4. key_share extension (type 51)
+ {
+ extensions.extend_from_slice(&[0x00, 0x33]); // Extension type: key_share
+ let key_share_len = 2 + 4 + client_public_key.len();
+ extensions.extend_from_slice(&(key_share_len as u16).to_be_bytes()); // Extension length
+ let key_share_list_len = 4 + client_public_key.len();
+ extensions.extend_from_slice(&(key_share_list_len as u16).to_be_bytes()); // Key share list length
+ extensions.extend_from_slice(&[0x00, 0x1d]); // Group: x25519
+ extensions.extend_from_slice(&(client_public_key.len() as u16).to_be_bytes()); // Key length
+ extensions.extend_from_slice(client_public_key); // Public key
+ }
+
+ // 5. signature_algorithms extension (type 13)
+ {
+ extensions.extend_from_slice(&[0x00, 0x0d]); // Extension type: signature_algorithms
+ extensions.extend_from_slice(&[0x00, 0x04]); // Extension length: 4
+ extensions.extend_from_slice(&[0x00, 0x02]); // Signature algorithms length: 2
+ extensions.extend_from_slice(&[0x08, 0x07]); // ed25519
+ }
+
+ // Write extensions length
+ let extensions_length = extensions.len();
+ hello[extensions_offset..extensions_offset + 2]
+ .copy_from_slice(&(extensions_length as u16).to_be_bytes());
+
+ // Append extensions
+ hello.extend_from_slice(&extensions);
+
+ // Write handshake message length
+ let message_length = hello.len() - 4; // Exclude type (1) and length (3)
+ hello[length_offset..length_offset + 3]
+ .copy_from_slice(&(message_length as u32).to_be_bytes()[1..]);
+
+ Ok(hello)
+}
+
/// # Arguments
/// * `record_type` - TLS record type (0x16 for Handshake, 0x17 for ApplicationData)
/// * `length` - Length of record payload
diff --git a/src/reality/reality_tls13_stream.rs b/src/reality/reality_tls13_stream.rs
deleted file mode 100644
index 75324a6..0000000
@@ -1,530 +0,0 @@
-// TLS 1.3 Post-Handshake Stream Wrapper
-//
-// Wraps an AsyncStream to provide TLS 1.3 encryption/decryption for application data
-
-use crate::async_stream::{AsyncStream, AsyncPing};
-use super::reality_tls13_crypto::{encrypt_tls13_record, decrypt_tls13_record};
-use super::reality_tls13_keys::derive_traffic_keys;
-use std::io::{Error, ErrorKind, Result};
-use std::pin::Pin;
-use std::task::{Context, Poll};
-use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
-
-/// TLS 1.3 encrypted stream wrapper
-///
-/// Provides transparent encryption/decryption of application data
-/// after TLS 1.3 handshake completes
-pub struct Tls13Stream {
- inner: Box<dyn AsyncStream>,
-
- // Encryption state
- write_key: Vec<u8>,
- write_iv: Vec<u8>,
- write_seq: u64,
-
- // Decryption state
- read_key: Vec<u8>,
- read_iv: Vec<u8>,
- read_seq: u64,
-
- // Read buffer for partial records
- read_buffer: Vec<u8>,
- read_buffer_pos: usize,
-
- // Pending encrypted record data
- pending_read: Vec<u8>,
- pending_read_pos: usize,
-
- // Connection state
- received_close_notify: bool,
- shutdown_sent: bool,
-
- // Write buffer for partial record writes
- write_buffer: Vec<u8>,
- write_buffer_pos: usize,
-}
-
-impl Tls13Stream {
- /// Create new TLS 1.3 stream from traffic secrets
- ///
- /// # Arguments
- /// * `inner` - Underlying async stream
- /// * `client_application_traffic_secret` - Client app traffic secret
- /// * `server_application_traffic_secret` - Server app traffic secret
- /// * `cipher_suite` - TLS cipher suite (e.g., 0x1301 for AES-128-GCM-SHA256)
- /// * `is_client` - True if this is a client, false if this is a server
- pub fn new(
- inner: Box<dyn AsyncStream>,
- client_application_traffic_secret: &[u8],
- server_application_traffic_secret: &[u8],
- cipher_suite: u16,
- is_client: bool,
- ) -> Result<Self> {
- // Derive keys based on whether we're client or server
- let (write_key, write_iv, read_key, read_iv) = if is_client {
- // CLIENT:
- // - Write with client keys (client -> server direction)
- // - Read with server keys (server -> client direction)
- let (write_key, write_iv) = derive_traffic_keys(
- client_application_traffic_secret,
- cipher_suite,
- )?;
- let (read_key, read_iv) = derive_traffic_keys(
- server_application_traffic_secret,
- cipher_suite,
- )?;
- (write_key, write_iv, read_key, read_iv)
- } else {
- // SERVER:
- // - Read with client keys (client -> server direction)
- // - Write with server keys (server -> client direction)
- let (read_key, read_iv) = derive_traffic_keys(
- client_application_traffic_secret,
- cipher_suite,
- )?;
- let (write_key, write_iv) = derive_traffic_keys(
- server_application_traffic_secret,
- cipher_suite,
- )?;
- (write_key, write_iv, read_key, read_iv)
- };
-
- log::debug!("TLS13 STREAM DEBUG: Created stream with keys:");
- log::debug!(" read_key: {:?}", &read_key[..8]);
- log::debug!(" write_key: {:?}", &write_key[..8]);
-
- Ok(Tls13Stream {
- inner,
- write_key,
- write_iv,
- write_seq: 0,
- read_key,
- read_iv,
- read_seq: 0,
- read_buffer: Vec::new(),
- read_buffer_pos: 0,
- pending_read: Vec::new(),
- pending_read_pos: 0,
- received_close_notify: false,
- shutdown_sent: false,
- write_buffer: Vec::new(),
- write_buffer_pos: 0,
- })
- }
-
- /// Read and decrypt one TLS record from the stream
- fn read_tls_record(
- &mut self,
- cx: &mut Context<'_>,
- ) -> Poll<Result<()>> {
- // Read TLS record header (5 bytes: type + version + length)
- let mut header = [0u8; 5];
- let mut header_buf = ReadBuf::new(&mut header);
-
- match Pin::new(&mut self.inner).poll_read(cx, &mut header_buf) {
- Poll::Ready(Ok(())) if header_buf.filled().len() == 5 => {
- let record_type = header[0];
- let length = u16::from_be_bytes([header[3], header[4]]);
-
- // Expect ApplicationData (0x17) for encrypted records
- if record_type != 0x17 {
- return Poll::Ready(Err(Error::new(
- ErrorKind::InvalidData,
- format!("Unexpected record type: 0x{:02x}", record_type),
- )));
- }
-
- // Read record payload
- let mut ciphertext = vec![0u8; length as usize];
- let mut ciphertext_buf = ReadBuf::new(&mut ciphertext);
-
- match Pin::new(&mut self.inner).poll_read(cx, &mut ciphertext_buf) {
- Poll::Ready(Ok(())) if ciphertext_buf.filled().len() == length as usize => {
- // Decrypt the record
- let mut aad = Vec::new();
- aad.push(0x17); // ApplicationData
- aad.extend_from_slice(&[0x03, 0x03]); // TLS 1.2
- aad.extend_from_slice(&length.to_be_bytes());
-
- log::debug!("TLS13 STREAM: Decrypting record seq={} len={}", self.read_seq, ciphertext.len());
- log::debug!("TLS13 STREAM: read_key={:?}", &self.read_key[..8]);
- log::debug!("TLS13 STREAM: read_iv={:?}", &self.read_iv);
- log::debug!("TLS13 STREAM: ciphertext (first 32)={:?}", &ciphertext[..32.min(ciphertext.len())]);
-
- match decrypt_tls13_record(
- &self.read_key,
- &self.read_iv,
- self.read_seq,
- &ciphertext,
- &aad,
- ) {
- Ok(mut plaintext) => {
- log::info!("TLS13 STREAM: Decryption successful, plaintext len={}", plaintext.len());
- self.read_seq += 1;
-
- // Remove ContentType trailer
- if plaintext.is_empty() {
- return Poll::Ready(Err(Error::new(
- ErrorKind::InvalidData,
- "Empty plaintext",
- )));
- }
-
- let content_type = plaintext.pop().unwrap();
-
- // Handle different content types
- match content_type {
- 0x17 => {
- // ApplicationData - store in pending buffer
- self.pending_read = plaintext;
- self.pending_read_pos = 0;
- Poll::Ready(Ok(()))
- }
- 0x15 => {
- // Alert - check if it's close_notify (normal) or an error
- if plaintext.len() >= 2 {
- let alert_level = plaintext[0];
- let alert_description = plaintext[1];
- log::debug!("TLS13 STREAM: Alert received - level={}, description={}", alert_level, alert_description);
-
- // close_notify has description = 0
- if alert_description == 0 {
- // close_notify - treat as EOF
- log::debug!("TLS13 STREAM: Received close_notify alert (normal connection closure)");
- // Mark that we've received close_notify
- self.received_close_notify = true;
- // Clear pending buffer
- self.pending_read = Vec::new();
- self.pending_read_pos = 0;
- Poll::Ready(Ok(()))
- } else {
- // Other alerts are errors
- Poll::Ready(Err(Error::new(
- ErrorKind::ConnectionAborted,
- format!("TLS Alert: level={}, description={}", alert_level, alert_description),
- )))
- }
- } else {
- Poll::Ready(Err(Error::new(
- ErrorKind::InvalidData,
- "Invalid alert message",
- )))
- }
- }
- _ => {
- Poll::Ready(Err(Error::new(
- ErrorKind::InvalidData,
- format!("Unexpected content type: 0x{:02x}", content_type),
- )))
- }
- }
- }
- Err(e) => Poll::Ready(Err(e)),
- }
- }
- Poll::Ready(Ok(())) => {
- Poll::Ready(Err(Error::new(
- ErrorKind::UnexpectedEof,
- "Incomplete record payload",
- )))
- }
- Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
- Poll::Pending => Poll::Pending,
- }
- }
- Poll::Ready(Ok(())) => {
- Poll::Ready(Err(Error::new(
- ErrorKind::UnexpectedEof,
- "Incomplete record header",
- )))
- }
- Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
- Poll::Pending => Poll::Pending,
- }
- }
-}
-
-impl AsyncRead for Tls13Stream {
- fn poll_read(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut ReadBuf<'_>,
- ) -> Poll<Result<()>> {
- // If we've received close_notify, return EOF (0 bytes)
- if self.received_close_notify {
- return Poll::Ready(Ok(()));
- }
-
- // If we have pending decrypted data, return it
- if self.pending_read_pos < self.pending_read.len() {
- let available = self.pending_read.len() - self.pending_read_pos;
- let to_copy = std::cmp::min(available, buf.remaining());
-
- buf.put_slice(&self.pending_read[self.pending_read_pos..self.pending_read_pos + to_copy]);
- self.pending_read_pos += to_copy;
-
- return Poll::Ready(Ok(()));
- }
-
- // Need to read a new TLS record
- match self.read_tls_record(cx) {
- Poll::Ready(Ok(())) => {
- // Check if we received close_notify while reading
- if self.received_close_notify {
- // Return EOF
- Poll::Ready(Ok(()))
- } else {
- // Recursively call poll_read to return the data
- self.poll_read(cx, buf)
- }
- }
- Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
- Poll::Pending => Poll::Pending,
- }
- }
-}
-
-impl AsyncWrite for Tls13Stream {
- fn poll_write(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &[u8],
- ) -> Poll<Result<usize>> {
- // If we have a pending write, complete it first
- if self.write_buffer_pos < self.write_buffer.len() {
- // Extract values to avoid borrow conflict
- let this = self.as_mut().get_mut();
- let remaining_len = this.write_buffer.len() - this.write_buffer_pos;
-
- // Create a temporary buffer to hold the slice
- let mut temp_buf = vec![0u8; remaining_len];
- temp_buf.copy_from_slice(&this.write_buffer[this.write_buffer_pos..]);
-
- match Pin::new(&mut this.inner).poll_write(cx, &temp_buf) {
- Poll::Ready(Ok(n)) => {
- this.write_buffer_pos += n;
- if this.write_buffer_pos >= this.write_buffer.len() {
- // Finished writing buffered data
- this.write_buffer.clear();
- this.write_buffer_pos = 0;
- } else {
- // Still have more to write, return Pending
- return Poll::Pending;
- }
- }
- Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
- Poll::Pending => return Poll::Pending,
- }
- }
-
- // TLS 1.3 maximum plaintext size is 16384 bytes (2^14)
- const MAX_PLAINTEXT_SIZE: usize = 16384;
-
- // Limit the write to MAX_PLAINTEXT_SIZE
- let to_write = std::cmp::min(buf.len(), MAX_PLAINTEXT_SIZE);
- let write_buf = &buf[..to_write];
-
- // Encrypt application data
- let mut plaintext = Vec::new();
- plaintext.extend_from_slice(write_buf);
- plaintext.push(0x17); // ContentType: ApplicationData
-
- // Additional data for AEAD
- let ciphertext_length = (plaintext.len() + 16) as u16; // +16 for auth tag
- let mut aad = Vec::new();
- aad.push(0x17); // ApplicationData
- aad.extend_from_slice(&[0x03, 0x03]); // TLS 1.2
- aad.extend_from_slice(&ciphertext_length.to_be_bytes());
-
- let ciphertext = match encrypt_tls13_record(
- &self.write_key,
- &self.write_iv,
- self.write_seq,
- &plaintext,
- &aad,
- ) {
- Ok(ct) => ct,
- Err(e) => return Poll::Ready(Err(e)),
- };
-
- self.write_seq += 1;
-
- // Construct full TLS record
- let mut record = Vec::new();
- record.extend_from_slice(&aad); // Header
- record.extend_from_slice(&ciphertext);
-
- // Try to write the record
- match Pin::new(&mut self.inner).poll_write(cx, &record) {
- Poll::Ready(Ok(n)) => {
- if n == record.len() {
- // Successfully wrote entire record
- Poll::Ready(Ok(to_write))
- } else if n > 0 {
- // Partial write - buffer the remainder
- self.write_buffer = record;
- self.write_buffer_pos = n;
- // Still report success for the user data
- Poll::Ready(Ok(to_write))
- } else {
- // No bytes written
- Poll::Ready(Err(Error::new(
- ErrorKind::WriteZero,
- "Failed to write TLS record",
- )))
- }
- }
- Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
- Poll::Pending => Poll::Pending,
- }
- }
-
- fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
- // First complete any pending writes
- if self.write_buffer_pos < self.write_buffer.len() {
- // Extract values to avoid borrow conflict
- let this = self.as_mut().get_mut();
- let remaining_len = this.write_buffer.len() - this.write_buffer_pos;
-
- // Create a temporary buffer to hold the slice
- let mut temp_buf = vec![0u8; remaining_len];
- temp_buf.copy_from_slice(&this.write_buffer[this.write_buffer_pos..]);
-
- match Pin::new(&mut this.inner).poll_write(cx, &temp_buf) {
- Poll::Ready(Ok(n)) => {
- this.write_buffer_pos += n;
- if this.write_buffer_pos >= this.write_buffer.len() {
- // Finished writing buffered data
- this.write_buffer.clear();
- this.write_buffer_pos = 0;
- } else {
- // Still have more to write
- return Poll::Pending;
- }
- }
- Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
- Poll::Pending => return Poll::Pending,
- }
- }
-
- // Then flush the underlying stream to ensure all data is sent
- Pin::new(&mut self.inner).poll_flush(cx)
- }
-
- fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
- // For TLS 1.3, we should send a close_notify alert before shutting down
- // However, for simplicity and to avoid deadlocks, we'll just ensure
- // the underlying stream is properly flushed before shutdown
-
- // First ensure all pending data is flushed
- match Pin::new(&mut self.inner).poll_flush(cx) {
- Poll::Pending => return Poll::Pending,
- Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
- Poll::Ready(Ok(())) => {}
- }
-
- // Now shutdown the underlying stream
- Pin::new(&mut self.inner).poll_shutdown(cx)
- }
-}
-
-impl AsyncPing for Tls13Stream {
- fn supports_ping(&self) -> bool {
- self.inner.supports_ping()
- }
-
- fn poll_write_ping(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<bool>> {
- Pin::new(&mut self.inner).poll_write_ping(cx)
- }
-}
-
-impl AsyncStream for Tls13Stream {}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::reality::reality_tls13_crypto::{encrypt_tls13_record, decrypt_tls13_record};
-
- #[test]
- fn test_tls13_stream_key_derivation() {
- // Test that client and server derive opposite keys
- let client_secret = vec![0x11u8; 32];
- let server_secret = vec![0x22u8; 32];
-
- // Derive keys for client
- let (client_write_key, client_write_iv, client_read_key, client_read_iv) = {
- let (write_key, write_iv) = derive_traffic_keys(&client_secret, 0x1301).unwrap();
- let (read_key, read_iv) = derive_traffic_keys(&server_secret, 0x1301).unwrap();
- (write_key, write_iv, read_key, read_iv)
- };
-
- // Derive keys for server (should be opposite)
- let (server_write_key, server_write_iv, server_read_key, server_read_iv) = {
- let (write_key, write_iv) = derive_traffic_keys(&server_secret, 0x1301).unwrap();
- let (read_key, read_iv) = derive_traffic_keys(&client_secret, 0x1301).unwrap();
- (write_key, write_iv, read_key, read_iv)
- };
-
- // Client writes with client keys, server reads with client keys
- assert_eq!(client_write_key, server_read_key);
- assert_eq!(client_write_iv, server_read_iv);
-
- // Server writes with server keys, client reads with server keys
- assert_eq!(server_write_key, client_read_key);
- assert_eq!(server_write_iv, client_read_iv);
- }
-
- #[test]
- fn test_tls13_encrypt_decrypt_roundtrip() {
- let key = vec![0x42u8; 16]; // AES-128 key
- let iv = vec![0x11u8; 12]; // GCM IV
- let plaintext = b"Hello, TLS 1.3!";
- let mut plaintext_with_type = plaintext.to_vec();
- plaintext_with_type.push(0x17); // ApplicationData content type
-
- // AAD for TLS record
- let mut aad = Vec::new();
- aad.push(0x17); // ApplicationData
- aad.extend_from_slice(&[0x03, 0x03]); // TLS 1.2
- let ciphertext_len = (plaintext_with_type.len() + 16) as u16; // +16 for tag
- aad.extend_from_slice(&ciphertext_len.to_be_bytes());
-
- // Encrypt
- let ciphertext = encrypt_tls13_record(&key, &iv, 0, &plaintext_with_type, &aad).unwrap();
- assert_eq!(ciphertext.len(), plaintext_with_type.len() + 16);
-
- // Decrypt
- let decrypted = decrypt_tls13_record(&key, &iv, 0, &ciphertext, &aad).unwrap();
- assert_eq!(decrypted, plaintext_with_type);
- }
-
- #[test]
- fn test_sequence_number_in_nonce() {
- let key = vec![0x42u8; 16];
- let iv = vec![0x11u8; 12];
- let plaintext = b"test";
- let mut plaintext_with_type = plaintext.to_vec();
- plaintext_with_type.push(0x17);
-
- let mut aad = Vec::new();
- aad.push(0x17);
- aad.extend_from_slice(&[0x03, 0x03]);
- let ciphertext_len = (plaintext_with_type.len() + 16) as u16;
- aad.extend_from_slice(&ciphertext_len.to_be_bytes());
-
- // Encrypt with sequence 0
- let ct1 = encrypt_tls13_record(&key, &iv, 0, &plaintext_with_type, &aad).unwrap();
-
- // Encrypt with sequence 1 - should give different ciphertext
- let ct2 = encrypt_tls13_record(&key, &iv, 1, &plaintext_with_type, &aad).unwrap();
- assert_ne!(ct1, ct2);
-
- // Decrypting with wrong sequence should fail
- let result = decrypt_tls13_record(&key, &iv, 1, &ct1, &aad);
- assert!(result.is_err());
-
- // Decrypting with correct sequence should work
- let result = decrypt_tls13_record(&key, &iv, 0, &ct1, &aad);
- assert!(result.is_ok());
- }
-}
diff --git a/src/reality/reality_tls13_stream_test.rs b/src/reality/reality_tls13_stream_test.rs
deleted file mode 100644
index 310737d..0000000
@@ -1,172 +0,0 @@
-// Unit tests for reality_tls13_stream.rs
-
-#[cfg(test)]
-mod tests {
- use crate::reality::reality_tls13_stream::Tls13Stream;
- use crate::async_stream::AsyncStream;
- use tokio::io::{AsyncReadExt, AsyncWriteExt};
- use std::pin::Pin;
- use std::task::{Context, Poll};
- use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
-
- // Mock stream for testing
- struct MockStream {
- read_data: Vec<u8>,
- read_pos: usize,
- written_data: Vec<u8>,
- }
-
- impl MockStream {
- fn new() -> Self {
- MockStream {
- read_data: Vec::new(),
- read_pos: 0,
- written_data: Vec::new(),
- }
- }
-
- fn set_read_data(&mut self, data: Vec<u8>) {
- self.read_data = data;
- self.read_pos = 0;
- }
-
- fn get_written_data(&self) -> &[u8] {
- &self.written_data
- }
- }
-
- impl AsyncRead for MockStream {
- fn poll_read(
- mut self: Pin<&mut Self>,
- _cx: &mut Context<'_>,
- buf: &mut ReadBuf<'_>,
- ) -> Poll<std::io::Result<()>> {
- let remaining = &self.read_data[self.read_pos..];
- let to_read = std::cmp::min(remaining.len(), buf.remaining());
-
- buf.put_slice(&remaining[..to_read]);
- self.read_pos += to_read;
-
- Poll::Ready(Ok(()))
- }
- }
-
- impl AsyncWrite for MockStream {
- fn poll_write(
- mut self: Pin<&mut Self>,
- _cx: &mut Context<'_>,
- buf: &[u8],
- ) -> Poll<std::io::Result<usize>> {
- self.written_data.extend_from_slice(buf);
- Poll::Ready(Ok(buf.len()))
- }
-
- fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
- Poll::Ready(Ok(()))
- }
-
- fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
- Poll::Ready(Ok(()))
- }
- }
-
- impl crate::async_stream::AsyncPing for MockStream {
- fn supports_ping(&self) -> bool {
- false
- }
-
- fn poll_write_ping(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<bool>> {
- Poll::Ready(Ok(false))
- }
- }
-
- impl AsyncStream for MockStream {}
-
- #[tokio::test]
- async fn test_tls13_stream_creation() {
- let mock_stream = Box::new(MockStream::new());
- let client_secret = vec![0x11u8; 32];
- let server_secret = vec![0x22u8; 32];
-
- let stream_result = Tls13Stream::new(
- mock_stream,
- &client_secret,
- &server_secret,
- 0x1301, // TLS_AES_128_GCM_SHA256
- true, // is_client
- );
-
- assert!(stream_result.is_ok());
- }
-
- #[tokio::test]
- async fn test_tls13_stream_write_read_roundtrip() {
- use crate::reality::reality_tls13_crypto::encrypt_tls13_record;
-
- let client_secret = vec![0x11u8; 32];
- let server_secret = vec![0x22u8; 32];
-
- // Create a client stream for writing
- let mut client_mock = Box::new(MockStream::new());
- let mut client_stream = Tls13Stream::new(
- client_mock,
- &client_secret,
- &server_secret,
- 0x1301, // TLS_AES_128_GCM_SHA256
- true, // is_client
- ).unwrap();
-
- // Write some data
- let test_data = b"Hello, TLS 1.3!";
- let written = client_stream.write(test_data).await.unwrap();
- assert_eq!(written, test_data.len());
-
- // The stream should have written an encrypted TLS record
- // We can't easily test the decryption without implementing the full logic,
- // but we can verify that data was written
- // Note: In a real test, we'd need to access the internal mock stream
- }
-
- #[test]
- fn test_derive_traffic_keys() {
- use crate::reality::reality_tls13_keys::derive_traffic_keys;
-
- let traffic_secret = vec![0x42u8; 32];
-
- // Test AES-128-GCM
- let result = derive_traffic_keys(&traffic_secret, 0x1301);
- assert!(result.is_ok());
- let (key, iv) = result.unwrap();
- assert_eq!(key.len(), 16); // AES-128 key
- assert_eq!(iv.len(), 12); // GCM IV
-
- // Test AES-256-GCM
- let result = derive_traffic_keys(&traffic_secret, 0x1302);
- assert!(result.is_ok());
- let (key, iv) = result.unwrap();
- assert_eq!(key.len(), 32); // AES-256 key
- assert_eq!(iv.len(), 12); // GCM IV
-
- // Test unsupported cipher
- let result = derive_traffic_keys(&traffic_secret, 0xFFFF);
- assert!(result.is_err());
- }
-
- #[test]
- fn test_compute_finished_verify_data() {
- use crate::reality::reality_tls13_keys::compute_finished_verify_data;
-
- let base_key = vec![0x11u8; 32];
- let handshake_hash = vec![0x22u8; 32];
-
- let result = compute_finished_verify_data(&base_key, &handshake_hash);
- assert!(result.is_ok());
-
- let verify_data = result.unwrap();
- assert_eq!(verify_data.len(), 32); // HMAC-SHA256 output
-
- // Should be deterministic
- let result2 = compute_finished_verify_data(&base_key, &handshake_hash).unwrap();
- assert_eq!(verify_data, result2);
- }
-}
\ No newline at end of file