quadra-a-core 0.1.0-beta.1

Core protocol primitives for the quadra-a agent network
Documentation
use anyhow::{Context, Result};
use serde_json::Value;

use crate::e2e::cbor_x::{encode, E2eCborValue};
use crate::e2e::crypto::{decrypt_xchacha20poly1305, encrypt_xchacha20poly1305};
use crate::e2e::types::{E2EMessageType, PreKeyMessage, SessionMessage, E2E_PROTOCOL_VERSION};
use crate::protocol::cbor_decode_value;

fn get_str(value: &Value, field: &str) -> Result<String> {
    value
        .get(field)
        .and_then(Value::as_str)
        .map(ToOwned::to_owned)
        .with_context(|| format!("Missing string field {}", field))
}

fn get_u64(value: &Value, field: &str) -> Result<u64> {
    value
        .get(field)
        .and_then(Value::as_u64)
        .with_context(|| format!("Missing integer field {}", field))
}

fn get_bytes(value: &Value, field: &str) -> Result<Vec<u8>> {
    value
        .get(field)
        .and_then(Value::as_array)
        .with_context(|| format!("Missing byte-array field {}", field))?
        .iter()
        .map(|item| {
            item.as_u64()
                .filter(|byte| *byte <= 255)
                .map(|byte| byte as u8)
                .with_context(|| format!("Field {} contains non-byte value", field))
        })
        .collect()
}

fn build_prekey_map(
    message: &PreKeyMessage,
    include_ciphertext: bool,
) -> Vec<(&str, E2eCborValue<'_>)> {
    let mut fields = vec![
        ("version", E2eCborValue::Int(message.version as i64)),
        ("type", E2eCborValue::Str("PREKEY_MESSAGE")),
        ("senderDid", E2eCborValue::Str(&message.sender_did)),
        ("receiverDid", E2eCborValue::Str(&message.receiver_did)),
        (
            "senderDeviceId",
            E2eCborValue::Str(&message.sender_device_id),
        ),
        (
            "receiverDeviceId",
            E2eCborValue::Str(&message.receiver_device_id),
        ),
        ("sessionId", E2eCborValue::Str(&message.session_id)),
        ("messageId", E2eCborValue::Str(&message.message_id)),
        (
            "initiatorIdentityKey",
            E2eCborValue::TypedBytes(&message.initiator_identity_key),
        ),
        (
            "initiatorEphemeralKey",
            E2eCborValue::TypedBytes(&message.initiator_ephemeral_key),
        ),
        (
            "recipientSignedPreKeyId",
            E2eCborValue::Int(message.recipient_signed_pre_key_id as i64),
        ),
    ];

    if let Some(one_time_pre_key_id) = message.recipient_one_time_pre_key_id {
        fields.push((
            "recipientOneTimePreKeyId",
            E2eCborValue::Int(one_time_pre_key_id as i64),
        ));
    }

    fields.push(("nonce", E2eCborValue::TypedBytes(&message.nonce)));

    if include_ciphertext {
        fields.push(("ciphertext", E2eCborValue::TypedBytes(&message.ciphertext)));
    }

    fields
}

fn build_session_map(
    message: &SessionMessage,
    include_ciphertext: bool,
) -> Vec<(&str, E2eCborValue<'_>)> {
    let mut fields = vec![
        ("version", E2eCborValue::Int(message.version as i64)),
        ("type", E2eCborValue::Str("SESSION_MESSAGE")),
        ("senderDid", E2eCborValue::Str(&message.sender_did)),
        ("receiverDid", E2eCborValue::Str(&message.receiver_did)),
        (
            "senderDeviceId",
            E2eCborValue::Str(&message.sender_device_id),
        ),
        (
            "receiverDeviceId",
            E2eCborValue::Str(&message.receiver_device_id),
        ),
        ("sessionId", E2eCborValue::Str(&message.session_id)),
        ("messageId", E2eCborValue::Str(&message.message_id)),
        (
            "ratchetPublicKey",
            E2eCborValue::TypedBytes(&message.ratchet_public_key),
        ),
        (
            "previousChainLength",
            E2eCborValue::Int(message.previous_chain_length as i64),
        ),
        (
            "messageNumber",
            E2eCborValue::Int(message.message_number as i64),
        ),
        ("nonce", E2eCborValue::TypedBytes(&message.nonce)),
    ];

    if include_ciphertext {
        fields.push(("ciphertext", E2eCborValue::TypedBytes(&message.ciphertext)));
    }

    fields
}

pub fn build_prekey_message_associated_data(message: &PreKeyMessage) -> Result<Vec<u8>> {
    Ok(encode(&E2eCborValue::Map(build_prekey_map(message, false))))
}

pub fn build_session_message_associated_data(message: &SessionMessage) -> Result<Vec<u8>> {
    Ok(encode(&E2eCborValue::Map(build_session_map(
        message, false,
    ))))
}

