use anyhow::Result;
use bincode_next::{Decode, Encode};
use crate::error::Error as MoshpitError;
pub const KEX_X25519_SHA256: &str = "x25519-sha256";
pub const KEX_P384_SHA384: &str = "p384-sha384";
pub const KEX_P256_SHA256: &str = "p256-sha256";
pub const KEX_ML_KEM_512_SHA256: &str = "ml-kem-512-sha256";
pub const KEX_ML_KEM_768_SHA256: &str = "ml-kem-768-sha256";
pub const KEX_ML_KEM_1024_SHA256: &str = "ml-kem-1024-sha256";
pub const AEAD_AES256_GCM_SIV: &str = "aes256-gcm-siv";
pub const AEAD_AES256_GCM: &str = "aes256-gcm";
pub const AEAD_CHACHA20_POLY1305: &str = "chacha20-poly1305";
pub const AEAD_AES128_GCM_SIV: &str = "aes128-gcm-siv";
pub const MAC_HMAC_SHA512: &str = "hmac-sha512";
pub const MAC_HMAC_SHA256: &str = "hmac-sha256";
pub const KDF_HKDF_SHA256: &str = "hkdf-sha256";
pub const KDF_HKDF_SHA384: &str = "hkdf-sha384";
pub const KDF_HKDF_SHA512: &str = "hkdf-sha512";
#[derive(Clone, Debug, Decode, Encode, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct AlgorithmList {
pub kex: Vec<String>,
pub aead: Vec<String>,
pub mac: Vec<String>,
pub kdf: Vec<String>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct NegotiatedAlgorithms {
pub kex: String,
pub aead: String,
pub mac: String,
pub kdf: String,
}
impl Default for NegotiatedAlgorithms {
fn default() -> Self {
Self {
kex: KEX_X25519_SHA256.to_string(),
aead: AEAD_AES256_GCM_SIV.to_string(),
mac: MAC_HMAC_SHA512.to_string(),
kdf: KDF_HKDF_SHA256.to_string(),
}
}
}
#[must_use]
pub fn supported_algorithms() -> AlgorithmList {
AlgorithmList {
kex: vec![
KEX_X25519_SHA256.to_string(),
KEX_ML_KEM_768_SHA256.to_string(),
KEX_ML_KEM_512_SHA256.to_string(),
KEX_ML_KEM_1024_SHA256.to_string(),
KEX_P384_SHA384.to_string(),
KEX_P256_SHA256.to_string(),
],
aead: vec![
AEAD_AES256_GCM_SIV.to_string(),
AEAD_AES256_GCM.to_string(),
AEAD_CHACHA20_POLY1305.to_string(),
AEAD_AES128_GCM_SIV.to_string(),
],
mac: vec![MAC_HMAC_SHA512.to_string(), MAC_HMAC_SHA256.to_string()],
kdf: vec![
KDF_HKDF_SHA256.to_string(),
KDF_HKDF_SHA384.to_string(),
KDF_HKDF_SHA512.to_string(),
],
}
}
pub fn negotiate(
client_prefs: &AlgorithmList,
server_supports: &AlgorithmList,
) -> Result<NegotiatedAlgorithms> {
let pick = |client: &[String], server: &[String]| -> Option<String> {
client.iter().find(|a| server.contains(a)).cloned()
};
let kex =
pick(&client_prefs.kex, &server_supports.kex).ok_or(MoshpitError::NoCommonAlgorithm)?;
let aead =
pick(&client_prefs.aead, &server_supports.aead).ok_or(MoshpitError::NoCommonAlgorithm)?;
let mac =
pick(&client_prefs.mac, &server_supports.mac).ok_or(MoshpitError::NoCommonAlgorithm)?;
let kdf =
pick(&client_prefs.kdf, &server_supports.kdf).ok_or(MoshpitError::NoCommonAlgorithm)?;
Ok(NegotiatedAlgorithms {
kex,
aead,
mac,
kdf,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn current() -> AlgorithmList {
supported_algorithms()
}
#[test]
fn negotiate_current_stack_succeeds() {
let client = current();
let server = current();
let negotiated = negotiate(&client, &server).expect("should succeed with identical lists");
assert_eq!(negotiated.kex, KEX_X25519_SHA256);
assert_eq!(negotiated.aead, AEAD_AES256_GCM_SIV);
assert_eq!(negotiated.mac, MAC_HMAC_SHA512);
assert_eq!(negotiated.kdf, KDF_HKDF_SHA256);
}
#[test]
fn negotiate_picks_first_common_kex() {
let client = AlgorithmList {
kex: vec!["future-algo".to_string(), KEX_X25519_SHA256.to_string()],
aead: vec![AEAD_AES256_GCM_SIV.to_string()],
mac: vec![MAC_HMAC_SHA512.to_string()],
kdf: vec![KDF_HKDF_SHA256.to_string()],
};
let server = current();
let negotiated = negotiate(&client, &server).expect("should find x25519-sha256");
assert_eq!(negotiated.kex, KEX_X25519_SHA256);
}
#[test]
fn negotiate_picks_ml_kem_when_preferred_and_supported() {
let client = AlgorithmList {
kex: vec![
KEX_ML_KEM_768_SHA256.to_string(),
KEX_X25519_SHA256.to_string(),
],
aead: vec![AEAD_AES256_GCM_SIV.to_string()],
mac: vec![MAC_HMAC_SHA512.to_string()],
kdf: vec![KDF_HKDF_SHA256.to_string()],
};
let server = current();
let negotiated = negotiate(&client, &server).expect("should find ml-kem-768-sha256");
assert_eq!(negotiated.kex, KEX_ML_KEM_768_SHA256);
}
#[test]
fn negotiate_falls_back_from_ml_kem_to_ecdh() {
let client = AlgorithmList {
kex: vec![
KEX_ML_KEM_768_SHA256.to_string(),
KEX_X25519_SHA256.to_string(),
],
aead: vec![AEAD_AES256_GCM_SIV.to_string()],
mac: vec![MAC_HMAC_SHA512.to_string()],
kdf: vec![KDF_HKDF_SHA256.to_string()],
};
let server = AlgorithmList {
kex: vec![KEX_X25519_SHA256.to_string()],
aead: vec![AEAD_AES256_GCM_SIV.to_string()],
mac: vec![MAC_HMAC_SHA512.to_string()],
kdf: vec![KDF_HKDF_SHA256.to_string()],
};
let negotiated = negotiate(&client, &server).expect("should fall back to x25519");
assert_eq!(negotiated.kex, KEX_X25519_SHA256);
}
#[test]
fn negotiate_no_common_kex_returns_error() {
let client = AlgorithmList {
kex: vec!["unknown-kex".to_string()],
aead: vec![AEAD_AES256_GCM_SIV.to_string()],
mac: vec![MAC_HMAC_SHA512.to_string()],
kdf: vec![KDF_HKDF_SHA256.to_string()],
};
let server = current();
let err = negotiate(&client, &server).unwrap_err();
assert!(
err.downcast_ref::<MoshpitError>()
.is_some_and(|e| *e == MoshpitError::NoCommonAlgorithm)
);
}
#[test]
fn negotiate_no_common_aead_returns_error() {
let client = AlgorithmList {
kex: vec![KEX_X25519_SHA256.to_string()],
aead: vec!["unknown-aead".to_string()],
mac: vec![MAC_HMAC_SHA512.to_string()],
kdf: vec![KDF_HKDF_SHA256.to_string()],
};
let server = current();
let err = negotiate(&client, &server).unwrap_err();
assert!(
err.downcast_ref::<MoshpitError>()
.is_some_and(|e| *e == MoshpitError::NoCommonAlgorithm)
);
}
#[test]
fn negotiate_empty_client_list_returns_error() {
let client = AlgorithmList {
kex: vec![],
aead: vec![AEAD_AES256_GCM_SIV.to_string()],
mac: vec![MAC_HMAC_SHA512.to_string()],
kdf: vec![KDF_HKDF_SHA256.to_string()],
};
let server = current();
assert!(negotiate(&client, &server).is_err());
}
#[test]
fn negotiate_preference_ordering_respected() {
let client = AlgorithmList {
kex: vec!["future-kex".to_string(), KEX_X25519_SHA256.to_string()],
aead: vec!["future-aead".to_string(), AEAD_AES256_GCM_SIV.to_string()],
mac: vec!["future-mac".to_string(), MAC_HMAC_SHA512.to_string()],
kdf: vec!["future-kdf".to_string(), KDF_HKDF_SHA256.to_string()],
};
let server = current();
let negotiated = negotiate(&client, &server).expect("second-choice should match");
assert_eq!(negotiated.kex, KEX_X25519_SHA256);
assert_eq!(negotiated.aead, AEAD_AES256_GCM_SIV);
assert_eq!(negotiated.mac, MAC_HMAC_SHA512);
assert_eq!(negotiated.kdf, KDF_HKDF_SHA256);
}
#[test]
fn negotiate_server_preference_order_wins() {
let server_prefs = AlgorithmList {
kex: vec![KEX_P384_SHA384.to_string(), KEX_X25519_SHA256.to_string()],
aead: vec![
AEAD_CHACHA20_POLY1305.to_string(),
AEAD_AES256_GCM_SIV.to_string(),
],
mac: vec![MAC_HMAC_SHA256.to_string(), MAC_HMAC_SHA512.to_string()],
kdf: vec![KDF_HKDF_SHA512.to_string(), KDF_HKDF_SHA256.to_string()],
};
let client_offered = AlgorithmList {
kex: vec![KEX_X25519_SHA256.to_string(), KEX_P384_SHA384.to_string()],
aead: vec![
AEAD_AES256_GCM_SIV.to_string(),
AEAD_CHACHA20_POLY1305.to_string(),
],
mac: vec![MAC_HMAC_SHA512.to_string(), MAC_HMAC_SHA256.to_string()],
kdf: vec![KDF_HKDF_SHA256.to_string(), KDF_HKDF_SHA512.to_string()],
};
let negotiated = negotiate(&server_prefs, &client_offered)
.expect("should find common algorithms in server preference order");
assert_eq!(negotiated.kex, KEX_P384_SHA384, "server prefers x448");
assert_eq!(
negotiated.aead, AEAD_CHACHA20_POLY1305,
"server prefers chacha20"
);
assert_eq!(negotiated.mac, MAC_HMAC_SHA256, "server prefers sha256 mac");
assert_eq!(
negotiated.kdf, KDF_HKDF_SHA512,
"server prefers hkdf-sha512"
);
}
#[test]
fn supported_algorithms_contains_all_known_algorithms() {
let algos = supported_algorithms();
assert!(algos.kex.contains(&KEX_X25519_SHA256.to_string()));
assert!(algos.kex.contains(&KEX_ML_KEM_512_SHA256.to_string()));
assert!(algos.kex.contains(&KEX_ML_KEM_768_SHA256.to_string()));
assert!(algos.kex.contains(&KEX_ML_KEM_1024_SHA256.to_string()));
assert!(algos.kex.contains(&KEX_P384_SHA384.to_string()));
assert!(algos.kex.contains(&KEX_P256_SHA256.to_string()));
assert!(algos.aead.contains(&AEAD_AES256_GCM_SIV.to_string()));
assert!(algos.aead.contains(&AEAD_AES256_GCM.to_string()));
assert!(algos.aead.contains(&AEAD_CHACHA20_POLY1305.to_string()));
assert!(algos.aead.contains(&AEAD_AES128_GCM_SIV.to_string()));
assert!(algos.mac.contains(&MAC_HMAC_SHA512.to_string()));
assert!(algos.mac.contains(&MAC_HMAC_SHA256.to_string()));
assert!(algos.kdf.contains(&KDF_HKDF_SHA256.to_string()));
assert!(algos.kdf.contains(&KDF_HKDF_SHA512.to_string()));
}
}