public_appservice/
cache.rs

1use crate::config::Config;
2use redis::{AsyncCommands, RedisError};
3use serde::{Deserialize, Serialize};
4
5use crate::appservice::RoomSummary;
6use crate::rooms::PublicRoom;
7
8pub trait Cacheable: Serialize + for<'a> Deserialize<'a> + Send + Sync {}
9
10impl<T> Cacheable for T where T: Serialize + for<'a> Deserialize<'a> + Send + Sync {}
11
12pub trait CacheKey {
13    fn cache_key(&self) -> String;
14}
15
16impl CacheKey for String {
17    fn cache_key(&self) -> String {
18        self.clone()
19    }
20}
21
22impl CacheKey for &str {
23    fn cache_key(&self) -> String {
24        self.to_string()
25    }
26}
27
28impl CacheKey for (&str, &str) {
29    fn cache_key(&self) -> String {
30        format!("{}:{}", self.0, self.1)
31    }
32}
33
34impl CacheKey for (&str, String) {
35    fn cache_key(&self) -> String {
36        format!("{}:{}", self.0, self.1)
37    }
38}
39
40#[derive(Debug, Clone)]
41pub struct Cache {
42    pub client: redis::Client,
43}
44
45impl Cache {
46    pub async fn new(config: &Config) -> Result<Self, anyhow::Error> {
47        let url = format!("redis://{}", config.redis.url);
48        let client = redis::Client::open(url)?;
49
50        Ok(Self { client })
51    }
52
53    pub async fn cache_data<T>(&self, key: &str, data: &T, ttl: u64) -> Result<(), RedisError>
54    where
55        T: Cacheable,
56    {
57        let mut conn = self.client.get_multiplexed_tokio_connection().await?;
58
59        let serialized = serde_json::to_string(data).map_err(|e| {
60            RedisError::from((
61                redis::ErrorKind::IoError,
62                "Serialization error",
63                e.to_string(),
64            ))
65        })?;
66
67        let _: () = conn.set_ex(key, serialized, ttl).await?;
68        Ok(())
69    }
70
71    pub async fn get_cached_data<T>(&self, key: &str) -> Result<Option<T>, RedisError>
72    where
73        T: Cacheable,
74    {
75        let mut conn = self.client.get_multiplexed_tokio_connection().await?;
76
77        let exists: bool = conn.exists(key).await?;
78        if !exists {
79            return Ok(None);
80        }
81
82        let data: String = conn.get(key).await?;
83        let value = serde_json::from_str(&data).map_err(|e| {
84            RedisError::from((
85                redis::ErrorKind::IoError,
86                "Deserialization error",
87                e.to_string(),
88            ))
89        })?;
90        Ok(Some(value))
91    }
92
93    pub async fn cache_or_fetch<T, F, Fut>(
94        &self,
95        key: &str,
96        ttl: u64,
97        fetch_fn: F,
98    ) -> Result<T, RedisError>
99    where
100        T: Cacheable,
101        F: FnOnce() -> Fut,
102        Fut: std::future::Future<Output = Result<T, RedisError>>,
103    {
104        if let Some(cached) = self.get_cached_data::<T>(key).await? {
105            return Ok(cached);
106        }
107
108        let data = fetch_fn().await?;
109
110        let _ = self.cache_data(key, &data, ttl).await;
111
112        Ok(data)
113    }
114
115    pub async fn cache_with_key<K, T>(&self, key: K, data: &T, ttl: u64) -> Result<(), RedisError>
116    where
117        K: CacheKey,
118        T: Cacheable,
119    {
120        self.cache_data(&key.cache_key(), data, ttl).await
121    }
122
123    pub async fn get_with_key<K, T>(&self, key: K) -> Result<Option<T>, RedisError>
124    where
125        K: CacheKey,
126        T: Cacheable,
127    {
128        self.get_cached_data(&key.cache_key()).await
129    }
130
131    pub async fn delete_cached_data(&self, key: &str) -> Result<(), RedisError> {
132        let mut conn = self.client.get_multiplexed_tokio_connection().await?;
133        let _: () = conn.del(key).await?;
134        Ok(())
135    }
136
137    pub async fn cache_rooms(&self, rooms: &Vec<PublicRoom>, ttl: u64) -> Result<(), RedisError> {
138        self.cache_data("public_rooms", rooms, ttl).await
139    }
140
141    pub async fn get_cached_rooms(&self) -> Result<Vec<PublicRoom>, RedisError> {
142        self.get_cached_data("public_rooms")
143            .await?
144            .ok_or_else(|| RedisError::from((redis::ErrorKind::ResponseError, "Key not found")))
145    }
146
147    pub async fn get_cached_room_state(
148        &self,
149        room_id: &str,
150    ) -> Result<Vec<PublicRoom>, RedisError> {
151        let key = format!("room_state:{room_id}");
152        self.get_cached_data(&key)
153            .await?
154            .ok_or_else(|| RedisError::from((redis::ErrorKind::ResponseError, "Key not found")))
155    }
156
157    pub async fn cache_public_spaces(
158        &self,
159        rooms: &Vec<RoomSummary>,
160        ttl: u64,
161    ) -> Result<(), RedisError> {
162        self.cache_data("public_spaces", rooms, ttl).await
163    }
164
165    pub async fn get_cached_public_spaces(&self) -> Result<Vec<RoomSummary>, RedisError> {
166        self.get_cached_data("public_spaces")
167            .await?
168            .ok_or_else(|| RedisError::from((redis::ErrorKind::ResponseError, "Key not found")))
169    }
170
171    pub async fn cache_room_state(
172        &self,
173        room_id: &str,
174        state: &Vec<PublicRoom>,
175        ttl: u64,
176    ) -> Result<(), RedisError> {
177        let key = format!("room_state:{room_id}");
178        self.cache_data(&key, state, ttl).await
179    }
180
181    pub async fn cache_proxy_response(
182        &self,
183        key: &str,
184        data: &[u8],
185        ttl: u64,
186    ) -> Result<(), RedisError> {
187        let mut conn = self.client.get_multiplexed_tokio_connection().await?;
188        let _: () = conn.set_ex(key, data, ttl).await?;
189        Ok(())
190    }
191
192    pub async fn get_cached_proxy_response(&self, key: &str) -> Result<Vec<u8>, RedisError> {
193        let mut conn = self.client.get_multiplexed_tokio_connection().await?;
194
195        if !conn.exists(key).await? {
196            return Err(RedisError::from((
197                redis::ErrorKind::ResponseError,
198                "Key not found",
199            )));
200        }
201
202        conn.get(key).await
203    }
204
205    pub async fn cache_multiple<T>(&self, items: Vec<(&str, &T, u64)>) -> Result<(), RedisError>
206    where
207        T: Cacheable,
208    {
209        for (key, data, ttl) in items {
210            self.cache_data(key, data, ttl).await?;
211        }
212        Ok(())
213    }
214
215    pub async fn delete_multiple(&self, keys: &[&str]) -> Result<(), RedisError> {
216        for key in keys {
217            self.delete_cached_data(key).await?;
218        }
219        Ok(())
220    }
221}