matrix_bot_sdk/
storage.rs1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::Value;
7use tokio::sync::RwLock;
8
9#[async_trait]
10pub trait IStorageProvider: Send + Sync {
11 async fn get(&self, key: &str) -> anyhow::Result<Option<Value>>;
12 async fn set(&self, key: &str, value: Value) -> anyhow::Result<()>;
13 async fn delete(&self, key: &str) -> anyhow::Result<()>;
14 async fn keys(&self) -> anyhow::Result<Vec<String>>;
15}
16
17#[async_trait]
18pub trait IAppserviceStorageProvider: IStorageProvider {
19 async fn register_user(&self, user_id: &str) -> anyhow::Result<()>;
20 async fn is_user_registered(&self, user_id: &str) -> anyhow::Result<bool>;
21}
22
23#[async_trait]
24pub trait ICryptoStorageProvider: IStorageProvider {
25 async fn store_session_key(&self, room_id: &str, session_key: &str) -> anyhow::Result<()>;
26 async fn session_key(&self, room_id: &str) -> anyhow::Result<Option<String>>;
27}
28
29#[derive(Debug, Default)]
30pub struct MemoryStorageProvider {
31 entries: RwLock<HashMap<String, Value>>,
32}
33
34impl MemoryStorageProvider {
35 pub fn new() -> Self {
36 Self::default()
37 }
38}
39
40#[async_trait]
41impl IStorageProvider for MemoryStorageProvider {
42 async fn get(&self, key: &str) -> anyhow::Result<Option<Value>> {
43 Ok(self.entries.read().await.get(key).cloned())
44 }
45
46 async fn set(&self, key: &str, value: Value) -> anyhow::Result<()> {
47 self.entries.write().await.insert(key.to_owned(), value);
48 Ok(())
49 }
50
51 async fn delete(&self, key: &str) -> anyhow::Result<()> {
52 self.entries.write().await.remove(key);
53 Ok(())
54 }
55
56 async fn keys(&self) -> anyhow::Result<Vec<String>> {
57 Ok(self.entries.read().await.keys().cloned().collect())
58 }
59}
60
61#[async_trait]
62impl IAppserviceStorageProvider for MemoryStorageProvider {
63 async fn register_user(&self, user_id: &str) -> anyhow::Result<()> {
64 self.set(
65 &format!("appservice:registered:{user_id}"),
66 Value::Bool(true),
67 )
68 .await
69 }
70
71 async fn is_user_registered(&self, user_id: &str) -> anyhow::Result<bool> {
72 Ok(self
73 .get(&format!("appservice:registered:{user_id}"))
74 .await?
75 .and_then(|v| v.as_bool())
76 .unwrap_or(false))
77 }
78}
79
80#[async_trait]
81impl ICryptoStorageProvider for MemoryStorageProvider {
82 async fn store_session_key(&self, room_id: &str, session_key: &str) -> anyhow::Result<()> {
83 self.set(
84 &format!("crypto:session:{room_id}"),
85 Value::String(session_key.to_owned()),
86 )
87 .await
88 }
89
90 async fn session_key(&self, room_id: &str) -> anyhow::Result<Option<String>> {
91 Ok(self
92 .get(&format!("crypto:session:{room_id}"))
93 .await?
94 .and_then(|v| v.as_str().map(ToOwned::to_owned)))
95 }
96}
97
98#[derive(Debug)]
99pub struct SimpleFsStorageProvider {
100 file_path: PathBuf,
101 entries: RwLock<HashMap<String, Value>>,
102}
103
104impl SimpleFsStorageProvider {
105 pub async fn new(path: impl AsRef<Path>) -> anyhow::Result<Self> {
106 let file_path = path.as_ref().to_path_buf();
107 let entries = if tokio::fs::try_exists(&file_path).await? {
108 let content = tokio::fs::read_to_string(&file_path).await?;
109 serde_json::from_str::<HashMap<String, Value>>(&content)?
110 } else {
111 HashMap::new()
112 };
113 Ok(Self {
114 file_path,
115 entries: RwLock::new(entries),
116 })
117 }
118
119 async fn flush(&self) -> anyhow::Result<()> {
120 let serialized = serde_json::to_string_pretty(&*self.entries.read().await)?;
121 tokio::fs::write(&self.file_path, serialized).await?;
122 Ok(())
123 }
124}
125
126#[async_trait]
127impl IStorageProvider for SimpleFsStorageProvider {
128 async fn get(&self, key: &str) -> anyhow::Result<Option<Value>> {
129 Ok(self.entries.read().await.get(key).cloned())
130 }
131
132 async fn set(&self, key: &str, value: Value) -> anyhow::Result<()> {
133 self.entries.write().await.insert(key.to_owned(), value);
134 self.flush().await
135 }
136
137 async fn delete(&self, key: &str) -> anyhow::Result<()> {
138 self.entries.write().await.remove(key);
139 self.flush().await
140 }
141
142 async fn keys(&self) -> anyhow::Result<Vec<String>> {
143 Ok(self.entries.read().await.keys().cloned().collect())
144 }
145}
146
147#[async_trait]
148impl IAppserviceStorageProvider for SimpleFsStorageProvider {
149 async fn register_user(&self, user_id: &str) -> anyhow::Result<()> {
150 self.set(
151 &format!("appservice:registered:{user_id}"),
152 Value::Bool(true),
153 )
154 .await
155 }
156
157 async fn is_user_registered(&self, user_id: &str) -> anyhow::Result<bool> {
158 Ok(self
159 .get(&format!("appservice:registered:{user_id}"))
160 .await?
161 .and_then(|v| v.as_bool())
162 .unwrap_or(false))
163 }
164}
165
166#[async_trait]
167impl ICryptoStorageProvider for SimpleFsStorageProvider {
168 async fn store_session_key(&self, room_id: &str, session_key: &str) -> anyhow::Result<()> {
169 self.set(
170 &format!("crypto:session:{room_id}"),
171 Value::String(session_key.to_owned()),
172 )
173 .await
174 }
175
176 async fn session_key(&self, room_id: &str) -> anyhow::Result<Option<String>> {
177 Ok(self
178 .get(&format!("crypto:session:{room_id}"))
179 .await?
180 .and_then(|v| v.as_str().map(ToOwned::to_owned)))
181 }
182}
183
184#[cfg(feature = "storage-postgres")]
185#[derive(Debug)]
186pub struct SimplePostgresStorageProvider {
187 client: tokio_postgres::Client,
188}
189
190#[cfg(feature = "storage-postgres")]
191impl SimplePostgresStorageProvider {
192 pub async fn connect(conn_str: &str) -> anyhow::Result<Self> {
193 let (client, connection) = tokio_postgres::connect(conn_str, tokio_postgres::NoTls).await?;
194 tokio::spawn(async move {
195 if let Err(error) = connection.await {
196 tracing::error!("postgres connection error: {error}");
197 }
198 });
199
200 client
201 .execute(
202 "CREATE TABLE IF NOT EXISTS matrix_sdk_store (key TEXT PRIMARY KEY, value JSONB NOT NULL)",
203 &[],
204 )
205 .await?;
206
207 Ok(Self { client })
208 }
209}
210
211#[cfg(feature = "storage-postgres")]
212#[async_trait]
213impl IStorageProvider for SimplePostgresStorageProvider {
214 async fn get(&self, key: &str) -> anyhow::Result<Option<Value>> {
215 let row = self
216 .client
217 .query_opt("SELECT value FROM matrix_sdk_store WHERE key = $1", &[&key])
218 .await?;
219 Ok(row.map(|r| r.get(0)))
220 }
221
222 async fn set(&self, key: &str, value: Value) -> anyhow::Result<()> {
223 self.client
224 .execute(
225 "INSERT INTO matrix_sdk_store (key, value) VALUES ($1, $2) \
226 ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value",
227 &[&key, &value],
228 )
229 .await?;
230 Ok(())
231 }
232
233 async fn delete(&self, key: &str) -> anyhow::Result<()> {
234 self.client
235 .execute("DELETE FROM matrix_sdk_store WHERE key = $1", &[&key])
236 .await?;
237 Ok(())
238 }
239
240 async fn keys(&self) -> anyhow::Result<Vec<String>> {
241 let rows = self
242 .client
243 .query("SELECT key FROM matrix_sdk_store ORDER BY key", &[])
244 .await?;
245 Ok(rows.into_iter().map(|row| row.get(0)).collect())
246 }
247}
248
249#[cfg(not(feature = "storage-postgres"))]
250#[derive(Debug, Default)]
251pub struct SimplePostgresStorageProvider;
252
253#[cfg(not(feature = "storage-postgres"))]
254impl SimplePostgresStorageProvider {
255 pub async fn connect(_conn_str: &str) -> anyhow::Result<Self> {
256 anyhow::bail!("feature 'storage-postgres' is not enabled");
257 }
258}
259
260#[derive(Clone)]
261pub struct RustSdkCryptoStorageProvider {
262 inner: Arc<dyn IStorageProvider>,
263}
264
265impl RustSdkCryptoStorageProvider {
266 pub fn new(inner: Arc<dyn IStorageProvider>) -> Self {
267 Self { inner }
268 }
269}
270
271#[async_trait]
272impl IStorageProvider for RustSdkCryptoStorageProvider {
273 async fn get(&self, key: &str) -> anyhow::Result<Option<Value>> {
274 self.inner.get(&format!("rust_crypto:{key}")).await
275 }
276
277 async fn set(&self, key: &str, value: Value) -> anyhow::Result<()> {
278 self.inner.set(&format!("rust_crypto:{key}"), value).await
279 }
280
281 async fn delete(&self, key: &str) -> anyhow::Result<()> {
282 self.inner.delete(&format!("rust_crypto:{key}")).await
283 }
284
285 async fn keys(&self) -> anyhow::Result<Vec<String>> {
286 let prefixed = self.inner.keys().await?;
287 Ok(prefixed
288 .into_iter()
289 .filter_map(|k| k.strip_prefix("rust_crypto:").map(ToOwned::to_owned))
290 .collect())
291 }
292}
293
294#[async_trait]
295impl ICryptoStorageProvider for RustSdkCryptoStorageProvider {
296 async fn store_session_key(&self, room_id: &str, session_key: &str) -> anyhow::Result<()> {
297 self.set(
298 &format!("session:{room_id}"),
299 Value::String(session_key.to_owned()),
300 )
301 .await
302 }
303
304 async fn session_key(&self, room_id: &str) -> anyhow::Result<Option<String>> {
305 Ok(self
306 .get(&format!("session:{room_id}"))
307 .await?
308 .and_then(|v| v.as_str().map(ToOwned::to_owned)))
309 }
310}