use alloc::vec::Vec;
use purecrypto::ec::boxed::{BoxedEcdhPrivateKey, BoxedEcdsaPublicKey};
use purecrypto::ec::curves::CurveId;
use purecrypto::hash::{Digest, Sha256, Sha384, Sha512};
use purecrypto::rng::{CryptoRng, RngCore};
use super::common::{
KexContext, KexInitOut, KexOutput, SSH_MSG_KEX_ECDH_INIT, SSH_MSG_KEX_ECDH_REPLY,
};
use super::hash::{mpint_bytes, ExchangeHash};
use super::Kex;
use crate::error::{Error, Result};
use crate::format::Reader;
use crate::hostkey::HostKeyVerify;
pub struct EcdhSha2Nistp256;
impl Kex for EcdhSha2Nistp256 {
const NAME: &'static str = "ecdh-sha2-nistp256";
const HASH_LEN: usize = 32;
}
pub struct EcdhSha2Nistp384;
impl Kex for EcdhSha2Nistp384 {
const NAME: &'static str = "ecdh-sha2-nistp384";
const HASH_LEN: usize = 48;
}
pub struct EcdhSha2Nistp521;
impl Kex for EcdhSha2Nistp521 {
const NAME: &'static str = "ecdh-sha2-nistp521";
const HASH_LEN: usize = 64;
}
pub struct ClientState {
curve: CurveId,
secret: BoxedEcdhPrivateKey,
q_c: Vec<u8>,
}
pub struct ServerReplyOut {
pub payload: Vec<u8>,
pub kex: KexOutput,
}
fn field_len(curve: CurveId) -> usize {
match curve {
CurveId::P256 => 32,
CurveId::P384 => 48,
CurveId::P521 => 66,
CurveId::Secp256k1 => 32,
}
}
fn sec1_point_len(curve: CurveId) -> usize {
1 + 2 * field_len(curve)
}
fn client_init<R: RngCore + CryptoRng>(curve: CurveId, rng: &mut R) -> (ClientState, KexInitOut) {
let secret = BoxedEcdhPrivateKey::generate(curve, rng);
let q_c = secret.public_key().to_sec1();
let mut payload = Vec::with_capacity(1 + 4 + q_c.len());
payload.push(SSH_MSG_KEX_ECDH_INIT);
payload.extend_from_slice(&(q_c.len() as u32).to_be_bytes());
payload.extend_from_slice(&q_c);
(ClientState { curve, secret, q_c }, KexInitOut { payload })
}
fn server_reply_inner<D, R, S>(
curve: CurveId,
rng: &mut R,
init_payload: &[u8],
host_key: &S,
ctx: &KexContext<'_>,
) -> Result<ServerReplyOut>
where
D: Digest,
R: RngCore + CryptoRng,
S: crate::hostkey::HostKey + ?Sized,
{
let mut r = Reader::new(init_payload);
let msg = r.read_u8()?;
if msg != SSH_MSG_KEX_ECDH_INIT {
return Err(Error::Protocol("expected SSH_MSG_KEX_ECDH_INIT"));
}
let q_c_bytes = r.read_string()?;
if q_c_bytes.len() != sec1_point_len(curve) {
return Err(Error::Format("ECDH Q_C wrong length"));
}
let peer = BoxedEcdsaPublicKey::from_sec1(curve, q_c_bytes)
.map_err(|_| Error::Format("invalid ECDH Q_C"))?;
let secret = BoxedEcdhPrivateKey::generate(curve, rng);
let q_s = secret.public_key().to_sec1();
let k_raw = secret
.diffie_hellman(&peer)
.map_err(|_| Error::Crypto("ECDH agreement failed"))?;
let k_s = host_key.public_blob();
let mut eh = ExchangeHash::<D>::new();
eh.write_string(ctx.v_c);
eh.write_string(ctx.v_s);
eh.write_string(ctx.i_c);
eh.write_string(ctx.i_s);
eh.write_string(&k_s);
eh.write_string(q_c_bytes);
eh.write_string(&q_s);
eh.write_mpint(&k_raw);
let h = eh.finalize();
let sig = host_key.sign(&h)?;
let mut payload = Vec::with_capacity(1 + 4 + k_s.len() + 4 + q_s.len() + 4 + sig.len());
payload.push(SSH_MSG_KEX_ECDH_REPLY);
payload.extend_from_slice(&(k_s.len() as u32).to_be_bytes());
payload.extend_from_slice(&k_s);
payload.extend_from_slice(&(q_s.len() as u32).to_be_bytes());
payload.extend_from_slice(&q_s);
payload.extend_from_slice(&(sig.len() as u32).to_be_bytes());
payload.extend_from_slice(&sig);
let k = mpint_bytes(&k_raw);
Ok(ServerReplyOut {
payload,
kex: KexOutput { k, h },
})
}
fn client_finish_inner<D: Digest>(
state: ClientState,
reply_payload: &[u8],
verifier: &dyn HostKeyVerify,
ctx: &KexContext<'_>,
) -> Result<KexOutput> {
let mut r = Reader::new(reply_payload);
let msg = r.read_u8()?;
if msg != SSH_MSG_KEX_ECDH_REPLY {
return Err(Error::Protocol("expected SSH_MSG_KEX_ECDH_REPLY"));
}
let k_s = r.read_string()?;
let q_s_bytes = r.read_string()?;
if q_s_bytes.len() != sec1_point_len(state.curve) {
return Err(Error::Format("ECDH Q_S wrong length"));
}
let sig = r.read_string()?;
let peer = BoxedEcdsaPublicKey::from_sec1(state.curve, q_s_bytes)
.map_err(|_| Error::Format("invalid ECDH Q_S"))?;
let k_raw = state
.secret
.diffie_hellman(&peer)
.map_err(|_| Error::Crypto("ECDH agreement failed"))?;
let mut eh = ExchangeHash::<D>::new();
eh.write_string(ctx.v_c);
eh.write_string(ctx.v_s);
eh.write_string(ctx.i_c);
eh.write_string(ctx.i_s);
eh.write_string(k_s);
eh.write_string(&state.q_c);
eh.write_string(q_s_bytes);
eh.write_mpint(&k_raw);
let h = eh.finalize();
verifier.verify(&h, sig)?;
let k = mpint_bytes(&k_raw);
Ok(KexOutput { k, h })
}
macro_rules! ecdh_impl {
($ty:ident, $curve:expr, $hash:ty) => {
impl $ty {
pub const NAME: &'static str = <Self as Kex>::NAME;
pub const HASH_LEN: usize = <Self as Kex>::HASH_LEN;
pub fn client_init<R: RngCore + CryptoRng>(rng: &mut R) -> (ClientState, KexInitOut) {
client_init($curve, rng)
}
pub fn server_reply<R, S>(
rng: &mut R,
init_payload: &[u8],
host_key: &S,
ctx: &KexContext<'_>,
) -> Result<ServerReplyOut>
where
R: RngCore + CryptoRng,
S: crate::hostkey::HostKey + ?Sized,
{
server_reply_inner::<$hash, _, _>($curve, rng, init_payload, host_key, ctx)
}
pub fn client_finish(
state: ClientState,
reply_payload: &[u8],
verifier: &dyn HostKeyVerify,
ctx: &KexContext<'_>,
) -> Result<KexOutput> {
if state.curve != $curve {
return Err(Error::Protocol("ECDH curve mismatch"));
}
client_finish_inner::<$hash>(state, reply_payload, verifier, ctx)
}
}
};
}
ecdh_impl!(EcdhSha2Nistp256, CurveId::P256, Sha256);
ecdh_impl!(EcdhSha2Nistp384, CurveId::P384, Sha384);
ecdh_impl!(EcdhSha2Nistp521, CurveId::P521, Sha512);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn algorithm_constants() {
assert_eq!(EcdhSha2Nistp256::NAME, "ecdh-sha2-nistp256");
assert_eq!(EcdhSha2Nistp256::HASH_LEN, 32);
assert_eq!(EcdhSha2Nistp384::NAME, "ecdh-sha2-nistp384");
assert_eq!(EcdhSha2Nistp384::HASH_LEN, 48);
assert_eq!(EcdhSha2Nistp521::NAME, "ecdh-sha2-nistp521");
assert_eq!(EcdhSha2Nistp521::HASH_LEN, 64);
}
#[test]
fn sec1_lengths_match_curves() {
assert_eq!(sec1_point_len(CurveId::P256), 65);
assert_eq!(sec1_point_len(CurveId::P384), 97);
assert_eq!(sec1_point_len(CurveId::P521), 133);
}
}