1use crate::{
2 base_key_type::BaseKeyType,
3 chain_type::ChainType,
4 session_record::{SessionRecord, SessionEntry, CurrentRatchet, IndexInfo, PendingPreKey, ChainInfo, ChainKey},
5 crypto,
6 curve::{self, KeyPair},
7 errors::{UntrustedIdentityKeyError, PreKeyError},
8 protocol_address::ProtocolAddress,
9 queue_job::queue_job,
10};
11use std::sync::Arc;
12
13pub trait SessionStorage: Send + Sync {
14 fn is_trusted_identity(&self, address: &str, identity_key: &[u8]) -> impl std::future::Future<Output = bool> + Send;
15 fn load_session(&self, address: &str) -> impl std::future::Future<Output = Option<SessionRecord>> + Send;
16 fn store_session(&self, address: &str, record: SessionRecord) -> impl std::future::Future<Output = ()> + Send;
17 fn load_pre_key(&self, pre_key_id: u32) -> impl std::future::Future<Output = Option<KeyPair>> + Send;
18 fn load_signed_pre_key(&self, signed_pre_key_id: u32) -> impl std::future::Future<Output = Option<KeyPair>> + Send;
19 fn get_our_identity(&self) -> impl std::future::Future<Output = KeyPair> + Send;
20}
21
22pub struct Device {
23 pub registration_id: u32,
24 pub identity_key: Vec<u8>,
25 pub signed_pre_key: SignedPreKeyBundle,
26 pub pre_key: Option<PreKeyBundle>,
27}
28
29pub struct SignedPreKeyBundle {
30 pub key_id: u32,
31 pub public_key: Vec<u8>,
32 pub signature: Vec<u8>,
33}
34
35pub struct PreKeyBundle {
36 pub key_id: u32,
37 pub public_key: Vec<u8>,
38}
39
40pub struct PreKeyWhisperMessage {
41 pub registration_id: u32,
42 pub pre_key_id: Option<u32>,
43 pub signed_pre_key_id: u32,
44 pub base_key: Vec<u8>,
45 pub identity_key: Vec<u8>,
46 pub message: Vec<u8>,
47}
48
49pub struct SessionBuilder<T: SessionStorage> {
50 addr: ProtocolAddress,
51 storage: Arc<T>,
52}
53
54impl<T: SessionStorage + 'static> SessionBuilder<T> {
55 pub fn new(storage: Arc<T>, protocol_address: ProtocolAddress) -> Self {
56 Self {
57 addr: protocol_address,
58 storage,
59 }
60 }
61
62 pub async fn init_outgoing(&self, device: Device) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
63 let storage = self.storage.clone();
64 let addr = self.addr.clone();
65
66 queue_job(addr.to_string(), async move {
67 if !storage.is_trusted_identity(&addr.id, &device.identity_key).await {
68 return Err(Box::new(UntrustedIdentityKeyError::new(addr.id.clone(), device.identity_key)) as Box<dyn std::error::Error + Send + Sync>);
69 }
70
71 curve::verify_signature(&device.identity_key, &device.signed_pre_key.public_key, &device.signed_pre_key.signature)?;
72
73 let base_key = curve::generate_key_pair();
74 let device_pre_key = device.pre_key.as_ref().map(|pk| &pk.public_key);
75
76 let session = Self::static_init_session(
77 storage.clone(),
78 true,
79 Some(&base_key),
80 None,
81 &device.identity_key,
82 device_pre_key.map(|v| &**v),
83 Some(&device.signed_pre_key.public_key),
84 device.registration_id,
85 ).await?;
86
87 let mut session_mut = session.clone();
88 session_mut.pending_pre_key = Some(PendingPreKey {
89 signed_key_id: device.signed_pre_key.key_id,
90 base_key: base_key.pub_key.clone(),
91 pre_key_id: device.pre_key.map(|pk| pk.key_id),
92 });
93
94 let mut record = storage.load_session(&addr.to_string()).await.unwrap_or_else(|| SessionRecord::new());
95
96 if let Some(open_session) = record.get_open_session() {
97 let base_key = open_session.index_info.base_key.clone();
98 record.close_session(&base_key);
99 }
100
101 record.set_session(session_mut);
102 storage.store_session(&addr.to_string(), record).await;
103 Ok(())
104 }).await
105 }
106
107 pub async fn init_incoming(&self, record: &mut SessionRecord, message: &PreKeyWhisperMessage) -> Result<Option<u32>, Box<dyn std::error::Error + Send + Sync>> {
108 let fq_addr = self.addr.to_string();
109
110 if !self.storage.is_trusted_identity(&fq_addr, &message.identity_key).await {
111 return Err(Box::new(UntrustedIdentityKeyError::new(self.addr.id.clone(), message.identity_key.clone())));
112 }
113
114 if record.get_session(&message.base_key).is_some() {
115 return Ok(None);
116 }
117
118 let pre_key_pair = if let Some(pre_key_id) = message.pre_key_id {
119 self.storage.load_pre_key(pre_key_id).await
120 } else {
121 None
122 };
123
124 if message.pre_key_id.is_some() && pre_key_pair.is_none() {
125 return Err(Box::new(PreKeyError::new("Invalid PreKey ID")));
126 }
127
128 let signed_pre_key_pair = self.storage.load_signed_pre_key(message.signed_pre_key_id).await
129 .ok_or_else(|| PreKeyError::new("Missing SignedPreKey"))?;
130
131 if let Some(open_session) = record.get_open_session() {
132 let base_key = open_session.index_info.base_key.clone();
133 record.close_session(&base_key);
134 }
135
136 let session = Self::static_init_session(
137 self.storage.clone(),
138 false,
139 pre_key_pair.as_ref(),
140 Some(&signed_pre_key_pair),
141 &message.identity_key,
142 Some(&message.base_key),
143 None,
144 message.registration_id,
145 ).await?;
146
147 record.set_session(session);
148 Ok(message.pre_key_id)
149 }
150
151 async fn static_init_session<S: SessionStorage>(
152 storage: Arc<S>,
153 is_initiator: bool,
154 our_ephemeral_key: Option<&KeyPair>,
155 our_signed_key: Option<&KeyPair>,
156 their_identity_pub_key: &[u8],
157 their_ephemeral_pub_key: Option<&[u8]>,
158 their_signed_pub_key: Option<&[u8]>,
159 registration_id: u32,
160 ) -> Result<SessionEntry, Box<dyn std::error::Error + Send + Sync>> {
161 let our_signed_key = if is_initiator {
162 our_ephemeral_key.unwrap()
163 } else {
164 our_signed_key.unwrap()
165 };
166
167 let their_signed_pub_key = if is_initiator {
168 their_signed_pub_key.unwrap()
169 } else {
170 their_ephemeral_pub_key.unwrap()
171 };
172
173 let shared_secret_len = if our_ephemeral_key.is_none() || their_ephemeral_pub_key.is_none() {
174 32 * 4
175 } else {
176 32 * 5
177 };
178
179 let mut shared_secret = vec![0xffu8; 32];
180 shared_secret.resize(shared_secret_len, 0);
181
182 let our_identity = storage.get_our_identity().await;
183 let a1 = curve::calculate_agreement(their_signed_pub_key, &our_identity.priv_key)?;
184 let a2 = curve::calculate_agreement(their_identity_pub_key, &our_signed_key.priv_key)?;
185 let a3 = curve::calculate_agreement(their_signed_pub_key, &our_signed_key.priv_key)?;
186
187 if is_initiator {
188 shared_secret[32..64].copy_from_slice(&a1);
189 shared_secret[64..96].copy_from_slice(&a2);
190 } else {
191 shared_secret[64..96].copy_from_slice(&a1);
192 shared_secret[32..64].copy_from_slice(&a2);
193 }
194 shared_secret[96..128].copy_from_slice(&a3);
195
196 if let (Some(our_eph), Some(their_eph)) = (our_ephemeral_key, their_ephemeral_pub_key) {
197 let a4 = curve::calculate_agreement(their_eph, &our_eph.priv_key)?;
198 shared_secret[128..160].copy_from_slice(&a4);
199 }
200
201 let master_key = crypto::derive_secrets(&shared_secret, &[0u8; 32], b"WhisperText", None)?;
202
203 let mut session = SessionEntry::new();
204 session.registration_id = registration_id;
205 session.current_ratchet = CurrentRatchet {
206 root_key: master_key[0].clone(),
207 ephemeral_key_pair: if is_initiator {
208 curve::generate_key_pair()
209 } else {
210 our_signed_key.clone()
211 },
212 last_remote_ephemeral_key: their_signed_pub_key.to_vec(),
213 previous_counter: 0,
214 };
215
216 session.index_info = IndexInfo {
217 created: chrono::Utc::now().timestamp() as u64,
218 used: chrono::Utc::now().timestamp() as u64,
219 remote_identity_key: their_identity_pub_key.to_vec(),
220 base_key: if is_initiator {
221 our_ephemeral_key.unwrap().pub_key.clone()
222 } else {
223 their_ephemeral_pub_key.unwrap().to_vec()
224 },
225 base_key_type: if is_initiator { BaseKeyType::Ours } else { BaseKeyType::Theirs },
226 closed: -1,
227 };
228
229 if is_initiator {
230 let ephemeral_pub_key = session.current_ratchet.ephemeral_key_pair.pub_key.clone();
231 session.add_chain(&ephemeral_pub_key, ChainInfo {
232 message_keys: Default::default(),
233 chain_key: ChainKey {
234 counter: -1,
235 key: Some(master_key[1].clone()),
236 },
237 chain_type: ChainType::Sending,
238 })?;
239 }
240
241 Ok(session)
242 }
243}