1use crate::account::OlmAccount;
19use crate::errors::{self, OlmSessionError};
20use crate::getrandom;
21use crate::{ByteBuf, PicklingMode};
22use std::cmp::Ordering;
23use std::convert::TryFrom;
24use std::ffi::CStr;
25use std::fmt;
26
27use zeroize::Zeroizing;
28
29#[derive(Debug)]
31pub struct OlmSession {
32 pub(crate) olm_session_ptr: *mut olm_sys::OlmSession,
33 _olm_session_buf: ByteBuf,
34}
35
36#[derive(Debug, Clone)]
37pub struct Message(String);
39
40#[derive(Debug, Clone)]
41pub struct PreKeyMessage(String);
45
46impl PreKeyMessage {
47 fn new(message: String) -> Self {
49 PreKeyMessage(message)
50 }
51}
52
53impl Message {
54 fn new(ciphertext: String) -> Self {
56 Message(ciphertext)
57 }
58}
59
60#[derive(Debug, Clone)]
61pub enum OlmMessage {
63 Message(Message),
65 PreKey(PreKeyMessage),
67}
68
69#[derive(Debug)]
70pub struct UnknownOlmMessageType;
71
72impl fmt::Display for UnknownOlmMessageType {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 write!(f, "Unknown message type")
75 }
76}
77
78impl std::error::Error for UnknownOlmMessageType {}
79
80impl OlmMessage {
81 pub fn from_type_and_ciphertext(
89 message_type: usize,
90 ciphertext: String,
91 ) -> Result<Self, UnknownOlmMessageType> {
92 match message_type {
93 olm_sys::OLM_MESSAGE_TYPE_PRE_KEY => {
94 Ok(OlmMessage::PreKey(PreKeyMessage::new(ciphertext)))
95 }
96 olm_sys::OLM_MESSAGE_TYPE_MESSAGE => Ok(OlmMessage::Message(Message::new(ciphertext))),
97 _ => Err(UnknownOlmMessageType),
98 }
99 }
100
101 #[allow(clippy::wrong_self_convention)]
102 pub fn to_tuple(self) -> (OlmMessageType, String) {
104 match self {
105 OlmMessage::Message(m) => (OlmMessageType::Message, m.0),
106 OlmMessage::PreKey(m) => (OlmMessageType::PreKey, m.0),
107 }
108 }
109}
110
111impl OlmSession {
112 pub(crate) fn create_inbound_session(
124 account: &OlmAccount,
125 mut message: PreKeyMessage,
126 ) -> Result<Self, OlmSessionError> {
127 Self::create_session_with(|olm_session_ptr| unsafe {
128 let one_time_key_message_buf = message.0.as_bytes_mut();
129 olm_sys::olm_create_inbound_session(
130 olm_session_ptr,
131 account.olm_account_ptr,
132 one_time_key_message_buf.as_mut_ptr() as *mut _,
133 one_time_key_message_buf.len(),
134 )
135 })
136 }
137
138 pub(crate) fn create_inbound_session_from(
150 account: &OlmAccount,
151 their_identity_key: &str,
152 mut one_time_key_message: PreKeyMessage,
153 ) -> Result<Self, OlmSessionError> {
154 Self::create_session_with(|olm_session_ptr| {
155 let their_identity_key_buf = their_identity_key.as_bytes();
156 unsafe {
157 let one_time_key_message_buf = one_time_key_message.0.as_bytes_mut();
158 olm_sys::olm_create_inbound_session_from(
159 olm_session_ptr,
160 account.olm_account_ptr,
161 their_identity_key_buf.as_ptr() as *const _,
162 their_identity_key_buf.len(),
163 one_time_key_message_buf.as_mut_ptr() as *mut _,
164 one_time_key_message_buf.len(),
165 )
166 }
167 })
168 }
169
170 pub(crate) fn create_outbound_session(
183 account: &OlmAccount,
184 their_identity_key: &str,
185 their_one_time_key: &str,
186 ) -> Result<Self, OlmSessionError> {
187 Self::create_session_with(|olm_session_ptr| {
188 let their_identity_key_buf = their_identity_key.as_bytes();
189 let their_one_time_key_buf = their_one_time_key.as_bytes();
190 let random_len =
191 unsafe { olm_sys::olm_create_outbound_session_random_length(olm_session_ptr) };
192 let mut random_buf: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0; random_len]);
193 getrandom(&mut random_buf);
194
195 unsafe {
196 olm_sys::olm_create_outbound_session(
197 olm_session_ptr,
198 account.olm_account_ptr,
199 their_identity_key_buf.as_ptr() as *const _,
200 their_identity_key_buf.len(),
201 their_one_time_key_buf.as_ptr() as *const _,
202 their_one_time_key_buf.len(),
203 random_buf.as_mut_ptr() as *mut _,
204 random_len,
205 )
206 }
207 })
208 }
209
210 fn create_session_with<F: FnMut(*mut olm_sys::OlmSession) -> usize>(
212 mut f: F,
213 ) -> Result<OlmSession, OlmSessionError> {
214 let mut olm_session_buf = ByteBuf::new(unsafe { olm_sys::olm_session_size() });
215 let olm_session_ptr = unsafe { olm_sys::olm_session(olm_session_buf.as_mut_void_ptr()) };
216
217 let error = f(olm_session_ptr);
218 if error == errors::olm_error() {
219 let last_error = Self::last_error(olm_session_ptr);
220 if last_error == OlmSessionError::NotEnoughRandom {
221 errors::handle_fatal_error(OlmSessionError::NotEnoughRandom);
222 }
223
224 Err(last_error)
225 } else {
226 Ok(OlmSession {
227 olm_session_ptr,
228 _olm_session_buf: olm_session_buf,
229 })
230 }
231 }
232
233 fn last_error(session_ptr: *mut olm_sys::OlmSession) -> OlmSessionError {
235 let error_raw = unsafe { olm_sys::olm_session_last_error(session_ptr) };
237 let error = unsafe { CStr::from_ptr(error_raw).to_str().unwrap() };
238
239 match error {
240 "BAD_ACCOUNT_KEY" => OlmSessionError::BadAccountKey,
241 "BAD_MESSAGE_MAC" => OlmSessionError::BadMessageMac,
242 "BAD_MESSAGE_FORMAT" => OlmSessionError::BadMessageFormat,
243 "BAD_MESSAGE_KEY_ID" => OlmSessionError::BadMessageKeyId,
244 "BAD_MESSAGE_VERSION" => OlmSessionError::BadMessageVersion,
245 "INVALID_BASE64" => OlmSessionError::InvalidBase64,
246 "NOT_ENOUGH_RANDOM" => OlmSessionError::NotEnoughRandom,
247 "OUTPUT_BUFFER_TOO_SMALL" => OlmSessionError::OutputBufferTooSmall,
248 _ => OlmSessionError::Unknown,
249 }
250 }
251
252 pub fn session_id(&self) -> String {
262 let session_id_len = unsafe { olm_sys::olm_session_id_length(self.olm_session_ptr) };
263 let mut session_id_buf: Vec<u8> = vec![0; session_id_len];
264
265 let error = unsafe {
266 olm_sys::olm_session_id(
267 self.olm_session_ptr,
268 session_id_buf.as_mut_ptr() as *mut _,
269 session_id_len,
270 )
271 };
272
273 let session_id_result = String::from_utf8(session_id_buf).unwrap();
274
275 if error == errors::olm_error() {
276 errors::handle_fatal_error(Self::last_error(self.olm_session_ptr));
277 }
278
279 session_id_result
280 }
281
282 pub fn pickle(&self, mode: PicklingMode) -> String {
292 let pickled_len = unsafe { olm_sys::olm_pickle_session_length(self.olm_session_ptr) };
293 let mut pickled_buf = vec![0; pickled_len];
294
295 let pickle_error = {
296 let key = Zeroizing::new(crate::convert_pickling_mode_to_key(mode));
297
298 unsafe {
299 olm_sys::olm_pickle_session(
300 self.olm_session_ptr,
301 key.as_ptr() as *const _,
302 key.len(),
303 pickled_buf.as_mut_ptr() as *mut _,
304 pickled_len,
305 )
306 }
307 };
308
309 let pickled_result = String::from_utf8(pickled_buf).unwrap();
310
311 if pickle_error == errors::olm_error() {
312 errors::handle_fatal_error(Self::last_error(self.olm_session_ptr));
313 }
314
315 pickled_result
316 }
317
318 pub fn unpickle(mut pickled: String, mode: PicklingMode) -> Result<Self, OlmSessionError> {
328 let key = Zeroizing::new(crate::convert_pickling_mode_to_key(mode));
329
330 Self::create_session_with(|olm_session_ptr| {
331 let pickled_len = pickled.len();
332 unsafe {
333 let pickled_buf = pickled.as_bytes_mut();
334
335 olm_sys::olm_unpickle_session(
336 olm_session_ptr,
337 key.as_ptr() as *const _,
338 key.len(),
339 pickled_buf.as_mut_ptr() as *mut _,
340 pickled_len,
341 )
342 }
343 })
344 }
345
346 pub fn encrypt(&self, plaintext: &str) -> OlmMessage {
357 let plaintext_buf = plaintext.as_bytes();
358 let plaintext_len = plaintext_buf.len();
359 let message_len =
360 unsafe { olm_sys::olm_encrypt_message_length(self.olm_session_ptr, plaintext_len) };
361 let mut message_buf: Vec<u8> = vec![0; message_len];
362
363 let message_type = self.encrypt_message_type();
364
365 let encrypt_error = {
366 let random_len = unsafe { olm_sys::olm_encrypt_random_length(self.olm_session_ptr) };
367 let mut random_buf: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0; random_len]);
368 getrandom(&mut random_buf);
369
370 unsafe {
371 olm_sys::olm_encrypt(
372 self.olm_session_ptr,
373 plaintext_buf.as_ptr() as *const _,
374 plaintext_len,
375 random_buf.as_mut_ptr() as *mut _,
376 random_len,
377 message_buf.as_mut_ptr() as *mut _,
378 message_len,
379 )
380 }
381 };
382
383 let message_result = String::from_utf8(message_buf).unwrap();
384
385 if encrypt_error == errors::olm_error() {
386 errors::handle_fatal_error(Self::last_error(self.olm_session_ptr));
387 }
388
389 match message_type {
390 OlmMessageType::Message => OlmMessage::Message(Message::new(message_result)),
391 OlmMessageType::PreKey => OlmMessage::PreKey(PreKeyMessage::new(message_result)),
392 }
393 }
394
395 pub fn decrypt(&self, message: OlmMessage) -> Result<String, OlmSessionError> {
412 let (message_type, mut ciphertext) = message.to_tuple();
414 let message_type_val = match message_type {
415 OlmMessageType::PreKey => olm_sys::OLM_MESSAGE_TYPE_PRE_KEY,
416 _ => olm_sys::OLM_MESSAGE_TYPE_MESSAGE,
417 };
418
419 let mut message_for_len = ciphertext.to_owned();
422 let message_buf = unsafe { message_for_len.as_bytes_mut() };
423 let message_len = message_buf.len();
424 let message_ptr = message_buf.as_mut_ptr() as *mut _;
425
426 let plaintext_max_len = unsafe {
427 olm_sys::olm_decrypt_max_plaintext_length(
428 self.olm_session_ptr,
429 message_type_val,
430 message_ptr,
431 message_len,
432 )
433 };
434 if plaintext_max_len == errors::olm_error() {
435 return Err(Self::last_error(self.olm_session_ptr));
436 }
437
438 let mut plaintext_buf = Zeroizing::new(vec![0; plaintext_max_len]);
439
440 let message_buf = unsafe { ciphertext.as_bytes_mut() };
441 let message_len = message_buf.len();
442 let message_ptr = message_buf.as_mut_ptr() as *mut _;
443
444 let plaintext_result_len = unsafe {
445 olm_sys::olm_decrypt(
446 self.olm_session_ptr,
447 message_type_val,
448 message_ptr,
449 message_len,
450 plaintext_buf.as_mut_ptr() as *mut _,
451 plaintext_max_len,
452 )
453 };
454
455 let decrypt_error = plaintext_result_len;
456 if decrypt_error == errors::olm_error() {
457 let last_error = Self::last_error(self.olm_session_ptr);
458 if last_error == OlmSessionError::OutputBufferTooSmall {
459 errors::handle_fatal_error(OlmSessionError::OutputBufferTooSmall);
460 }
461 return Err(last_error);
462 }
463
464 plaintext_buf.truncate(plaintext_result_len);
465 Ok(String::from_utf8_lossy(&plaintext_buf).to_string())
466 }
467
468 pub(crate) fn encrypt_message_type(&self) -> OlmMessageType {
478 let message_type_result =
479 unsafe { olm_sys::olm_encrypt_message_type(self.olm_session_ptr) };
480
481 let message_type_error = message_type_result;
483
484 if message_type_error == errors::olm_error() {
485 errors::handle_fatal_error(Self::last_error(self.olm_session_ptr));
486 }
487
488 match message_type_result {
489 olm_sys::OLM_MESSAGE_TYPE_PRE_KEY => OlmMessageType::PreKey,
490 _ => OlmMessageType::Message,
491 }
492 }
493
494 pub fn has_received_message(&self) -> bool {
500 0 != unsafe { olm_sys::olm_session_has_received_message(self.olm_session_ptr) }
512 }
513
514 pub fn matches_inbound_session(
525 &self,
526 mut message: PreKeyMessage,
527 ) -> Result<bool, OlmSessionError> {
528 let matches_result = unsafe {
529 let one_time_key_message_buf = message.0.as_bytes_mut();
530
531 olm_sys::olm_matches_inbound_session(
532 self.olm_session_ptr,
533 one_time_key_message_buf.as_mut_ptr() as *mut _,
534 one_time_key_message_buf.len(),
535 )
536 };
537
538 let matches_error = matches_result;
540 if matches_error == errors::olm_error() {
541 Err(OlmSession::last_error(self.olm_session_ptr))
542 } else {
543 match matches_result {
544 0 => Ok(false),
545 1 => Ok(true),
546 _ => Err(OlmSessionError::Unknown),
547 }
548 }
549 }
550
551 pub fn matches_inbound_session_from(
562 &self,
563 their_identity_key: &str,
564 mut message: PreKeyMessage,
565 ) -> Result<bool, OlmSessionError> {
566 let their_identity_key_buf = their_identity_key.as_bytes();
567 let their_identity_key_ptr = their_identity_key_buf.as_ptr() as *const _;
568 let matches_result = unsafe {
569 let one_time_key_message_buf = message.0.as_bytes_mut();
570
571 olm_sys::olm_matches_inbound_session_from(
572 self.olm_session_ptr,
573 their_identity_key_ptr,
574 their_identity_key_buf.len(),
575 one_time_key_message_buf.as_mut_ptr() as *mut _,
576 one_time_key_message_buf.len(),
577 )
578 };
579
580 let matches_error = matches_result;
582 if matches_error == errors::olm_error() {
583 Err(OlmSession::last_error(self.olm_session_ptr))
584 } else {
585 match matches_result {
586 0 => Ok(false),
587 1 => Ok(true),
588 _ => Err(OlmSessionError::Unknown),
589 }
590 }
591 }
592}
593
594#[derive(Clone, Copy, Debug, PartialEq)]
596pub enum OlmMessageType {
597 PreKey,
598 Message,
599}
600
601impl From<OlmMessageType> for usize {
602 fn from(message_type: OlmMessageType) -> Self {
603 match message_type {
604 OlmMessageType::PreKey => olm_sys::OLM_MESSAGE_TYPE_PRE_KEY,
605 OlmMessageType::Message => olm_sys::OLM_MESSAGE_TYPE_MESSAGE,
606 }
607 }
608}
609
610impl TryFrom<usize> for OlmMessageType {
611 type Error = ();
612
613 fn try_from(message_type: usize) -> Result<OlmMessageType, ()> {
614 match message_type {
615 olm_sys::OLM_MESSAGE_TYPE_PRE_KEY => Ok(OlmMessageType::PreKey),
616 olm_sys::OLM_MESSAGE_TYPE_MESSAGE => Ok(OlmMessageType::Message),
617 _ => Err(()),
618 }
619 }
620}
621
622impl Ord for OlmSession {
624 fn cmp(&self, other: &Self) -> Ordering {
625 self.session_id().cmp(&other.session_id())
626 }
627}
628
629impl PartialOrd for OlmSession {
630 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
631 Some(self.cmp(other))
632 }
633}
634
635impl PartialEq for OlmSession {
636 fn eq(&self, other: &Self) -> bool {
637 self.session_id() == other.session_id()
638 }
639}
640
641impl Eq for OlmSession {}
642
643impl Drop for OlmSession {
644 fn drop(&mut self) {
645 unsafe {
646 olm_sys::olm_clear_session(self.olm_session_ptr);
647 }
648 }
649}
650
651#[cfg(test)]
652mod test {
653 use crate::account::OlmAccount;
654 use crate::session::OlmMessageType;
655
656 #[test]
657 fn message_type() {
658 let alice = OlmAccount::new();
659 let bob = OlmAccount::new();
660
661 alice.generate_one_time_keys(1);
662
663 let identity_key = alice.parsed_identity_keys().ed25519().to_owned();
664 let one_time_key = alice
665 .parsed_one_time_keys()
666 .curve25519()
667 .values()
668 .next()
669 .unwrap()
670 .to_owned();
671
672 let outbound_session = bob
673 .create_outbound_session(&identity_key, &one_time_key)
674 .unwrap();
675
676 assert_eq!(
677 OlmMessageType::PreKey,
678 outbound_session.encrypt_message_type()
679 );
680 assert!(!outbound_session.has_received_message());
681 }
682}