use std::io::Read as _;
use std::io::Write as _;
use std::net::TcpStream;
use std::time::Duration;
use bstr::ByteSlice;
use bytes::BytesMut;
use metalssh::constants::msg::SSH_MSG_KEXINIT;
use metalssh::msg::Kexinit;
use metalssh::wire::Packet;
use metalssh::wire::SshDecode;
struct ClientState {
server_id: Vec<u8>,
client_id: Vec<u8>,
server_kexinit: Option<Vec<u8>>,
client_kexinit: Option<Vec<u8>>,
}
impl ClientState {
fn new() -> Self {
Self {
server_id: Vec::new(),
client_id: b"SSH-2.0-MetalSSH_0.0.0".to_vec(),
server_kexinit: None,
client_kexinit: None,
}
}
}
#[test]
fn real_client() -> anyhow::Result<()> {
let mut stream = TcpStream::connect("10.0.0.54:22")?;
stream.set_read_timeout(Some(Duration::from_secs(5)))?;
let mut state = ClientState::new();
let mut recv_buffer = BytesMut::with_capacity(35000);
exchange_banner(&mut stream, &mut state, &mut recv_buffer)?;
exchange_kexinit(&mut stream, &mut state, &mut recv_buffer)?;
Ok(())
}
fn exchange_banner(
stream: &mut TcpStream,
state: &mut ClientState,
buffer: &mut BytesMut,
) -> anyhow::Result<()> {
buffer.resize(35000, 0);
let n = stream.read(buffer)?;
let banner = buffer[..n].read_bytes_until(&mut 0, b'\n').unwrap();
state.server_id = banner.trim_ascii().to_vec();
stream.write_all(&state.client_id)?;
stream.write_all(b"\r\n")?;
stream.flush()?;
Ok(())
}
fn exchange_kexinit(
stream: &mut TcpStream,
state: &mut ClientState,
buffer: &mut BytesMut,
) -> anyhow::Result<()> {
buffer.clear();
buffer.resize(35000, 0);
let n = stream.read(buffer)?;
let packet = Packet::new(&buffer[..n], 0);
let payload = packet.payload()?;
if payload[0] != SSH_MSG_KEXINIT {
anyhow::bail!("Expected KEXINIT, got message type {}", payload[0]);
}
state.server_kexinit = Some(payload.to_vec());
let server_kexinit = Kexinit::from_bytes(payload)?;
let server_kex_algs = server_kexinit.kex_algorithms;
let mut kex_algs_vec: Vec<&[u8]> = server_kex_algs.split_str(",").collect();
let prioritized = [
b"curve25519-sha256".as_slice(),
b"curve25519-sha256@libssh.org".as_slice(),
];
for priority_alg in prioritized.iter().rev() {
if let Some(pos) = kex_algs_vec.iter().position(|&alg| alg == *priority_alg) {
let alg = kex_algs_vec.remove(pos);
kex_algs_vec.insert(0, alg);
}
}
for alg in kex_algs_vec.iter_mut() {
if *alg == b"kex-strict-s-v00@openssh.com" {
*alg = b"kex-strict-c-v00@openssh.com";
} else if *alg == b"ext-info-s" {
*alg = b"ext-info-c";
}
}
let mut kex_algs_string = Vec::new();
for (i, alg) in kex_algs_vec.iter().enumerate() {
if i > 0 {
kex_algs_string.push(b',');
}
kex_algs_string.extend_from_slice(alg);
}
let client_cookie = b"\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42\x42";
let mut client_kexinit_buffer = vec![0u8; 2048];
let (client_kexinit, size) = Kexinit::write(
&mut client_kexinit_buffer,
client_cookie,
&kex_algs_string,
server_kexinit.server_host_key_algorithms,
server_kexinit.encryption_algorithms_client_to_server,
server_kexinit.encryption_algorithms_server_to_client,
server_kexinit.mac_algorithms_client_to_server,
server_kexinit.mac_algorithms_server_to_client,
server_kexinit.compression_algorithms_client_to_server,
server_kexinit.compression_algorithms_server_to_client,
server_kexinit.languages_client_to_server,
server_kexinit.languages_server_to_client,
server_kexinit.first_kex_packet_follows,
0,
)?;
state.client_kexinit = Some(client_kexinit.raw()[..size].to_vec());
let kexinit_packet = Packet::from_payload(&client_kexinit.raw()[..size], 8, 0);
stream.write_all(kexinit_packet.as_ref())?;
stream.flush()?;
Ok(())
}