pub fn encrypt_prekey_message(
    message: &PreKeyMessage,
    key: &[u8],
    plaintext: &[u8],
) -> Result<PreKeyMessage> {
    let mut encrypted = message.clone();
    encrypted.ciphertext = encrypt_xchacha20poly1305(
        key,
        &encrypted.nonce,
        plaintext,
        &build_prekey_message_associated_data(message)?,
    )?;
    Ok(encrypted)
}

pub fn decrypt_prekey_message(message: &PreKeyMessage, key: &[u8]) -> Result<Vec<u8>> {
    decrypt_xchacha20poly1305(
        key,
        &message.nonce,
        &message.ciphertext,
        &build_prekey_message_associated_data(message)?,
    )
}

pub fn encrypt_session_message(
    message: &SessionMessage,
    key: &[u8],
    plaintext: &[u8],
) -> Result<SessionMessage> {
    let mut encrypted = message.clone();
    encrypted.ciphertext = encrypt_xchacha20poly1305(
        key,
        &encrypted.nonce,
        plaintext,
        &build_session_message_associated_data(message)?,
    )?;
    Ok(encrypted)
}

pub fn decrypt_session_message(message: &SessionMessage, key: &[u8]) -> Result<Vec<u8>> {
    decrypt_xchacha20poly1305(
        key,
        &message.nonce,
        &message.ciphertext,
        &build_session_message_associated_data(message)?,
    )
}

pub fn encode_prekey_message(message: &PreKeyMessage) -> Result<Vec<u8>> {
    if message.version != E2E_PROTOCOL_VERSION {
        anyhow::bail!("Unsupported PREKEY_MESSAGE version: {}", message.version);
    }
    Ok(encode(&E2eCborValue::Map(build_prekey_map(message, true))))
}

pub fn decode_prekey_message(bytes: &[u8]) -> Result<PreKeyMessage> {
    let value = cbor_decode_value(bytes).context("Failed to decode PREKEY_MESSAGE")?;
    if value.get("version").and_then(Value::as_u64) != Some(E2E_PROTOCOL_VERSION as u64)
        || value.get("type").and_then(Value::as_str) != Some("PREKEY_MESSAGE")
    {
        anyhow::bail!("Invalid PREKEY_MESSAGE payload");
    }

    Ok(PreKeyMessage {
        version: get_u64(&value, "version")? as u8,
        message_type: E2EMessageType::PreKeyMessage,
        sender_did: get_str(&value, "senderDid")?,
        receiver_did: get_str(&value, "receiverDid")?,
        sender_device_id: get_str(&value, "senderDeviceId")?,
        receiver_device_id: get_str(&value, "receiverDeviceId")?,
        session_id: get_str(&value, "sessionId")?,
        message_id: get_str(&value, "messageId")?,
        initiator_identity_key: get_bytes(&value, "initiatorIdentityKey")?,
        initiator_ephemeral_key: get_bytes(&value, "initiatorEphemeralKey")?,
        recipient_signed_pre_key_id: get_u64(&value, "recipientSignedPreKeyId")? as u32,
        recipient_one_time_pre_key_id: value
            .get("recipientOneTimePreKeyId")
            .and_then(Value::as_u64)
            .map(|value| value as u32),
        nonce: get_bytes(&value, "nonce")?,
        ciphertext: get_bytes(&value, "ciphertext")?,
    })
}

pub fn encode_session_message(message: &SessionMessage) -> Result<Vec<u8>> {
    if message.version != E2E_PROTOCOL_VERSION {
        anyhow::bail!("Unsupported SESSION_MESSAGE version: {}", message.version);
    }
    Ok(encode(&E2eCborValue::Map(build_session_map(message, true))))
}

pub fn decode_session_message(bytes: &[u8]) -> Result<SessionMessage> {
    let value = cbor_decode_value(bytes).context("Failed to decode SESSION_MESSAGE")?;
    if value.get("version").and_then(Value::as_u64) != Some(E2E_PROTOCOL_VERSION as u64)
        || value.get("type").and_then(Value::as_str) != Some("SESSION_MESSAGE")
    {
        anyhow::bail!("Invalid SESSION_MESSAGE payload");
    }

    Ok(SessionMessage {
        version: get_u64(&value, "version")? as u8,
        message_type: E2EMessageType::SessionMessage,
        sender_did: get_str(&value, "senderDid")?,
        receiver_did: get_str(&value, "receiverDid")?,
        sender_device_id: get_str(&value, "senderDeviceId")?,
        receiver_device_id: get_str(&value, "receiverDeviceId")?,
        session_id: get_str(&value, "sessionId")?,
        message_id: get_str(&value, "messageId")?,
        ratchet_public_key: get_bytes(&value, "ratchetPublicKey")?,
        previous_chain_length: get_u64(&value, "previousChainLength")? as u32,
        message_number: get_u64(&value, "messageNumber")? as u32,
        nonce: get_bytes(&value, "nonce")?,
        ciphertext: get_bytes(&value, "ciphertext")?,
    })
}