use std::io::prelude::*;
use std::net::TcpStream;
use std::time::Duration;
use bstr::ByteSlice;
use bytes::BytesMut;
use metalssh::constants::msg::SSH_MSG_KEX_ECDH_REPLY;
use metalssh::constants::msg::SSH_MSG_KEXINIT;
use metalssh::constants::msg::SSH_MSG_NEWKEYS;
use metalssh::crypto::cipher::Cipher;
use metalssh::crypto::cipher::chacha20poly1305::ChaCha20Poly1305;
use metalssh::crypto::hash::HashFn;
use metalssh::crypto::hash::Sha256;
use metalssh::crypto::kdf;
use metalssh::crypto::kex::Initiate;
use metalssh::crypto::kex::curve25519_sha256;
use metalssh::msg::kex_ecdh_init::KexEcdhInitBuilder;
use metalssh::msg::kex_ecdh_reply::KexEcdhReply;
use metalssh::msg::kexinit::Kexinit;
use metalssh::msg::kexinit::KexinitBuilder;
use metalssh::packet::Packet;
use metalssh::proto::SshReadable;
use metalssh::proto::SshWritable;
use scroll::Pread;
/// Client state for the SSH connection
struct ClientState {
server_id: Vec<u8>,
client_id: Vec<u8>,
server_kexinit: Option<Vec<u8>>,
client_kexinit: Option<Vec<u8>>,
kex: Option<curve25519_sha256::Initiator>,
cipher: Option<ChaCha20Poly1305>,
encryption_active: bool, // Only true after both sides send NEWKEYS
sequence_number: u32,
}
fn main() -> anyhow::Result<()> {
let mut stream = TcpStream::connect("10.0.0.54:22")?;
stream.set_read_timeout(Some(Duration::from_secs(5)))?;
let mut recv_buffer = BytesMut::with_capacity(35000);
recv_buffer.resize(35000, 0);
// Read and handle SSH banner
let n = stream.read(&mut recv_buffer)?;
let banner = recv_buffer[..n].read_bytes_until(&mut 0, b'\n').unwrap();
println!("read: {n}, banner len: {}", banner.len());
let server_id = banner.trim_ascii().to_vec();
println!("found server banner: '{}'", server_id.as_bstr());
// Send client banner
let client_id = b"SSH-2.0-MetalSSH_0.0.0".to_vec();
stream.write_all(&client_id)?;
stream.write_all(b"\r\n")?;
// Initialize client state
let mut state = ClientState {
server_id,
client_id,
server_kexinit: None,
client_kexinit: None,
kex: None,
cipher: None,
encryption_active: false,
sequence_number: 0,
};
// Main receive loop
let mut packet_count = 0;
loop {
// Clear and reuse the buffer
recv_buffer.clear();
recv_buffer.resize(35000, 0);
println!("\nWaiting for packet {}...", packet_count + 1);
// Don't sleep before the first read attempt
if packet_count > 0 {
std::thread::sleep(Duration::from_millis(100));
}
let n = match stream.read(&mut recv_buffer) {
Ok(0) => {
println!("Connection closed by server (EOF)");
break;
}
Ok(n) => {
println!("Raw bytes received: {:02x?}", &recv_buffer[..n.min(50)]);
n
}
Err(e)
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::TimedOut =>
{
println!("Read timeout - no more data from server");
break;
}
Err(e) => {
println!("Read error: {} (kind: {:?})", e, e.kind());
return Err(e.into());
}
};
packet_count += 1;
println!("Packet {}: read {n} bytes", packet_count);
// Process all packets in the buffer (server may pipeline multiple packets)
let mut offset = 0;
while offset < n {
// Check if we have enough bytes for a packet length field
if n - offset < 4 {
println!(
"Not enough bytes for packet length, remaining: {}",
n - offset
);
break;
}
// Read packet length (may be encrypted for ChaCha20-Poly1305)
let packet_len = if state.encryption_active {
// For ChaCha20-Poly1305, the packet length is encrypted
// We need to decrypt it first
let cipher = state.cipher.as_ref().expect("Cipher not initialized");
let encrypted_packet = Packet::new(&recv_buffer[offset..n], 16);
cipher.decrypt_packet_length(&encrypted_packet, state.sequence_number)?
} else {
// No encryption yet, read length directly
u32::from_be_bytes([
recv_buffer[offset],
recv_buffer[offset + 1],
recv_buffer[offset + 2],
recv_buffer[offset + 3],
])
};
let mac_len = if state.encryption_active { 16 } else { 0 };
let total_packet_len = 4 + packet_len as usize + mac_len;
// Check if we have the complete packet
if offset + total_packet_len > n {
println!(
"Incomplete packet, need {} more bytes",
offset + total_packet_len - n
);
break;
}
println!(
"Processing packet at offset {}, length {} (encrypted: {})",
offset, total_packet_len, state.encryption_active
);
// Dispatch this packet
dispatch_message(
&mut stream,
&mut state,
&recv_buffer[offset..offset + total_packet_len],
)?;
offset += total_packet_len;
}
}
Ok(())
}
fn dispatch_message(
stream: &mut TcpStream,
state: &mut ClientState,
data: &[u8],
) -> anyhow::Result<()> {
// Determine MAC length based on encryption state
let mac_len = if state.encryption_active { 16 } else { 0 };
// If encrypted, we need to decrypt first
if state.encryption_active {
let cipher = state.cipher.as_ref().expect("Cipher not initialized");
// Make a mutable copy for decryption
let mut decrypted_data = data.to_vec();
let mut decrypted_packet = Packet::new(&mut decrypted_data, mac_len);
cipher.decrypt_packet(&mut decrypted_packet, state.sequence_number)?;
state.sequence_number += 1;
// Now parse the decrypted packet
let packet = Packet::new(&decrypted_data, mac_len);
println!("\n✓ Packet decrypted successfully!");
dbg!(&packet);
let payload = packet.payload()?;
let msg_type: u8 = payload.pread(0).unwrap();
match msg_type {
_ => {
println!("Received encrypted message type: {}", msg_type);
println!("Decrypted payload: {:02x?}", payload);
}
}
return Ok(());
}
// Unencrypted packet processing
let packet = Packet::new(data, mac_len);
dbg!(&packet);
let payload = packet.payload().unwrap();
// Read message type
let msg_type: u8 = payload.pread(0).unwrap();
match msg_type {
SSH_MSG_KEXINIT => {
// Store the raw server KEXINIT packet for later hash calculation
state.server_kexinit = Some(payload.to_vec());
// Parse the server's KEXINIT using our new Kexinit struct
let server_kexinit = Kexinit::new(payload)?;
println!("Server KEXINIT parsed:");
dbg!(&server_kexinit);
// Parse the server's kex algorithms to modify them
// Note: kex_algorithms() returns the raw name-list contents (without length
// prefix) so we parse it directly by splitting on commas
let server_kex_algs = server_kexinit.kex_algorithms();
let mut kex_algs_vec = server_kex_algs.split_str(",").collect::<Vec<_>>();
println!("\nServer's kex algorithms:");
for alg in &kex_algs_vec {
println!(" {}", alg.as_bstr());
}
// Reorder the kex algorithms to prioritize curve25519-sha256
// Find and move curve25519-sha256 variants to the front
let prioritized = [
b"curve25519-sha256".as_slice(),
b"curve25519-sha256@libssh.org".as_slice(),
];
for priority_alg in prioritized.iter().rev() {
if let Some(pos) = kex_algs_vec.iter().position(|&alg| alg == *priority_alg) {
let alg = kex_algs_vec.remove(pos);
kex_algs_vec.insert(0, alg);
}
}
// Replace server-specific extensions with client equivalents
// ext-info-s -> ext-info-c
// kex-strict-s-v00@openssh.com -> kex-strict-c-v00@openssh.com
for alg in kex_algs_vec.iter_mut() {
if *alg == b"kex-strict-s-v00@openssh.com" {
*alg = b"kex-strict-c-v00@openssh.com";
} else if *alg == b"ext-info-s" {
*alg = b"ext-info-c";
}
}
// Build the modified kex algorithms name-list
let mut kex_algs_string = Vec::new();
for (i, alg) in kex_algs_vec.iter().enumerate() {
if i > 0 {
kex_algs_string.push(b',');
}
kex_algs_string.extend_from_slice(alg);
}
println!("\nClient's reordered kex algorithms:");
for alg in &kex_algs_vec {
println!(" {}", alg.as_bstr());
}
// Generate a random cookie for the client
// In a real implementation, this should be cryptographically random
// For now, we'll use a simple pattern
let client_cookie = b"\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42";
// Build our client KEXINIT response using the modified values
let client_kexinit = KexinitBuilder::new()
.cookie(client_cookie)
.kex_algorithms(&kex_algs_string)
.server_host_key_algorithms(server_kexinit.server_host_key_algorithms())
.encryption_algorithms_client_to_server(
server_kexinit.encryption_algorithms_client_to_server(),
)
.encryption_algorithms_server_to_client(
server_kexinit.encryption_algorithms_server_to_client(),
)
.mac_algorithms_client_to_server(server_kexinit.mac_algorithms_client_to_server())
.mac_algorithms_server_to_client(server_kexinit.mac_algorithms_server_to_client())
.compression_algorithms_client_to_server(
server_kexinit.compression_algorithms_client_to_server(),
)
.compression_algorithms_server_to_client(
server_kexinit.compression_algorithms_server_to_client(),
)
.languages_client_to_server(server_kexinit.languages_client_to_server())
.languages_server_to_client(server_kexinit.languages_server_to_client())
.first_kex_packet_follows(server_kexinit.first_kex_packet_follows())
.build()?;
println!(
"\nClient KEXINIT built, size: {} bytes",
client_kexinit.len()
);
// Store the client KEXINIT for later hash calculation
state.client_kexinit = Some(client_kexinit.clone());
// Wrap the KEXINIT payload in an SSH packet using Packet::from_payload
let kexinit_packet = Packet::from_payload(&client_kexinit, 8, 0);
println!(
"KEXINIT packet: length={}, padding={}",
kexinit_packet.packet_length()?,
kexinit_packet.padding_length()?
);
// Send the packet
let packet_data = kexinit_packet.as_ref();
println!(
"\nSending client KEXINIT packet, total size: {} bytes",
packet_data.len()
);
println!(
"First 100 bytes: {:02x?}",
&packet_data[..packet_data.len().min(100)]
);
stream.write_all(packet_data)?;
stream.flush()?;
println!("Client KEXINIT sent successfully!");
// Try to peek if there's data available
println!("Checking for immediate server response...");
std::thread::sleep(Duration::from_millis(200));
// Set a very short timeout to check for immediate response
stream.set_read_timeout(Some(Duration::from_millis(500)))?;
// Now send KEX_ECDH_INIT with our curve25519 public key
println!("\n=== Generating curve25519 keypair ===");
let kex = curve25519_sha256::Initiator::new()?;
let public_key_bytes = kex.public_key();
println!("Generated public key: {} bytes", public_key_bytes.len());
println!("Public key bytes: {:02x?}", public_key_bytes);
// Build KEX_ECDH_INIT message using the builder API
println!("\n=== Building KEX_ECDH_INIT packet ===");
let kex_init_payload = KexEcdhInitBuilder::new().q_c(public_key_bytes).build()?;
println!("Payload size: {} bytes", kex_init_payload.len());
// Store the kex for later use in KEX_ECDH_REPLY
state.kex = Some(kex);
// Wrap in SSH packet using Packet::from_payload
let kex_packet = Packet::from_payload(&kex_init_payload, 8, 0);
println!(
"KEX_ECDH_INIT packet: length={}, padding={}",
kex_packet.packet_length()?,
kex_packet.padding_length()?
);
// Send KEX_ECDH_INIT packet
let kex_packet_data = kex_packet.as_ref();
println!(
"\nSending KEX_ECDH_INIT packet, total size: {} bytes",
kex_packet_data.len()
);
println!(
"First 50 bytes: {:02x?}",
&kex_packet_data[..kex_packet_data.len().min(50)]
);
stream.write_all(kex_packet_data)?;
stream.flush()?;
println!("KEX_ECDH_INIT sent successfully!");
// Reset timeout for reading server response
stream.set_read_timeout(Some(Duration::from_secs(5)))?;
}
SSH_MSG_KEX_ECDH_REPLY => {
// Parse the server's KEX_ECDH_REPLY using our KexEcdhReply struct
let kex_ecdh_reply = KexEcdhReply::new(payload)?;
println!("\n=== Received KEX_ECDH_REPLY ===");
println!(
"Server's public host key (K_S): {} bytes",
kex_ecdh_reply.k_s().len()
);
println!(
"Server's ephemeral public key (Q_S): {} bytes",
kex_ecdh_reply.q_s().len()
);
println!("Signature: {} bytes", kex_ecdh_reply.signature().len());
dbg!(&kex_ecdh_reply);
// Calculate shared secret K using ECDH
println!("\n=== Calculating shared secret K ===");
let kex = state.kex.take().expect("Kex not found");
// Save client public key before consuming kex
let client_public_key = kex.public_key().to_vec();
let shared_secret = kex.agree(kex_ecdh_reply.q_s())?;
println!(
"Shared secret K: {} bytes (mpint encoded)",
shared_secret.len()
);
// Compute exchange hash H according to RFC 4253 Section 8
println!("\n=== Computing exchange hash H ===");
let client_kexinit = state
.client_kexinit
.as_ref()
.expect("Client KEXINIT not found");
let server_kexinit = state
.server_kexinit
.as_ref()
.expect("Server KEXINIT not found");
// TODO: Having to encode these fields back with their length prefixes exposes
// the want to have a method on the message that returns the wire encoding
// before parsing, ie its length prefix + bytes.
// Encode each field as SSH byte-strings into temporary buffers
let mut v_c = vec![0u8; 4 + state.client_id.len()];
v_c.write_byte_string(&state.client_id, &mut 0)?;
let mut v_s = vec![0u8; 4 + state.server_id.len()];
v_s.write_byte_string(&state.server_id, &mut 0)?;
let mut i_c = vec![0u8; 4 + client_kexinit.len()];
i_c.write_byte_string(client_kexinit, &mut 0)?;
let mut i_s = vec![0u8; 4 + server_kexinit.len()];
i_s.write_byte_string(server_kexinit, &mut 0)?;
let mut k_s = vec![0u8; 4 + kex_ecdh_reply.k_s().len()];
k_s.write_byte_string(kex_ecdh_reply.k_s(), &mut 0)?;
let mut q_c = vec![0u8; 4 + client_public_key.len()];
q_c.write_byte_string(&client_public_key, &mut 0)?;
let mut q_s = vec![0u8; 4 + kex_ecdh_reply.q_s().len()];
q_s.write_byte_string(kex_ecdh_reply.q_s(), &mut 0)?;
// Hash with SHA-256, passing each encoded field separately
let h = Sha256::exchange_hash(&v_c, &v_s, &i_c, &i_s, &k_s, &q_c, &q_s, &shared_secret);
let h_bytes = &h;
println!("Exchange hash H: {} bytes", h_bytes.len());
println!("H = {:02x?}", h_bytes);
// Derive encryption keys using SSH KDF (RFC 4253 Section 7.2)
println!("\n=== Deriving encryption keys ===");
// Initial IV client to server: HASH(K || H || "A" || session_id)
// Initial IV server to client: HASH(K || H || "B" || session_id)
// Encryption key client to server: HASH(K || H || "C" || session_id)
// Encryption key server to client: HASH(K || H || "D" || session_id)
// Integrity key client to server: HASH(K || H || "E" || session_id)
// Integrity key server to client: HASH(K || H || "F" || session_id)
// For ChaCha20-Poly1305, we need 64 bytes total (32 for main key + 32 for
// poly1305 key) We'll derive key material for server to client
// direction
let session_id = h_bytes; // First H becomes the session_id
// Use ssh_kdf to derive the encryption key with our new Sha256 type
let mut key_s2c = [0u8; 64];
kdf::ssh_kdf::<Sha256, 32>(
&shared_secret,
h_bytes,
kdf::key_type::ENCRYPTION_KEY_SERVER_TO_CLIENT,
session_id,
&mut key_s2c,
);
println!("Encryption key (server to client): {} bytes", key_s2c.len());
// Create ChaCha20-Poly1305 cipher with the derived key material
let mut key_material = [0u8; 64];
key_material.copy_from_slice(&key_s2c);
state.cipher = Some(ChaCha20Poly1305::new(key_material));
println!("ChaCha20-Poly1305 cipher initialized!");
// Send SSH_MSG_NEWKEYS
println!("\n=== Sending SSH_MSG_NEWKEYS ===");
let newkeys_payload = vec![SSH_MSG_NEWKEYS];
let newkeys_packet = Packet::from_payload(&newkeys_payload, 8, 0);
stream.write_all(newkeys_packet.as_ref())?;
stream.flush()?;
println!("SSH_MSG_NEWKEYS sent!");
// The server should send NEWKEYS immediately, so don't wait
stream.set_read_timeout(Some(Duration::from_millis(100)))?;
}
SSH_MSG_NEWKEYS => {
println!("\n=== Received SSH_MSG_NEWKEYS ===");
println!("Server has switched to encrypted mode!");
println!("Activating ChaCha20-Poly1305 encryption for all subsequent packets");
// Now activate encryption - all subsequent packets will be encrypted
state.encryption_active = true;
}
_ => {
println!("Received unhandled message type: {}", msg_type);
println!("Packet details:");
dbg!(&packet);
}
}
Ok(())
}