mls-rs-provider-sqlite 0.23.0

SQLite based state storage for mls-rs
Documentation
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use mls_rs_core::{
    key_package::{KeyPackageData, KeyPackageStorage},
    mls_rs_codec::{MlsDecode, MlsEncode},
    time::MlsTime,
};
use rusqlite::{params, Connection, OptionalExtension};
use std::sync::{Arc, Mutex};

use crate::SqLiteDataStorageError;

#[derive(Debug, Clone)]
/// SQLite storage for MLS Key Packages.
///
/// # Limitations
///
/// Expiration timestamps are stored as SQLite INTEGER (signed 64-bit), limiting
/// the maximum timestamp to [`i64::MAX`] (9,223,372,036,854,775,807). Operations
/// with timestamps exceeding this value will return [`SqLiteDataStorageError::TimestampOverflow`].
pub struct SqLiteKeyPackageStorage {
    connection: Arc<Mutex<Connection>>,
}

impl SqLiteKeyPackageStorage {
    pub(crate) fn new(connection: Connection) -> SqLiteKeyPackageStorage {
        SqLiteKeyPackageStorage {
            connection: Arc::new(Mutex::new(connection)),
        }
    }

    fn insert(
        &mut self,
        id: &[u8],
        key_package: KeyPackageData,
    ) -> Result<(), SqLiteDataStorageError> {
        let connection = self.connection.lock().unwrap();

        connection
            .execute(
                "INSERT INTO key_package (id, expiration, data) VALUES (?,?,?)",
                params![
                    id,
                    i64::try_from(key_package.expiration).map_err(|_| {
                        SqLiteDataStorageError::TimestampOverflow(key_package.expiration)
                    })?,
                    key_package
                        .mls_encode_to_vec()
                        .map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))?
                ],
            )
            .map(|_| ())
            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
    }

    fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, SqLiteDataStorageError> {
        let connection = self.connection.lock().unwrap();

        connection
            .query_row(
                "SELECT data FROM key_package WHERE id = ?",
                params![id],
                |row| {
                    Ok(
                        KeyPackageData::mls_decode(&mut row.get::<_, Vec<u8>>(0)?.as_slice())
                            .unwrap(),
                    )
                },
            )
            .optional()
            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
    }

    /// Delete a specific key package from storage based on it's id.
    pub fn delete(&self, id: &[u8]) -> Result<(), SqLiteDataStorageError> {
        let connection = self.connection.lock().unwrap();

        connection
            .execute("DELETE FROM key_package where id = ?", params![id])
            .map(|_| ())
            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
    }

    /// Delete key packages that are expired based on the current system clock time.
    pub fn delete_expired(&self) -> Result<(), SqLiteDataStorageError> {
        self.delete_expired_by_time(MlsTime::now().seconds_since_epoch())
    }

    /// Delete key packages that are expired based on an application provided time in seconds since
    /// unix epoch.
    pub fn delete_expired_by_time(&self, time: u64) -> Result<(), SqLiteDataStorageError> {
        let connection = self.connection.lock().unwrap();

        connection
            .execute(
                "DELETE FROM key_package where expiration < ?",
                params![i64::try_from(time)
                    .map_err(|_| SqLiteDataStorageError::TimestampOverflow(time))?],
            )
            .map(|_| ())
            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
    }

    /// Total number of key packages held in storage.
    pub fn count(&self) -> Result<usize, SqLiteDataStorageError> {
        let connection = self.connection.lock().unwrap();

        connection
            .query_row("SELECT count(*) FROM key_package", params![], |row| {
                row.get::<_, i64>(0).and_then(|v| {
                    usize::try_from(v).map_err(|_| rusqlite::Error::IntegralValueOutOfRange(0, v))
                })
            })
            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
    }

    /// Total number of key packages that will still remain in storage at a specific application provided
    /// time in seconds since unix epoch. This assumes that the application would also be calling
    /// [SqLiteKeyPackageStorage::delete_expired] at a reasonable cadence to be accurate.
    pub fn count_at_time(&self, time: u64) -> Result<usize, SqLiteDataStorageError> {
        let connection = self.connection.lock().unwrap();

        connection
            .query_row(
                "SELECT count(*) FROM key_package where expiration >= ?",
                params![i64::try_from(time)
                    .map_err(|_| SqLiteDataStorageError::TimestampOverflow(time))?],
                |row| {
                    row.get::<_, i64>(0).and_then(|v| {
                        usize::try_from(v)
                            .map_err(|_| rusqlite::Error::IntegralValueOutOfRange(0, v))
                    })
                },
            )
            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
    }
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
impl KeyPackageStorage for SqLiteKeyPackageStorage {
    type Error = SqLiteDataStorageError;

    async fn insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error> {
        self.insert(id.as_slice(), pkg)
    }

    async fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error> {
        self.get(id)
    }

    async fn delete(&mut self, id: &[u8]) -> Result<(), Self::Error> {
        (*self).delete(id)
    }
}

