libsignal-rust 0.1.0

Complete and Powerful Rust implementation of libsignal protocol
Documentation
use crate::{
    base_key_type::BaseKeyType,
    chain_type::ChainType,
    session_record::{SessionRecord, SessionEntry, CurrentRatchet, IndexInfo, PendingPreKey, ChainInfo, ChainKey},
    crypto,
    curve::{self, KeyPair},
    errors::{UntrustedIdentityKeyError, PreKeyError},
    protocol_address::ProtocolAddress,
    queue_job::queue_job,
};
use std::sync::Arc;

pub trait SessionStorage: Send + Sync {
    fn is_trusted_identity(&self, address: &str, identity_key: &[u8]) -> impl std::future::Future<Output = bool> + Send;
    fn load_session(&self, address: &str) -> impl std::future::Future<Output = Option<SessionRecord>> + Send;
    fn store_session(&self, address: &str, record: SessionRecord) -> impl std::future::Future<Output = ()> + Send;
    fn load_pre_key(&self, pre_key_id: u32) -> impl std::future::Future<Output = Option<KeyPair>> + Send;
    fn load_signed_pre_key(&self, signed_pre_key_id: u32) -> impl std::future::Future<Output = Option<KeyPair>> + Send;
    fn get_our_identity(&self) -> impl std::future::Future<Output = KeyPair> + Send;
}

pub struct Device {
    pub registration_id: u32,
    pub identity_key: Vec<u8>,
    pub signed_pre_key: SignedPreKeyBundle,
    pub pre_key: Option<PreKeyBundle>,
}

pub struct SignedPreKeyBundle {
    pub key_id: u32,
    pub public_key: Vec<u8>,
    pub signature: Vec<u8>,
}

pub struct PreKeyBundle {
    pub key_id: u32,
    pub public_key: Vec<u8>,
}

pub struct PreKeyWhisperMessage {
    pub registration_id: u32,
    pub pre_key_id: Option<u32>,
    pub signed_pre_key_id: u32,
    pub base_key: Vec<u8>,
    pub identity_key: Vec<u8>,
    pub message: Vec<u8>,
}

pub struct SessionBuilder<T: SessionStorage> {
    addr: ProtocolAddress,
    storage: Arc<T>,
}

