palpo_data/user/
key.rs

1use std::collections::{hash_map, BTreeMap, HashMap};
2use std::time::Instant;
3
4use diesel::prelude::*;
5use futures_util::stream::{FuturesUnordered, StreamExt};
6use serde_json::json;
7
8use crate::core::client::key::{ClaimKeysResBody, UploadSigningKeysReqBody};
9use crate::core::encryption::{CrossSigningKey, DeviceKeys, OneTimeKey};
10use crate::core::federation::key::{claim_keys_request, query_keys_request, QueryKeysReqBody, QueryKeysResBody};
11use crate::core::identifiers::*;
12use crate::core::{client, federation, DeviceKeyAlgorithm, JsonValue, MatrixError, Seqnum, UnixMillis};
13use crate::schema::*;
14use crate::user::clean_signatures;
15use crate::{DataError, connect, DataResult};
16
17#[derive(Identifiable, Insertable, Queryable, Debug, Clone)]
18#[diesel(table_name = e2e_cross_signing_keys)]
19pub struct DbCrossSigningKey {
20    pub id: i64,
21
22    pub user_id: OwnedUserId,
23    pub key_type: String,
24    pub key_data: JsonValue,
25}
26#[derive(Insertable, Debug, Clone)]
27#[diesel(table_name = e2e_cross_signing_keys)]
28pub struct NewDbCrossSigningKey {
29    pub user_id: OwnedUserId,
30    pub key_type: String,
31    pub key_data: JsonValue,
32}
33
34#[derive(Identifiable, Queryable, Debug, Clone)]
35#[diesel(table_name = e2e_cross_signing_sigs)]
36pub struct DbCrossSignature {
37    pub id: i64,
38
39    pub origin_user_id: OwnedUserId,
40    pub origin_key_id: OwnedDeviceKeyId,
41    pub target_user_id: OwnedUserId,
42    pub target_device_id: OwnedDeviceId,
43    pub signature: String,
44}
45#[derive(Insertable, Debug, Clone)]
46#[diesel(table_name = e2e_cross_signing_sigs)]
47pub struct NewDbCrossSignature {
48    pub origin_user_id: OwnedUserId,
49    pub origin_key_id: OwnedDeviceKeyId,
50    pub target_user_id: OwnedUserId,
51    pub target_device_id: OwnedDeviceId,
52    pub signature: String,
53}
54
55#[derive(Identifiable, Queryable, Debug, Clone)]
56#[diesel(table_name = e2e_fallback_keys)]
57pub struct DbFallbackKey {
58    pub id: String,
59
60    pub user_id: OwnedUserId,
61    pub device_id: OwnedDeviceId,
62    pub algorithm: String,
63    pub key_id: OwnedDeviceKeyId,
64    pub key_data: JsonValue,
65    pub used_at: Option<i64>,
66    pub created_at: UnixMillis,
67}
68#[derive(Insertable, Debug, Clone)]
69#[diesel(table_name = e2e_fallback_keys)]
70pub struct NewDbFallbackKey {
71    pub user_id: OwnedUserId,
72    pub device_id: OwnedDeviceId,
73    pub algorithm: String,
74    pub key_id: OwnedDeviceKeyId,
75    pub key_data: JsonValue,
76    pub used_at: Option<i64>,
77    pub created_at: UnixMillis,
78}
79
80#[derive(Identifiable, Queryable, Debug, Clone)]
81#[diesel(table_name = e2e_one_time_keys)]
82pub struct DbOneTimeKey {
83    pub id: i64,
84
85    pub user_id: OwnedUserId,
86    pub device_id: OwnedDeviceId,
87    pub algorithm: String,
88    pub key_id: OwnedDeviceKeyId,
89    pub key_data: JsonValue,
90    pub created_at: UnixMillis,
91}
92#[derive(Insertable, Debug, Clone)]
93#[diesel(table_name = e2e_one_time_keys)]
94pub struct NewDbOneTimeKey {
95    pub user_id: OwnedUserId,
96    pub device_id: OwnedDeviceId,
97    pub algorithm: String,
98    pub key_id: OwnedDeviceKeyId,
99    pub key_data: JsonValue,
100    pub created_at: UnixMillis,
101}
102
103#[derive(Identifiable, Queryable, Debug, Clone)]
104#[diesel(table_name = e2e_device_keys)]
105pub struct DbDeviceKey {
106    pub id: i64,
107
108    pub user_id: OwnedUserId,
109    pub device_id: OwnedDeviceId,
110    pub algorithm: String,
111    pub stream_id: i64,
112    pub display_name: Option<String>,
113    pub key_data: JsonValue,
114    pub created_at: UnixMillis,
115}
116#[derive(Insertable, AsChangeset, Debug, Clone)]
117#[diesel(table_name = e2e_device_keys)]
118pub struct NewDbDeviceKey {
119    pub user_id: OwnedUserId,
120    pub device_id: OwnedDeviceId,
121    pub stream_id: i64,
122    pub display_name: Option<String>,
123    pub key_data: JsonValue,
124    pub created_at: UnixMillis,
125}
126
127#[derive(Identifiable, Queryable, Debug, Clone)]
128#[diesel(table_name = e2e_key_changes)]
129pub struct DbKeyChange {
130    pub id: i64,
131
132    pub user_id: OwnedUserId,
133    pub room_id: Option<OwnedRoomId>,
134    pub occur_sn: i64,
135    pub changed_at: UnixMillis,
136}
137#[derive(Insertable, AsChangeset, Debug, Clone)]
138#[diesel(table_name = e2e_key_changes)]
139pub struct NewDbKeyChange {
140    pub user_id: OwnedUserId,
141    pub room_id: Option<OwnedRoomId>,
142    pub occur_sn: i64,
143    pub changed_at: UnixMillis,
144}
145
146pub fn get_master_key(
147    sender_id: Option<&UserId>,
148    user_id: &UserId,
149    allowed_signatures: &dyn Fn(&UserId) -> bool,
150) -> DataResult<Option<CrossSigningKey>> {
151    let key_data = e2e_cross_signing_keys::table
152        .filter(e2e_cross_signing_keys::user_id.eq(user_id))
153        .filter(e2e_cross_signing_keys::key_type.eq("master"))
154        .select(e2e_cross_signing_keys::key_data)
155        .first::<JsonValue>(&mut connect()?)
156        .optional()?;
157    if let Some(mut key_data) = key_data {
158        clean_signatures(&mut key_data, sender_id, user_id, allowed_signatures)?;
159        Ok(serde_json::from_value(key_data).ok())
160    } else {
161        Ok(None)
162    }
163}
164
165pub fn get_self_signing_key(
166    sender_id: Option<&UserId>,
167    user_id: &UserId,
168    allowed_signatures: &dyn Fn(&UserId) -> bool,
169) -> DataResult<Option<CrossSigningKey>> {
170    let key_data = e2e_cross_signing_keys::table
171        .filter(e2e_cross_signing_keys::user_id.eq(user_id))
172        .filter(e2e_cross_signing_keys::key_type.eq("self_signing"))
173        .select(e2e_cross_signing_keys::key_data)
174        .first::<JsonValue>(&mut connect()?)
175        .optional()?;
176    if let Some(mut key_data) = key_data {
177        clean_signatures(&mut key_data, sender_id, user_id, allowed_signatures)?;
178        Ok(serde_json::from_value(key_data).ok())
179    } else {
180        Ok(None)
181    }
182}
183pub fn get_user_signing_key(user_id: &OwnedUserId) -> DataResult<Option<CrossSigningKey>> {
184    e2e_cross_signing_keys::table
185        .filter(e2e_cross_signing_keys::user_id.eq(user_id))
186        .filter(e2e_cross_signing_keys::key_type.eq("user_signing"))
187        .select(e2e_cross_signing_keys::key_data)
188        .first::<JsonValue>(&mut connect()?)
189        .map(|data| serde_json::from_value(data).ok())
190        .optional()
191        .map(|v| v.flatten())
192        .map_err(Into::into)
193}
194
195pub fn add_one_time_key(
196    user_id: &OwnedUserId,
197    device_id: &DeviceId,
198    key_id: &DeviceKeyId,
199    one_time_key: &OneTimeKey,
200) -> DataResult<()> {
201    diesel::insert_into(e2e_one_time_keys::table)
202        .values(&NewDbOneTimeKey {
203            user_id: user_id.to_owned(),
204            device_id: device_id.to_owned(),
205            algorithm: key_id.algorithm().to_string(),
206            key_id: key_id.to_owned(),
207            key_data: serde_json::to_value(one_time_key).unwrap(),
208            created_at: UnixMillis::now(),
209        })
210        .on_conflict((
211            e2e_one_time_keys::user_id,
212            e2e_one_time_keys::device_id,
213            e2e_one_time_keys::algorithm,
214            e2e_one_time_keys::key_id,
215        ))
216        .do_update()
217        .set(e2e_one_time_keys::key_data.eq(serde_json::to_value(one_time_key).unwrap()))
218        .execute(&mut connect()?)?;
219    Ok(())
220}
221
222pub fn claim_one_time_key(
223    user_id: &OwnedUserId,
224    device_id: &DeviceId,
225    key_algorithm: &DeviceKeyAlgorithm,
226) -> DataResult<Option<(OwnedDeviceKeyId, OneTimeKey)>> {
227    let one_time_key = e2e_one_time_keys::table
228        .filter(e2e_one_time_keys::user_id.eq(user_id))
229        .filter(e2e_one_time_keys::device_id.eq(device_id))
230        .filter(e2e_one_time_keys::algorithm.eq(key_algorithm.as_ref()))
231        .order(e2e_one_time_keys::id.desc())
232        .first::<DbOneTimeKey>(&mut connect()?)
233        .optional()?;
234    if let Some(DbOneTimeKey {
235        id, key_id, key_data, ..
236    }) = one_time_key
237    {
238        diesel::delete(e2e_one_time_keys::table.find(id)).execute(&mut connect()?)?;
239        Ok(Some((key_id, serde_json::from_value::<OneTimeKey>(key_data)?)))
240    } else {
241        Ok(None)
242    }
243}
244
245pub fn count_one_time_keys(user_id: &UserId, device_id: &DeviceId) -> DataResult<BTreeMap<DeviceKeyAlgorithm, u64>> {
246    let list = e2e_one_time_keys::table
247        .filter(e2e_one_time_keys::user_id.eq(user_id))
248        .filter(e2e_one_time_keys::device_id.eq(device_id))
249        .group_by(e2e_one_time_keys::algorithm)
250        .select((e2e_one_time_keys::algorithm, diesel::dsl::count_star()))
251        .load::<(String, i64)>(&mut connect()?)?;
252    Ok(BTreeMap::from_iter(
253        list.into_iter().map(|(k, v)| (DeviceKeyAlgorithm::from(k), v as u64)),
254    ))
255}
256
257pub fn add_device_keys(user_id: &UserId, device_id: &DeviceId, device_keys: &DeviceKeys) -> DataResult<()> {
258    println!(
259        ">>>>>>>>>>>>>>>>>>add add_device_keys user_id: {:?} device_id: {device_id} device_keys:{device_keys:?}",
260        user_id
261    );
262    let new_device_key = NewDbDeviceKey {
263        user_id: user_id.to_owned(),
264        device_id: device_id.to_owned(),
265        stream_id: 0,
266        display_name: device_keys.unsigned.device_display_name.clone(),
267        key_data: serde_json::to_value(device_keys).unwrap(),
268        created_at: UnixMillis::now(),
269    };
270    diesel::insert_into(e2e_device_keys::table)
271        .values(&new_device_key)
272        .on_conflict((e2e_device_keys::user_id, e2e_device_keys::device_id))
273        .do_update()
274        .set(&new_device_key)
275        .execute(&mut connect()?)?;
276    mark_device_key_update(user_id)?;
277    Ok(())
278}
279
280pub fn add_cross_signing_keys(
281    user_id: &UserId,
282    master_key: &CrossSigningKey,
283    self_signing_key: &Option<CrossSigningKey>,
284    user_signing_key: &Option<CrossSigningKey>,
285    notify: bool,
286) -> DataResult<()> {
287    // TODO: Check signatures
288    diesel::insert_into(e2e_cross_signing_keys::table)
289        .values(NewDbCrossSigningKey {
290            user_id: user_id.to_owned(),
291            key_type: "master".to_owned(),
292            key_data: serde_json::to_value(master_key)?,
293        })
294        .execute(&mut connect()?)?;
295
296    // Self-signing key
297    if let Some(self_signing_key) = self_signing_key {
298        let mut self_signing_key_ids = self_signing_key.keys.values();
299
300        let self_signing_key_id = self_signing_key_ids
301            .next()
302            .ok_or(MatrixError::invalid_param("Self signing key contained no key."))?;
303
304        if self_signing_key_ids.next().is_some() {
305            return Err(MatrixError::invalid_param("Self signing key contained more than one key.").into());
306        }
307
308        diesel::insert_into(e2e_cross_signing_keys::table)
309            .values(NewDbCrossSigningKey {
310                user_id: user_id.to_owned(),
311                key_type: "self_signing".to_owned(),
312                key_data: serde_json::to_value(self_signing_key)?,
313            })
314            .execute(&mut connect()?)?;
315    }
316
317    // User-signing key
318    if let Some(user_signing_key) = user_signing_key {
319        let mut user_signing_key_ids = user_signing_key.keys.values();
320
321        let user_signing_key_id = user_signing_key_ids
322            .next()
323            .ok_or(MatrixError::invalid_param("User signing key contained no key."))?;
324
325        if user_signing_key_ids.next().is_some() {
326            return Err(MatrixError::invalid_param("User signing key contained more than one key.").into());
327        }
328
329        diesel::insert_into(e2e_cross_signing_keys::table)
330            .values(NewDbCrossSigningKey {
331                user_id: user_id.to_owned(),
332                key_type: "user_signing".to_owned(),
333                key_data: serde_json::to_value(user_signing_key)?,
334            })
335            .execute(&mut connect()?)?;
336    }
337
338    if notify {
339        mark_device_key_update(user_id)?;
340    }
341
342    Ok(())
343}
344
345pub fn sign_key(
346    target_user_id: &UserId,
347    target_device_id: &str,
348    signature: (String, String),
349    sender_id: &UserId,
350) -> DataResult<()> {
351    // let cross_signing_key = e2e_cross_signing_keys::table
352    //     .filter(e2e_cross_signing_keys::user_id.eq(target_id))
353    //     .filter(e2e_cross_signing_keys::key_type.eq("master"))
354    //     .order_by(e2e_cross_signing_keys::id.desc())
355    //     .first::<DbCrossSigningKey>(&mut *connect()?)?;
356    // let mut cross_signing_key: CrossSigningKey = serde_json::from_value(cross_signing_key.key_data.clone())?;
357    let origin_key_id = DeviceKeyId::parse(&signature.0)?.to_owned();
358
359    // cross_signing_key
360    //     .signatures
361    //     .entry(sender_id.to_owned())
362    //     .or_defaut()
363    //     .insert(key_id.clone(), signature.1);
364
365    diesel::insert_into(e2e_cross_signing_sigs::table)
366        .values(NewDbCrossSignature {
367            origin_user_id: sender_id.to_owned(),
368            origin_key_id,
369            target_user_id: target_user_id.to_owned(),
370            target_device_id: OwnedDeviceId::from(target_device_id),
371            signature: signature.1,
372        })
373        .execute(&mut connect()?)?;
374    mark_device_key_update(target_user_id)
375}
376
377pub fn mark_device_key_update(user_id: &UserId) -> DataResult<()> {
378    println!(">>>>>>>>>>>>>>mark_device_key_update, user_id: {:?}", user_id);
379    let changed_at = UnixMillis::now();
380    for room_id in crate::user::joined_rooms(user_id, 0)? {
381        // comment for testing
382        // // Don't send key updates to unencrypted rooms
383        // if crate::room::state::get_state(&room_id, &StateEventType::RoomEncryption, "")?.is_none() {
384        //     continue;
385        // }
386
387        let change = NewDbKeyChange {
388            user_id: user_id.to_owned(),
389            room_id: Some(room_id.to_owned()),
390            changed_at,
391            occur_sn: crate::next_sn()?,
392        };
393
394        diesel::delete(
395            e2e_key_changes::table
396                .filter(e2e_key_changes::user_id.eq(user_id))
397                .filter(e2e_key_changes::room_id.eq(room_id)),
398        )
399        .execute(&mut connect()?)?;
400        diesel::insert_into(e2e_key_changes::table)
401            .values(&change)
402            .execute(&mut connect()?)?;
403    }
404
405    let change = NewDbKeyChange {
406        user_id: user_id.to_owned(),
407        room_id: None,
408        changed_at,
409        occur_sn: crate::next_sn()?,
410    };
411
412    diesel::delete(
413        e2e_key_changes::table
414            .filter(e2e_key_changes::user_id.eq(user_id))
415            .filter(e2e_key_changes::room_id.is_null()),
416    )
417    .execute(&mut connect()?)?;
418    diesel::insert_into(e2e_key_changes::table)
419        .values(&change)
420        .execute(&mut connect()?)?;
421
422    Ok(())
423}
424
425pub fn get_device_keys(user_id: &UserId, device_id: &DeviceId) -> DataResult<Option<DeviceKeys>> {
426    e2e_device_keys::table
427        .filter(e2e_device_keys::user_id.eq(user_id))
428        .filter(e2e_device_keys::device_id.eq(device_id))
429        .select(e2e_device_keys::key_data)
430        .first::<JsonValue>(&mut *connect()?)
431        .optional()?
432        .map(|v| serde_json::from_value(v).map_err(Into::into))
433        .transpose()
434}
435
436pub fn get_device_keys_and_sigs(user_id: &UserId, device_id: &DeviceId) -> DataResult<Option<DeviceKeys>> {
437    let Some(mut device_keys) = get_device_keys(user_id, device_id)? else {
438        return Ok(None);
439    };
440    let signatures = e2e_cross_signing_sigs::table
441        .filter(e2e_cross_signing_sigs::origin_user_id.eq(user_id))
442        .filter(e2e_cross_signing_sigs::target_user_id.eq(user_id))
443        .filter(e2e_cross_signing_sigs::target_device_id.eq(device_id))
444        .load::<DbCrossSignature>(&mut *connect()?)?;
445    for DbCrossSignature {
446        origin_key_id,
447        signature,
448        ..
449    } in signatures
450    {
451        device_keys
452            .signatures
453            .entry(user_id.to_owned())
454            .or_default()
455            .insert(origin_key_id, signature);
456    }
457    Ok(Some(device_keys))
458}
459
460pub fn keys_changed_users(user_id: &UserId, since_sn: i64, until_sn: Option<i64>) -> DataResult<Vec<OwnedUserId>> {
461    let room_ids = crate::user::joined_rooms(user_id, 0)?;
462    if let Some(until_sn) = until_sn {
463        e2e_key_changes::table
464            .filter(
465                e2e_key_changes::room_id
466                    .eq_any(&room_ids)
467                    .or(e2e_key_changes::room_id.is_null()),
468            )
469            .filter(e2e_key_changes::occur_sn.ge(since_sn))
470            .filter(e2e_key_changes::occur_sn.le(until_sn))
471            .select(e2e_key_changes::user_id)
472            .load::<OwnedUserId>(&mut connect()?)
473            .map_err(Into::into)
474    } else {
475        e2e_key_changes::table
476            .filter(
477                e2e_key_changes::room_id
478                    .eq_any(&room_ids)
479                    .or(e2e_key_changes::room_id.is_null()),
480            )
481            .filter(e2e_key_changes::occur_sn.ge(since_sn))
482            .select(e2e_key_changes::user_id)
483            .load::<OwnedUserId>(&mut connect()?)
484            .map_err(Into::into)
485    }
486}
487
488pub fn room_keys_changed(
489    room_id: &RoomId,
490    since_sn: i64,
491    until_sn: Option<i64>,
492) -> DataResult<Vec<(OwnedUserId, Seqnum)>> {
493    if let Some(until_sn) = until_sn {
494        e2e_key_changes::table
495            .filter(e2e_key_changes::room_id.eq(room_id))
496            .filter(e2e_key_changes::occur_sn.ge(since_sn))
497            .filter(e2e_key_changes::occur_sn.le(until_sn))
498            .select((e2e_key_changes::user_id, e2e_key_changes::occur_sn))
499            .load::<(OwnedUserId, i64)>(&mut connect()?)
500            .map_err(Into::into)
501    } else {
502        e2e_key_changes::table
503            .filter(e2e_key_changes::room_id.eq(room_id))
504            .filter(e2e_key_changes::occur_sn.ge(since_sn))
505            .select((e2e_key_changes::user_id, e2e_key_changes::occur_sn))
506            .load::<(OwnedUserId, i64)>(&mut connect()?)
507            .map_err(Into::into)
508    }
509}
510
511// Check if a key provided in `body` differs from the same key stored in the DB. Returns
512// true on the first difference. If a key exists in `body` but does not exist in the DB,
513// returns True. If `body` has no keys, this always returns False.
514// Note by 'key' we mean Matrix key rather than JSON key.
515
516// The purpose of this function is to detect whether or not we need to apply UIA checks.
517// We must apply UIA checks if any key in the database is being overwritten. If a key is
518// being inserted for the first time, or if the key exactly matches what is in the database,
519// then no UIA check needs to be performed.
520
521// Args:
522//     user_id: The user who sent the `body`.
523//     body: The JSON request body from POST /keys/device_signing/upload
524// Returns:
525//     true if any key in `body` has a different value in the database.
526pub fn has_different_keys(user_id: &UserId, body: &UploadSigningKeysReqBody) -> DataResult<bool> {
527    //TODO: NOW
528    Ok(true)
529}