#[cfg(test)]
mod tests {
    use super::SqLiteKeyPackageStorage;
    use crate::{
        SqLiteDataStorageEngine, SqLiteDataStorageError,
        {connection_strategy::MemoryStrategy, test_utils::gen_rand_bytes},
    };
    use assert_matches::assert_matches;
    use mls_rs_core::{crypto::HpkeSecretKey, key_package::KeyPackageData};

    fn test_storage() -> SqLiteKeyPackageStorage {
        SqLiteDataStorageEngine::new(MemoryStrategy)
            .unwrap()
            .key_package_storage()
            .unwrap()
    }

    fn test_key_package() -> (Vec<u8>, KeyPackageData) {
        let key_id = gen_rand_bytes(32);
        let key_package = KeyPackageData::new(
            gen_rand_bytes(256),
            HpkeSecretKey::from(gen_rand_bytes(256)),
            HpkeSecretKey::from(gen_rand_bytes(256)),
            123,
        );

        (key_id, key_package)
    }

    #[test]
    fn key_package_insert() {
        let mut storage = test_storage();
        let (key_package_id, key_package) = test_key_package();

        storage
            .insert(&key_package_id, key_package.clone())
            .unwrap();

        let from_storage = storage.get(&key_package_id).unwrap().unwrap();
        assert_eq!(from_storage, key_package);
    }

    #[test]
    fn duplicate_insert_should_fail() {
        let mut storage = test_storage();
        let (key_package_id, key_package) = test_key_package();

        storage
            .insert(&key_package_id, key_package.clone())
            .unwrap();

        let dupe_res = storage.insert(&key_package_id, key_package);

        assert_matches!(dupe_res, Err(SqLiteDataStorageError::SqlEngineError(_)));
    }

    #[test]
    fn key_package_not_found() {
        let mut storage = test_storage();
        let (key_package_id, key_package) = test_key_package();

        storage.insert(&key_package_id, key_package).unwrap();

        let (another_package_id, _) = test_key_package();

        assert!(storage.get(&another_package_id).unwrap().is_none());
    }

    #[test]
    fn key_package_delete() {
        let mut storage = test_storage();
        let (key_package_id, key_package) = test_key_package();

        storage.insert(&key_package_id, key_package).unwrap();

        storage.delete(&key_package_id).unwrap();
        assert!(storage.get(&key_package_id).unwrap().is_none());
    }

    #[test]
    fn expired_key_package_gelete() {
        let mut storage = test_storage();

        let data = [1, 15, 30, 1698652376].map(|exp| {
            let mut kp = test_key_package();
            kp.1.expiration = exp;
            kp
        });

        for (id, data) in &data {
            storage.insert(id, data.clone()).unwrap();
        }

        storage.delete_expired_by_time(30).unwrap();

        assert!(storage.get(&data[0].0).unwrap().is_none());
        assert!(storage.get(&data[1].0).unwrap().is_none());
        storage.get(&data[2].0).unwrap().unwrap();
        storage.get(&data[3].0).unwrap().unwrap();

        storage.delete_expired().unwrap();

        assert!(storage.get(&data[2].0).unwrap().is_none());
        assert!(storage.get(&data[3].0).unwrap().is_none());
    }

    #[test]
    fn key_count() {
        let mut storage = test_storage();

        let test_packages = (0..10).map(|_| test_key_package()).collect::<Vec<_>>();

        test_packages
            .into_iter()
            .for_each(|(key_package_id, key_package)| {
                storage.insert(&key_package_id, key_package).unwrap();
            });

        assert_eq!(storage.count().unwrap(), 10);
    }

    #[test]
    fn key_count_at_time() {
        let mut storage = test_storage();

        let mut kp_1 = test_key_package();
        kp_1.1.expiration = 1;
        storage.insert(&kp_1.0, kp_1.1).unwrap();

        let mut kp_2 = test_key_package();
        kp_2.1.expiration = 2;
        storage.insert(&kp_2.0, kp_2.1).unwrap();

        assert_eq!(storage.count_at_time(3).unwrap(), 0);
        assert_eq!(storage.count_at_time(2).unwrap(), 1);
        assert_eq!(storage.count_at_time(1).unwrap(), 2);
        assert_eq!(storage.count_at_time(0).unwrap(), 2);
    }

    #[test]
    fn timestamp_overflow() {
        let mut storage = test_storage();
        let (id, mut kp) = test_key_package();
        kp.expiration = u64::MAX;

        let err = storage.insert(&id, kp).unwrap_err();
        assert_matches!(err, SqLiteDataStorageError::TimestampOverflow(u64::MAX));

        let err = storage.delete_expired_by_time(u64::MAX).unwrap_err();
        assert_matches!(err, SqLiteDataStorageError::TimestampOverflow(u64::MAX));

        let err = storage.count_at_time(u64::MAX).unwrap_err();
        assert_matches!(err, SqLiteDataStorageError::TimestampOverflow(u64::MAX));
    }
}