impl<T: SessionStorage + 'static> SessionBuilder<T> {
    pub fn new(storage: Arc<T>, protocol_address: ProtocolAddress) -> Self {
        Self {
            addr: protocol_address,
            storage,
        }
    }

    pub async fn init_outgoing(&self, device: Device) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
        let storage = self.storage.clone();
        let addr = self.addr.clone();
        
        queue_job(addr.to_string(), async move {
            if !storage.is_trusted_identity(&addr.id, &device.identity_key).await {
                return Err(Box::new(UntrustedIdentityKeyError::new(addr.id.clone(), device.identity_key)) as Box<dyn std::error::Error + Send + Sync>);
            }

            curve::verify_signature(&device.identity_key, &device.signed_pre_key.public_key, &device.signed_pre_key.signature)?;

            let base_key = curve::generate_key_pair();
            let device_pre_key = device.pre_key.as_ref().map(|pk| &pk.public_key);

            let session = Self::static_init_session(
                storage.clone(),
                true,
                Some(&base_key),
                None,
                &device.identity_key,
                device_pre_key.map(|v| &**v),
                Some(&device.signed_pre_key.public_key),
                device.registration_id,
            ).await?;

            let mut session_mut = session.clone();
            session_mut.pending_pre_key = Some(PendingPreKey {
                signed_key_id: device.signed_pre_key.key_id,
                base_key: base_key.pub_key.clone(),
                pre_key_id: device.pre_key.map(|pk| pk.key_id),
            });

            let mut record = storage.load_session(&addr.to_string()).await.unwrap_or_else(|| SessionRecord::new());
            
            if let Some(open_session) = record.get_open_session() {
                let base_key = open_session.index_info.base_key.clone();
                record.close_session(&base_key);
            }

            record.set_session(session_mut);
            storage.store_session(&addr.to_string(), record).await;
            Ok(())
        }).await
    }

    pub async fn init_incoming(&self, record: &mut SessionRecord, message: &PreKeyWhisperMessage) -> Result<Option<u32>, Box<dyn std::error::Error + Send + Sync>> {
        let fq_addr = self.addr.to_string();
        
        if !self.storage.is_trusted_identity(&fq_addr, &message.identity_key).await {
            return Err(Box::new(UntrustedIdentityKeyError::new(self.addr.id.clone(), message.identity_key.clone())));
        }

        if record.get_session(&message.base_key).is_some() {
            return Ok(None);
        }

        let pre_key_pair = if let Some(pre_key_id) = message.pre_key_id {
            self.storage.load_pre_key(pre_key_id).await
        } else {
            None
        };

        if message.pre_key_id.is_some() && pre_key_pair.is_none() {
            return Err(Box::new(PreKeyError::new("Invalid PreKey ID")));
        }

        let signed_pre_key_pair = self.storage.load_signed_pre_key(message.signed_pre_key_id).await
            .ok_or_else(|| PreKeyError::new("Missing SignedPreKey"))?;

        if let Some(open_session) = record.get_open_session() {
            let base_key = open_session.index_info.base_key.clone();
            record.close_session(&base_key);
        }

        let session = Self::static_init_session(
            self.storage.clone(),
            false,
            pre_key_pair.as_ref(),
            Some(&signed_pre_key_pair),
            &message.identity_key,
            Some(&message.base_key),
            None,
            message.registration_id,
        ).await?;

        record.set_session(session);
        Ok(message.pre_key_id)
    }

    async fn static_init_session<S: SessionStorage>(
        storage: Arc<S>,
        is_initiator: bool,
        our_ephemeral_key: Option<&KeyPair>,
        our_signed_key: Option<&KeyPair>,
        their_identity_pub_key: &[u8],
        their_ephemeral_pub_key: Option<&[u8]>,
        their_signed_pub_key: Option<&[u8]>,
        registration_id: u32,
    ) -> Result<SessionEntry, Box<dyn std::error::Error + Send + Sync>> {
        let our_signed_key = if is_initiator {
            our_ephemeral_key.unwrap()
        } else {
            our_signed_key.unwrap()
        };

        let their_signed_pub_key = if is_initiator {
            their_signed_pub_key.unwrap()
        } else {
            their_ephemeral_pub_key.unwrap()
        };

        let shared_secret_len = if our_ephemeral_key.is_none() || their_ephemeral_pub_key.is_none() {
            32 * 4
        } else {
            32 * 5
        };

        let mut shared_secret = vec![0xffu8; 32];
        shared_secret.resize(shared_secret_len, 0);

        let our_identity = storage.get_our_identity().await;
        let a1 = curve::calculate_agreement(their_signed_pub_key, &our_identity.priv_key)?;
        let a2 = curve::calculate_agreement(their_identity_pub_key, &our_signed_key.priv_key)?;
        let a3 = curve::calculate_agreement(their_signed_pub_key, &our_signed_key.priv_key)?;

        if is_initiator {
            shared_secret[32..64].copy_from_slice(&a1);
            shared_secret[64..96].copy_from_slice(&a2);
        } else {
            shared_secret[64..96].copy_from_slice(&a1);
            shared_secret[32..64].copy_from_slice(&a2);
        }
        shared_secret[96..128].copy_from_slice(&a3);

        if let (Some(our_eph), Some(their_eph)) = (our_ephemeral_key, their_ephemeral_pub_key) {
            let a4 = curve::calculate_agreement(their_eph, &our_eph.priv_key)?;
            shared_secret[128..160].copy_from_slice(&a4);
        }

        let master_key = crypto::derive_secrets(&shared_secret, &[0u8; 32], b"WhisperText", None)?;

        let mut session = SessionEntry::new();
        session.registration_id = registration_id;
        session.current_ratchet = CurrentRatchet {
            root_key: master_key[0].clone(),
            ephemeral_key_pair: if is_initiator { 
                curve::generate_key_pair() 
            } else { 
                our_signed_key.clone() 
            },
            last_remote_ephemeral_key: their_signed_pub_key.to_vec(),
            previous_counter: 0,
        };

        session.index_info = IndexInfo {
            created: chrono::Utc::now().timestamp() as u64,
            used: chrono::Utc::now().timestamp() as u64,
            remote_identity_key: their_identity_pub_key.to_vec(),
            base_key: if is_initiator { 
                our_ephemeral_key.unwrap().pub_key.clone() 
            } else { 
                their_ephemeral_pub_key.unwrap().to_vec() 
            },
            base_key_type: if is_initiator { BaseKeyType::Ours } else { BaseKeyType::Theirs },
            closed: -1,
        };

        if is_initiator {
            let ephemeral_pub_key = session.current_ratchet.ephemeral_key_pair.pub_key.clone();
        session.add_chain(&ephemeral_pub_key, ChainInfo {
                message_keys: Default::default(),
                chain_key: ChainKey {
                    counter: -1,
                    key: Some(master_key[1].clone()),
                },
                chain_type: ChainType::Sending,
            })?;
        }

        Ok(session)
    }
}