klave-connector 0.1.2

Module allowing easy connection to deployed Klave apps
Documentation
use std::time::Duration;
use base64::{self, Engine};
use futures::SinkExt;
use rand::rngs::OsRng;
use num_bigint::BigUint;
use thiserror::Error;
use tokio::{net::TcpStream, time::timeout};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::MaybeTlsStream;
use futures_util::stream::{SplitSink, SplitStream, StreamExt};
use tokio_tungstenite::WebSocketStream;
use p256::ecdsa::{signature, Signature, VerifyingKey};
use signature::Verifier;
use aes_gcm::{Aes128Gcm, KeyInit, aead::Error as AesGcmError};
use p256::{EncodedPoint, PublicKey as EcdhPublicKey, SecretKey};
use elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint};
use crate::key::Key;
use crate::types::SCPSession;
use crate::utils::{get_random_bytes, get_sha256_bytes, pointwise_xor};
use crate::scp::{ConnectionState, encrypt_data, decrypt_data};

#[derive(Debug, Clone)]
pub struct UserKey {
    private_key: p256::SecretKey,
    public_key: p256::PublicKey,
}

impl UserKey {
    pub fn new(private_key_bytes: &[u8]) -> Result<Self, p256::ecdsa::Error> {
        let private_key = match p256::SecretKey::from_bytes(private_key_bytes.into()) {
            Ok(key) => key,
            Err(_) => return Err(p256::ecdsa::Error::default()),
        };
        let public_key = private_key.public_key();
        Ok(Self { private_key, public_key })
    }

    pub async fn get_raw_public_key(&self) -> Vec<u8> {
        self.public_key.to_encoded_point(false).as_bytes()[1..].to_vec()
    }

    pub fn get_signing_key(&self) -> p256::ecdsa::SigningKey {
        p256::ecdsa::SigningKey::from(self.private_key.clone())
    }

    pub fn get_verifying_key(&self) -> p256::ecdsa::VerifyingKey {
        self.public_key.into()
    }
}

