use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
#[derive(Debug)]
pub enum ServerResponse {
ServerHello {
selected_suite: u16,
selected_group: Option<u16>,
tls_version: u16,
},
HandshakeFailure,
ConnectionClose,
Timeout,
}
pub fn parse_server_response(bytes: &[u8]) -> Result<ServerResponse, String> {
if bytes.len() < 5 {
return Err("too short".into());
}
match bytes[0] {
0x15 => {
if bytes.len() >= 7 {
let level = bytes[5];
let desc = bytes[6];
if level == 0x02 && desc == 0x28 {
Ok(ServerResponse::HandshakeFailure)
} else {
Ok(ServerResponse::ConnectionClose)
}
} else {
Ok(ServerResponse::HandshakeFailure) }
}
0x16 => parse_server_hello(bytes), b => Err(format!("unexpected record type: 0x{:02x}", b)),
}
}
fn parse_server_hello(bytes: &[u8]) -> Result<ServerResponse, String> {
if bytes.len() < 5 {
return Err("record too short".into());
}
let record_len = u16::from_be_bytes([bytes[3], bytes[4]]) as usize;
if bytes.len() < 5 + record_len {
return Err("record truncated".into());
}
let payload = &bytes[5..5 + record_len];
if payload.len() < 4 {
return Err("handshake too short".into());
}
if payload[0] != 0x02 {
return Err(format!(
"expected ServerHello (0x02), got 0x{:02x}",
payload[0]
));
}
let body = &payload[4..];
if body.len() < 2 + 32 + 1 {
return Err("ServerHello body too short".into());
}
let tls_version = u16::from_be_bytes([body[0], body[1]]);
let sid_len = body[34] as usize;
let after_sid = 35 + sid_len;
if body.len() < after_sid + 2 {
return Err("ServerHello truncated after session_id".into());
}
let selected_suite = u16::from_be_bytes([body[after_sid], body[after_sid + 1]]);
let selected_group = parse_key_share_group(body, after_sid + 3);
Ok(ServerResponse::ServerHello {
selected_suite,
selected_group,
tls_version,
})
}
fn parse_key_share_group(body: &[u8], extensions_offset: usize) -> Option<u16> {
if body.len() < extensions_offset + 2 {
return None; }
let ext_total_len =
u16::from_be_bytes([body[extensions_offset], body[extensions_offset + 1]]) as usize;
let mut pos = extensions_offset + 2;
let end = pos + ext_total_len;
if body.len() < end {
return None;
}
while pos + 4 <= end {
let ext_type = u16::from_be_bytes([body[pos], body[pos + 1]]);
let ext_len = u16::from_be_bytes([body[pos + 2], body[pos + 3]]) as usize;
pos += 4;
if pos + ext_len > end {
break;
}
if ext_type == 0x0033 && ext_len >= 2 {
let group = u16::from_be_bytes([body[pos], body[pos + 1]]);
return Some(group);
}
pos += ext_len;
}
None
}
pub fn build_client_hello(
sni: &str,
cipher_suites: &[u16],
named_groups: &[u16],
max_version: u16,
) -> Vec<u8> {
let mut extensions = Vec::new();
if !sni.is_empty() {
let name_bytes = sni.as_bytes();
let name_len = name_bytes.len();
let list_len = name_len + 3; let mut sni_ext = vec![
0x00, 0x00, ];
let total_len = (list_len + 2) as u16;
sni_ext.extend_from_slice(&total_len.to_be_bytes());
sni_ext.extend_from_slice(&(list_len as u16).to_be_bytes()); sni_ext.push(0x00); sni_ext.extend_from_slice(&(name_len as u16).to_be_bytes());
sni_ext.extend_from_slice(name_bytes);
extensions.extend_from_slice(&sni_ext);
}
{
let versions: &[u16] = if max_version >= 0x0304 {
&[0x0304, 0x0303] } else {
&[0x0303] };
let mut ext = vec![0x00, 0x2b]; let versions_len = versions.len() * 2;
let ext_data_len = (versions_len + 1) as u16; ext.extend_from_slice(&ext_data_len.to_be_bytes());
ext.push(versions_len as u8);
for v in versions {
ext.extend_from_slice(&v.to_be_bytes());
}
extensions.extend_from_slice(&ext);
}
{
let groups_bytes: Vec<u8> = named_groups.iter().flat_map(|g| g.to_be_bytes()).collect();
let list_len = groups_bytes.len() as u16;
let ext_len = list_len + 2; let mut ext = vec![0x00, 0x0a];
ext.extend_from_slice(&ext_len.to_be_bytes());
ext.extend_from_slice(&list_len.to_be_bytes());
ext.extend_from_slice(&groups_bytes);
extensions.extend_from_slice(&ext);
}
{
const SIG_ALGS: &[u16] = &[
0x0403, 0x0503, 0x0603, 0x0804, 0x0805, 0x0806, 0x0401, 0x0501, 0x0601, ];
let alg_bytes: Vec<u8> = SIG_ALGS.iter().flat_map(|a| a.to_be_bytes()).collect();
let list_len = alg_bytes.len() as u16;
let ext_data_len = list_len + 2;
let mut ext = vec![0x00, 0x0d];
ext.extend_from_slice(&ext_data_len.to_be_bytes());
ext.extend_from_slice(&list_len.to_be_bytes());
ext.extend_from_slice(&alg_bytes);
extensions.extend_from_slice(&ext);
}
{
let (ks_group, ks_key): (u16, Vec<u8>) = if named_groups.contains(&0x001D) {
let mut key = vec![0u8; 32];
key[0] = 9; (0x001D, key)
} else {
let group = named_groups.first().copied().unwrap_or(0x001D);
(group, vec![0u8; 32])
};
let key_len = ks_key.len() as u16;
let entry_len: u16 = 2 + 2 + key_len; let list_len: u16 = entry_len;
let ext_data_len: u16 = list_len + 2;
let mut ext = vec![0x00, 0x33];
ext.extend_from_slice(&ext_data_len.to_be_bytes());
ext.extend_from_slice(&list_len.to_be_bytes());
ext.extend_from_slice(&ks_group.to_be_bytes());
ext.extend_from_slice(&key_len.to_be_bytes());
ext.extend_from_slice(&ks_key);
extensions.extend_from_slice(&ext);
}
let suites_bytes: Vec<u8> = cipher_suites.iter().flat_map(|s| s.to_be_bytes()).collect();
let suites_len = suites_bytes.len() as u16;
let mut hello_body = Vec::new();
hello_body.extend_from_slice(&0x0303u16.to_be_bytes()); hello_body.extend(std::iter::repeat_n(0u8, 32)); hello_body.push(0x00); hello_body.extend_from_slice(&suites_len.to_be_bytes());
hello_body.extend_from_slice(&suites_bytes);
hello_body.push(0x01); hello_body.push(0x00); hello_body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
hello_body.extend_from_slice(&extensions);
let hello_len = hello_body.len();
let mut handshake = vec![
0x01, ((hello_len >> 16) & 0xFF) as u8,
((hello_len >> 8) & 0xFF) as u8,
(hello_len & 0xFF) as u8,
];
handshake.extend_from_slice(&hello_body);
let mut record = vec![
0x16,
0x03,
0x01, ((handshake.len() >> 8) & 0xFF) as u8,
(handshake.len() & 0xFF) as u8,
];
record.extend_from_slice(&handshake);
record
}
pub async fn probe_once(stream: &mut TcpStream, hello: &[u8], timeout_ms: u64) -> ServerResponse {
use tokio::time::{Duration, Instant};
if stream.write_all(hello).await.is_err() {
return ServerResponse::ConnectionClose;
}
let deadline = Instant::now() + Duration::from_millis(timeout_ms);
let mut buf = Vec::with_capacity(8192);
loop {
let need = if buf.len() >= 5 {
5 + u16::from_be_bytes([buf[3], buf[4]]) as usize
} else {
5 };
if buf.len() >= need {
break;
}
let remaining = match deadline.checked_duration_since(Instant::now()) {
Some(d) => d,
None => return ServerResponse::Timeout,
};
let mut chunk = [0u8; 4096];
match tokio::time::timeout(remaining, stream.read(&mut chunk)).await {
Ok(Ok(0)) => return ServerResponse::ConnectionClose,
Ok(Ok(n)) => buf.extend_from_slice(&chunk[..n]),
Ok(Err(_)) => return ServerResponse::ConnectionClose,
Err(_) => return ServerResponse::Timeout,
}
}
parse_server_response(&buf).unwrap_or(ServerResponse::ConnectionClose)
}