libsignal_rust/
session_builder.rs

1use crate::{
2    base_key_type::BaseKeyType,
3    chain_type::ChainType,
4    session_record::{SessionRecord, SessionEntry, CurrentRatchet, IndexInfo, PendingPreKey, ChainInfo, ChainKey},
5    crypto,
6    curve::{self, KeyPair},
7    errors::{UntrustedIdentityKeyError, PreKeyError},
8    protocol_address::ProtocolAddress,
9    queue_job::queue_job,
10};
11use std::sync::Arc;
12
13pub trait SessionStorage: Send + Sync {
14    fn is_trusted_identity(&self, address: &str, identity_key: &[u8]) -> impl std::future::Future<Output = bool> + Send;
15    fn load_session(&self, address: &str) -> impl std::future::Future<Output = Option<SessionRecord>> + Send;
16    fn store_session(&self, address: &str, record: SessionRecord) -> impl std::future::Future<Output = ()> + Send;
17    fn load_pre_key(&self, pre_key_id: u32) -> impl std::future::Future<Output = Option<KeyPair>> + Send;
18    fn load_signed_pre_key(&self, signed_pre_key_id: u32) -> impl std::future::Future<Output = Option<KeyPair>> + Send;
19    fn get_our_identity(&self) -> impl std::future::Future<Output = KeyPair> + Send;
20}
21
22pub struct Device {
23    pub registration_id: u32,
24    pub identity_key: Vec<u8>,
25    pub signed_pre_key: SignedPreKeyBundle,
26    pub pre_key: Option<PreKeyBundle>,
27}
28
29pub struct SignedPreKeyBundle {
30    pub key_id: u32,
31    pub public_key: Vec<u8>,
32    pub signature: Vec<u8>,
33}
34
35pub struct PreKeyBundle {
36    pub key_id: u32,
37    pub public_key: Vec<u8>,
38}
39
40pub struct PreKeyWhisperMessage {
41    pub registration_id: u32,
42    pub pre_key_id: Option<u32>,
43    pub signed_pre_key_id: u32,
44    pub base_key: Vec<u8>,
45    pub identity_key: Vec<u8>,
46    pub message: Vec<u8>,
47}
48
49pub struct SessionBuilder<T: SessionStorage> {
50    addr: ProtocolAddress,
51    storage: Arc<T>,
52}
53
54impl<T: SessionStorage + 'static> SessionBuilder<T> {
55    pub fn new(storage: Arc<T>, protocol_address: ProtocolAddress) -> Self {
56        Self {
57            addr: protocol_address,
58            storage,
59        }
60    }
61
62    pub async fn init_outgoing(&self, device: Device) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
63        let storage = self.storage.clone();
64        let addr = self.addr.clone();
65        
66        queue_job(addr.to_string(), async move {
67            if !storage.is_trusted_identity(&addr.id, &device.identity_key).await {
68                return Err(Box::new(UntrustedIdentityKeyError::new(addr.id.clone(), device.identity_key)) as Box<dyn std::error::Error + Send + Sync>);
69            }
70
71            curve::verify_signature(&device.identity_key, &device.signed_pre_key.public_key, &device.signed_pre_key.signature)?;
72
73            let base_key = curve::generate_key_pair();
74            let device_pre_key = device.pre_key.as_ref().map(|pk| &pk.public_key);
75
76            let session = Self::static_init_session(
77                storage.clone(),
78                true,
79                Some(&base_key),
80                None,
81                &device.identity_key,
82                device_pre_key.map(|v| &**v),
83                Some(&device.signed_pre_key.public_key),
84                device.registration_id,
85            ).await?;
86
87            let mut session_mut = session.clone();
88            session_mut.pending_pre_key = Some(PendingPreKey {
89                signed_key_id: device.signed_pre_key.key_id,
90                base_key: base_key.pub_key.clone(),
91                pre_key_id: device.pre_key.map(|pk| pk.key_id),
92            });
93
94            let mut record = storage.load_session(&addr.to_string()).await.unwrap_or_else(|| SessionRecord::new());
95            
96            if let Some(open_session) = record.get_open_session() {
97                let base_key = open_session.index_info.base_key.clone();
98                record.close_session(&base_key);
99            }
100
101            record.set_session(session_mut);
102            storage.store_session(&addr.to_string(), record).await;
103            Ok(())
104        }).await
105    }
106
107    pub async fn init_incoming(&self, record: &mut SessionRecord, message: &PreKeyWhisperMessage) -> Result<Option<u32>, Box<dyn std::error::Error + Send + Sync>> {
108        let fq_addr = self.addr.to_string();
109        
110        if !self.storage.is_trusted_identity(&fq_addr, &message.identity_key).await {
111            return Err(Box::new(UntrustedIdentityKeyError::new(self.addr.id.clone(), message.identity_key.clone())));
112        }
113
114        if record.get_session(&message.base_key).is_some() {
115            return Ok(None);
116        }
117
118        let pre_key_pair = if let Some(pre_key_id) = message.pre_key_id {
119            self.storage.load_pre_key(pre_key_id).await
120        } else {
121            None
122        };
123
124        if message.pre_key_id.is_some() && pre_key_pair.is_none() {
125            return Err(Box::new(PreKeyError::new("Invalid PreKey ID")));
126        }
127
128        let signed_pre_key_pair = self.storage.load_signed_pre_key(message.signed_pre_key_id).await
129            .ok_or_else(|| PreKeyError::new("Missing SignedPreKey"))?;
130
131        if let Some(open_session) = record.get_open_session() {
132            let base_key = open_session.index_info.base_key.clone();
133            record.close_session(&base_key);
134        }
135
136        let session = Self::static_init_session(
137            self.storage.clone(),
138            false,
139            pre_key_pair.as_ref(),
140            Some(&signed_pre_key_pair),
141            &message.identity_key,
142            Some(&message.base_key),
143            None,
144            message.registration_id,
145        ).await?;
146
147        record.set_session(session);
148        Ok(message.pre_key_id)
149    }
150
151    async fn static_init_session<S: SessionStorage>(
152        storage: Arc<S>,
153        is_initiator: bool,
154        our_ephemeral_key: Option<&KeyPair>,
155        our_signed_key: Option<&KeyPair>,
156        their_identity_pub_key: &[u8],
157        their_ephemeral_pub_key: Option<&[u8]>,
158        their_signed_pub_key: Option<&[u8]>,
159        registration_id: u32,
160    ) -> Result<SessionEntry, Box<dyn std::error::Error + Send + Sync>> {
161        let our_signed_key = if is_initiator {
162            our_ephemeral_key.unwrap()
163        } else {
164            our_signed_key.unwrap()
165        };
166
167        let their_signed_pub_key = if is_initiator {
168            their_signed_pub_key.unwrap()
169        } else {
170            their_ephemeral_pub_key.unwrap()
171        };
172
173        let shared_secret_len = if our_ephemeral_key.is_none() || their_ephemeral_pub_key.is_none() {
174            32 * 4
175        } else {
176            32 * 5
177        };
178
179        let mut shared_secret = vec![0xffu8; 32];
180        shared_secret.resize(shared_secret_len, 0);
181
182        let our_identity = storage.get_our_identity().await;
183        let a1 = curve::calculate_agreement(their_signed_pub_key, &our_identity.priv_key)?;
184        let a2 = curve::calculate_agreement(their_identity_pub_key, &our_signed_key.priv_key)?;
185        let a3 = curve::calculate_agreement(their_signed_pub_key, &our_signed_key.priv_key)?;
186
187        if is_initiator {
188            shared_secret[32..64].copy_from_slice(&a1);
189            shared_secret[64..96].copy_from_slice(&a2);
190        } else {
191            shared_secret[64..96].copy_from_slice(&a1);
192            shared_secret[32..64].copy_from_slice(&a2);
193        }
194        shared_secret[96..128].copy_from_slice(&a3);
195
196        if let (Some(our_eph), Some(their_eph)) = (our_ephemeral_key, their_ephemeral_pub_key) {
197            let a4 = curve::calculate_agreement(their_eph, &our_eph.priv_key)?;
198            shared_secret[128..160].copy_from_slice(&a4);
199        }
200
201        let master_key = crypto::derive_secrets(&shared_secret, &[0u8; 32], b"WhisperText", None)?;
202
203        let mut session = SessionEntry::new();
204        session.registration_id = registration_id;
205        session.current_ratchet = CurrentRatchet {
206            root_key: master_key[0].clone(),
207            ephemeral_key_pair: if is_initiator { 
208                curve::generate_key_pair() 
209            } else { 
210                our_signed_key.clone() 
211            },
212            last_remote_ephemeral_key: their_signed_pub_key.to_vec(),
213            previous_counter: 0,
214        };
215
216        session.index_info = IndexInfo {
217            created: chrono::Utc::now().timestamp() as u64,
218            used: chrono::Utc::now().timestamp() as u64,
219            remote_identity_key: their_identity_pub_key.to_vec(),
220            base_key: if is_initiator { 
221                our_ephemeral_key.unwrap().pub_key.clone() 
222            } else { 
223                their_ephemeral_pub_key.unwrap().to_vec() 
224            },
225            base_key_type: if is_initiator { BaseKeyType::Ours } else { BaseKeyType::Theirs },
226            closed: -1,
227        };
228
229        if is_initiator {
230            let ephemeral_pub_key = session.current_ratchet.ephemeral_key_pair.pub_key.clone();
231        session.add_chain(&ephemeral_pub_key, ChainInfo {
232                message_keys: Default::default(),
233                chain_key: ChainKey {
234                    counter: -1,
235                    key: Some(master_key[1].clone()),
236                },
237                chain_type: ChainType::Sending,
238            })?;
239        }
240
241        Ok(session)
242    }
243}