#[derive(Error, Debug)]
pub enum HandshakeError {
    #[error("Socket not set")]
    SocketNotSet,
    #[error("Endpoint not set")]
    EndpointNotSet,
    #[error("IO error: {0}")]
    Io(#[from] std::io::Error),
    #[error("Timeout")]
    Timeout,
    #[error("Base64 decode error: {0}")]
    Base64Decode(#[from] base64::DecodeError),
    #[error("ECDH error")]
    EcdhError,
    #[error("ECDSA error: {0}")]
    Ecdsa(#[from] p256::ecdsa::Error),
    #[error("AES-GCM error: {0}")]
    AesGcm(AesGcmError),
    #[error("UTF-8 error: {0}")]
    Utf8(#[from] std::str::Utf8Error),
    #[error("ASN.1 decode error: {0}")]
    Asn1Decode(String),
    #[error("Handshake failed: {0}")]
    Generic(String),
}

async fn read_n_bytes<'a>(    
    reader: &mut SplitStream<&mut WebSocketStream<MaybeTlsStream<TcpStream>>>,
    n: usize,
    timeout_duration: Duration,
) -> Result<Vec<u8>, HandshakeError> {    

    let n_with_prefix = n+4; // Add 4 bytes for the prefix
    let mut buffer = vec![0u8; n_with_prefix];
    let mut total_read = 0;

    while total_read < n_with_prefix {
        let read_result = timeout(timeout_duration, reader.next()).await;
        match read_result {
            Ok(Some(Ok(Message::Binary(data)))) => {
                let bytes_to_copy = std::cmp::min(data.len(), n_with_prefix - total_read);
                buffer[total_read..total_read + bytes_to_copy].copy_from_slice(&data[..bytes_to_copy]);
                total_read += bytes_to_copy;
            }
            Ok(Some(Ok(Message::Close(_)))) => {
                return Err(HandshakeError::Generic("Connection closed unexpectedly".into()));
            }
            Ok(Some(Ok(_))) => continue,
            Ok(Some(Err(e))) => {
                return Err(HandshakeError::Generic(format!("WebSocket error: {}", e)));
            }
            Ok(None) => break,
            Err(_) => return Err(HandshakeError::Timeout),
        }
    }

    if total_read < n_with_prefix {
        return Err(HandshakeError::Timeout);
    }

    //Remove the first 4 bytes
    buffer = buffer[4..].to_vec();

    Ok(buffer)
}

async fn send_bytes<'a>(
    writer: &mut SplitSink<&mut WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
    data: &[u8],
) -> Result<(), HandshakeError> {

    let mut clear_data = data.to_vec();
    // //Add [0,0,0,1] prefix to the ECDH public key
    clear_data.insert(0, 1);
    clear_data.insert(0, 0);
    clear_data.insert(0, 0);
    clear_data.insert(0, 0);
    
    let message = Message::Binary(clear_data);
    writer.send(message).await.map_err(|e| HandshakeError::Generic(format!("Failed to send message: {}", e)))?;
    Ok(())
}

fn compute_proof_of_work(server_random: &[u8]) -> Vec<u8> {
    // let challenge = server_random; // Assuming `server_random` is the intended challenge
    // let pow = get_sha256_bytes(challenge);
    // pow.into()
    server_random.to_vec() // Placeholder for the actual proof of work logic
}

fn pad_to_size(bytes: &[u8], size: usize) -> Vec<u8> {
    if bytes.len() >= size {
        bytes[bytes.len() - size..].to_vec()
    } else {
        let mut padded = vec![0u8; size - bytes.len()];
        padded.extend_from_slice(bytes);
        padded
    }
}

fn extract_key_and_iv(server_identity_bytes: &[u8], ecdh_secret: &SecretKey) -> Result<(Vec<u8>, Vec<u8>), HandshakeError> {
    if server_identity_bytes.len() < 32 + 64 + 64 {
        return Err(HandshakeError::Generic("Invalid server identity bytes length".into()));
    }
    let pre_master_secret = &server_identity_bytes[0..32];
    let server_ecdh_pub_key_bytes = &server_identity_bytes[32..96];
    
    let server_ecdh_public_key = EcdhPublicKey::from_encoded_point(
        &EncodedPoint::from_bytes(
            [&[0x04], server_ecdh_pub_key_bytes].concat()
        ).unwrap()
    ).unwrap();

    // Compute the shared secret using the SecretKey
    let shared_secret = elliptic_curve::ecdh::diffie_hellman(
        ecdh_secret.to_nonzero_scalar(),
        server_ecdh_public_key.as_affine()
    ).raw_secret_bytes().to_vec();                     
    let sha256_common = get_sha256_bytes(shared_secret.as_slice());
    let symmetric_key: Vec<u8> = pointwise_xor(pre_master_secret, sha256_common.as_slice());
    let iv = &symmetric_key[16..32];
    let key = &symmetric_key[0..16];

    Ok((key.to_vec(), iv.to_vec()))
}

fn der_signature_to_raw(
    user_key: &Key,
    signature: &Signature,
) -> Result<Vec<u8>, HandshakeError> {
    let r_bigint = BigUint::from_bytes_be(&signature.r().to_bytes());
    let s_bigint = BigUint::from_bytes_be(&signature.s().to_bytes());
    let r_bytes = r_bigint.to_bytes_be();
    let s_bytes = s_bigint.to_bytes_be();
    let curve_size = user_key.get_raw_private_key().len();
    if r_bytes.len() > curve_size || s_bytes.len() > curve_size {
        return Err(HandshakeError::Generic("Signature components are too large".into()));
    }
    let r_padded = pad_to_size(&r_bytes, curve_size);
    let s_padded = pad_to_size(&s_bytes, curve_size);
    let signed_nonce = [r_padded, s_padded].concat();
    Ok(signed_nonce)
}

pub async fn process<'a>(
    mut writer: &mut SplitSink<&mut WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
    mut reader: &mut SplitStream<&mut WebSocketStream<MaybeTlsStream<TcpStream>>>,
    user_key: &Key,
    known_trusted_key: Option<&str>,
    connect_timeout: Duration,
) -> Result<(SCPSession, ConnectionState), Box<dyn std::error::Error + Send + Sync>> {

    // --- Client Hello ---
    let ecdh_private_key = SecretKey::random(&mut OsRng);    
    let ecdh_pub_key_raw = ecdh_private_key.public_key().to_sec1_bytes()[1..].to_vec();
    if ecdh_pub_key_raw.len() != 64 {
        return Err(HandshakeError::Generic("Invalid ECDH public key length".into()).into());
    }    
    send_bytes(&mut writer, &ecdh_pub_key_raw).await?;

    // --- Server Hello ---
    let server_hello_bytes = read_n_bytes(&mut reader, 32 + 4, connect_timeout).await?;
    let server_random = &server_hello_bytes[0..32];

    // --- Client Proof of Work ---
    let pow = compute_proof_of_work(server_random);
    let trusted_key =  match known_trusted_key {
        Some(key) => {
            let temp = base64::engine::general_purpose::URL_SAFE.decode(key)?;
            if temp.len() != 64 {
                return Err(HandshakeError::Generic("Invalid trusted key length".into()).into());
            }
            temp
        }
        None => {
            //Random key with 64 bytes
            let temp = get_random_bytes(64);
            temp
        }
    };
    let client_proof_of_work = [pow, trusted_key].concat();

    send_bytes(&mut writer, &client_proof_of_work).await?;

    // --- Server Identity ---
    let server_identity_bytes = read_n_bytes(& mut reader, 32 + 64 + 64, connect_timeout).await?;

    let (key, iv) = extract_key_and_iv(&server_identity_bytes, &ecdh_private_key)?;
    let crypto_key = Aes128Gcm::new(key.as_slice().into());
    let session = SCPSession { crypto_key, iv: iv.try_into().map_err(|_| HandshakeError::Generic("Failed to create IV".into()))? };

    // --- Client Proof of Identity ---
    let public_key_raw = user_key.get_raw_public_key();
    let nonce = get_random_bytes(32);    

    let signing_key = UserKey::new(user_key.get_raw_private_key().to_vec().as_slice())
    .map_err(|_| HandshakeError::Generic("Failed to create signing key".into()))?
    .get_signing_key();
    let (signature, _recovery_id) = signing_key.sign_recoverable(&nonce).map_err(HandshakeError::Ecdsa)?;
    let signed_nonce = der_signature_to_raw(user_key, &signature)?;

    let client_proof_of_identity = [nonce, ecdh_pub_key_raw, public_key_raw, signed_nonce].concat();

    let iv_offset = get_random_bytes(16);    
    let encrypted_client_proof_of_identity = {
        encrypt_data(&session, &client_proof_of_identity, &iv_offset)?
    };

    send_bytes(&mut writer, &encrypted_client_proof_of_identity).await?;

    // --- Server Proof of Identity ---
    let server_proof_of_identity_encrypted = read_n_bytes(& mut reader, 64 + 64, connect_timeout).await?;
    let server_proof_of_identity = {
        decrypt_data(&session, &server_proof_of_identity_encrypted)?
    };

    let welcome = b"Hey you! Welcome to Secretarium!";
    let to_verify = [&server_proof_of_identity[0..32], welcome].concat();
    
    let server_signature = Signature::from_slice(&server_proof_of_identity[32..96]).unwrap();
    let server_ecdsa_pub_key_bytes = &server_identity_bytes[96..160];

    let server_ecdsa_verifying_key = VerifyingKey::from_encoded_point(
        &EncodedPoint::from_bytes(
            [&[0x04], server_ecdsa_pub_key_bytes].concat()
        ).unwrap()
    ).unwrap();
    if server_ecdsa_verifying_key.verify(&to_verify, &server_signature).is_err() {
        return Err(HandshakeError::Generic("ECDSA verification failed".into()).into());
    }        

    Ok((session, ConnectionState::OPEN))
}

#[cfg(test)]
mod tests {    
    use hex;
    use super::*;

