1use std::collections::HashMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use vodozemac::megolm::{
14 GroupSession, GroupSessionPickle, InboundGroupSession, InboundGroupSessionPickle,
15 MegolmMessage, SessionConfig, SessionKey,
16};
17
18use tracing::warn;
19
20use crate::error::{HuddleError, Result};
21use crate::storage::repo::{self, StoredMegolmSession};
22
23use crate::storage::Db;
24
25pub struct RoomCrypto {
27 room_id: String,
28 our_fingerprint: String,
29 outbound: GroupSession,
30 inbound: HashMap<(String, String), InboundGroupSession>,
32 db: Db,
33 persist_key: [u8; 32],
38}
39
40impl RoomCrypto {
41 pub fn new_for_room(
44 db: Db,
45 room_id: String,
46 our_fingerprint: String,
47 persist_key: [u8; 32],
48 ) -> Result<Self> {
49 let outbound = GroupSession::new(SessionConfig::version_1());
50 let crypto = Self {
51 room_id,
52 our_fingerprint,
53 outbound,
54 inbound: HashMap::new(),
55 db,
56 persist_key,
57 };
58 crypto.persist_outbound()?;
59 Ok(crypto)
60 }
61
62 pub fn load(
70 db: Db,
71 room_id: String,
72 our_fingerprint: String,
73 persist_key: [u8; 32],
74 ) -> Result<Option<Self>> {
75 let sessions = repo::load_megolm_sessions_for_room(&db, &room_id)?;
76 let mut outbound: Option<GroupSession> = None;
77 let mut inbound: HashMap<(String, String), InboundGroupSession> = HashMap::new();
78
79 for s in sessions {
80 let data_str = match String::from_utf8(s.session_data) {
81 Ok(d) => d,
82 Err(e) => {
83 warn!(%e, room_id = %room_id, "skipping persisted megolm session: invalid utf8");
84 continue;
85 }
86 };
87 if s.is_outbound {
88 match GroupSessionPickle::from_encrypted(&data_str, &persist_key) {
89 Ok(p) => outbound = Some(GroupSession::from_pickle(p)),
90 Err(e) => {
91 warn!(%e, room_id = %room_id, "skipping persisted outbound megolm session: restore failed");
92 }
93 }
94 } else {
95 match InboundGroupSessionPickle::from_encrypted(&data_str, &persist_key) {
96 Ok(p) => {
97 inbound.insert(
98 (s.sender_fingerprint, s.session_id),
99 InboundGroupSession::from_pickle(p),
100 );
101 }
102 Err(e) => {
103 warn!(%e, room_id = %room_id, "skipping persisted inbound megolm session: restore failed");
104 }
105 }
106 }
107 }
108
109 match outbound {
110 Some(outbound) => Ok(Some(Self {
111 room_id,
112 our_fingerprint,
113 outbound,
114 inbound,
115 db,
116 persist_key,
117 })),
118 None => Ok(None),
119 }
120 }
121
122 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<(String, Vec<u8>)> {
125 let msg = self.outbound.encrypt(plaintext);
126 let session_id = self.outbound.session_id();
127 self.persist_outbound()?;
128 Ok((session_id, msg.to_bytes()))
129 }
130
131 pub fn decrypt(
133 &mut self,
134 sender_fingerprint: &str,
135 session_id: &str,
136 ciphertext: &[u8],
137 ) -> Result<Vec<u8>> {
138 let key = (sender_fingerprint.to_string(), session_id.to_string());
139 let session = self.inbound.get_mut(&key).ok_or_else(|| {
140 HuddleError::Session(format!(
141 "no inbound megolm session for {sender_fingerprint} / {session_id}"
142 ))
143 })?;
144 let msg = MegolmMessage::from_bytes(ciphertext)
145 .map_err(|e| HuddleError::Session(format!("bad megolm message: {e}")))?;
146 let decrypted = session
147 .decrypt(&msg)
148 .map_err(|e| HuddleError::Session(format!("megolm decrypt failed: {e}")))?;
149
150 let persisted = session.pickle().encrypt(&self.persist_key);
152 repo::save_megolm_session(
153 &self.db,
154 &StoredMegolmSession {
155 room_id: self.room_id.clone(),
156 sender_fingerprint: sender_fingerprint.to_string(),
157 session_id: session_id.to_string(),
158 session_data: persisted.into_bytes(),
159 is_outbound: false,
160 created_at: now_unix(),
161 },
162 )?;
163
164 Ok(decrypted.plaintext)
165 }
166
167 pub fn add_inbound_session(
170 &mut self,
171 sender_fingerprint: &str,
172 session_key_b64: &str,
173 ) -> Result<()> {
174 let key = SessionKey::from_base64(session_key_b64)
175 .map_err(|e| HuddleError::Session(format!("bad session key: {e}")))?;
176 let session = InboundGroupSession::new(&key, SessionConfig::version_1());
177 let session_id = session.session_id();
178
179 let persisted = session.pickle().encrypt(&self.persist_key);
180 repo::save_megolm_session(
181 &self.db,
182 &StoredMegolmSession {
183 room_id: self.room_id.clone(),
184 sender_fingerprint: sender_fingerprint.to_string(),
185 session_id: session_id.clone(),
186 session_data: persisted.into_bytes(),
187 is_outbound: false,
188 created_at: now_unix(),
189 },
190 )?;
191
192 self.inbound
193 .insert((sender_fingerprint.to_string(), session_id), session);
194 Ok(())
195 }
196
197 pub fn our_session_key_b64(&self) -> String {
199 self.outbound.session_key().to_base64()
200 }
201
202 pub fn our_session_id(&self) -> String {
203 self.outbound.session_id()
204 }
205
206 pub fn our_fingerprint(&self) -> &str {
207 &self.our_fingerprint
208 }
209
210 pub fn room_id(&self) -> &str {
211 &self.room_id
212 }
213
214 fn persist_outbound(&self) -> Result<()> {
215 let persisted = self.outbound.pickle().encrypt(&self.persist_key);
216 repo::save_megolm_session(
217 &self.db,
218 &StoredMegolmSession {
219 room_id: self.room_id.clone(),
220 sender_fingerprint: self.our_fingerprint.clone(),
221 session_id: self.outbound.session_id(),
222 session_data: persisted.into_bytes(),
223 is_outbound: true,
224 created_at: now_unix(),
225 },
226 )?;
227 Ok(())
228 }
229}
230
231fn now_unix() -> i64 {
232 SystemTime::now()
239 .duration_since(UNIX_EPOCH)
240 .map(|d| d.as_secs() as i64)
241 .unwrap_or(0)
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::storage::open_db_in_memory;
248 use crate::storage::repo::{derive_room_id, insert_room, RoomKind, StoredRoom};
249
250 fn setup_room(db: &Db, name: &str, creator_fp: &str) -> String {
251 let created_at = 1000;
252 let room = StoredRoom {
253 id: derive_room_id(creator_fp, name, created_at),
254 name: name.into(),
255 creator_fingerprint: creator_fp.into(),
256 encrypted: true,
257 passphrase_salt: None,
258 created_at,
259 last_active: None,
260 kind: RoomKind::Group,
261 };
262 let id = room.id.clone();
263 insert_room(db, &room).unwrap();
264 id
265 }
266
267 #[test]
268 fn outbound_encrypt_inbound_decrypt() {
269 let db_alice = open_db_in_memory().unwrap();
270 let db_bob = open_db_in_memory().unwrap();
271 let room_id = setup_room(&db_alice, "test", "alice-fp");
272 setup_room(&db_bob, "test", "alice-fp");
273
274 let mut alice =
275 RoomCrypto::new_for_room(db_alice.clone(), room_id.clone(), "alice-fp".into(), [0u8; 32])
276 .unwrap();
277 let mut bob =
278 RoomCrypto::new_for_room(db_bob.clone(), room_id.clone(), "bob-fp".into(), [0u8; 32]).unwrap();
279
280 bob.add_inbound_session("alice-fp", &alice.our_session_key_b64())
281 .unwrap();
282
283 let (session_id, ciphertext) = alice.encrypt(b"hello group").unwrap();
284 let plaintext = bob.decrypt("alice-fp", &session_id, &ciphertext).unwrap();
285 assert_eq!(plaintext, b"hello group");
286 }
287
288 #[test]
289 fn bidirectional_round_trip() {
290 let db_a = open_db_in_memory().unwrap();
291 let db_b = open_db_in_memory().unwrap();
292 let room_id = setup_room(&db_a, "r", "a-fp");
293 setup_room(&db_b, "r", "a-fp");
294
295 let mut alice =
296 RoomCrypto::new_for_room(db_a.clone(), room_id.clone(), "a-fp".into(), [0u8; 32]).unwrap();
297 let mut bob =
298 RoomCrypto::new_for_room(db_b.clone(), room_id.clone(), "b-fp".into(), [0u8; 32]).unwrap();
299
300 alice
301 .add_inbound_session("b-fp", &bob.our_session_key_b64())
302 .unwrap();
303 bob.add_inbound_session("a-fp", &alice.our_session_key_b64())
304 .unwrap();
305
306 let (sid_a, ct_a) = alice.encrypt(b"from alice").unwrap();
307 assert_eq!(bob.decrypt("a-fp", &sid_a, &ct_a).unwrap(), b"from alice");
308
309 let (sid_b, ct_b) = bob.encrypt(b"from bob").unwrap();
310 assert_eq!(alice.decrypt("b-fp", &sid_b, &ct_b).unwrap(), b"from bob");
311 }
312
313 #[test]
314 fn outbound_persists_and_reloads() {
315 let db = open_db_in_memory().unwrap();
316 let room_id = setup_room(&db, "r", "me-fp");
317
318 let mut crypto =
319 RoomCrypto::new_for_room(db.clone(), room_id.clone(), "me-fp".into(), [0u8; 32]).unwrap();
320 let original_session_id = crypto.our_session_id();
321 let (_, _) = crypto.encrypt(b"advance the ratchet").unwrap();
322 drop(crypto);
323
324 let reloaded = RoomCrypto::load(db.clone(), room_id.clone(), "me-fp".into(), [0u8; 32])
325 .unwrap()
326 .expect("should have outbound session");
327 assert_eq!(reloaded.our_session_id(), original_session_id);
328 }
329
330 #[test]
331 fn decrypt_unknown_sender_errors() {
332 let db = open_db_in_memory().unwrap();
333 let room_id = setup_room(&db, "r", "me-fp");
334 let mut crypto =
335 RoomCrypto::new_for_room(db.clone(), room_id.clone(), "me-fp".into(), [0u8; 32]).unwrap();
336 let err = crypto.decrypt("unknown-fp", "session-id", b"junk");
337 assert!(err.is_err());
338 }
339}