Skip to main content

matrix_bot_sdk/
storage.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::Value;
7use tokio::sync::RwLock;
8
9#[async_trait]
10pub trait IStorageProvider: Send + Sync {
11    async fn get(&self, key: &str) -> anyhow::Result<Option<Value>>;
12    async fn set(&self, key: &str, value: Value) -> anyhow::Result<()>;
13    async fn delete(&self, key: &str) -> anyhow::Result<()>;
14    async fn keys(&self) -> anyhow::Result<Vec<String>>;
15}
16
17#[async_trait]
18pub trait IAppserviceStorageProvider: IStorageProvider {
19    async fn register_user(&self, user_id: &str) -> anyhow::Result<()>;
20    async fn is_user_registered(&self, user_id: &str) -> anyhow::Result<bool>;
21}
22
23#[async_trait]
24pub trait ICryptoStorageProvider: IStorageProvider {
25    async fn store_session_key(&self, room_id: &str, session_key: &str) -> anyhow::Result<()>;
26    async fn session_key(&self, room_id: &str) -> anyhow::Result<Option<String>>;
27}
28
29#[derive(Debug, Default)]
30pub struct MemoryStorageProvider {
31    entries: RwLock<HashMap<String, Value>>,
32}
33
34impl MemoryStorageProvider {
35    pub fn new() -> Self {
36        Self::default()
37    }
38}
39
40#[async_trait]
41impl IStorageProvider for MemoryStorageProvider {
42    async fn get(&self, key: &str) -> anyhow::Result<Option<Value>> {
43        Ok(self.entries.read().await.get(key).cloned())
44    }
45
46    async fn set(&self, key: &str, value: Value) -> anyhow::Result<()> {
47        self.entries.write().await.insert(key.to_owned(), value);
48        Ok(())
49    }
50
51    async fn delete(&self, key: &str) -> anyhow::Result<()> {
52        self.entries.write().await.remove(key);
53        Ok(())
54    }
55
56    async fn keys(&self) -> anyhow::Result<Vec<String>> {
57        Ok(self.entries.read().await.keys().cloned().collect())
58    }
59}
60
61#[async_trait]
62impl IAppserviceStorageProvider for MemoryStorageProvider {
63    async fn register_user(&self, user_id: &str) -> anyhow::Result<()> {
64        self.set(
65            &format!("appservice:registered:{user_id}"),
66            Value::Bool(true),
67        )
68        .await
69    }
70
71    async fn is_user_registered(&self, user_id: &str) -> anyhow::Result<bool> {
72        Ok(self
73            .get(&format!("appservice:registered:{user_id}"))
74            .await?
75            .and_then(|v| v.as_bool())
76            .unwrap_or(false))
77    }
78}
79
80#[async_trait]
81impl ICryptoStorageProvider for MemoryStorageProvider {
82    async fn store_session_key(&self, room_id: &str, session_key: &str) -> anyhow::Result<()> {
83        self.set(
84            &format!("crypto:session:{room_id}"),
85            Value::String(session_key.to_owned()),
86        )
87        .await
88    }
89
90    async fn session_key(&self, room_id: &str) -> anyhow::Result<Option<String>> {
91        Ok(self
92            .get(&format!("crypto:session:{room_id}"))
93            .await?
94            .and_then(|v| v.as_str().map(ToOwned::to_owned)))
95    }
96}
97
98#[derive(Debug)]
99pub struct SimpleFsStorageProvider {
100    file_path: PathBuf,
101    entries: RwLock<HashMap<String, Value>>,
102}
103
104impl SimpleFsStorageProvider {
105    pub async fn new(path: impl AsRef<Path>) -> anyhow::Result<Self> {
106        let file_path = path.as_ref().to_path_buf();
107        let entries = if tokio::fs::try_exists(&file_path).await? {
108            let content = tokio::fs::read_to_string(&file_path).await?;
109            serde_json::from_str::<HashMap<String, Value>>(&content)?
110        } else {
111            HashMap::new()
112        };
113        Ok(Self {
114            file_path,
115            entries: RwLock::new(entries),
116        })
117    }
118
119    async fn flush(&self) -> anyhow::Result<()> {
120        let serialized = serde_json::to_string_pretty(&*self.entries.read().await)?;
121        tokio::fs::write(&self.file_path, serialized).await?;
122        Ok(())
123    }
124}
125
126#[async_trait]
127impl IStorageProvider for SimpleFsStorageProvider {
128    async fn get(&self, key: &str) -> anyhow::Result<Option<Value>> {
129        Ok(self.entries.read().await.get(key).cloned())
130    }
131
132    async fn set(&self, key: &str, value: Value) -> anyhow::Result<()> {
133        self.entries.write().await.insert(key.to_owned(), value);
134        self.flush().await
135    }
136
137    async fn delete(&self, key: &str) -> anyhow::Result<()> {
138        self.entries.write().await.remove(key);
139        self.flush().await
140    }
141
142    async fn keys(&self) -> anyhow::Result<Vec<String>> {
143        Ok(self.entries.read().await.keys().cloned().collect())
144    }
145}
146
147#[async_trait]
148impl IAppserviceStorageProvider for SimpleFsStorageProvider {
149    async fn register_user(&self, user_id: &str) -> anyhow::Result<()> {
150        self.set(
151            &format!("appservice:registered:{user_id}"),
152            Value::Bool(true),
153        )
154        .await
155    }
156
157    async fn is_user_registered(&self, user_id: &str) -> anyhow::Result<bool> {
158        Ok(self
159            .get(&format!("appservice:registered:{user_id}"))
160            .await?
161            .and_then(|v| v.as_bool())
162            .unwrap_or(false))
163    }
164}
165
166#[async_trait]
167impl ICryptoStorageProvider for SimpleFsStorageProvider {
168    async fn store_session_key(&self, room_id: &str, session_key: &str) -> anyhow::Result<()> {
169        self.set(
170            &format!("crypto:session:{room_id}"),
171            Value::String(session_key.to_owned()),
172        )
173        .await
174    }
175
176    async fn session_key(&self, room_id: &str) -> anyhow::Result<Option<String>> {
177        Ok(self
178            .get(&format!("crypto:session:{room_id}"))
179            .await?
180            .and_then(|v| v.as_str().map(ToOwned::to_owned)))
181    }
182}
183
184#[cfg(feature = "storage-postgres")]
185#[derive(Debug)]
186pub struct SimplePostgresStorageProvider {
187    client: tokio_postgres::Client,
188}
189
190#[cfg(feature = "storage-postgres")]
191impl SimplePostgresStorageProvider {
192    pub async fn connect(conn_str: &str) -> anyhow::Result<Self> {
193        let (client, connection) = tokio_postgres::connect(conn_str, tokio_postgres::NoTls).await?;
194        tokio::spawn(async move {
195            if let Err(error) = connection.await {
196                tracing::error!("postgres connection error: {error}");
197            }
198        });
199
200        client
201            .execute(
202                "CREATE TABLE IF NOT EXISTS matrix_sdk_store (key TEXT PRIMARY KEY, value JSONB NOT NULL)",
203                &[],
204            )
205            .await?;
206
207        Ok(Self { client })
208    }
209}
210
211#[cfg(feature = "storage-postgres")]
212#[async_trait]
213impl IStorageProvider for SimplePostgresStorageProvider {
214    async fn get(&self, key: &str) -> anyhow::Result<Option<Value>> {
215        let row = self
216            .client
217            .query_opt("SELECT value FROM matrix_sdk_store WHERE key = $1", &[&key])
218            .await?;
219        Ok(row.map(|r| r.get(0)))
220    }
221
222    async fn set(&self, key: &str, value: Value) -> anyhow::Result<()> {
223        self.client
224            .execute(
225                "INSERT INTO matrix_sdk_store (key, value) VALUES ($1, $2) \
226                 ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value",
227                &[&key, &value],
228            )
229            .await?;
230        Ok(())
231    }
232
233    async fn delete(&self, key: &str) -> anyhow::Result<()> {
234        self.client
235            .execute("DELETE FROM matrix_sdk_store WHERE key = $1", &[&key])
236            .await?;
237        Ok(())
238    }
239
240    async fn keys(&self) -> anyhow::Result<Vec<String>> {
241        let rows = self
242            .client
243            .query("SELECT key FROM matrix_sdk_store ORDER BY key", &[])
244            .await?;
245        Ok(rows.into_iter().map(|row| row.get(0)).collect())
246    }
247}
248
249#[cfg(not(feature = "storage-postgres"))]
250#[derive(Debug, Default)]
251pub struct SimplePostgresStorageProvider;
252
253#[cfg(not(feature = "storage-postgres"))]
254impl SimplePostgresStorageProvider {
255    pub async fn connect(_conn_str: &str) -> anyhow::Result<Self> {
256        anyhow::bail!("feature 'storage-postgres' is not enabled");
257    }
258}
259
260#[derive(Clone)]
261pub struct RustSdkCryptoStorageProvider {
262    inner: Arc<dyn IStorageProvider>,
263}
264
265impl RustSdkCryptoStorageProvider {
266    pub fn new(inner: Arc<dyn IStorageProvider>) -> Self {
267        Self { inner }
268    }
269}
270
271#[async_trait]
272impl IStorageProvider for RustSdkCryptoStorageProvider {
273    async fn get(&self, key: &str) -> anyhow::Result<Option<Value>> {
274        self.inner.get(&format!("rust_crypto:{key}")).await
275    }
276
277    async fn set(&self, key: &str, value: Value) -> anyhow::Result<()> {
278        self.inner.set(&format!("rust_crypto:{key}"), value).await
279    }
280
281    async fn delete(&self, key: &str) -> anyhow::Result<()> {
282        self.inner.delete(&format!("rust_crypto:{key}")).await
283    }
284
285    async fn keys(&self) -> anyhow::Result<Vec<String>> {
286        let prefixed = self.inner.keys().await?;
287        Ok(prefixed
288            .into_iter()
289            .filter_map(|k| k.strip_prefix("rust_crypto:").map(ToOwned::to_owned))
290            .collect())
291    }
292}
293
294#[async_trait]
295impl ICryptoStorageProvider for RustSdkCryptoStorageProvider {
296    async fn store_session_key(&self, room_id: &str, session_key: &str) -> anyhow::Result<()> {
297        self.set(
298            &format!("session:{room_id}"),
299            Value::String(session_key.to_owned()),
300        )
301        .await
302    }
303
304    async fn session_key(&self, room_id: &str) -> anyhow::Result<Option<String>> {
305        Ok(self
306            .get(&format!("session:{room_id}"))
307            .await?
308            .and_then(|v| v.as_str().map(ToOwned::to_owned)))
309    }
310}