    #[test]
    fn test_server_identity() {
        let server_identity_hex = "e479053863dd7bd4c440ea62e6d1db0e59eb02c08f81144be704afb87e96561e13448169d219dfe19bdbe1734bd87f7822504de07e9e5930c4b1ccd5d1933f5941d1ad46e485f4d1225ba43adbaab6b8a1cf53862236714fab24a7171bca26eaae5883fc2212a8f11e60a6d661dc1af8bfbca32b403ef74699b2c2d0a76fb07f8e54cada6ab9b57a8faae1fb5ed1c589efe926b2af307a414824ad1af37ff20a";
        let server_identity_bytes = hex::decode(server_identity_hex).unwrap();

        let user_key_hex = "bb6594d95a216082d54777d8e6b5d99985beff9b5e6ba99519c53bbfdb8ac2a1";
        let user_key_bytes = hex::decode(user_key_hex).unwrap();
        let user_key_bytes: [u8; 32] = user_key_bytes
        .try_into()
        .expect("Private key must be 32 bytes");
    
        let user_key = match SecretKey::from_bytes(&user_key_bytes.into()) {
            Ok(secret_key) => Key::new(Some(secret_key)),
            Err(_) => panic!("Failed to create SecretKey from bytes"),
        };

        // Example byte array (must be 32 bytes)
        let ecdh_key_hex = "fa2c3c7a9a52f4104ba8834c8d57afcbeb123aa72833c150bec3af5eb41f22e7";
        let ecdh_key_bytes = hex::decode(ecdh_key_hex).unwrap();
        let ecdh_key_bytes: [u8; 32] = ecdh_key_bytes
        .try_into()
        .expect("Private key must be 32 bytes");

        // Create a SecretKey from the byte array
        let ecdh_key = SecretKey::from_bytes(&ecdh_key_bytes.into())
        .expect("Failed to create SecretKey from bytes");

        let (key, iv) = extract_key_and_iv(&server_identity_bytes, &ecdh_key.clone()).unwrap();
        let key_hex = hex::encode(key.clone());
        let iv_hex = hex::encode(iv.clone());

        println!("Key: {:?}", key_hex);
        println!("IV: {:?}", iv_hex);

        let expected_key = "1c510cda986452d4ebb32f702b08589b";
        let expected_iv = "c71953dea4339dd0e796188127f0baeb";

        assert_eq!(key_hex, expected_key);
        assert_eq!(iv_hex, expected_iv);

        let crypto_key = Aes128Gcm::new(key.as_slice().into());
        let session = SCPSession { crypto_key, iv: match iv.try_into() {
            Ok(iv) => iv,
            Err(_) => panic!("Failed to create IV from bytes"),
        } };
    
        // --- Client Proof of Identity ---
        let _public_key_raw = user_key.get_raw_public_key();
        let nonce_hex = "5c513f8ab90094b01cda913d858ca71e14ee81ba484e9d6130abd96c21c5e8fa";
        let nonce_bytes = hex::decode(nonce_hex).unwrap();

        let signing_key = match UserKey::new(user_key.get_raw_private_key().to_vec().as_slice()) {
            Ok(user_key) => user_key.get_signing_key(),
            Err(_) => panic!("Failed to create UserKey from bytes"),
        };        
        let (_signature, _recovery_id) = match signing_key.sign_recoverable(&nonce_bytes) {
            Ok(signature) => signature,
            Err(_) => panic!("Failed to sign nonce"),
        };

        let signature_hex = "3046022100eee8e43c035ef81badc6c02107cc13a29eac119b05fb645de8636e4fea7b240b022100ab83b4a4a35f293c6898da1f9c8d2c79eedd1b395d20eec4779949d5cdf22f7e";
        let signature_bytes = hex::decode(signature_hex).unwrap();
        let signature = match Signature::from_der(signature_bytes.as_slice()) {
            Ok(signature) => signature,
            Err(_) => panic!("Failed to create Signature from bytes"),
        };

        let signed_nonce = match der_signature_to_raw(&user_key, &signature) {
            Ok(signed_nonce) => signed_nonce,
            Err(e) => panic!("Failed to sign nonce: {:?}", e),
        };
        
        let signed_nonce_hex = hex::encode(&signed_nonce);

        println!("Nonce: {:?}", nonce_bytes);
        println!("Signed Nonce: {:?}", signed_nonce);
        
        assert_eq!(signed_nonce_hex, "eee8e43c035ef81badc6c02107cc13a29eac119b05fb645de8636e4fea7b240bab83b4a4a35f293c6898da1f9c8d2c79eedd1b395d20eec4779949d5cdf22f7e");

        let client_proof_of_identity = [nonce_bytes, ecdh_key.public_key().to_sec1_bytes()[1..].to_vec(), user_key.get_raw_public_key(), signed_nonce].concat();

        println!("Client Proof of Identity: {:?}", hex::encode(&client_proof_of_identity));

        let expected_client_proof_of_identity = "5c513f8ab90094b01cda913d858ca71e14ee81ba484e9d6130abd96c21c5e8faa4c5eef4e22806412594d08a08c66201f3bb39161db47166b44f111c9eb6a40983ba3a408c2c44afded0d4b65359eb25397e4a39e2d5667f1329b413ff0efc8f04cc2484e17744aad0554e16d8b42c04688dd838a84b7dfacf53d629217ea404535c81b1fadb66e4dea5039001104d83935fb64bcd9d514dc364a893b445fad5eee8e43c035ef81badc6c02107cc13a29eac119b05fb645de8636e4fea7b240bab83b4a4a35f293c6898da1f9c8d2c79eedd1b395d20eec4779949d5cdf22f7e";
        assert_eq!(hex::encode(&client_proof_of_identity), expected_client_proof_of_identity);

        let iv_offset_hex = "ca682c217eaae46717e57df1b6f72693";
        let iv_offset_bytes = hex::decode(iv_offset_hex).unwrap();
        let encrypted_client_proof_of_identity = match encrypt_data(&session, &client_proof_of_identity, iv_offset_bytes.as_slice()) {
            Ok(encrypted_client_proof_of_identity) => encrypted_client_proof_of_identity,
            Err(e) => panic!("Failed to encrypt client proof of identity: {:?}", e),
        };

        println!("Encrypted Client Proof of Identity: {:?}", hex::encode(&encrypted_client_proof_of_identity));

        let expected_encrypted_client_proof_of_identity = "ca682c217eaae46717e57df1b6f72693d5a3126577a30fc4207887e98e5aa3ce802f4a8b3ed193ebe39882247d39ea0ce79d78d1aa04c3db94e40a4b020eb91e6f53e87ac466c6c72b90773d66202dd9d266daa54c6bfc3e2b2b0e11004e686c959fa89e98a1ff58c2f49a1b2ce07d8c2fdd48b97aaef4876763850bb3b0d566f168445dae4ec9018f277e650d794d8653a4336156f5fa5a3f99ee51c687e7f1b84d70cb8b8ebc8c2fa97d210e208850ba05b06cfb95ad99ccbda6687fc082643a437669f220b9c15782d92e6a498bc4d1ede69cfcc0d7c7badafa473cd5a6aab3413659b365f1365c960cf2553bf33db7a1e16a866583387bed40e9587b805d";
        assert_eq!(hex::encode(&encrypted_client_proof_of_identity), expected_encrypted_client_proof_of_identity);
    }
}