use core::fmt;
use std::cell::RefCell;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use bytes::Bytes;
use log::{debug, error, warn};
use ssh_encoding::{Decode, Encode};
use ssh_key::{Mpint, PublicKey, Signature};
use super::IncomingSshPacket;
use crate::client::{Config, NewKeys};
use crate::kex::dh::groups::DhGroup;
use crate::kex::{KEXES, KexAlgorithm, KexAlgorithmImplementor, KexCause, KexProgress};
use crate::keys::key::parse_public_key;
use crate::negotiation::{Names, Select};
use crate::session::Exchange;
use crate::sshbuffer::PacketWriter;
use crate::{CryptoVec, Error, SshId, msg, negotiation, strict_kex_violation};
thread_local! {
static HASH_BUFFER: RefCell<CryptoVec> = RefCell::new(CryptoVec::new());
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
enum ClientKexState {
Created,
WaitingForGexReply {
names: Names,
kex: KexAlgorithm,
},
WaitingForDhReply {
names: Names,
kex: KexAlgorithm,
},
WaitingForNewKeys {
server_host_key: PublicKey,
newkeys: NewKeys,
},
}
pub(crate) struct ClientKex {
exchange: Exchange,
cause: KexCause,
state: ClientKexState,
config: Arc<Config>,
}
impl Debug for ClientKex {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut s = f.debug_struct("ClientKex");
s.field("cause", &self.cause);
match self.state {
ClientKexState::Created => {
s.field("state", &"created");
}
ClientKexState::WaitingForGexReply { .. } => {
s.field("state", &"waiting for GEX response");
}
ClientKexState::WaitingForDhReply { .. } => {
s.field("state", &"waiting for DH response");
}
ClientKexState::WaitingForNewKeys { .. } => {
s.field("state", &"waiting for NEWKEYS");
}
}
s.finish()
}
}
impl ClientKex {
pub fn new(
config: Arc<Config>,
client_sshid: &SshId,
server_sshid: &[u8],
cause: KexCause,
) -> Self {
let exchange = Exchange::new(client_sshid.as_kex_hash_bytes(), server_sshid);
Self {
config,
exchange,
cause,
state: ClientKexState::Created,
}
}
pub fn kexinit(&mut self, output: &mut PacketWriter) -> Result<(), Error> {
self.exchange.client_kex_init =
negotiation::write_kex(&self.config.preferred, output, None)?;
Ok(())
}
pub fn step(
mut self,
input: Option<&mut IncomingSshPacket>,
output: &mut PacketWriter,
) -> Result<KexProgress<Self>, Error> {
match self.state {
ClientKexState::Created => {
let Some(input) = input else {
return Err(Error::KexInit);
};
if input.buffer.first() != Some(&msg::KEXINIT) {
error!(
"Unexpected kex message at this stage: {:?}",
input.buffer.first()
);
return Err(Error::KexInit);
}
let names = {
self.exchange
.server_kex_init
.extend_from_slice(&input.buffer);
negotiation::Client::read_kex(
&input.buffer,
&self.config.preferred,
None,
&self.cause,
)?
};
debug!("negotiated algorithms: {names:?}");
if names.strict_kex() && !self.cause.is_rekey() && input.seqn.0 != 1 {
return Err(strict_kex_violation(
msg::KEXINIT,
input.seqn.0 as usize - 1,
));
}
let mut kex = KEXES.get(&names.kex).ok_or(Error::UnknownAlgo)?.make();
if kex.skip_exchange() {
let newkeys = compute_keys(
Vec::new(),
kex,
names.clone(),
self.exchange.clone(),
self.cause.session_id(),
)?;
output.packet(|w| {
msg::NEWKEYS.encode(w)?;
Ok(())
})?;
return Ok(KexProgress::Done {
newkeys,
server_host_key: None,
});
}
if kex.is_dh_gex() {
output.packet(|w| {
kex.client_dh_gex_init(&self.config.gex, w)?;
Ok(())
})?;
self.state = ClientKexState::WaitingForGexReply { names, kex };
} else {
output.packet(|w| {
kex.client_dh(&mut self.exchange.client_ephemeral, w)?;
Ok(())
})?;
self.state = ClientKexState::WaitingForDhReply { names, kex };
}
Ok(KexProgress::NeedsReply {
kex: self,
reset_seqn: false,
})
}
ClientKexState::WaitingForGexReply { names, mut kex } => {
let Some(input) = input else {
return Err(Error::KexInit);
};
if input.buffer.first() != Some(&msg::KEX_DH_GEX_GROUP) {
error!(
"Unexpected kex message at this stage: {:?}",
input.buffer.first()
);
return Err(Error::KexInit);
}
#[allow(clippy::indexing_slicing)] let mut r = &input.buffer[1..];
let prime = Mpint::decode(&mut r)?;
let generator = Mpint::decode(&mut r)?;
debug!("received gex group: prime={prime}, generator={generator}");
let group = DhGroup {
prime: prime.as_bytes().to_vec().into(),
generator: generator.as_bytes().to_vec().into(),
};
if group.bit_size() < self.config.gex.min_group_size
|| group.bit_size() > self.config.gex.max_group_size
{
warn!(
"DH prime size ({} bits) not within requested range",
group.bit_size()
);
return Err(Error::KexInit);
}
let exchange = &mut self.exchange;
exchange.gex = Some((self.config.gex.clone(), group.clone()));
kex.dh_gex_set_group(group)?;
output.packet(|w| {
kex.client_dh(&mut exchange.client_ephemeral, w)?;
Ok(())
})?;
self.state = ClientKexState::WaitingForDhReply { names, kex };
Ok(KexProgress::NeedsReply {
kex: self,
reset_seqn: false,
})
}
ClientKexState::WaitingForDhReply { mut names, mut kex } => {
let Some(input) = input else {
return Err(Error::KexInit);
};
if names.ignore_guessed {
debug!("ignoring guessed kex");
names.ignore_guessed = false;
self.state = ClientKexState::WaitingForDhReply { names, kex };
return Ok(KexProgress::NeedsReply {
kex: self,
reset_seqn: false,
});
}
if input.buffer.first()
!= Some(match kex.is_dh_gex() {
true => &msg::KEX_DH_GEX_REPLY,
false => &msg::KEX_ECDH_REPLY,
})
{
error!(
"Unexpected kex message at this stage: {:?}",
input.buffer.first()
);
return Err(Error::KexInit);
}
#[allow(clippy::indexing_slicing)] let r = &mut &input.buffer[1..];
let server_host_key = Bytes::decode(r)?; let server_host_key = parse_public_key(&server_host_key)?;
debug!(
"received server host key: {:?}",
server_host_key.to_openssh()
);
let server_ephemeral = Bytes::decode(r)?;
self.exchange
.server_ephemeral
.extend_from_slice(&server_ephemeral);
kex.compute_shared_secret(&self.exchange.server_ephemeral)?;
let mut pubkey_vec = Vec::new();
server_host_key.to_bytes()?.encode(&mut pubkey_vec)?;
let exchange = &self.exchange;
let hash = HASH_BUFFER.with({
|buffer| {
let mut buffer = buffer.borrow_mut();
buffer.clear();
kex.compute_exchange_hash(&pubkey_vec, exchange, &mut buffer)
}
})?;
let signature = Bytes::decode(r)?;
let signature = Signature::decode(&mut &signature[..])?;
if let Err(e) =
signature::Verifier::verify(&server_host_key, hash.as_ref(), &signature)
{
debug!("wrong server sig: {e:?}");
return Err(Error::WrongServerSig);
}
let newkeys = compute_keys(
hash,
kex,
names.clone(),
self.exchange.clone(),
self.cause.session_id(),
)?;
output.packet(|w| {
msg::NEWKEYS.encode(w)?;
Ok(())
})?;
let reset_seqn = newkeys.names.strict_kex() || self.cause.is_strict_rekey();
self.state = ClientKexState::WaitingForNewKeys {
server_host_key,
newkeys,
};
Ok(KexProgress::NeedsReply {
kex: self,
reset_seqn,
})
}
ClientKexState::WaitingForNewKeys {
server_host_key,
newkeys,
} => {
let Some(input) = input else {
return Err(Error::KexInit);
};
if input.buffer.first() != Some(&msg::NEWKEYS) {
error!(
"Unexpected kex message at this stage: {:?}",
input.buffer.first()
);
return Err(Error::Kex);
}
Ok(KexProgress::Done {
newkeys,
server_host_key: Some(server_host_key),
})
}
}
}
}
fn compute_keys(
hash: Vec<u8>,
kex: KexAlgorithm,
names: Names,
exchange: Exchange,
session_id: Option<&CryptoVec>,
) -> Result<NewKeys, Error> {
let session_id_ref: &[u8] = match session_id {
Some(sid) => sid,
None => &hash,
};
let c = kex.compute_keys(
session_id_ref,
&hash,
names.cipher,
names.server_mac,
names.client_mac,
false,
)?;
let session_id_cv = match session_id {
Some(s) => s.clone(),
None => {
let mut cv = CryptoVec::new();
cv.extend(&hash);
cv
}
};
Ok(NewKeys {
exchange,
names,
kex,
key: 0,
cipher: c,
session_id: session_id_cv,
})
}