1use std::collections::HashMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use vodozemac::megolm::{
14 GroupSession, GroupSessionPickle, InboundGroupSession, InboundGroupSessionPickle,
15 MegolmMessage, SessionConfig, SessionKey,
16};
17
18use crate::error::{HuddleError, Result};
19use crate::storage::repo::{self, StoredMegolmSession};
20use std::sync::OnceLock;
21
22use crate::storage::Db;
23
24static SESSION_PERSIST_KEY_OVERRIDE: OnceLock<[u8; 32]> = OnceLock::new();
28const SESSION_PERSIST_KEY_DEFAULT: [u8; 32] = [0u8; 32];
29
30pub fn install_session_persist_key(key: [u8; 32]) {
34 let _ = SESSION_PERSIST_KEY_OVERRIDE.set(key);
35}
36
37fn persist_key() -> &'static [u8; 32] {
38 SESSION_PERSIST_KEY_OVERRIDE
39 .get()
40 .unwrap_or(&SESSION_PERSIST_KEY_DEFAULT)
41}
42
43pub struct RoomCrypto {
45 room_id: String,
46 our_fingerprint: String,
47 outbound: GroupSession,
48 inbound: HashMap<(String, String), InboundGroupSession>,
50 db: Db,
51}
52
53impl RoomCrypto {
54 pub fn new_for_room(db: Db, room_id: String, our_fingerprint: String) -> Result<Self> {
56 let outbound = GroupSession::new(SessionConfig::version_1());
57 let crypto = Self {
58 room_id,
59 our_fingerprint,
60 outbound,
61 inbound: HashMap::new(),
62 db,
63 };
64 crypto.persist_outbound()?;
65 Ok(crypto)
66 }
67
68 pub fn load(db: Db, room_id: String, our_fingerprint: String) -> Result<Option<Self>> {
71 let sessions = repo::load_megolm_sessions_for_room(&db, &room_id)?;
72 let mut outbound: Option<GroupSession> = None;
73 let mut inbound: HashMap<(String, String), InboundGroupSession> = HashMap::new();
74
75 for s in sessions {
76 let data_str = String::from_utf8(s.session_data)
77 .map_err(|e| HuddleError::Session(format!("session data utf8: {e}")))?;
78 if s.is_outbound {
79 let p = GroupSessionPickle::from_encrypted(&data_str, persist_key())
80 .map_err(|e| HuddleError::Session(format!("restore outbound: {e}")))?;
81 outbound = Some(GroupSession::from_pickle(p));
82 } else {
83 let p = InboundGroupSessionPickle::from_encrypted(&data_str, persist_key())
84 .map_err(|e| HuddleError::Session(format!("restore inbound: {e}")))?;
85 let session = InboundGroupSession::from_pickle(p);
86 inbound.insert((s.sender_fingerprint, s.session_id), session);
87 }
88 }
89
90 match outbound {
91 Some(outbound) => Ok(Some(Self {
92 room_id,
93 our_fingerprint,
94 outbound,
95 inbound,
96 db,
97 })),
98 None => Ok(None),
99 }
100 }
101
102 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<(String, Vec<u8>)> {
105 let msg = self.outbound.encrypt(plaintext);
106 let session_id = self.outbound.session_id();
107 self.persist_outbound()?;
108 Ok((session_id, msg.to_bytes()))
109 }
110
111 pub fn decrypt(
113 &mut self,
114 sender_fingerprint: &str,
115 session_id: &str,
116 ciphertext: &[u8],
117 ) -> Result<Vec<u8>> {
118 let key = (sender_fingerprint.to_string(), session_id.to_string());
119 let session = self.inbound.get_mut(&key).ok_or_else(|| {
120 HuddleError::Session(format!(
121 "no inbound megolm session for {sender_fingerprint} / {session_id}"
122 ))
123 })?;
124 let msg = MegolmMessage::from_bytes(ciphertext)
125 .map_err(|e| HuddleError::Session(format!("bad megolm message: {e}")))?;
126 let decrypted = session
127 .decrypt(&msg)
128 .map_err(|e| HuddleError::Session(format!("megolm decrypt failed: {e}")))?;
129
130 let persisted = session.pickle().encrypt(persist_key());
132 repo::save_megolm_session(
133 &self.db,
134 &StoredMegolmSession {
135 room_id: self.room_id.clone(),
136 sender_fingerprint: sender_fingerprint.to_string(),
137 session_id: session_id.to_string(),
138 session_data: persisted.into_bytes(),
139 is_outbound: false,
140 created_at: now_unix(),
141 },
142 )?;
143
144 Ok(decrypted.plaintext)
145 }
146
147 pub fn add_inbound_session(
150 &mut self,
151 sender_fingerprint: &str,
152 session_key_b64: &str,
153 ) -> Result<()> {
154 let key = SessionKey::from_base64(session_key_b64)
155 .map_err(|e| HuddleError::Session(format!("bad session key: {e}")))?;
156 let session = InboundGroupSession::new(&key, SessionConfig::version_1());
157 let session_id = session.session_id();
158
159 let persisted = session.pickle().encrypt(persist_key());
160 repo::save_megolm_session(
161 &self.db,
162 &StoredMegolmSession {
163 room_id: self.room_id.clone(),
164 sender_fingerprint: sender_fingerprint.to_string(),
165 session_id: session_id.clone(),
166 session_data: persisted.into_bytes(),
167 is_outbound: false,
168 created_at: now_unix(),
169 },
170 )?;
171
172 self.inbound
173 .insert((sender_fingerprint.to_string(), session_id), session);
174 Ok(())
175 }
176
177 pub fn our_session_key_b64(&self) -> String {
179 self.outbound.session_key().to_base64()
180 }
181
182 pub fn our_session_id(&self) -> String {
183 self.outbound.session_id()
184 }
185
186 pub fn our_fingerprint(&self) -> &str {
187 &self.our_fingerprint
188 }
189
190 pub fn room_id(&self) -> &str {
191 &self.room_id
192 }
193
194 fn persist_outbound(&self) -> Result<()> {
195 let persisted = self.outbound.pickle().encrypt(persist_key());
196 repo::save_megolm_session(
197 &self.db,
198 &StoredMegolmSession {
199 room_id: self.room_id.clone(),
200 sender_fingerprint: self.our_fingerprint.clone(),
201 session_id: self.outbound.session_id(),
202 session_data: persisted.into_bytes(),
203 is_outbound: true,
204 created_at: now_unix(),
205 },
206 )?;
207 Ok(())
208 }
209}
210
211fn now_unix() -> i64 {
212 SystemTime::now()
213 .duration_since(UNIX_EPOCH)
214 .unwrap()
215 .as_secs() as i64
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use crate::storage::open_db_in_memory;
222 use crate::storage::repo::{derive_room_id, insert_room, StoredRoom};
223
224 fn setup_room(db: &Db, name: &str, creator_fp: &str) -> String {
225 let created_at = 1000;
226 let room = StoredRoom {
227 id: derive_room_id(creator_fp, name, created_at),
228 name: name.into(),
229 creator_fingerprint: creator_fp.into(),
230 encrypted: true,
231 passphrase_salt: None,
232 created_at,
233 last_active: None,
234 };
235 let id = room.id.clone();
236 insert_room(db, &room).unwrap();
237 id
238 }
239
240 #[test]
241 fn outbound_encrypt_inbound_decrypt() {
242 let db_alice = open_db_in_memory().unwrap();
243 let db_bob = open_db_in_memory().unwrap();
244 let room_id = setup_room(&db_alice, "test", "alice-fp");
245 setup_room(&db_bob, "test", "alice-fp");
246
247 let mut alice =
248 RoomCrypto::new_for_room(db_alice.clone(), room_id.clone(), "alice-fp".into())
249 .unwrap();
250 let mut bob =
251 RoomCrypto::new_for_room(db_bob.clone(), room_id.clone(), "bob-fp".into()).unwrap();
252
253 bob.add_inbound_session("alice-fp", &alice.our_session_key_b64())
254 .unwrap();
255
256 let (session_id, ciphertext) = alice.encrypt(b"hello group").unwrap();
257 let plaintext = bob.decrypt("alice-fp", &session_id, &ciphertext).unwrap();
258 assert_eq!(plaintext, b"hello group");
259 }
260
261 #[test]
262 fn bidirectional_round_trip() {
263 let db_a = open_db_in_memory().unwrap();
264 let db_b = open_db_in_memory().unwrap();
265 let room_id = setup_room(&db_a, "r", "a-fp");
266 setup_room(&db_b, "r", "a-fp");
267
268 let mut alice =
269 RoomCrypto::new_for_room(db_a.clone(), room_id.clone(), "a-fp".into()).unwrap();
270 let mut bob =
271 RoomCrypto::new_for_room(db_b.clone(), room_id.clone(), "b-fp".into()).unwrap();
272
273 alice
274 .add_inbound_session("b-fp", &bob.our_session_key_b64())
275 .unwrap();
276 bob.add_inbound_session("a-fp", &alice.our_session_key_b64())
277 .unwrap();
278
279 let (sid_a, ct_a) = alice.encrypt(b"from alice").unwrap();
280 assert_eq!(bob.decrypt("a-fp", &sid_a, &ct_a).unwrap(), b"from alice");
281
282 let (sid_b, ct_b) = bob.encrypt(b"from bob").unwrap();
283 assert_eq!(alice.decrypt("b-fp", &sid_b, &ct_b).unwrap(), b"from bob");
284 }
285
286 #[test]
287 fn outbound_persists_and_reloads() {
288 let db = open_db_in_memory().unwrap();
289 let room_id = setup_room(&db, "r", "me-fp");
290
291 let mut crypto =
292 RoomCrypto::new_for_room(db.clone(), room_id.clone(), "me-fp".into()).unwrap();
293 let original_session_id = crypto.our_session_id();
294 let (_, _) = crypto.encrypt(b"advance the ratchet").unwrap();
295 drop(crypto);
296
297 let reloaded = RoomCrypto::load(db.clone(), room_id.clone(), "me-fp".into())
298 .unwrap()
299 .expect("should have outbound session");
300 assert_eq!(reloaded.our_session_id(), original_session_id);
301 }
302
303 #[test]
304 fn decrypt_unknown_sender_errors() {
305 let db = open_db_in_memory().unwrap();
306 let room_id = setup_room(&db, "r", "me-fp");
307 let mut crypto =
308 RoomCrypto::new_for_room(db.clone(), room_id.clone(), "me-fp".into()).unwrap();
309 let err = crypto.decrypt("unknown-fp", "session-id", b"junk");
310 assert!(err.is_err());
311 }
312}