matrix_sdk_crypto/olm/
account.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, HashMap},
17    fmt,
18    ops::{Deref, Not as _},
19    sync::Arc,
20    time::Duration,
21};
22
23use hkdf::Hkdf;
24use js_option::JsOption;
25use matrix_sdk_common::deserialized_responses::{
26    AlgorithmInfo, DeviceLinkProblem, EncryptionInfo, VerificationLevel, VerificationState,
27};
28#[cfg(test)]
29use ruma::api::client::dehydrated_device::DehydratedDeviceV1;
30use ruma::{
31    api::client::{
32        dehydrated_device::{DehydratedDeviceData, DehydratedDeviceV2},
33        keys::{
34            upload_keys,
35            upload_signatures::v3::{Request as SignatureUploadRequest, SignedKeys},
36        },
37    },
38    events::{room::history_visibility::HistoryVisibility, AnyToDeviceEvent},
39    serde::Raw,
40    DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm,
41    OneTimeKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedOneTimeKeyId, OwnedUserId, RoomId,
42    SecondsSinceUnixEpoch, UInt, UserId,
43};
44use serde::{de::Error, Deserialize, Serialize};
45use serde_json::{
46    value::{to_raw_value, RawValue as RawJsonValue},
47    Value,
48};
49use sha2::{Digest, Sha256};
50use tokio::sync::Mutex;
51use tracing::{debug, field::debug, info, instrument, trace, warn, Span};
52use vodozemac::{
53    base64_encode,
54    olm::{
55        Account as InnerAccount, AccountPickle, IdentityKeys, OlmMessage,
56        OneTimeKeyGenerationResult, PreKeyMessage, SessionConfig,
57    },
58    Curve25519PublicKey, Ed25519Signature, KeyId, PickleError,
59};
60
61use super::{
62    utility::SignJson, EncryptionSettings, InboundGroupSession, OutboundGroupSession,
63    PrivateCrossSigningIdentity, Session, SessionCreationError as MegolmSessionCreationError,
64};
65#[cfg(feature = "experimental-algorithms")]
66use crate::types::events::room::encrypted::OlmV2Curve25519AesSha2Content;
67use crate::{
68    dehydrated_devices::DehydrationError,
69    error::{EventError, OlmResult, SessionCreationError},
70    identities::DeviceData,
71    olm::SenderData,
72    store::{Changes, DeviceChanges, Store},
73    types::{
74        events::{
75            olm_v1::AnyDecryptedOlmEvent,
76            room::encrypted::{
77                EncryptedToDeviceEvent, OlmV1Curve25519AesSha2Content,
78                ToDeviceEncryptedEventContent,
79            },
80        },
81        requests::UploadSigningKeysRequest,
82        CrossSigningKey, DeviceKeys, EventEncryptionAlgorithm, MasterPubkey, OneTimeKey, SignedKey,
83    },
84    Device, OlmError, SignatureError,
85};
86
87#[derive(Debug)]
88enum PrekeyBundle {
89    Olm3DH { key: SignedKey },
90}
91
92#[derive(Debug, Clone)]
93pub(crate) enum SessionType {
94    New(Session),
95    Existing(Session),
96}
97
98#[derive(Debug)]
99pub struct InboundCreationResult {
100    pub session: Session,
101    pub plaintext: String,
102}
103
104impl SessionType {
105    #[cfg(test)]
106    pub fn session(self) -> Session {
107        match self {
108            SessionType::New(s) => s,
109            SessionType::Existing(s) => s,
110        }
111    }
112}
113
114/// A struct witnessing a successful decryption of an Olm-encrypted to-device
115/// event.
116///
117/// Contains the decrypted event plaintext along with some associated metadata,
118/// such as the identity (Curve25519) key of the to-device event sender.
119#[derive(Debug)]
120pub(crate) struct OlmDecryptionInfo {
121    pub session: SessionType,
122    pub message_hash: OlmMessageHash,
123    pub inbound_group_session: Option<InboundGroupSession>,
124    pub result: DecryptionResult,
125}
126
127#[derive(Debug)]
128pub(crate) struct DecryptionResult {
129    // AnyDecryptedOlmEvent is pretty big at 512 bytes, box it to reduce stack size
130    pub event: Box<AnyDecryptedOlmEvent>,
131    pub raw_event: Raw<AnyToDeviceEvent>,
132    pub sender_key: Curve25519PublicKey,
133    pub encryption_info: EncryptionInfo,
134}
135
136/// A hash of a successfully decrypted Olm message.
137///
138/// Can be used to check if a message has been replayed to us.
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct OlmMessageHash {
141    /// The curve25519 key of the sender that sent us the Olm message.
142    pub sender_key: String,
143    /// The hash of the message.
144    pub hash: String,
145}
146
147impl OlmMessageHash {
148    fn new(sender_key: Curve25519PublicKey, ciphertext: &OlmMessage) -> Self {
149        let (message_type, ciphertext) = ciphertext.clone().to_parts();
150        let sender_key = sender_key.to_base64();
151
152        let sha = Sha256::new()
153            .chain_update(sender_key.as_bytes())
154            .chain_update([message_type as u8])
155            .chain_update(ciphertext)
156            .finalize();
157
158        Self { sender_key, hash: base64_encode(sha.as_slice()) }
159    }
160}
161
162/// Account data that's static for the lifetime of a Client.
163///
164/// This data never changes once it's set, so it can be freely passed and cloned
165/// everywhere.
166#[derive(Clone)]
167#[cfg_attr(not(tarpaulin_include), derive(Debug))]
168pub struct StaticAccountData {
169    /// The user_id this account belongs to.
170    pub user_id: OwnedUserId,
171    /// The device_id of this entry.
172    pub device_id: OwnedDeviceId,
173    /// The associated identity keys.
174    pub identity_keys: Arc<IdentityKeys>,
175    /// Whether the account is for a dehydrated device.
176    pub dehydrated: bool,
177    // The creation time of the account in milliseconds since epoch.
178    creation_local_time: MilliSecondsSinceUnixEpoch,
179}
180
181impl StaticAccountData {
182    const ALGORITHMS: &'static [&'static EventEncryptionAlgorithm] = &[
183        &EventEncryptionAlgorithm::OlmV1Curve25519AesSha2,
184        #[cfg(feature = "experimental-algorithms")]
185        &EventEncryptionAlgorithm::OlmV2Curve25519AesSha2,
186        &EventEncryptionAlgorithm::MegolmV1AesSha2,
187        #[cfg(feature = "experimental-algorithms")]
188        &EventEncryptionAlgorithm::MegolmV2AesSha2,
189    ];
190
191    /// Create a group session pair.
192    ///
193    /// This session pair can be used to encrypt and decrypt messages meant for
194    /// a large group of participants.
195    ///
196    /// The outbound session is used to encrypt messages while the inbound one
197    /// is used to decrypt messages encrypted by the outbound one.
198    ///
199    /// # Arguments
200    ///
201    /// * `room_id` - The ID of the room where the group session will be used.
202    ///
203    /// * `settings` - Settings determining the algorithm and rotation period of
204    ///   the outbound group session.
205    pub async fn create_group_session_pair(
206        &self,
207        room_id: &RoomId,
208        settings: EncryptionSettings,
209        own_sender_data: SenderData,
210    ) -> Result<(OutboundGroupSession, InboundGroupSession), MegolmSessionCreationError> {
211        trace!(?room_id, algorithm = settings.algorithm.as_str(), "Creating a new room key");
212
213        let visibility = settings.history_visibility.clone();
214        let algorithm = settings.algorithm.to_owned();
215
216        let outbound = OutboundGroupSession::new(
217            self.device_id.clone(),
218            self.identity_keys.clone(),
219            room_id,
220            settings,
221        )?;
222
223        let identity_keys = &self.identity_keys;
224
225        let sender_key = identity_keys.curve25519;
226        let signing_key = identity_keys.ed25519;
227        let shared_history = shared_history_from_history_visibility(&visibility);
228
229        let inbound = InboundGroupSession::new(
230            sender_key,
231            signing_key,
232            room_id,
233            &outbound.session_key().await,
234            own_sender_data,
235            algorithm,
236            Some(visibility),
237            shared_history,
238        )?;
239
240        Ok((outbound, inbound))
241    }
242
243    #[cfg(any(test, feature = "testing"))]
244    #[allow(dead_code)]
245    /// Testing only facility to create a group session pair with default
246    /// settings.
247    pub async fn create_group_session_pair_with_defaults(
248        &self,
249        room_id: &RoomId,
250    ) -> (OutboundGroupSession, InboundGroupSession) {
251        self.create_group_session_pair(
252            room_id,
253            EncryptionSettings::default(),
254            SenderData::unknown(),
255        )
256        .await
257        .expect("Can't create default group session pair")
258    }
259
260    /// Get the key ID of our Ed25519 signing key.
261    pub fn signing_key_id(&self) -> OwnedDeviceKeyId {
262        DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id())
263    }
264
265    /// Check if the given JSON is signed by this Account key.
266    ///
267    /// This method should only be used if an object's signature needs to be
268    /// checked multiple times, and you'd like to avoid performing the
269    /// canonicalization step each time.
270    ///
271    /// **Note**: Use this method with caution, the `canonical_json` needs to be
272    /// correctly canonicalized and make sure that the object you are checking
273    /// the signature for is allowed to be signed by our own device.
274    pub fn has_signed_raw(
275        &self,
276        signatures: &crate::types::Signatures,
277        canonical_json: &str,
278    ) -> Result<(), SignatureError> {
279        use crate::olm::utility::VerifyJson;
280
281        let signing_key = self.identity_keys.ed25519;
282
283        signing_key.verify_canonicalized_json(
284            &self.user_id,
285            &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id()),
286            signatures,
287            canonical_json,
288        )
289    }
290
291    /// Generate the unsigned `DeviceKeys` from this `StaticAccountData`.
292    pub fn unsigned_device_keys(&self) -> DeviceKeys {
293        let identity_keys = self.identity_keys();
294        let keys = BTreeMap::from([
295            (
296                DeviceKeyId::from_parts(DeviceKeyAlgorithm::Curve25519, &self.device_id),
297                identity_keys.curve25519.into(),
298            ),
299            (
300                DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
301                identity_keys.ed25519.into(),
302            ),
303        ]);
304
305        let mut ret = DeviceKeys::new(
306            (*self.user_id).to_owned(),
307            (*self.device_id).to_owned(),
308            Self::ALGORITHMS.iter().map(|a| (**a).clone()).collect(),
309            keys,
310            Default::default(),
311        );
312        if self.dehydrated {
313            ret.dehydrated = JsOption::Some(true);
314        }
315        ret
316    }
317
318    /// Get the user id of the owner of the account.
319    pub fn user_id(&self) -> &UserId {
320        &self.user_id
321    }
322
323    /// Get the device ID that owns this account.
324    pub fn device_id(&self) -> &DeviceId {
325        &self.device_id
326    }
327
328    /// Get the public parts of the identity keys for the account.
329    pub fn identity_keys(&self) -> IdentityKeys {
330        *self.identity_keys
331    }
332
333    /// Get the local timestamp creation of the account in secs since epoch.
334    pub fn creation_local_time(&self) -> MilliSecondsSinceUnixEpoch {
335        self.creation_local_time
336    }
337}
338
339/// Account holding identity keys for which sessions can be created.
340///
341/// An account is the central identity for encrypted communication between two
342/// devices.
343pub struct Account {
344    pub(crate) static_data: StaticAccountData,
345    /// `vodozemac` account.
346    inner: Box<InnerAccount>,
347    /// Is this account ready to encrypt messages? (i.e. has it shared keys with
348    /// a homeserver)
349    shared: bool,
350    /// The number of signed one-time keys we have uploaded to the server. If
351    /// this is None, no action will be taken. After a sync request the client
352    /// needs to set this for us, depending on the count we will suggest the
353    /// client to upload new keys.
354    uploaded_signed_key_count: u64,
355    /// The timestamp of the last time we generated a fallback key. Fallback
356    /// keys are rotated in a time-based manner. This field records when we
357    /// either generated our first fallback key or rotated one.
358    ///
359    /// Will be `None` if we never created a fallback key, or if we're migrating
360    /// from a `AccountPickle` that didn't use time-based fallback key
361    /// rotation.
362    fallback_creation_timestamp: Option<MilliSecondsSinceUnixEpoch>,
363}
364
365impl Deref for Account {
366    type Target = StaticAccountData;
367
368    fn deref(&self) -> &Self::Target {
369        &self.static_data
370    }
371}
372
373/// A pickled version of an `Account`.
374///
375/// Holds all the information that needs to be stored in a database to restore
376/// an account.
377#[derive(Serialize, Deserialize)]
378#[allow(missing_debug_implementations)]
379pub struct PickledAccount {
380    /// The user id of the account owner.
381    pub user_id: OwnedUserId,
382    /// The device ID of the account owner.
383    pub device_id: OwnedDeviceId,
384    /// The pickled version of the Olm account.
385    pub pickle: AccountPickle,
386    /// Was the account shared.
387    pub shared: bool,
388    /// Whether this is for a dehydrated device
389    #[serde(default)]
390    pub dehydrated: bool,
391    /// The number of uploaded one-time keys we have on the server.
392    pub uploaded_signed_key_count: u64,
393    /// The local time creation of this account (milliseconds since epoch), used
394    /// as creation time of own device
395    #[serde(default = "default_account_creation_time")]
396    pub creation_local_time: MilliSecondsSinceUnixEpoch,
397    /// The timestamp of the last time we generated a fallback key.
398    #[serde(default)]
399    pub fallback_key_creation_timestamp: Option<MilliSecondsSinceUnixEpoch>,
400}
401
402fn default_account_creation_time() -> MilliSecondsSinceUnixEpoch {
403    MilliSecondsSinceUnixEpoch(UInt::default())
404}
405
406#[cfg(not(tarpaulin_include))]
407impl fmt::Debug for Account {
408    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
409        f.debug_struct("Account")
410            .field("identity_keys", &self.identity_keys())
411            .field("shared", &self.shared())
412            .finish()
413    }
414}
415
416pub type OneTimeKeys = BTreeMap<OwnedOneTimeKeyId, Raw<ruma::encryption::OneTimeKey>>;
417pub type FallbackKeys = OneTimeKeys;
418
419impl Account {
420    pub(crate) fn new_helper(
421        mut account: InnerAccount,
422        user_id: &UserId,
423        device_id: &DeviceId,
424    ) -> Self {
425        let identity_keys = account.identity_keys();
426
427        // Let's generate some initial one-time keys while we're here. Since we know
428        // that this is a completely new [`Account`] we're certain that the
429        // server does not yet have any one-time keys of ours.
430        //
431        // This ensures we upload one-time keys along with our device keys right
432        // away, rather than waiting for the key counts to be echoed back to us
433        // from the server.
434        //
435        // It would be nice to do this for the fallback key as well but we can't assume
436        // that the server supports fallback keys. Maybe one of these days we
437        // will be able to do so.
438        account.generate_one_time_keys(account.max_number_of_one_time_keys());
439
440        Self {
441            static_data: StaticAccountData {
442                user_id: user_id.into(),
443                device_id: device_id.into(),
444                identity_keys: Arc::new(identity_keys),
445                dehydrated: false,
446                creation_local_time: MilliSecondsSinceUnixEpoch::now(),
447            },
448            inner: Box::new(account),
449            shared: false,
450            uploaded_signed_key_count: 0,
451            fallback_creation_timestamp: None,
452        }
453    }
454
455    /// Create a fresh new account, this will generate the identity key-pair.
456    pub fn with_device_id(user_id: &UserId, device_id: &DeviceId) -> Self {
457        let account = InnerAccount::new();
458
459        Self::new_helper(account, user_id, device_id)
460    }
461
462    /// Create a new random Olm Account, the long-term Curve25519 identity key
463    /// encoded as base64 will be used for the device ID.
464    pub fn new(user_id: &UserId) -> Self {
465        let account = InnerAccount::new();
466        let device_id: OwnedDeviceId =
467            base64_encode(account.identity_keys().curve25519.as_bytes()).into();
468
469        Self::new_helper(account, user_id, &device_id)
470    }
471
472    /// Create a new random Olm Account for a dehydrated device
473    pub fn new_dehydrated(user_id: &UserId) -> Self {
474        let account = InnerAccount::new();
475        let device_id: OwnedDeviceId =
476            base64_encode(account.identity_keys().curve25519.as_bytes()).into();
477
478        let mut ret = Self::new_helper(account, user_id, &device_id);
479        ret.static_data.dehydrated = true;
480        ret
481    }
482
483    /// Get the immutable data for this account.
484    pub fn static_data(&self) -> &StaticAccountData {
485        &self.static_data
486    }
487
488    /// Update the uploaded key count.
489    ///
490    /// # Arguments
491    ///
492    /// * `new_count` - The new count that was reported by the server.
493    pub fn update_uploaded_key_count(&mut self, new_count: u64) {
494        self.uploaded_signed_key_count = new_count;
495    }
496
497    /// Get the currently known uploaded key count.
498    pub fn uploaded_key_count(&self) -> u64 {
499        self.uploaded_signed_key_count
500    }
501
502    /// Has the account been shared with the server.
503    pub fn shared(&self) -> bool {
504        self.shared
505    }
506
507    /// Mark the account as shared.
508    ///
509    /// Messages shouldn't be encrypted with the session before it has been
510    /// shared.
511    pub fn mark_as_shared(&mut self) {
512        self.shared = true;
513    }
514
515    /// Get the one-time keys of the account.
516    ///
517    /// This can be empty, keys need to be generated first.
518    pub fn one_time_keys(&self) -> HashMap<KeyId, Curve25519PublicKey> {
519        self.inner.one_time_keys()
520    }
521
522    /// Generate count number of one-time keys.
523    pub fn generate_one_time_keys(&mut self, count: usize) -> OneTimeKeyGenerationResult {
524        self.inner.generate_one_time_keys(count)
525    }
526
527    /// Get the maximum number of one-time keys the account can hold.
528    pub fn max_one_time_keys(&self) -> usize {
529        self.inner.max_number_of_one_time_keys()
530    }
531
532    pub(crate) fn update_key_counts(
533        &mut self,
534        one_time_key_counts: &BTreeMap<OneTimeKeyAlgorithm, UInt>,
535        unused_fallback_keys: Option<&[OneTimeKeyAlgorithm]>,
536    ) {
537        if let Some(count) = one_time_key_counts.get(&OneTimeKeyAlgorithm::SignedCurve25519) {
538            let count: u64 = (*count).into();
539            let old_count = self.uploaded_key_count();
540
541            // Some servers might always return the key counts in the sync
542            // response, we don't want to the logs with noop changes if they do
543            // so.
544            if count != old_count {
545                debug!(
546                    "Updated uploaded one-time key count {} -> {count}.",
547                    self.uploaded_key_count(),
548                );
549            }
550
551            self.update_uploaded_key_count(count);
552            self.generate_one_time_keys_if_needed();
553        }
554
555        // If the server supports fallback keys or if it did so in the past, shown by
556        // the existence of a fallback creation timestamp, generate a new one if
557        // we don't have one, or if the current fallback key expired.
558        if unused_fallback_keys.is_some() || self.fallback_creation_timestamp.is_some() {
559            self.generate_fallback_key_if_needed();
560        }
561    }
562
563    /// Generate new one-time keys that need to be uploaded to the server.
564    ///
565    /// Returns None if no keys need to be uploaded, otherwise the number of
566    /// newly generated one-time keys. May return 0 if some one-time keys are
567    /// already generated but weren't uploaded.
568    ///
569    /// Generally `Some` means that keys should be uploaded, while `None` means
570    /// that keys should not be uploaded.
571    #[instrument(skip_all)]
572    pub fn generate_one_time_keys_if_needed(&mut self) -> Option<u64> {
573        // Only generate one-time keys if there aren't any, otherwise the caller
574        // might have failed to upload them the last time this method was
575        // called.
576        if !self.one_time_keys().is_empty() {
577            return Some(0);
578        }
579
580        let count = self.uploaded_key_count();
581        let max_keys = self.max_one_time_keys();
582
583        if count >= max_keys as u64 {
584            return None;
585        }
586
587        let key_count = (max_keys as u64) - count;
588        let key_count: usize = key_count.try_into().unwrap_or(max_keys);
589
590        let result = self.generate_one_time_keys(key_count);
591
592        debug!(
593            count = key_count,
594            discarded_keys = ?result.removed,
595            created_keys = ?result.created,
596            "Generated new one-time keys"
597        );
598
599        Some(key_count as u64)
600    }
601
602    /// Generate a new fallback key iff a unpublished one isn't already inside
603    /// of vodozemac and if the currently active one expired.
604    ///
605    /// The former is checked using [`Account::fallback_key().is_empty()`],
606    /// which is a hashmap that gets cleared by the
607    /// [`Account::mark_keys_as_published()`] call.
608    pub(crate) fn generate_fallback_key_if_needed(&mut self) {
609        if self.inner.fallback_key().is_empty() && self.fallback_key_expired() {
610            let removed_fallback_key = self.inner.generate_fallback_key();
611            self.fallback_creation_timestamp = Some(MilliSecondsSinceUnixEpoch::now());
612
613            debug!(
614                ?removed_fallback_key,
615                "The fallback key either expired or we didn't have one: generated a new fallback key.",
616            );
617        }
618    }
619
620    /// Check if our most recent fallback key has expired.
621    ///
622    /// We consider the fallback key to be expired if it's older than a week.
623    /// This is the lower bound for the recommended signed pre-key bundle
624    /// rotation interval in the X3DH spec[1].
625    ///
626    /// [1]: https://signal.org/docs/specifications/x3dh/#publishing-keys
627    fn fallback_key_expired(&self) -> bool {
628        const FALLBACK_KEY_MAX_AGE: Duration = Duration::from_secs(3600 * 24 * 7);
629
630        if let Some(time) = self.fallback_creation_timestamp {
631            // `to_system_time()` returns `None` if the the UNIX_EPOCH + `time` doesn't fit
632            // into a i64. This will likely never happen, but let's rotate the
633            // key in case the values are messed up for some other reason.
634            let Some(system_time) = time.to_system_time() else {
635                return true;
636            };
637
638            // `elapsed()` errors if the `system_time` is in the future, this should mean
639            // that our clock has changed to the past, let's rotate just in case
640            // and then we'll get to a normal time.
641            let Ok(elapsed) = system_time.elapsed() else {
642                return true;
643            };
644
645            // Alright, our times are normal and we know how much time elapsed since the
646            // last time we created/rotated a fallback key.
647            //
648            // If the key is older than a week, then we rotate it.
649            elapsed > FALLBACK_KEY_MAX_AGE
650        } else {
651            // We never created a fallback key, or we're migrating to the time-based
652            // fallback key rotation, so let's generate a new fallback key.
653            true
654        }
655    }
656
657    fn fallback_key(&self) -> HashMap<KeyId, Curve25519PublicKey> {
658        self.inner.fallback_key()
659    }
660
661    /// Get a tuple of device, one-time, and fallback keys that need to be
662    /// uploaded.
663    ///
664    /// If no keys need to be uploaded the `DeviceKeys` will be `None` and the
665    /// one-time and fallback keys maps will be empty.
666    pub fn keys_for_upload(&self) -> (Option<DeviceKeys>, OneTimeKeys, FallbackKeys) {
667        let device_keys = self.shared().not().then(|| self.device_keys());
668
669        let one_time_keys = self.signed_one_time_keys();
670        let fallback_keys = self.signed_fallback_keys();
671
672        (device_keys, one_time_keys, fallback_keys)
673    }
674
675    /// Mark the current set of one-time keys as being published.
676    pub fn mark_keys_as_published(&mut self) {
677        self.inner.mark_keys_as_published();
678    }
679
680    /// Sign the given string using the accounts signing key.
681    ///
682    /// Returns the signature as a base64 encoded string.
683    pub fn sign(&self, string: &str) -> Ed25519Signature {
684        self.inner.sign(string)
685    }
686
687    /// Get a serializable version of the `Account` so it can be persisted.
688    pub fn pickle(&self) -> PickledAccount {
689        let pickle = self.inner.pickle();
690
691        PickledAccount {
692            user_id: self.user_id().to_owned(),
693            device_id: self.device_id().to_owned(),
694            pickle,
695            shared: self.shared(),
696            dehydrated: self.static_data.dehydrated,
697            uploaded_signed_key_count: self.uploaded_key_count(),
698            creation_local_time: self.static_data.creation_local_time,
699            fallback_key_creation_timestamp: self.fallback_creation_timestamp,
700        }
701    }
702
703    pub(crate) fn dehydrate(&self, pickle_key: &[u8; 32]) -> Raw<DehydratedDeviceData> {
704        let dehydration_result = self
705            .inner
706            .to_dehydrated_device(pickle_key)
707            .expect("We should be able to convert a freshly created Account into a libolm pickle");
708
709        let data = DehydratedDeviceData::V2(DehydratedDeviceV2::new(
710            dehydration_result.ciphertext,
711            dehydration_result.nonce,
712        ));
713        Raw::from_json(to_raw_value(&data).expect("Couldn't serialize our dehydrated device data"))
714    }
715
716    pub(crate) fn rehydrate(
717        pickle_key: &[u8; 32],
718        user_id: &UserId,
719        device_id: &DeviceId,
720        device_data: Raw<DehydratedDeviceData>,
721    ) -> Result<Self, DehydrationError> {
722        let data = device_data.deserialize()?;
723
724        match data {
725            DehydratedDeviceData::V1(d) => {
726                let pickle_key = expand_legacy_pickle_key(pickle_key, device_id);
727                let account =
728                    InnerAccount::from_libolm_pickle(&d.device_pickle, pickle_key.as_ref())?;
729                Ok(Self::new_helper(account, user_id, device_id))
730            }
731            DehydratedDeviceData::V2(d) => {
732                let account =
733                    InnerAccount::from_dehydrated_device(&d.device_pickle, &d.nonce, pickle_key)?;
734                Ok(Self::new_helper(account, user_id, device_id))
735            }
736            _ => Err(DehydrationError::Json(serde_json::Error::custom(format!(
737                "Unsupported dehydrated device algorithm {:?}",
738                data.algorithm()
739            )))),
740        }
741    }
742
743    /// Produce a dehydrated device using a format described in an older version
744    /// of MSC3814.
745    #[cfg(test)]
746    pub(crate) fn legacy_dehydrate(&self, pickle_key: &[u8; 32]) -> Raw<DehydratedDeviceData> {
747        let pickle_key = expand_legacy_pickle_key(pickle_key, &self.device_id);
748        let device_pickle = self
749            .inner
750            .to_libolm_pickle(pickle_key.as_ref())
751            .expect("We should be able to convert a freshly created Account into a libolm pickle");
752
753        let data = DehydratedDeviceData::V1(DehydratedDeviceV1::new(device_pickle));
754        Raw::from_json(to_raw_value(&data).expect("Couldn't serialize our dehydrated device data"))
755    }
756
757    /// Restore an account from a previously pickled one.
758    ///
759    /// # Arguments
760    ///
761    /// * `pickle` - The pickled version of the Account.
762    ///
763    /// * `pickle_mode` - The mode that was used to pickle the account, either
764    ///   an unencrypted mode or an encrypted using passphrase.
765    pub fn from_pickle(pickle: PickledAccount) -> Result<Self, PickleError> {
766        let account: vodozemac::olm::Account = pickle.pickle.into();
767        let identity_keys = account.identity_keys();
768
769        Ok(Self {
770            static_data: StaticAccountData {
771                user_id: (*pickle.user_id).into(),
772                device_id: (*pickle.device_id).into(),
773                identity_keys: Arc::new(identity_keys),
774                dehydrated: pickle.dehydrated,
775                creation_local_time: pickle.creation_local_time,
776            },
777            inner: Box::new(account),
778            shared: pickle.shared,
779            uploaded_signed_key_count: pickle.uploaded_signed_key_count,
780            fallback_creation_timestamp: pickle.fallback_key_creation_timestamp,
781        })
782    }
783
784    /// Sign the device keys of the account and return them so they can be
785    /// uploaded.
786    pub fn device_keys(&self) -> DeviceKeys {
787        let mut device_keys = self.unsigned_device_keys();
788
789        // Create a copy of the device keys containing only fields that will
790        // get signed.
791        let json_device_keys =
792            serde_json::to_value(&device_keys).expect("device key is always safe to serialize");
793        let signature = self
794            .sign_json(json_device_keys)
795            .expect("Newly created device keys can always be signed");
796
797        device_keys.signatures.add_signature(
798            self.user_id().to_owned(),
799            DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.static_data.device_id),
800            signature,
801        );
802
803        device_keys
804    }
805
806    /// Bootstrap Cross-Signing
807    pub async fn bootstrap_cross_signing(
808        &self,
809    ) -> (PrivateCrossSigningIdentity, UploadSigningKeysRequest, SignatureUploadRequest) {
810        PrivateCrossSigningIdentity::with_account(self).await
811    }
812
813    /// Sign the given CrossSigning Key in place
814    pub fn sign_cross_signing_key(
815        &self,
816        cross_signing_key: &mut CrossSigningKey,
817    ) -> Result<(), SignatureError> {
818        #[allow(clippy::needless_borrows_for_generic_args)]
819        // XXX: false positive, see https://github.com/rust-lang/rust-clippy/issues/12856
820        let signature = self.sign_json(serde_json::to_value(&cross_signing_key)?)?;
821
822        cross_signing_key.signatures.add_signature(
823            self.user_id().to_owned(),
824            DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id()),
825            signature,
826        );
827
828        Ok(())
829    }
830
831    /// Sign the given Master Key
832    pub fn sign_master_key(
833        &self,
834        master_key: &MasterPubkey,
835    ) -> Result<SignatureUploadRequest, SignatureError> {
836        let public_key =
837            master_key.get_first_key().ok_or(SignatureError::MissingSigningKey)?.to_base64().into();
838
839        let mut cross_signing_key: CrossSigningKey = master_key.as_ref().clone();
840        cross_signing_key.signatures.clear();
841        self.sign_cross_signing_key(&mut cross_signing_key)?;
842
843        let mut user_signed_keys = SignedKeys::new();
844        user_signed_keys.add_cross_signing_keys(public_key, cross_signing_key.to_raw());
845
846        let signed_keys = [(self.user_id().to_owned(), user_signed_keys)].into();
847        Ok(SignatureUploadRequest::new(signed_keys))
848    }
849
850    /// Convert a JSON value to the canonical representation and sign the JSON
851    /// string.
852    ///
853    /// # Arguments
854    ///
855    /// * `json` - The value that should be converted into a canonical JSON
856    ///   string.
857    pub fn sign_json(&self, json: Value) -> Result<Ed25519Signature, SignatureError> {
858        self.inner.sign_json(json)
859    }
860
861    /// Sign and prepare one-time keys to be uploaded.
862    ///
863    /// If no one-time keys need to be uploaded, returns an empty `BTreeMap`.
864    pub fn signed_one_time_keys(&self) -> OneTimeKeys {
865        let one_time_keys = self.one_time_keys();
866
867        if one_time_keys.is_empty() {
868            BTreeMap::new()
869        } else {
870            self.signed_keys(one_time_keys, false)
871        }
872    }
873
874    /// Sign and prepare fallback keys to be uploaded.
875    ///
876    /// If no fallback keys need to be uploaded returns an empty BTreeMap.
877    pub fn signed_fallback_keys(&self) -> FallbackKeys {
878        let fallback_key = self.fallback_key();
879
880        if fallback_key.is_empty() {
881            BTreeMap::new()
882        } else {
883            self.signed_keys(fallback_key, true)
884        }
885    }
886
887    fn signed_keys(
888        &self,
889        keys: HashMap<KeyId, Curve25519PublicKey>,
890        fallback: bool,
891    ) -> OneTimeKeys {
892        let mut keys_map = BTreeMap::new();
893
894        for (key_id, key) in keys {
895            let signed_key = self.sign_key(key, fallback);
896
897            keys_map.insert(
898                OneTimeKeyId::from_parts(
899                    OneTimeKeyAlgorithm::SignedCurve25519,
900                    key_id.to_base64().as_str().into(),
901                ),
902                signed_key.into_raw(),
903            );
904        }
905
906        keys_map
907    }
908
909    fn sign_key(&self, key: Curve25519PublicKey, fallback: bool) -> SignedKey {
910        let mut key = if fallback {
911            SignedKey::new_fallback(key.to_owned())
912        } else {
913            SignedKey::new(key.to_owned())
914        };
915
916        let signature = self
917            .sign_json(serde_json::to_value(&key).expect("Can't serialize a signed key"))
918            .expect("Newly created one-time keys can always be signed");
919
920        key.signatures_mut().add_signature(
921            self.user_id().to_owned(),
922            DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id()),
923            signature,
924        );
925
926        key
927    }
928
929    /// Create a new session with another account given a one-time key.
930    ///
931    /// Returns the newly created session or a `OlmSessionError` if creating a
932    /// session failed.
933    ///
934    /// # Arguments
935    ///
936    /// * `config` - The session config that should be used when creating the
937    ///   Session.
938    ///
939    /// * `identity_key` - The other account's identity/curve25519 key.
940    ///
941    /// * `one_time_key` - A signed one-time key that the other account created
942    ///   and shared with us.
943    ///
944    /// * `fallback_used` - Was the one-time key a fallback key.
945    ///
946    /// * `our_device_keys` - Our own `DeviceKeys`, including cross-signing
947    ///   signatures if applicable, for embedding in encrypted messages.
948    pub fn create_outbound_session_helper(
949        &self,
950        config: SessionConfig,
951        identity_key: Curve25519PublicKey,
952        one_time_key: Curve25519PublicKey,
953        fallback_used: bool,
954        our_device_keys: DeviceKeys,
955    ) -> Session {
956        let session = self.inner.create_outbound_session(config, identity_key, one_time_key);
957
958        let now = SecondsSinceUnixEpoch::now();
959        let session_id = session.session_id();
960
961        Session {
962            inner: Arc::new(Mutex::new(session)),
963            session_id: session_id.into(),
964            sender_key: identity_key,
965            our_device_keys,
966            created_using_fallback_key: fallback_used,
967            creation_time: now,
968            last_use_time: now,
969        }
970    }
971
972    #[instrument(
973        skip_all,
974        fields(
975            user_id = ?device.user_id(),
976            device_id = ?device.device_id(),
977            algorithms = ?device.algorithms()
978        )
979    )]
980    fn find_pre_key_bundle(
981        device: &DeviceData,
982        key_map: &OneTimeKeys,
983    ) -> Result<PrekeyBundle, SessionCreationError> {
984        let mut keys = key_map.iter();
985
986        let first_key = keys.next().ok_or_else(|| {
987            SessionCreationError::OneTimeKeyMissing(
988                device.user_id().to_owned(),
989                device.device_id().into(),
990            )
991        })?;
992
993        let first_key_id = first_key.0.to_owned();
994        let first_key = OneTimeKey::deserialize(first_key_id.algorithm(), first_key.1)?;
995
996        let result = match first_key {
997            OneTimeKey::SignedKey(key) => Ok(PrekeyBundle::Olm3DH { key }),
998        };
999
1000        trace!(?result, "Finished searching for a valid pre-key bundle");
1001
1002        result
1003    }
1004
1005    /// Create a new session with another account given a one-time key and a
1006    /// device.
1007    ///
1008    /// Returns the newly created session or a `OlmSessionError` if creating a
1009    /// session failed.
1010    ///
1011    /// # Arguments
1012    /// * `device` - The other account's device.
1013    ///
1014    /// * `key_map` - A map from the algorithm and device ID to the one-time key
1015    ///   that the other account created and shared with us.
1016    ///
1017    /// * `our_device_keys` - Our own `DeviceKeys`, including cross-signing
1018    ///   signatures if applicable, for embedding in encrypted messages.
1019    #[allow(clippy::result_large_err)]
1020    pub fn create_outbound_session(
1021        &self,
1022        device: &DeviceData,
1023        key_map: &OneTimeKeys,
1024        our_device_keys: DeviceKeys,
1025    ) -> Result<Session, SessionCreationError> {
1026        let pre_key_bundle = Self::find_pre_key_bundle(device, key_map)?;
1027
1028        match pre_key_bundle {
1029            PrekeyBundle::Olm3DH { key } => {
1030                device.verify_one_time_key(&key).map_err(|error| {
1031                    SessionCreationError::InvalidSignature {
1032                        signing_key: device.ed25519_key().map(Box::new),
1033                        one_time_key: key.clone().into(),
1034                        error: error.into(),
1035                    }
1036                })?;
1037
1038                let identity_key = device.curve25519_key().ok_or_else(|| {
1039                    SessionCreationError::DeviceMissingCurveKey(
1040                        device.user_id().to_owned(),
1041                        device.device_id().into(),
1042                    )
1043                })?;
1044
1045                let is_fallback = key.fallback();
1046                let one_time_key = key.key();
1047                let config = device.olm_session_config();
1048
1049                Ok(self.create_outbound_session_helper(
1050                    config,
1051                    identity_key,
1052                    one_time_key,
1053                    is_fallback,
1054                    our_device_keys,
1055                ))
1056            }
1057        }
1058    }
1059
1060    /// Create a new session with another account given a pre-key Olm message.
1061    ///
1062    /// Returns the newly created session or a `OlmSessionError` if creating a
1063    /// session failed.
1064    ///
1065    /// # Arguments
1066    ///
1067    /// * `their_identity_key` - The other account's identity/curve25519 key.
1068    ///
1069    /// * `our_device_keys` - Our own `DeviceKeys`, including cross-signing
1070    ///   signatures if applicable, for embedding in encrypted messages.
1071    ///
1072    /// * `message` - A pre-key Olm message that was sent to us by the other
1073    ///   account.
1074    pub fn create_inbound_session(
1075        &mut self,
1076        their_identity_key: Curve25519PublicKey,
1077        our_device_keys: DeviceKeys,
1078        message: &PreKeyMessage,
1079    ) -> Result<InboundCreationResult, SessionCreationError> {
1080        Span::current().record("session_id", debug(message.session_id()));
1081        trace!("Creating a new Olm session from a pre-key message");
1082
1083        let result = self.inner.create_inbound_session(their_identity_key, message)?;
1084        let now = SecondsSinceUnixEpoch::now();
1085        let session_id = result.session.session_id();
1086
1087        debug!(session=?result.session, "Decrypted an Olm message from a new Olm session");
1088
1089        let session = Session {
1090            inner: Arc::new(Mutex::new(result.session)),
1091            session_id: session_id.into(),
1092            sender_key: their_identity_key,
1093            our_device_keys,
1094            created_using_fallback_key: false,
1095            creation_time: now,
1096            last_use_time: now,
1097        };
1098
1099        let plaintext = String::from_utf8_lossy(&result.plaintext).to_string();
1100
1101        Ok(InboundCreationResult { session, plaintext })
1102    }
1103
1104    #[cfg(any(test, feature = "testing"))]
1105    #[allow(dead_code)]
1106    /// Testing only helper to create a session for the given Account
1107    pub async fn create_session_for_test_helper(
1108        &mut self,
1109        other: &mut Account,
1110    ) -> (Session, Session) {
1111        use ruma::events::dummy::ToDeviceDummyEventContent;
1112
1113        other.generate_one_time_keys(1);
1114        let one_time_map = other.signed_one_time_keys();
1115        let device = DeviceData::from_account(other);
1116
1117        let mut our_session =
1118            self.create_outbound_session(&device, &one_time_map, self.device_keys()).unwrap();
1119
1120        other.mark_keys_as_published();
1121
1122        let message = our_session
1123            .encrypt(&device, "m.dummy", ToDeviceDummyEventContent::new(), None)
1124            .await
1125            .unwrap()
1126            .deserialize()
1127            .unwrap();
1128
1129        #[cfg(feature = "experimental-algorithms")]
1130        let content = if let ToDeviceEncryptedEventContent::OlmV2Curve25519AesSha2(c) = message {
1131            c
1132        } else {
1133            panic!("Invalid encrypted event algorithm {}", message.algorithm());
1134        };
1135
1136        #[cfg(not(feature = "experimental-algorithms"))]
1137        let content = if let ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(c) = message {
1138            c
1139        } else {
1140            panic!("Invalid encrypted event algorithm {}", message.algorithm());
1141        };
1142
1143        let prekey = if let OlmMessage::PreKey(m) = content.ciphertext {
1144            m
1145        } else {
1146            panic!("Wrong Olm message type");
1147        };
1148
1149        let our_device = DeviceData::from_account(self);
1150        let other_session = other
1151            .create_inbound_session(
1152                our_device.curve25519_key().unwrap(),
1153                other.device_keys(),
1154                &prekey,
1155            )
1156            .unwrap();
1157
1158        (our_session, other_session.session)
1159    }
1160
1161    async fn decrypt_olm_helper(
1162        &mut self,
1163        store: &Store,
1164        sender: &UserId,
1165        sender_key: Curve25519PublicKey,
1166        ciphertext: &OlmMessage,
1167    ) -> OlmResult<OlmDecryptionInfo> {
1168        let message_hash = OlmMessageHash::new(sender_key, ciphertext);
1169
1170        match self.decrypt_and_parse_olm_message(store, sender, sender_key, ciphertext).await {
1171            Ok((session, result)) => {
1172                Ok(OlmDecryptionInfo { session, message_hash, result, inbound_group_session: None })
1173            }
1174            Err(OlmError::SessionWedged(user_id, sender_key)) => {
1175                if store.is_message_known(&message_hash).await? {
1176                    info!(?sender_key, "An Olm message got replayed, decryption failed");
1177                    Err(OlmError::ReplayedMessage(user_id, sender_key))
1178                } else {
1179                    Err(OlmError::SessionWedged(user_id, sender_key))
1180                }
1181            }
1182            Err(e) => Err(e),
1183        }
1184    }
1185
1186    #[cfg(feature = "experimental-algorithms")]
1187    async fn decrypt_olm_v2(
1188        &mut self,
1189        store: &Store,
1190        sender: &UserId,
1191        content: &OlmV2Curve25519AesSha2Content,
1192    ) -> OlmResult<OlmDecryptionInfo> {
1193        self.decrypt_olm_helper(store, sender, content.sender_key, &content.ciphertext).await
1194    }
1195
1196    #[instrument(skip_all, fields(sender, sender_key = ?content.sender_key))]
1197    async fn decrypt_olm_v1(
1198        &mut self,
1199        store: &Store,
1200        sender: &UserId,
1201        content: &OlmV1Curve25519AesSha2Content,
1202    ) -> OlmResult<OlmDecryptionInfo> {
1203        if content.recipient_key != self.static_data.identity_keys.curve25519 {
1204            warn!("Olm event doesn't contain a ciphertext for our key");
1205
1206            Err(EventError::MissingCiphertext.into())
1207        } else {
1208            Box::pin(self.decrypt_olm_helper(
1209                store,
1210                sender,
1211                content.sender_key,
1212                &content.ciphertext,
1213            ))
1214            .await
1215        }
1216    }
1217
1218    #[instrument(skip_all, fields(algorithm = ?event.content.algorithm()))]
1219    pub(crate) async fn decrypt_to_device_event(
1220        &mut self,
1221        store: &Store,
1222        event: &EncryptedToDeviceEvent,
1223    ) -> OlmResult<OlmDecryptionInfo> {
1224        trace!("Decrypting a to-device event");
1225
1226        match &event.content {
1227            ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(c) => {
1228                self.decrypt_olm_v1(store, &event.sender, c).await
1229            }
1230            #[cfg(feature = "experimental-algorithms")]
1231            ToDeviceEncryptedEventContent::OlmV2Curve25519AesSha2(c) => {
1232                self.decrypt_olm_v2(store, &event.sender, c).await
1233            }
1234            ToDeviceEncryptedEventContent::Unknown(_) => {
1235                warn!(
1236                    "Error decrypting an to-device event, unsupported \
1237                    encryption algorithm"
1238                );
1239
1240                Err(EventError::UnsupportedAlgorithm.into())
1241            }
1242        }
1243    }
1244
1245    /// Handles a response to a /keys/upload request.
1246    pub fn receive_keys_upload_response(
1247        &mut self,
1248        response: &upload_keys::v3::Response,
1249    ) -> OlmResult<()> {
1250        if !self.shared() {
1251            debug!("Marking account as shared");
1252        }
1253        self.mark_as_shared();
1254
1255        debug!("Marking one-time keys as published");
1256        // First mark the current keys as published, as updating the key counts might
1257        // generate some new keys if we're still below the limit.
1258        self.mark_keys_as_published();
1259        self.update_key_counts(&response.one_time_key_counts, None);
1260
1261        Ok(())
1262    }
1263
1264    /// Try to decrypt an olm message, creating a new session if necessary.
1265    async fn decrypt_olm_message(
1266        &mut self,
1267        store: &Store,
1268        sender: &UserId,
1269        sender_key: Curve25519PublicKey,
1270        message: &OlmMessage,
1271    ) -> Result<(SessionType, String), OlmError> {
1272        let existing_sessions = store.get_sessions(&sender_key.to_base64()).await?;
1273
1274        match message {
1275            OlmMessage::Normal(_) => {
1276                let mut errors_by_olm_session = Vec::new();
1277
1278                if let Some(sessions) = existing_sessions {
1279                    // Try to decrypt the message using each Session we share with the
1280                    // given curve25519 sender key.
1281                    for session in sessions.lock().await.iter_mut() {
1282                        match session.decrypt(message).await {
1283                            Ok(p) => {
1284                                // success!
1285                                return Ok((SessionType::Existing(session.clone()), p));
1286                            }
1287
1288                            Err(e) => {
1289                                // An error here is completely normal, after all we don't know
1290                                // which session was used to encrypt a message.
1291                                // We keep hold of the error, so that if *all* sessions fail to
1292                                // decrypt, we can log something useful.
1293                                errors_by_olm_session.push((session.session_id().to_owned(), e));
1294                            }
1295                        }
1296                    }
1297                }
1298
1299                warn!(
1300                    ?errors_by_olm_session,
1301                    "Failed to decrypt a non-pre-key message with all available sessions"
1302                );
1303                Err(OlmError::SessionWedged(sender.to_owned(), sender_key))
1304            }
1305
1306            OlmMessage::PreKey(prekey_message) => {
1307                // First try to decrypt using an existing session.
1308                if let Some(sessions) = existing_sessions {
1309                    for session in sessions.lock().await.iter_mut() {
1310                        if prekey_message.session_id() != session.session_id() {
1311                            // wrong session
1312                            continue;
1313                        }
1314
1315                        if let Ok(p) = session.decrypt(message).await {
1316                            // success!
1317                            return Ok((SessionType::Existing(session.clone()), p));
1318                        }
1319
1320                        // The message was intended for this session, but we weren't able to
1321                        // decrypt it.
1322                        //
1323                        // There's no point trying any other sessions, nor should we try to
1324                        // create a new one since we have already previously created a `Session`
1325                        // with the same keys.
1326                        //
1327                        // (Attempts to create a new session would likely fail anyway since the
1328                        // corresponding one-time key would've been already used up in the
1329                        // previous session creation operation. The one exception where this
1330                        // would not be so is if the fallback key was used for creating the
1331                        // session in lieu of an OTK.)
1332
1333                        warn!(
1334                            session_id = session.session_id(),
1335                            "Failed to decrypt a pre-key message with the corresponding session"
1336                        );
1337
1338                        return Err(OlmError::SessionWedged(
1339                            session.our_device_keys.user_id.to_owned(),
1340                            session.sender_key(),
1341                        ));
1342                    }
1343                }
1344
1345                let device_keys = store.get_own_device().await?.as_device_keys().clone();
1346                let result =
1347                    match self.create_inbound_session(sender_key, device_keys, prekey_message) {
1348                        Ok(r) => r,
1349                        Err(e) => {
1350                            warn!(
1351                                "Failed to create a new Olm session from a pre-key message: {e:?}"
1352                            );
1353                            return Err(OlmError::SessionWedged(sender.to_owned(), sender_key));
1354                        }
1355                    };
1356
1357                // We need to add the new session to the session cache, otherwise
1358                // we might try to create the same session again.
1359                // TODO: separate the session cache from the storage so we only add
1360                // it to the cache but don't store it.
1361                let mut changes =
1362                    Changes { sessions: vec![result.session.clone()], ..Default::default() };
1363
1364                // Any new Olm session will bump the Olm wedging index for the
1365                // sender's device, if we have their device, which will cause us
1366                // to re-send existing Megolm sessions to them the next time we
1367                // use the session.  If we don't have their device, this means
1368                // that we haven't tried to send them any Megolm sessions yet,
1369                // so we don't need to worry about it.
1370                if let Some(device) = store.get_device_from_curve_key(sender, sender_key).await? {
1371                    let mut device_data = device.inner;
1372                    device_data.olm_wedging_index.increment();
1373
1374                    changes.devices =
1375                        DeviceChanges { changed: vec![device_data], ..Default::default() };
1376                }
1377
1378                store.save_changes(changes).await?;
1379
1380                Ok((SessionType::New(result.session), result.plaintext))
1381            }
1382        }
1383    }
1384
1385    /// Decrypt an Olm message, creating a new Olm session if necessary, and
1386    /// parse the result.
1387    #[instrument(skip(self, store), fields(session, session_id))]
1388    async fn decrypt_and_parse_olm_message(
1389        &mut self,
1390        store: &Store,
1391        sender: &UserId,
1392        sender_key: Curve25519PublicKey,
1393        message: &OlmMessage,
1394    ) -> OlmResult<(SessionType, DecryptionResult)> {
1395        let (session, plaintext) =
1396            self.decrypt_olm_message(store, sender, sender_key, message).await?;
1397
1398        trace!("Successfully decrypted an Olm message");
1399
1400        match self.parse_decrypted_to_device_event(store, sender, sender_key, plaintext).await {
1401            Ok(result) => Ok((session, result)),
1402            Err(e) => {
1403                // We might have created a new session but decryption might still
1404                // have failed, store it for the error case here, this is fine
1405                // since we don't expect this to happen often or at all.
1406                match session {
1407                    SessionType::New(s) | SessionType::Existing(s) => {
1408                        store.save_sessions(&[s]).await?;
1409                    }
1410                }
1411
1412                warn!(
1413                    error = ?e,
1414                    "A to-device message was successfully decrypted but \
1415                    parsing and checking the event fields failed"
1416                );
1417
1418                Err(e)
1419            }
1420        }
1421    }
1422
1423    /// Parse the decrypted plaintext as JSON and verify that it wasn't
1424    /// forwarded by a third party.
1425    ///
1426    /// These checks are mandated by the spec[1]:
1427    ///
1428    /// > Other properties are included in order to prevent an attacker from
1429    /// > publishing someone else's Curve25519 keys as their own and
1430    /// > subsequently claiming to have sent messages which they didn't.
1431    /// > sender must correspond to the user who sent the event, recipient to
1432    /// > the local user, and recipient_keys to the local Ed25519 key.
1433    ///
1434    /// # Arguments
1435    ///
1436    /// * `sender` -  The `sender` field from the top level of the received
1437    ///   event.
1438    /// * `sender_key` - The `sender_key` from the cleartext `content` of the
1439    ///   received event (which should also have been used to find or establish
1440    ///   the Olm session that was used to decrypt the event -- so it is
1441    ///   guaranteed to be correct).
1442    /// * `plaintext` - The decrypted content of the event.
1443    async fn parse_decrypted_to_device_event(
1444        &self,
1445        store: &Store,
1446        sender: &UserId,
1447        sender_key: Curve25519PublicKey,
1448        plaintext: String,
1449    ) -> OlmResult<DecryptionResult> {
1450        let event: Box<AnyDecryptedOlmEvent> = serde_json::from_str(&plaintext)?;
1451        let identity_keys = &self.static_data.identity_keys;
1452
1453        if event.recipient() != self.static_data.user_id {
1454            Err(EventError::MismatchedSender(
1455                event.recipient().to_owned(),
1456                self.static_data.user_id.clone(),
1457            )
1458            .into())
1459        }
1460        // Check that the `sender` in the decrypted to-device event matches that at the
1461        // top level of the encrypted event.
1462        else if event.sender() != sender {
1463            Err(EventError::MismatchedSender(event.sender().to_owned(), sender.to_owned()).into())
1464        } else if identity_keys.ed25519 != event.recipient_keys().ed25519 {
1465            Err(EventError::MismatchedKeys(
1466                identity_keys.ed25519.into(),
1467                event.recipient_keys().ed25519.into(),
1468            )
1469            .into())
1470        } else {
1471            // If the event contained sender_device_keys, check them now.
1472            // WARN: If you move or modify this check, ensure that the code below is still
1473            // valid. The processing of the historic room key bundle depends on this being
1474            // here.
1475            Self::check_sender_device_keys(event.as_ref(), sender_key)?;
1476            let mut sender_device: Option<Device> = None;
1477            if let AnyDecryptedOlmEvent::RoomKey(_) = event.as_ref() {
1478                // If this event is an `m.room_key` event, defer the check for
1479                // the Ed25519 key of the sender until we decrypt room events.
1480                // This ensures that we receive the room key even if we don't
1481                // have access to the device.
1482            } else if let AnyDecryptedOlmEvent::RoomKeyBundle(_) = event.as_ref() {
1483                // If this is a room key bundle we're requiring the device keys to be part of
1484                // the `AnyDecryptedOlmEvent`. This ensures that we can skip the check for the
1485                // Ed25519 key below since `Self::check_sender_device_keys` already did so.
1486                //
1487                // If the event didn't contain any sender device keys we'll throw an error
1488                // refusing to decrypt the room key bundle.
1489                event.sender_device_keys().ok_or(EventError::MissingSigningKey).inspect_err(
1490                    |_| {
1491                        warn!("The room key bundle was missing the sender device keys in the event")
1492                    },
1493                )?;
1494            } else {
1495                let device = store
1496                    .get_device_from_curve_key(event.sender(), sender_key)
1497                    .await?
1498                    .ok_or(EventError::MissingSigningKey)?;
1499
1500                let key = device.ed25519_key().ok_or(EventError::MissingSigningKey)?;
1501
1502                if key != event.keys().ed25519 {
1503                    return Err(EventError::MismatchedKeys(
1504                        key.into(),
1505                        event.keys().ed25519.into(),
1506                    )
1507                    .into());
1508                }
1509                sender_device = Some(device);
1510            }
1511
1512            let encryption_info = Self::get_olm_encryption_info(sender_key, sender, &sender_device);
1513
1514            Ok(DecryptionResult {
1515                event,
1516                raw_event: Raw::from_json(RawJsonValue::from_string(plaintext)?),
1517                sender_key,
1518                encryption_info,
1519            })
1520        }
1521    }
1522
1523    /// Gets the EncryptionInfo for a successfully decrypted to-device message
1524    /// that have passed the mismatched sender_key/user_id validation.
1525    ///
1526    /// `sender_device` is optional because for some to-device messages we defer
1527    /// the check for the ed25519 key, in that case the
1528    /// `verification_state` will have a `MissingDevice` link problem.
1529    fn get_olm_encryption_info(
1530        sender_key: Curve25519PublicKey,
1531        sender_id: &UserId,
1532        sender_device: &Option<Device>,
1533    ) -> EncryptionInfo {
1534        let verification_state = sender_device
1535            .as_ref()
1536            .map(|device| {
1537                if device.is_verified() {
1538                    // The device is locally verified or signed by a verified user
1539                    VerificationState::Verified
1540                } else if device.is_cross_signed_by_owner() {
1541                    // The device is not verified, but it is signed by its owner
1542                    if device
1543                        .device_owner_identity
1544                        .as_ref()
1545                        .expect("A device cross-signed by the owner must have an owner identity")
1546                        .was_previously_verified()
1547                    {
1548                        VerificationState::Unverified(VerificationLevel::VerificationViolation)
1549                    } else {
1550                        VerificationState::Unverified(VerificationLevel::UnverifiedIdentity)
1551                    }
1552                } else {
1553                    // No identity or not signed
1554                    VerificationState::Unverified(VerificationLevel::UnsignedDevice)
1555                }
1556            })
1557            .unwrap_or(VerificationState::Unverified(VerificationLevel::None(
1558                DeviceLinkProblem::MissingDevice,
1559            )));
1560
1561        let encryption_info = EncryptionInfo {
1562            sender: sender_id.to_owned(),
1563            sender_device: sender_device.as_ref().map(|d| d.device_id().to_owned()),
1564            algorithm_info: AlgorithmInfo::OlmV1Curve25519AesSha2 {
1565                curve25519_public_key_base64: sender_key.to_base64(),
1566            },
1567            verification_state,
1568        };
1569        encryption_info
1570    }
1571
1572    /// If the plaintext of the decrypted message includes a
1573    /// `sender_device_keys` property per [MSC4147], check that it is valid.
1574    ///
1575    /// # Arguments
1576    ///
1577    /// * `event` - The decrypted and deserialized plaintext of the event.
1578    /// * `sender_key` - The curve25519 key of the sender of the event.
1579    ///
1580    /// [MSC4147]: https://github.com/matrix-org/matrix-spec-proposals/pull/4147
1581    fn check_sender_device_keys(
1582        event: &AnyDecryptedOlmEvent,
1583        sender_key: Curve25519PublicKey,
1584    ) -> OlmResult<()> {
1585        let Some(sender_device_keys) = event.sender_device_keys() else {
1586            return Ok(());
1587        };
1588
1589        // Check the signature within the device_keys structure
1590        let sender_device_data = DeviceData::try_from(sender_device_keys).map_err(|err| {
1591            warn!(
1592                "Received a to-device message with sender_device_keys with \
1593                 invalid signature: {err:?}",
1594            );
1595            OlmError::EventError(EventError::InvalidSenderDeviceKeys)
1596        })?;
1597
1598        // Check that the Ed25519 key in the sender_device_keys matches the `ed25519`
1599        // key in the `keys` field in the event.
1600        if sender_device_data.ed25519_key() != Some(event.keys().ed25519) {
1601            warn!(
1602                "Received a to-device message with sender_device_keys with incorrect \
1603                 ed25519 key: expected {:?}, got {:?}",
1604                event.keys().ed25519,
1605                sender_device_data.ed25519_key(),
1606            );
1607            return Err(OlmError::EventError(EventError::InvalidSenderDeviceKeys));
1608        }
1609
1610        // Check that the Curve25519 key in the sender_device_keys matches the key that
1611        // was used for the Olm session.
1612        if sender_device_data.curve25519_key() != Some(sender_key) {
1613            warn!(
1614                "Received a to-device message with sender_device_keys with incorrect \
1615                 curve25519 key: expected {sender_key:?}, got {:?}",
1616                sender_device_data.curve25519_key(),
1617            );
1618            return Err(OlmError::EventError(EventError::InvalidSenderDeviceKeys));
1619        }
1620
1621        Ok(())
1622    }
1623
1624    /// Internal use only.
1625    ///
1626    /// Cloning should only be done for testing purposes or when we are certain
1627    /// that we don't want the inner state to be shared.
1628    #[doc(hidden)]
1629    pub fn deep_clone(&self) -> Self {
1630        // `vodozemac::Account` isn't really cloneable, but... Don't tell anyone.
1631        Self::from_pickle(self.pickle()).unwrap()
1632    }
1633}
1634
1635impl PartialEq for Account {
1636    fn eq(&self, other: &Self) -> bool {
1637        self.identity_keys() == other.identity_keys() && self.shared() == other.shared()
1638    }
1639}
1640
1641/// Calculate the shared history flag from the history visibility as defined in
1642/// [MSC3061]
1643///
1644/// The MSC defines that the shared history flag should be set to true when the
1645/// history visibility setting is set to `shared` or `world_readable`:
1646///
1647/// > A room key is flagged as having been used for shared history when it was
1648/// > used to encrypt a message while the room's history visibility setting
1649/// > was set to world_readable or shared.
1650///
1651/// In all other cases, even if we encounter a custom history visibility, we
1652/// should return false:
1653///
1654/// > If the client does not have an m.room.history_visibility state event for
1655/// > the room, or its value is not understood, the client should treat it as if
1656/// > its value is joined for the purposes of determining whether the key is
1657/// > used for shared history.
1658///
1659/// [MSC3061]: https://github.com/matrix-org/matrix-spec-proposals/pull/3061
1660pub(crate) fn shared_history_from_history_visibility(
1661    history_visibility: &HistoryVisibility,
1662) -> bool {
1663    match history_visibility {
1664        HistoryVisibility::Shared | HistoryVisibility::WorldReadable => true,
1665        HistoryVisibility::Invited | HistoryVisibility::Joined | _ => false,
1666    }
1667}
1668
1669/// Expand the pickle key for an older version of dehydrated devices
1670///
1671/// The `org.matrix.msc3814.v1.olm` variant of dehydrated devices used the
1672/// libolm Account pickle format for the dehydrated device. The libolm pickle
1673/// encryption scheme uses HKDF to deterministically expand an input key
1674/// material, usually 32 bytes, into a AES key, MAC key, and the initialization
1675/// vector (IV).
1676///
1677/// This means that the same input key material will always end up producing the
1678/// same AES key, and IV.
1679///
1680/// This encryption scheme is used in the Olm double ratchet and was designed to
1681/// minimize the size of the ciphertext. As a tradeof, it requires a unique
1682/// input key material for each plaintext that gets encrypted, otherwise IV
1683/// reuse happens.
1684///
1685/// To combat the IV reuse, we're going to create a per-dehydrated-device unique
1686/// pickle key by expanding the key itself with the device ID used as the salt.
1687fn expand_legacy_pickle_key(key: &[u8; 32], device_id: &DeviceId) -> Box<[u8; 32]> {
1688    let kdf: Hkdf<Sha256> = Hkdf::new(Some(device_id.as_bytes()), key);
1689    let mut key = Box::new([0u8; 32]);
1690
1691    kdf.expand(b"dehydrated-device-pickle-key", key.as_mut_slice())
1692        .expect("We should be able to expand the 32 byte pickle key");
1693
1694    key
1695}
1696
1697#[cfg(test)]
1698mod tests {
1699    use std::{
1700        collections::{BTreeMap, BTreeSet},
1701        ops::Deref,
1702        time::Duration,
1703    };
1704
1705    use anyhow::Result;
1706    use matrix_sdk_test::async_test;
1707    use ruma::{
1708        device_id, events::room::history_visibility::HistoryVisibility, room_id, user_id, DeviceId,
1709        MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, UserId,
1710    };
1711    use serde_json::json;
1712
1713    use super::Account;
1714    use crate::{
1715        olm::{account::shared_history_from_history_visibility, SignedJsonObject},
1716        types::{DeviceKeys, SignedKey},
1717        DeviceData, EncryptionSettings,
1718    };
1719
1720    fn user_id() -> &'static UserId {
1721        user_id!("@alice:localhost")
1722    }
1723
1724    fn device_id() -> &'static DeviceId {
1725        device_id!("DEVICEID")
1726    }
1727
1728    #[test]
1729    fn test_one_time_key_creation() -> Result<()> {
1730        let mut account = Account::with_device_id(user_id(), device_id());
1731
1732        let (_, one_time_keys, _) = account.keys_for_upload();
1733        assert!(!one_time_keys.is_empty());
1734
1735        let (_, second_one_time_keys, _) = account.keys_for_upload();
1736        assert!(!second_one_time_keys.is_empty());
1737
1738        let one_time_key_ids: BTreeSet<&OneTimeKeyId> =
1739            one_time_keys.keys().map(Deref::deref).collect();
1740        let second_one_time_key_ids: BTreeSet<&OneTimeKeyId> =
1741            second_one_time_keys.keys().map(Deref::deref).collect();
1742
1743        assert_eq!(one_time_key_ids, second_one_time_key_ids);
1744
1745        account.mark_keys_as_published();
1746        account.update_uploaded_key_count(50);
1747        account.generate_one_time_keys_if_needed();
1748
1749        let (_, third_one_time_keys, _) = account.keys_for_upload();
1750        assert!(third_one_time_keys.is_empty());
1751
1752        account.update_uploaded_key_count(0);
1753        account.generate_one_time_keys_if_needed();
1754
1755        let (_, fourth_one_time_keys, _) = account.keys_for_upload();
1756        assert!(!fourth_one_time_keys.is_empty());
1757
1758        let fourth_one_time_key_ids: BTreeSet<&OneTimeKeyId> =
1759            fourth_one_time_keys.keys().map(Deref::deref).collect();
1760
1761        assert_ne!(one_time_key_ids, fourth_one_time_key_ids);
1762        Ok(())
1763    }
1764
1765    #[test]
1766    fn test_fallback_key_creation() -> Result<()> {
1767        let mut account = Account::with_device_id(user_id(), device_id());
1768
1769        let (_, _, fallback_keys) = account.keys_for_upload();
1770
1771        // We don't create fallback keys since we don't know if the server
1772        // supports them, we need to receive a sync response to decide if we're
1773        // going to create them or not.
1774        assert!(
1775            fallback_keys.is_empty(),
1776            "We should not upload fallback keys until we know if the server supports them."
1777        );
1778
1779        let one_time_keys = BTreeMap::from([(OneTimeKeyAlgorithm::SignedCurve25519, 50u8.into())]);
1780
1781        // A `None` here means that the server doesn't support fallback keys, no
1782        // fallback key gets uploaded.
1783        account.update_key_counts(&one_time_keys, None);
1784        let (_, _, fallback_keys) = account.keys_for_upload();
1785        assert!(
1786            fallback_keys.is_empty(),
1787            "We should not upload a fallback key if we're certain that the server doesn't support \
1788             them."
1789        );
1790
1791        // The empty array means that the server supports fallback keys but
1792        // there isn't a unused fallback key on the server. This time we upload
1793        // a fallback key.
1794        let unused_fallback_keys = &[];
1795        account.update_key_counts(&one_time_keys, Some(unused_fallback_keys.as_ref()));
1796        let (_, _, fallback_keys) = account.keys_for_upload();
1797        assert!(
1798            !fallback_keys.is_empty(),
1799            "We should upload the initial fallback key if the server supports them."
1800        );
1801        account.mark_keys_as_published();
1802
1803        // There's no unused fallback key on the server, but our initial fallback key
1804        // did not yet expire.
1805        let unused_fallback_keys = &[];
1806        account.update_key_counts(&one_time_keys, Some(unused_fallback_keys.as_ref()));
1807        let (_, _, fallback_keys) = account.keys_for_upload();
1808        assert!(
1809            fallback_keys.is_empty(),
1810            "We should not upload new fallback keys unless our current fallback key expires."
1811        );
1812
1813        let fallback_key_timestamp =
1814            account.fallback_creation_timestamp.unwrap().to_system_time().unwrap()
1815                - Duration::from_secs(3600 * 24 * 30);
1816
1817        account.fallback_creation_timestamp =
1818            Some(MilliSecondsSinceUnixEpoch::from_system_time(fallback_key_timestamp).unwrap());
1819
1820        account.update_key_counts(&one_time_keys, None);
1821        let (_, _, fallback_keys) = account.keys_for_upload();
1822        assert!(
1823            !fallback_keys.is_empty(),
1824            "Now that our fallback key has expired, we should try to upload a new one, even if the \
1825             server supposedly doesn't support fallback keys anymore"
1826        );
1827
1828        Ok(())
1829    }
1830
1831    #[test]
1832    fn test_fallback_key_signing() -> Result<()> {
1833        let key = vodozemac::Curve25519PublicKey::from_base64(
1834            "7PUPP6Ijt5R8qLwK2c8uK5hqCNF9tOzWYgGaAay5JBs",
1835        )?;
1836        let account = Account::with_device_id(user_id(), device_id());
1837
1838        let key = account.sign_key(key, true);
1839
1840        let canonical_key = key.to_canonical_json()?;
1841
1842        assert_eq!(
1843            canonical_key,
1844            "{\"fallback\":true,\"key\":\"7PUPP6Ijt5R8qLwK2c8uK5hqCNF9tOzWYgGaAay5JBs\"}"
1845        );
1846
1847        account
1848            .has_signed_raw(key.signatures(), &canonical_key)
1849            .expect("Couldn't verify signature");
1850
1851        let device = DeviceData::from_account(&account);
1852        device.verify_one_time_key(&key).expect("The device can verify its own signature");
1853
1854        Ok(())
1855    }
1856
1857    #[test]
1858    fn test_account_and_device_creation_timestamp() -> Result<()> {
1859        let now = MilliSecondsSinceUnixEpoch::now();
1860        let account = Account::with_device_id(user_id(), device_id());
1861        let then = MilliSecondsSinceUnixEpoch::now();
1862
1863        assert!(account.creation_local_time() >= now);
1864        assert!(account.creation_local_time() <= then);
1865
1866        let device = DeviceData::from_account(&account);
1867        assert_eq!(account.creation_local_time(), device.first_time_seen_ts());
1868
1869        Ok(())
1870    }
1871
1872    #[async_test]
1873    async fn test_fallback_key_signature_verification() -> Result<()> {
1874        let fallback_key = json!({
1875            "fallback": true,
1876            "key": "XPFqtLvBepBmW6jSAbBuJbhEpprBhQOX1IjUu+cnMF4",
1877            "signatures": {
1878                "@dkasak_c:matrix.org": {
1879                    "ed25519:EXPDYDPWZH": "RJCBMJPL5hvjxgq8rmLmqkNOuPsaan7JeL1wsE+gW6R39G894lb2sBmzapHeKCn/KFjmkonPLkICApRDS+zyDw"
1880                }
1881            }
1882        });
1883
1884        let device_keys = json!({
1885            "algorithms": [
1886                "m.olm.v1.curve25519-aes-sha2",
1887                "m.megolm.v1.aes-sha2"
1888            ],
1889            "device_id": "EXPDYDPWZH",
1890            "keys": {
1891                "curve25519:EXPDYDPWZH": "k7f3igo0Vrdm88JSSA5d3OCuUfHYELChB2b57aOROB8",
1892                "ed25519:EXPDYDPWZH": "GdjYI8fxs175gSpYRJkyN6FRfvcyTsNOhJ2OR/Ggp+E"
1893            },
1894            "signatures": {
1895                "@dkasak_c:matrix.org": {
1896                    "ed25519:EXPDYDPWZH": "kzrtfQMbJXWXQ1uzhybtwFnGk0JJBS4Mg8VPMusMu6U8MPJccwoHVZKo5+owuHTzIodI+GZYqLmMSzvfvsChAA"
1897                }
1898            },
1899            "user_id": "@dkasak_c:matrix.org",
1900            "unsigned": {}
1901        });
1902
1903        let device_keys: DeviceKeys = serde_json::from_value(device_keys).unwrap();
1904        let device = DeviceData::try_from(&device_keys).unwrap();
1905        let fallback_key: SignedKey = serde_json::from_value(fallback_key).unwrap();
1906
1907        device
1908            .verify_one_time_key(&fallback_key)
1909            .expect("The fallback key should pass the signature verification");
1910
1911        Ok(())
1912    }
1913
1914    #[test]
1915    fn test_shared_history_flag_from_history_visibility() {
1916        assert!(
1917            shared_history_from_history_visibility(&HistoryVisibility::WorldReadable),
1918            "The world readable visibility should set the shared history flag to true"
1919        );
1920
1921        assert!(
1922            shared_history_from_history_visibility(&HistoryVisibility::Shared),
1923            "The shared visibility should set the shared history flag to true"
1924        );
1925
1926        assert!(
1927            !shared_history_from_history_visibility(&HistoryVisibility::Joined),
1928            "The joined visibility should set the shared history flag to false"
1929        );
1930
1931        assert!(
1932            !shared_history_from_history_visibility(&HistoryVisibility::Invited),
1933            "The invited visibility should set the shared history flag to false"
1934        );
1935
1936        let visibility = HistoryVisibility::from("custom_visibility");
1937        assert!(
1938            !shared_history_from_history_visibility(&visibility),
1939            "A custom visibility should set the shared history flag to false"
1940        );
1941    }
1942
1943    #[async_test]
1944    async fn test_shared_history_set_when_creating_group_sessions() {
1945        let account = Account::new(user_id());
1946        let room_id = room_id!("!room:id");
1947        let settings = EncryptionSettings {
1948            history_visibility: HistoryVisibility::Shared,
1949            ..Default::default()
1950        };
1951
1952        let (_, session) = account
1953            .create_group_session_pair(room_id, settings, Default::default())
1954            .await
1955            .expect("We should be able to create a group session pair");
1956
1957        assert!(
1958            session.shared_history(),
1959            "The shared history flag should have been set when we created the new session"
1960        );
1961    }
1962}