use alloc::string::{String, ToString};
use alloc::vec::Vec;
use crate::error::{Error, Result};
use crate::format::{NameList, Reader, Writer};
pub const SSH_MSG_KEXINIT: u8 = 20;
pub const SSH_MSG_NEWKEYS: u8 = 21;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KexInit {
pub cookie: [u8; 16],
pub kex: Vec<String>,
pub server_host_key: Vec<String>,
pub ciphers_c2s: Vec<String>,
pub ciphers_s2c: Vec<String>,
pub macs_c2s: Vec<String>,
pub macs_s2c: Vec<String>,
pub comp_c2s: Vec<String>,
pub comp_s2c: Vec<String>,
pub lang_c2s: Vec<String>,
pub lang_s2c: Vec<String>,
pub first_kex_packet_follows: bool,
}
impl KexInit {
pub fn from_algorithms(algs: &super::KexAlgorithms<'_>, cookie: [u8; 16]) -> Self {
Self {
cookie,
kex: collect(algs.kex),
server_host_key: collect(algs.server_host_key),
ciphers_c2s: collect(algs.ciphers_c2s),
ciphers_s2c: collect(algs.ciphers_s2c),
macs_c2s: collect(algs.macs_c2s),
macs_s2c: collect(algs.macs_s2c),
comp_c2s: collect(algs.comp_c2s),
comp_s2c: collect(algs.comp_s2c),
lang_c2s: collect(algs.lang_c2s),
lang_s2c: collect(algs.lang_s2c),
first_kex_packet_follows: false,
}
}
pub fn encode(&self) -> Vec<u8> {
let mut w = Writer::with_capacity(256);
w.write_u8(SSH_MSG_KEXINIT);
w.write_raw(&self.cookie);
write_name_list(&mut w, &self.kex);
write_name_list(&mut w, &self.server_host_key);
write_name_list(&mut w, &self.ciphers_c2s);
write_name_list(&mut w, &self.ciphers_s2c);
write_name_list(&mut w, &self.macs_c2s);
write_name_list(&mut w, &self.macs_s2c);
write_name_list(&mut w, &self.comp_c2s);
write_name_list(&mut w, &self.comp_s2c);
write_name_list(&mut w, &self.lang_c2s);
write_name_list(&mut w, &self.lang_s2c);
w.write_bool(self.first_kex_packet_follows);
w.write_u32(0);
w.into_vec()
}
pub fn decode(payload: &[u8]) -> Result<Self> {
let mut r = Reader::new(payload);
let msg = r.read_u8()?;
if msg != SSH_MSG_KEXINIT {
return Err(Error::Protocol("expected SSH_MSG_KEXINIT"));
}
let cookie_slice = r.take(16)?;
let mut cookie = [0u8; 16];
cookie.copy_from_slice(cookie_slice);
let kex = read_name_list(&mut r)?;
let server_host_key = read_name_list(&mut r)?;
let ciphers_c2s = read_name_list(&mut r)?;
let ciphers_s2c = read_name_list(&mut r)?;
let macs_c2s = read_name_list(&mut r)?;
let macs_s2c = read_name_list(&mut r)?;
let comp_c2s = read_name_list(&mut r)?;
let comp_s2c = read_name_list(&mut r)?;
let lang_c2s = read_name_list(&mut r)?;
let lang_s2c = read_name_list(&mut r)?;
let first_kex_packet_follows = r.read_bool()?;
let _reserved = r.read_u32()?;
if !r.is_empty() {
return Err(Error::Format("KEXINIT trailing bytes"));
}
Ok(Self {
cookie,
kex,
server_host_key,
ciphers_c2s,
ciphers_s2c,
macs_c2s,
macs_s2c,
comp_c2s,
comp_s2c,
lang_c2s,
lang_s2c,
first_kex_packet_follows,
})
}
}
fn collect(items: &[&str]) -> Vec<String> {
items.iter().map(|s| (*s).to_string()).collect()
}
fn write_name_list(w: &mut Writer, names: &[String]) {
let mut joined = Vec::new();
for (i, n) in names.iter().enumerate() {
if i > 0 {
joined.push(b',');
}
joined.extend_from_slice(n.as_bytes());
}
w.write_string(&joined);
}
fn read_name_list(r: &mut Reader<'_>) -> Result<Vec<String>> {
let nl = NameList::read(r)?;
let mut out = Vec::new();
for entry in nl.iter() {
let s =
core::str::from_utf8(entry).map_err(|_| Error::Format("non-UTF8 name in name-list"))?;
out.push(s.to_string());
}
Ok(out)
}
#[derive(Debug, Clone)]
pub struct NegotiatedOwned {
pub kex: String,
pub host_key: String,
pub cipher_c2s: String,
pub cipher_s2c: String,
pub mac_c2s: String,
pub mac_s2c: String,
pub comp_c2s: String,
pub comp_s2c: String,
pub first_kex_packet_follows_wrong_guess: bool,
}
fn pick<'a>(category: &'static str, client: &'a [String], server: &[String]) -> Result<&'a str> {
for name in client {
if server.iter().any(|s| s == name) {
return Ok(name.as_str());
}
}
Err(Error::NoCommonAlgorithm(category))
}
pub fn negotiate(client: &KexInit, server: &KexInit) -> Result<NegotiatedOwned> {
let kex = pick("kex algorithm", &client.kex, &server.kex)?;
let host_key = pick(
"host-key algorithm",
&client.server_host_key,
&server.server_host_key,
)?;
let cipher_c2s = pick("cipher c2s", &client.ciphers_c2s, &server.ciphers_c2s)?;
let cipher_s2c = pick("cipher s2c", &client.ciphers_s2c, &server.ciphers_s2c)?;
let c_aead = crate::cipher::by_name(cipher_c2s)
.map(|s| s.aead)
.unwrap_or(false);
let s_aead = crate::cipher::by_name(cipher_s2c)
.map(|s| s.aead)
.unwrap_or(false);
let mac_c2s = if c_aead {
""
} else {
pick("mac c2s", &client.macs_c2s, &server.macs_c2s)?
};
let mac_s2c = if s_aead {
""
} else {
pick("mac s2c", &client.macs_s2c, &server.macs_s2c)?
};
let comp_c2s = pick("compression c2s", &client.comp_c2s, &server.comp_c2s)?;
let comp_s2c = pick("compression s2c", &client.comp_s2c, &server.comp_s2c)?;
let guesser_present = client.first_kex_packet_follows || server.first_kex_packet_follows;
let wrong = if guesser_present {
let client_first_kex = client.kex.first().map(String::as_str);
let server_first_kex = server.kex.first().map(String::as_str);
let client_first_hk = client.server_host_key.first().map(String::as_str);
let server_first_hk = server.server_host_key.first().map(String::as_str);
!(client_first_kex == server_first_kex && client_first_hk == server_first_hk)
} else {
false
};
Ok(NegotiatedOwned {
kex: kex.to_string(),
host_key: host_key.to_string(),
cipher_c2s: cipher_c2s.to_string(),
cipher_s2c: cipher_s2c.to_string(),
mac_c2s: mac_c2s.to_string(),
mac_s2c: mac_s2c.to_string(),
comp_c2s: comp_c2s.to_string(),
comp_s2c: comp_s2c.to_string(),
first_kex_packet_follows_wrong_guess: wrong,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::kex::{defaults, KexAlgorithms};
fn defaults_algorithms() -> KexAlgorithms<'static> {
KexAlgorithms {
kex: defaults::KEX,
server_host_key: defaults::HOST_KEY,
ciphers_c2s: defaults::CIPHERS,
ciphers_s2c: defaults::CIPHERS,
macs_c2s: defaults::MACS,
macs_s2c: defaults::MACS,
comp_c2s: defaults::COMP,
comp_s2c: defaults::COMP,
lang_c2s: &[],
lang_s2c: &[],
}
}
#[test]
fn round_trip_encode_decode() {
let algs = defaults_algorithms();
let mut cookie = [0u8; 16];
for (i, b) in cookie.iter_mut().enumerate() {
*b = i as u8;
}
let ki = KexInit::from_algorithms(&algs, cookie);
let bytes = ki.encode();
assert_eq!(bytes[0], SSH_MSG_KEXINIT);
assert_eq!(&bytes[1..17], &cookie[..]);
let decoded = KexInit::decode(&bytes).unwrap();
assert_eq!(decoded, ki);
}
#[test]
fn decode_rejects_wrong_message_type() {
let mut buf = Vec::from([21u8]);
buf.extend_from_slice(&[0u8; 16]);
for _ in 0..10 {
buf.extend_from_slice(&[0, 0, 0, 0]);
}
buf.push(0);
buf.extend_from_slice(&[0, 0, 0, 0]);
assert!(matches!(KexInit::decode(&buf), Err(Error::Protocol(_))));
}
#[test]
fn decode_rejects_trailing_bytes() {
let algs = defaults_algorithms();
let mut ki = KexInit::from_algorithms(&algs, [0u8; 16]);
let mut bytes = ki.encode();
bytes.push(0xff);
assert!(matches!(KexInit::decode(&bytes), Err(Error::Format(_))));
ki.cookie[0] = 1;
}
#[test]
fn negotiate_picks_clients_first_common() {
let mut a = KexInit::from_algorithms(&defaults_algorithms(), [0u8; 16]);
a.kex = ["curve25519-sha256", "ecdh-sha2-nistp256"]
.iter()
.map(|s| s.to_string())
.collect();
let mut b = a.clone();
b.kex = ["ecdh-sha2-nistp256", "curve25519-sha256"]
.iter()
.map(|s| s.to_string())
.collect();
let neg = negotiate(&a, &b).unwrap();
assert_eq!(neg.kex, "curve25519-sha256");
}
#[test]
fn negotiate_rejects_no_overlap() {
let mut a = KexInit::from_algorithms(&defaults_algorithms(), [0u8; 16]);
let mut b = a.clone();
a.kex = vec!["curve25519-sha256".into()];
b.kex = vec!["diffie-hellman-group14-sha256".into()];
assert!(matches!(
negotiate(&a, &b),
Err(Error::NoCommonAlgorithm(_))
));
}
#[test]
fn negotiate_skips_mac_for_aead() {
let mut a = KexInit::from_algorithms(&defaults_algorithms(), [0u8; 16]);
let mut b = a.clone();
a.ciphers_c2s = vec!["chacha20-poly1305@openssh.com".into()];
a.ciphers_s2c = vec!["chacha20-poly1305@openssh.com".into()];
b.ciphers_c2s = vec!["chacha20-poly1305@openssh.com".into()];
b.ciphers_s2c = vec!["chacha20-poly1305@openssh.com".into()];
a.macs_c2s.clear();
a.macs_s2c.clear();
b.macs_c2s.clear();
b.macs_s2c.clear();
let neg = negotiate(&a, &b).unwrap();
assert_eq!(neg.cipher_c2s, "chacha20-poly1305@openssh.com");
assert_eq!(neg.mac_c2s, "");
assert_eq!(neg.mac_s2c, "");
}
}