use hkdf::Hkdf;
use rand::RngCore;
use sha2::Sha256;
use x25519_dalek::{PublicKey, StaticSecret};
use crate::error::{HuddleError, Result};
pub const TX_ID_LEN: usize = 16;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SasCode {
pub emoji_indices: [u8; 6],
pub decimal: String,
}
impl SasCode {
pub fn emoji_string(&self) -> String {
self.emoji_indices
.iter()
.map(|i| SAS_EMOJI[*i as usize].0)
.collect::<Vec<_>>()
.join(" ")
}
pub fn emoji_labels(&self) -> String {
self.emoji_indices
.iter()
.map(|i| SAS_EMOJI[*i as usize].1)
.collect::<Vec<_>>()
.join(" / ")
}
}
pub fn new_session() -> ([u8; TX_ID_LEN], StaticSecret, PublicKey) {
let mut tx_id = [0u8; TX_ID_LEN];
rand::thread_rng().fill_bytes(&mut tx_id);
let secret = StaticSecret::random_from_rng(rand::thread_rng());
let public = PublicKey::from(&secret);
(tx_id, secret, public)
}
pub fn derive_sas_code(
our_secret: &StaticSecret,
their_public: &PublicKey,
tx_id: &[u8; TX_ID_LEN],
) -> SasCode {
let shared = our_secret.diffie_hellman(their_public);
let hk = Hkdf::<Sha256>::new(Some(tx_id), shared.as_bytes());
let mut okm = [0u8; 9];
hk.expand(b"huddle-sas-v1", &mut okm)
.expect("9 bytes is well within HKDF output limit");
let mut emoji_indices = [0u8; 6];
for i in 0..6 {
emoji_indices[i] = okm[i] & 0x3f; }
let decimal = format!(
"{:06}",
(u32::from(okm[6]) << 16 | u32::from(okm[7]) << 8 | u32::from(okm[8])) % 1_000_000
);
SasCode {
emoji_indices,
decimal,
}
}
pub const SAS_EMOJI: [(&str, &str); 64] = [
("🐶", "dog"),
("🐱", "cat"),
("🦁", "lion"),
("🐴", "horse"),
("🦄", "unicorn"),
("🐷", "pig"),
("🐘", "elephant"),
("🐰", "rabbit"),
("🐼", "panda"),
("🐔", "rooster"),
("🐧", "penguin"),
("🐢", "turtle"),
("🐟", "fish"),
("🐙", "octopus"),
("🦋", "butterfly"),
("🌷", "flower"),
("🌳", "tree"),
("🌵", "cactus"),
("🍄", "mushroom"),
("🌍", "globe"),
("🌙", "moon"),
("☁️", "cloud"),
("🔥", "fire"),
("🍌", "banana"),
("🍎", "apple"),
("🍓", "strawberry"),
("🌽", "corn"),
("🍕", "pizza"),
("🎂", "cake"),
("❤️", "heart"),
("🙂", "smiley"),
("🤖", "robot"),
("🎩", "hat"),
("👓", "glasses"),
("🔧", "spanner"),
("🎅", "santa"),
("👍", "thumbs up"),
("☂️", "umbrella"),
("⌛", "hourglass"),
("⏰", "clock"),
("🎁", "gift"),
("💡", "lightbulb"),
("📕", "book"),
("✏️", "pencil"),
("📎", "paperclip"),
("✂️", "scissors"),
("🔒", "lock"),
("🔑", "key"),
("🔨", "hammer"),
("☎️", "telephone"),
("🏁", "flag"),
("🚂", "train"),
("🚲", "bicycle"),
("✈️", "plane"),
("🚀", "rocket"),
("🏆", "trophy"),
("⚽", "ball"),
("🎸", "guitar"),
("🎺", "trumpet"),
("🔔", "bell"),
("⚓", "anchor"),
("🎧", "headphones"),
("📁", "folder"),
("📌", "pin"),
];
pub fn parse_pubkey(b64: &str) -> Result<PublicKey> {
use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine;
let bytes = B64
.decode(b64)
.map_err(|e| HuddleError::Session(format!("bad x25519 pubkey b64: {e}")))?;
if bytes.len() != 32 {
return Err(HuddleError::Session(format!(
"x25519 pubkey is {} bytes, expected 32",
bytes.len()
)));
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(PublicKey::from(arr))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn both_sides_derive_same_code() {
let (tx_id, alice_secret, alice_pub) = new_session();
let (_, bob_secret, bob_pub) = new_session();
let alice_code = derive_sas_code(&alice_secret, &bob_pub, &tx_id);
let bob_code = derive_sas_code(&bob_secret, &alice_pub, &tx_id);
assert_eq!(alice_code, bob_code);
assert_eq!(alice_code.decimal.len(), 6);
assert!(alice_code.decimal.chars().all(|c| c.is_ascii_digit()));
for i in alice_code.emoji_indices {
assert!((i as usize) < SAS_EMOJI.len());
}
}
#[test]
fn different_tx_id_yields_different_code() {
let (tx_id_a, alice_secret, _) = new_session();
let (_, bob_secret, bob_pub) = new_session();
let alice_code = derive_sas_code(&alice_secret, &bob_pub, &tx_id_a);
let mut tx_id_b = tx_id_a;
tx_id_b[0] ^= 0xff;
let alice_code_b = derive_sas_code(&alice_secret, &bob_pub, &tx_id_b);
let _ = bob_secret;
assert_ne!(alice_code, alice_code_b);
}
#[test]
fn mitm_substitute_yields_different_code() {
let (tx_id, alice_secret, alice_pub) = new_session();
let (_, bob_secret, bob_pub) = new_session();
let (_, _mallory_secret, mallory_pub) = new_session();
let alice_thinks_bob = derive_sas_code(&alice_secret, &mallory_pub, &tx_id);
let bob_thinks_alice = derive_sas_code(&bob_secret, &mallory_pub, &tx_id);
assert_ne!(alice_thinks_bob, bob_thinks_alice);
let alice_real = derive_sas_code(&alice_secret, &bob_pub, &tx_id);
let bob_real = derive_sas_code(&bob_secret, &alice_pub, &tx_id);
assert_eq!(alice_real, bob_real);
}
#[test]
fn pubkey_round_trip() {
let (_, _, pub_) = new_session();
use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine;
let encoded = B64.encode(pub_.as_bytes());
let decoded = parse_pubkey(&encoded).unwrap();
assert_eq!(decoded.as_bytes(), pub_.as_bytes());
}
}