Skip to main content

mls_rs_provider_sqlite/
psk.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use crate::SqLiteDataStorageError;
6use mls_rs_core::psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage};
7use rusqlite::{params, Connection, OptionalExtension};
8use std::{
9    ops::Deref,
10    sync::{Arc, Mutex},
11};
12
13#[derive(Debug, Clone)]
14/// SQLite storage for MLS pre-shared keys.
15pub struct SqLitePreSharedKeyStorage {
16    connection: Arc<Mutex<Connection>>,
17}
18
19impl SqLitePreSharedKeyStorage {
20    pub(crate) fn new(connection: Connection) -> SqLitePreSharedKeyStorage {
21        SqLitePreSharedKeyStorage {
22            connection: Arc::new(Mutex::new(connection)),
23        }
24    }
25
26    /// Insert a pre-shared key into storage.
27    pub fn insert(&self, psk_id: &[u8], psk: &PreSharedKey) -> Result<(), SqLiteDataStorageError> {
28        let connection = self.connection.lock().unwrap();
29
30        // Upsert into the database
31        connection
32            .execute(
33                "INSERT INTO psk (psk_id, data) VALUES (?,?) ON CONFLICT(psk_id) DO UPDATE SET data=excluded.data",
34                params![psk_id, psk.deref()],
35            )
36            .map(|_| ())
37            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
38    }
39
40    /// Get a pre-shared key from storage based on a unique id.
41    pub fn get(&self, psk_id: &[u8]) -> Result<Option<PreSharedKey>, SqLiteDataStorageError> {
42        let connection = self.connection.lock().unwrap();
43
44        connection
45            .query_row(
46                "SELECT data FROM psk WHERE psk_id = ?",
47                params![psk_id],
48                |row| Ok(PreSharedKey::new(row.get(0)?)),
49            )
50            .optional()
51            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
52    }
53
54    /// Delete a pre-shared key from storage based on a unique id.
55    pub fn delete(&self, psk_id: &[u8]) -> Result<(), SqLiteDataStorageError> {
56        let connection = self.connection.lock().unwrap();
57
58        connection
59            .execute("DELETE FROM psk WHERE psk_id = ?", params![psk_id])
60            .map(|_| ())
61            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
62    }
63}
64
65#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
66#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
67impl PreSharedKeyStorage for SqLitePreSharedKeyStorage {
68    type Error = SqLiteDataStorageError;
69
70    async fn get(&self, id: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
71        self.get(id)
72            .map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use mls_rs_core::psk::PreSharedKey;
79
80    use crate::{
81        SqLiteDataStorageEngine,
82        {connection_strategy::MemoryStrategy, test_utils::gen_rand_bytes},
83    };
84
85    use super::SqLitePreSharedKeyStorage;
86
87    fn test_psk() -> (Vec<u8>, PreSharedKey) {
88        let psk_id = gen_rand_bytes(32);
89        let stored_psk = PreSharedKey::new(gen_rand_bytes(64));
90
91        (psk_id, stored_psk)
92    }
93
94    fn test_storage() -> SqLitePreSharedKeyStorage {
95        SqLiteDataStorageEngine::new(MemoryStrategy)
96            .unwrap()
97            .pre_shared_key_storage()
98            .unwrap()
99    }
100
101    #[test]
102    fn test_insert() {
103        let (psk_id, psk) = test_psk();
104        let storage = test_storage();
105
106        storage.insert(&psk_id, &psk).unwrap();
107
108        let from_storage = storage.get(&psk_id).unwrap().unwrap();
109        assert_eq!(from_storage, psk);
110    }
111
112    #[test]
113    fn test_insert_existing_overwrite() {
114        let (psk_id, psk) = test_psk();
115        let (_, new_psk) = test_psk();
116
117        let storage = test_storage();
118
119        storage.insert(&psk_id, &psk).unwrap();
120        storage.insert(&psk_id, &new_psk).unwrap();
121
122        let from_storage = storage.get(&psk_id).unwrap().unwrap();
123        assert_eq!(from_storage, new_psk);
124    }
125
126    #[test]
127    fn test_delete() {
128        let (psk_id, psk) = test_psk();
129        let storage = test_storage();
130
131        storage.insert(&psk_id, &psk).unwrap();
132        storage.delete(&psk_id).unwrap();
133
134        assert!(storage.get(&psk_id).unwrap().is_none());
135    }
136}