1use std::collections::{hash_map, BTreeMap, HashMap};
2use std::time::Instant;
3
4use diesel::prelude::*;
5use futures_util::stream::{FuturesUnordered, StreamExt};
6use serde_json::json;
7
8use crate::core::client::key::{ClaimKeysResBody, UploadSigningKeysReqBody};
9use crate::core::encryption::{CrossSigningKey, DeviceKeys, OneTimeKey};
10use crate::core::federation::key::{claim_keys_request, query_keys_request, QueryKeysReqBody, QueryKeysResBody};
11use crate::core::identifiers::*;
12use crate::core::{client, federation, DeviceKeyAlgorithm, JsonValue, MatrixError, Seqnum, UnixMillis};
13use crate::schema::*;
14use crate::user::clean_signatures;
15use crate::{DataError, connect, DataResult};
16
17#[derive(Identifiable, Insertable, Queryable, Debug, Clone)]
18#[diesel(table_name = e2e_cross_signing_keys)]
19pub struct DbCrossSigningKey {
20 pub id: i64,
21
22 pub user_id: OwnedUserId,
23 pub key_type: String,
24 pub key_data: JsonValue,
25}
26#[derive(Insertable, Debug, Clone)]
27#[diesel(table_name = e2e_cross_signing_keys)]
28pub struct NewDbCrossSigningKey {
29 pub user_id: OwnedUserId,
30 pub key_type: String,
31 pub key_data: JsonValue,
32}
33
34#[derive(Identifiable, Queryable, Debug, Clone)]
35#[diesel(table_name = e2e_cross_signing_sigs)]
36pub struct DbCrossSignature {
37 pub id: i64,
38
39 pub origin_user_id: OwnedUserId,
40 pub origin_key_id: OwnedDeviceKeyId,
41 pub target_user_id: OwnedUserId,
42 pub target_device_id: OwnedDeviceId,
43 pub signature: String,
44}
45#[derive(Insertable, Debug, Clone)]
46#[diesel(table_name = e2e_cross_signing_sigs)]
47pub struct NewDbCrossSignature {
48 pub origin_user_id: OwnedUserId,
49 pub origin_key_id: OwnedDeviceKeyId,
50 pub target_user_id: OwnedUserId,
51 pub target_device_id: OwnedDeviceId,
52 pub signature: String,
53}
54
55#[derive(Identifiable, Queryable, Debug, Clone)]
56#[diesel(table_name = e2e_fallback_keys)]
57pub struct DbFallbackKey {
58 pub id: String,
59
60 pub user_id: OwnedUserId,
61 pub device_id: OwnedDeviceId,
62 pub algorithm: String,
63 pub key_id: OwnedDeviceKeyId,
64 pub key_data: JsonValue,
65 pub used_at: Option<i64>,
66 pub created_at: UnixMillis,
67}
68#[derive(Insertable, Debug, Clone)]
69#[diesel(table_name = e2e_fallback_keys)]
70pub struct NewDbFallbackKey {
71 pub user_id: OwnedUserId,
72 pub device_id: OwnedDeviceId,
73 pub algorithm: String,
74 pub key_id: OwnedDeviceKeyId,
75 pub key_data: JsonValue,
76 pub used_at: Option<i64>,
77 pub created_at: UnixMillis,
78}
79
80#[derive(Identifiable, Queryable, Debug, Clone)]
81#[diesel(table_name = e2e_one_time_keys)]
82pub struct DbOneTimeKey {
83 pub id: i64,
84
85 pub user_id: OwnedUserId,
86 pub device_id: OwnedDeviceId,
87 pub algorithm: String,
88 pub key_id: OwnedDeviceKeyId,
89 pub key_data: JsonValue,
90 pub created_at: UnixMillis,
91}
92#[derive(Insertable, Debug, Clone)]
93#[diesel(table_name = e2e_one_time_keys)]
94pub struct NewDbOneTimeKey {
95 pub user_id: OwnedUserId,
96 pub device_id: OwnedDeviceId,
97 pub algorithm: String,
98 pub key_id: OwnedDeviceKeyId,
99 pub key_data: JsonValue,
100 pub created_at: UnixMillis,
101}
102
103#[derive(Identifiable, Queryable, Debug, Clone)]
104#[diesel(table_name = e2e_device_keys)]
105pub struct DbDeviceKey {
106 pub id: i64,
107
108 pub user_id: OwnedUserId,
109 pub device_id: OwnedDeviceId,
110 pub algorithm: String,
111 pub stream_id: i64,
112 pub display_name: Option<String>,
113 pub key_data: JsonValue,
114 pub created_at: UnixMillis,
115}
116#[derive(Insertable, AsChangeset, Debug, Clone)]
117#[diesel(table_name = e2e_device_keys)]
118pub struct NewDbDeviceKey {
119 pub user_id: OwnedUserId,
120 pub device_id: OwnedDeviceId,
121 pub stream_id: i64,
122 pub display_name: Option<String>,
123 pub key_data: JsonValue,
124 pub created_at: UnixMillis,
125}
126
127#[derive(Identifiable, Queryable, Debug, Clone)]
128#[diesel(table_name = e2e_key_changes)]
129pub struct DbKeyChange {
130 pub id: i64,
131
132 pub user_id: OwnedUserId,
133 pub room_id: Option<OwnedRoomId>,
134 pub occur_sn: i64,
135 pub changed_at: UnixMillis,
136}
137#[derive(Insertable, AsChangeset, Debug, Clone)]
138#[diesel(table_name = e2e_key_changes)]
139pub struct NewDbKeyChange {
140 pub user_id: OwnedUserId,
141 pub room_id: Option<OwnedRoomId>,
142 pub occur_sn: i64,
143 pub changed_at: UnixMillis,
144}
145
146pub fn get_master_key(
147 sender_id: Option<&UserId>,
148 user_id: &UserId,
149 allowed_signatures: &dyn Fn(&UserId) -> bool,
150) -> DataResult<Option<CrossSigningKey>> {
151 let key_data = e2e_cross_signing_keys::table
152 .filter(e2e_cross_signing_keys::user_id.eq(user_id))
153 .filter(e2e_cross_signing_keys::key_type.eq("master"))
154 .select(e2e_cross_signing_keys::key_data)
155 .first::<JsonValue>(&mut connect()?)
156 .optional()?;
157 if let Some(mut key_data) = key_data {
158 clean_signatures(&mut key_data, sender_id, user_id, allowed_signatures)?;
159 Ok(serde_json::from_value(key_data).ok())
160 } else {
161 Ok(None)
162 }
163}
164
165pub fn get_self_signing_key(
166 sender_id: Option<&UserId>,
167 user_id: &UserId,
168 allowed_signatures: &dyn Fn(&UserId) -> bool,
169) -> DataResult<Option<CrossSigningKey>> {
170 let key_data = e2e_cross_signing_keys::table
171 .filter(e2e_cross_signing_keys::user_id.eq(user_id))
172 .filter(e2e_cross_signing_keys::key_type.eq("self_signing"))
173 .select(e2e_cross_signing_keys::key_data)
174 .first::<JsonValue>(&mut connect()?)
175 .optional()?;
176 if let Some(mut key_data) = key_data {
177 clean_signatures(&mut key_data, sender_id, user_id, allowed_signatures)?;
178 Ok(serde_json::from_value(key_data).ok())
179 } else {
180 Ok(None)
181 }
182}
183pub fn get_user_signing_key(user_id: &OwnedUserId) -> DataResult<Option<CrossSigningKey>> {
184 e2e_cross_signing_keys::table
185 .filter(e2e_cross_signing_keys::user_id.eq(user_id))
186 .filter(e2e_cross_signing_keys::key_type.eq("user_signing"))
187 .select(e2e_cross_signing_keys::key_data)
188 .first::<JsonValue>(&mut connect()?)
189 .map(|data| serde_json::from_value(data).ok())
190 .optional()
191 .map(|v| v.flatten())
192 .map_err(Into::into)
193}
194
195pub fn add_one_time_key(
196 user_id: &OwnedUserId,
197 device_id: &DeviceId,
198 key_id: &DeviceKeyId,
199 one_time_key: &OneTimeKey,
200) -> DataResult<()> {
201 diesel::insert_into(e2e_one_time_keys::table)
202 .values(&NewDbOneTimeKey {
203 user_id: user_id.to_owned(),
204 device_id: device_id.to_owned(),
205 algorithm: key_id.algorithm().to_string(),
206 key_id: key_id.to_owned(),
207 key_data: serde_json::to_value(one_time_key).unwrap(),
208 created_at: UnixMillis::now(),
209 })
210 .on_conflict((
211 e2e_one_time_keys::user_id,
212 e2e_one_time_keys::device_id,
213 e2e_one_time_keys::algorithm,
214 e2e_one_time_keys::key_id,
215 ))
216 .do_update()
217 .set(e2e_one_time_keys::key_data.eq(serde_json::to_value(one_time_key).unwrap()))
218 .execute(&mut connect()?)?;
219 Ok(())
220}
221
222pub fn claim_one_time_key(
223 user_id: &OwnedUserId,
224 device_id: &DeviceId,
225 key_algorithm: &DeviceKeyAlgorithm,
226) -> DataResult<Option<(OwnedDeviceKeyId, OneTimeKey)>> {
227 let one_time_key = e2e_one_time_keys::table
228 .filter(e2e_one_time_keys::user_id.eq(user_id))
229 .filter(e2e_one_time_keys::device_id.eq(device_id))
230 .filter(e2e_one_time_keys::algorithm.eq(key_algorithm.as_ref()))
231 .order(e2e_one_time_keys::id.desc())
232 .first::<DbOneTimeKey>(&mut connect()?)
233 .optional()?;
234 if let Some(DbOneTimeKey {
235 id, key_id, key_data, ..
236 }) = one_time_key
237 {
238 diesel::delete(e2e_one_time_keys::table.find(id)).execute(&mut connect()?)?;
239 Ok(Some((key_id, serde_json::from_value::<OneTimeKey>(key_data)?)))
240 } else {
241 Ok(None)
242 }
243}
244
245pub fn count_one_time_keys(user_id: &UserId, device_id: &DeviceId) -> DataResult<BTreeMap<DeviceKeyAlgorithm, u64>> {
246 let list = e2e_one_time_keys::table
247 .filter(e2e_one_time_keys::user_id.eq(user_id))
248 .filter(e2e_one_time_keys::device_id.eq(device_id))
249 .group_by(e2e_one_time_keys::algorithm)
250 .select((e2e_one_time_keys::algorithm, diesel::dsl::count_star()))
251 .load::<(String, i64)>(&mut connect()?)?;
252 Ok(BTreeMap::from_iter(
253 list.into_iter().map(|(k, v)| (DeviceKeyAlgorithm::from(k), v as u64)),
254 ))
255}
256
257pub fn add_device_keys(user_id: &UserId, device_id: &DeviceId, device_keys: &DeviceKeys) -> DataResult<()> {
258 println!(
259 ">>>>>>>>>>>>>>>>>>add add_device_keys user_id: {:?} device_id: {device_id} device_keys:{device_keys:?}",
260 user_id
261 );
262 let new_device_key = NewDbDeviceKey {
263 user_id: user_id.to_owned(),
264 device_id: device_id.to_owned(),
265 stream_id: 0,
266 display_name: device_keys.unsigned.device_display_name.clone(),
267 key_data: serde_json::to_value(device_keys).unwrap(),
268 created_at: UnixMillis::now(),
269 };
270 diesel::insert_into(e2e_device_keys::table)
271 .values(&new_device_key)
272 .on_conflict((e2e_device_keys::user_id, e2e_device_keys::device_id))
273 .do_update()
274 .set(&new_device_key)
275 .execute(&mut connect()?)?;
276 mark_device_key_update(user_id)?;
277 Ok(())
278}
279
280pub fn add_cross_signing_keys(
281 user_id: &UserId,
282 master_key: &CrossSigningKey,
283 self_signing_key: &Option<CrossSigningKey>,
284 user_signing_key: &Option<CrossSigningKey>,
285 notify: bool,
286) -> DataResult<()> {
287 diesel::insert_into(e2e_cross_signing_keys::table)
289 .values(NewDbCrossSigningKey {
290 user_id: user_id.to_owned(),
291 key_type: "master".to_owned(),
292 key_data: serde_json::to_value(master_key)?,
293 })
294 .execute(&mut connect()?)?;
295
296 if let Some(self_signing_key) = self_signing_key {
298 let mut self_signing_key_ids = self_signing_key.keys.values();
299
300 let self_signing_key_id = self_signing_key_ids
301 .next()
302 .ok_or(MatrixError::invalid_param("Self signing key contained no key."))?;
303
304 if self_signing_key_ids.next().is_some() {
305 return Err(MatrixError::invalid_param("Self signing key contained more than one key.").into());
306 }
307
308 diesel::insert_into(e2e_cross_signing_keys::table)
309 .values(NewDbCrossSigningKey {
310 user_id: user_id.to_owned(),
311 key_type: "self_signing".to_owned(),
312 key_data: serde_json::to_value(self_signing_key)?,
313 })
314 .execute(&mut connect()?)?;
315 }
316
317 if let Some(user_signing_key) = user_signing_key {
319 let mut user_signing_key_ids = user_signing_key.keys.values();
320
321 let user_signing_key_id = user_signing_key_ids
322 .next()
323 .ok_or(MatrixError::invalid_param("User signing key contained no key."))?;
324
325 if user_signing_key_ids.next().is_some() {
326 return Err(MatrixError::invalid_param("User signing key contained more than one key.").into());
327 }
328
329 diesel::insert_into(e2e_cross_signing_keys::table)
330 .values(NewDbCrossSigningKey {
331 user_id: user_id.to_owned(),
332 key_type: "user_signing".to_owned(),
333 key_data: serde_json::to_value(user_signing_key)?,
334 })
335 .execute(&mut connect()?)?;
336 }
337
338 if notify {
339 mark_device_key_update(user_id)?;
340 }
341
342 Ok(())
343}
344
345pub fn sign_key(
346 target_user_id: &UserId,
347 target_device_id: &str,
348 signature: (String, String),
349 sender_id: &UserId,
350) -> DataResult<()> {
351 let origin_key_id = DeviceKeyId::parse(&signature.0)?.to_owned();
358
359 diesel::insert_into(e2e_cross_signing_sigs::table)
366 .values(NewDbCrossSignature {
367 origin_user_id: sender_id.to_owned(),
368 origin_key_id,
369 target_user_id: target_user_id.to_owned(),
370 target_device_id: OwnedDeviceId::from(target_device_id),
371 signature: signature.1,
372 })
373 .execute(&mut connect()?)?;
374 mark_device_key_update(target_user_id)
375}
376
377pub fn mark_device_key_update(user_id: &UserId) -> DataResult<()> {
378 println!(">>>>>>>>>>>>>>mark_device_key_update, user_id: {:?}", user_id);
379 let changed_at = UnixMillis::now();
380 for room_id in crate::user::joined_rooms(user_id, 0)? {
381 let change = NewDbKeyChange {
388 user_id: user_id.to_owned(),
389 room_id: Some(room_id.to_owned()),
390 changed_at,
391 occur_sn: crate::next_sn()?,
392 };
393
394 diesel::delete(
395 e2e_key_changes::table
396 .filter(e2e_key_changes::user_id.eq(user_id))
397 .filter(e2e_key_changes::room_id.eq(room_id)),
398 )
399 .execute(&mut connect()?)?;
400 diesel::insert_into(e2e_key_changes::table)
401 .values(&change)
402 .execute(&mut connect()?)?;
403 }
404
405 let change = NewDbKeyChange {
406 user_id: user_id.to_owned(),
407 room_id: None,
408 changed_at,
409 occur_sn: crate::next_sn()?,
410 };
411
412 diesel::delete(
413 e2e_key_changes::table
414 .filter(e2e_key_changes::user_id.eq(user_id))
415 .filter(e2e_key_changes::room_id.is_null()),
416 )
417 .execute(&mut connect()?)?;
418 diesel::insert_into(e2e_key_changes::table)
419 .values(&change)
420 .execute(&mut connect()?)?;
421
422 Ok(())
423}
424
425pub fn get_device_keys(user_id: &UserId, device_id: &DeviceId) -> DataResult<Option<DeviceKeys>> {
426 e2e_device_keys::table
427 .filter(e2e_device_keys::user_id.eq(user_id))
428 .filter(e2e_device_keys::device_id.eq(device_id))
429 .select(e2e_device_keys::key_data)
430 .first::<JsonValue>(&mut *connect()?)
431 .optional()?
432 .map(|v| serde_json::from_value(v).map_err(Into::into))
433 .transpose()
434}
435
436pub fn get_device_keys_and_sigs(user_id: &UserId, device_id: &DeviceId) -> DataResult<Option<DeviceKeys>> {
437 let Some(mut device_keys) = get_device_keys(user_id, device_id)? else {
438 return Ok(None);
439 };
440 let signatures = e2e_cross_signing_sigs::table
441 .filter(e2e_cross_signing_sigs::origin_user_id.eq(user_id))
442 .filter(e2e_cross_signing_sigs::target_user_id.eq(user_id))
443 .filter(e2e_cross_signing_sigs::target_device_id.eq(device_id))
444 .load::<DbCrossSignature>(&mut *connect()?)?;
445 for DbCrossSignature {
446 origin_key_id,
447 signature,
448 ..
449 } in signatures
450 {
451 device_keys
452 .signatures
453 .entry(user_id.to_owned())
454 .or_default()
455 .insert(origin_key_id, signature);
456 }
457 Ok(Some(device_keys))
458}
459
460pub fn keys_changed_users(user_id: &UserId, since_sn: i64, until_sn: Option<i64>) -> DataResult<Vec<OwnedUserId>> {
461 let room_ids = crate::user::joined_rooms(user_id, 0)?;
462 if let Some(until_sn) = until_sn {
463 e2e_key_changes::table
464 .filter(
465 e2e_key_changes::room_id
466 .eq_any(&room_ids)
467 .or(e2e_key_changes::room_id.is_null()),
468 )
469 .filter(e2e_key_changes::occur_sn.ge(since_sn))
470 .filter(e2e_key_changes::occur_sn.le(until_sn))
471 .select(e2e_key_changes::user_id)
472 .load::<OwnedUserId>(&mut connect()?)
473 .map_err(Into::into)
474 } else {
475 e2e_key_changes::table
476 .filter(
477 e2e_key_changes::room_id
478 .eq_any(&room_ids)
479 .or(e2e_key_changes::room_id.is_null()),
480 )
481 .filter(e2e_key_changes::occur_sn.ge(since_sn))
482 .select(e2e_key_changes::user_id)
483 .load::<OwnedUserId>(&mut connect()?)
484 .map_err(Into::into)
485 }
486}
487
488pub fn room_keys_changed(
489 room_id: &RoomId,
490 since_sn: i64,
491 until_sn: Option<i64>,
492) -> DataResult<Vec<(OwnedUserId, Seqnum)>> {
493 if let Some(until_sn) = until_sn {
494 e2e_key_changes::table
495 .filter(e2e_key_changes::room_id.eq(room_id))
496 .filter(e2e_key_changes::occur_sn.ge(since_sn))
497 .filter(e2e_key_changes::occur_sn.le(until_sn))
498 .select((e2e_key_changes::user_id, e2e_key_changes::occur_sn))
499 .load::<(OwnedUserId, i64)>(&mut connect()?)
500 .map_err(Into::into)
501 } else {
502 e2e_key_changes::table
503 .filter(e2e_key_changes::room_id.eq(room_id))
504 .filter(e2e_key_changes::occur_sn.ge(since_sn))
505 .select((e2e_key_changes::user_id, e2e_key_changes::occur_sn))
506 .load::<(OwnedUserId, i64)>(&mut connect()?)
507 .map_err(Into::into)
508 }
509}
510
511pub fn has_different_keys(user_id: &UserId, body: &UploadSigningKeysReqBody) -> DataResult<bool> {
527 Ok(true)
529}