1use std::ffi::CStr;
18
19use zeroize::Zeroizing;
20
21use crate::errors::{self, OlmPkDecryptionError, OlmPkEncryptionError, OlmPkSigningError};
22use crate::{getrandom, ByteBuf, PicklingMode};
23
24pub struct PkMessage {
26 pub ciphertext: String,
27 pub mac: String,
28 pub ephemeral_key: String,
29}
30
31impl PkMessage {
32 pub fn new(ephemeral_key: String, mac: String, ciphertext: String) -> Self {
43 PkMessage {
44 ciphertext,
45 mac,
46 ephemeral_key,
47 }
48 }
49}
50
51pub struct OlmPkEncryption {
53 ptr: *mut olm_sys::OlmPkEncryption,
54 _buf: ByteBuf,
55}
56
57impl Drop for OlmPkEncryption {
58 fn drop(&mut self) {
59 unsafe {
60 olm_sys::olm_clear_pk_encryption(self.ptr);
61 }
62 }
63}
64
65impl Default for OlmPkDecryption {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl OlmPkEncryption {
72 pub fn new(recipient_key: &str) -> Self {
79 let mut buf = ByteBuf::new(unsafe { olm_sys::olm_pk_encryption_size() });
80 let ptr = unsafe { olm_sys::olm_pk_encryption(buf.as_mut_void_ptr()) };
81
82 unsafe {
83 olm_sys::olm_pk_encryption_set_recipient_key(
84 ptr,
85 recipient_key.as_ptr() as *mut _,
86 recipient_key.len(),
87 );
88 }
89
90 Self { ptr, _buf: buf }
91 }
92
93 fn last_error(ptr: *mut olm_sys::OlmPkEncryption) -> OlmPkEncryptionError {
94 let error = unsafe {
95 let error_raw = olm_sys::olm_pk_encryption_last_error(ptr);
96 CStr::from_ptr(error_raw).to_str().unwrap()
97 };
98 error.into()
99 }
100
101 pub fn encrypt(&self, plaintext: &str) -> PkMessage {
116 let random_length = unsafe { olm_sys::olm_pk_encrypt_random_length(self.ptr) };
117
118 let mut random_buf = Zeroizing::new(vec![0; random_length]);
119 getrandom(&mut random_buf);
120
121 let ciphertext_length =
122 unsafe { olm_sys::olm_pk_ciphertext_length(self.ptr, plaintext.len()) };
123
124 let mac_length = unsafe { olm_sys::olm_pk_mac_length(self.ptr) };
125
126 let ephemeral_key_size = unsafe { olm_sys::olm_pk_key_length() };
127
128 let mut ciphertext = vec![0; ciphertext_length];
129 let mut mac = vec![0; mac_length];
130 let mut ephemeral_key = vec![0; ephemeral_key_size];
131
132 let ret = unsafe {
133 olm_sys::olm_pk_encrypt(
134 self.ptr,
135 plaintext.as_ptr() as *const _,
136 plaintext.len(),
137 ciphertext.as_mut_ptr() as *mut _,
138 ciphertext.len(),
139 mac.as_mut_ptr() as *mut _,
140 mac.len(),
141 ephemeral_key.as_mut_ptr() as *mut _,
142 ephemeral_key.len(),
143 random_buf.as_ptr() as *mut _,
144 random_buf.len(),
145 )
146 };
147
148 if ret == errors::olm_error() {
149 errors::handle_fatal_error(OlmPkEncryption::last_error(self.ptr));
150 }
151
152 let ciphertext = unsafe { String::from_utf8_unchecked(ciphertext) };
153 let mac = unsafe { String::from_utf8_unchecked(mac) };
154 let ephemeral_key = unsafe { String::from_utf8_unchecked(ephemeral_key) };
155
156 PkMessage {
157 ciphertext,
158 mac,
159 ephemeral_key,
160 }
161 }
162}
163
164pub struct OlmPkDecryption {
166 ptr: *mut olm_sys::OlmPkDecryption,
167 _buf: ByteBuf,
168 public_key: String,
169}
170
171impl Drop for OlmPkDecryption {
172 fn drop(&mut self) {
173 unsafe {
174 olm_sys::olm_clear_pk_decryption(self.ptr);
175 }
176 }
177}
178
179impl OlmPkDecryption {
180 pub fn new() -> Self {
189 let random_len = Self::private_key_length();
190 let mut random_buf = Zeroizing::new(vec![0; random_len]);
191 getrandom(&mut random_buf);
192
193 Self::from_bytes(&random_buf)
194 .expect("Can't create a PK decryption object from a valid random key")
195 }
196
197 pub fn private_key_length() -> usize {
199 unsafe { olm_sys::olm_pk_private_key_length() }
200 }
201
202 pub fn from_bytes(bytes: &[u8]) -> Result<Self, OlmPkDecryptionError> {
217 let (ptr, buf) = OlmPkDecryption::init();
218
219 let key_length = unsafe { olm_sys::olm_pk_key_length() };
220 let mut key_buffer = vec![0; key_length];
221
222 let ret = unsafe {
223 olm_sys::olm_pk_key_from_private(
224 ptr,
225 key_buffer.as_mut_ptr() as *mut _,
226 key_buffer.len(),
227 bytes.as_ptr() as *const _,
228 bytes.len(),
229 )
230 };
231
232 if ret == errors::olm_error() {
233 Err(Self::last_error(ptr))
234 } else {
235 let public_key = String::from_utf8(key_buffer)
236 .expect("Can't convert the public key buffer to a string");
237
238 Ok(Self {
239 ptr,
240 _buf: buf,
241 public_key,
242 })
243 }
244 }
245
246 fn init() -> (*mut olm_sys::OlmPkDecryption, ByteBuf) {
247 let mut buf = ByteBuf::new(unsafe { olm_sys::olm_pk_decryption_size() });
248 let ptr = unsafe { olm_sys::olm_pk_decryption(buf.as_mut_void_ptr() as *mut _) };
249
250 (ptr, buf)
251 }
252
253 fn last_error(ptr: *mut olm_sys::OlmPkDecryption) -> OlmPkDecryptionError {
254 let error = unsafe {
255 let error_raw = olm_sys::olm_pk_decryption_last_error(ptr);
256 CStr::from_ptr(error_raw).to_str().unwrap()
257 };
258 error.into()
259 }
260
261 pub fn pickle(&self, mode: PicklingMode) -> String {
276 let mut pickled_buf: Vec<u8> =
277 vec![0; unsafe { olm_sys::olm_pickle_pk_decryption_length(self.ptr) }];
278
279 let pickle_error = {
280 let key = Zeroizing::new(crate::convert_pickling_mode_to_key(mode));
281
282 unsafe {
283 olm_sys::olm_pickle_pk_decryption(
284 self.ptr,
285 key.as_ptr() as *const _,
286 key.len(),
287 pickled_buf.as_mut_ptr() as *mut _,
288 pickled_buf.len(),
289 )
290 }
291 };
292
293 let pickled_result =
294 String::from_utf8(pickled_buf).expect("Pickle string is not valid utf-8");
295
296 if pickle_error == errors::olm_error() {
297 errors::handle_fatal_error(Self::last_error(self.ptr));
298 }
299
300 pickled_result
301 }
302
303 pub fn unpickle(mut pickle: String, mode: PicklingMode) -> Result<Self, OlmPkDecryptionError> {
326 let (ptr, buf) = OlmPkDecryption::init();
327
328 let pubkey_length = unsafe { olm_sys::olm_pk_signing_public_key_length() };
329 let mut pubkey_buffer = vec![0; pubkey_length];
330
331 let unpickle_error = {
332 let key = Zeroizing::new(crate::convert_pickling_mode_to_key(mode));
333
334 unsafe {
335 olm_sys::olm_unpickle_pk_decryption(
336 ptr,
337 key.as_ptr() as *const _,
338 key.len(),
339 pickle.as_mut_ptr() as *mut _,
340 pickle.len(),
341 pubkey_buffer.as_mut_ptr() as *mut _,
342 pubkey_buffer.len(),
343 )
344 }
345 };
346
347 let public_key = String::from_utf8(pubkey_buffer)
348 .expect("Can't conver the public key buffer to a string");
349
350 if unpickle_error == errors::olm_error() {
351 Err(Self::last_error(ptr))
352 } else {
353 Ok(Self {
354 ptr,
355 _buf: buf,
356 public_key,
357 })
358 }
359 }
360
361 pub fn decrypt(&self, mut message: PkMessage) -> Result<String, OlmPkDecryptionError> {
384 let max_plaintext = {
385 let ret =
386 unsafe { olm_sys::olm_pk_max_plaintext_length(self.ptr, message.ciphertext.len()) };
387
388 if ret == errors::olm_error() {
389 return Err(OlmPkDecryptionError::InvalidBase64);
390 }
391
392 ret
393 };
394
395 let mut plaintext = vec![0; max_plaintext];
396
397 let plaintext_len = unsafe {
398 olm_sys::olm_pk_decrypt(
399 self.ptr,
400 message.ephemeral_key.as_ptr() as *const _,
401 message.ephemeral_key.len(),
402 message.mac.as_ptr() as *const _,
403 message.mac.len(),
404 message.ciphertext.as_mut_ptr() as *mut _,
405 message.ciphertext.len(),
406 plaintext.as_mut_ptr() as *mut _,
407 max_plaintext,
408 )
409 };
410
411 if plaintext_len == errors::olm_error() {
412 Err(Self::last_error(self.ptr))
413 } else {
414 plaintext.truncate(plaintext_len);
415 Ok(String::from_utf8_lossy(&plaintext).to_string())
416 }
417 }
418
419 pub fn public_key(&self) -> &str {
424 &self.public_key
425 }
426}
427
428pub struct OlmPkSigning {
430 ptr: *mut olm_sys::OlmPkSigning,
431 _buf: ByteBuf,
432 public_key: String,
433}
434
435impl Drop for OlmPkSigning {
436 fn drop(&mut self) {
437 unsafe { olm_sys::olm_clear_pk_signing(self.ptr) };
438 }
439}
440
441impl OlmPkSigning {
442 pub fn new(seed: &[u8]) -> Result<Self, OlmPkSigningError> {
451 if seed.len() != OlmPkSigning::seed_length() {
452 return Err(OlmPkSigningError::InvalidSeed);
453 }
454
455 let mut buffer = ByteBuf::new(unsafe { olm_sys::olm_pk_signing_size() });
456
457 let ptr = unsafe { olm_sys::olm_pk_signing(buffer.as_mut_void_ptr() as *mut _) };
458 let pubkey_length = unsafe { olm_sys::olm_pk_signing_public_key_length() };
459 let mut pubkey_buffer = vec![0; pubkey_length];
460
461 let ret = unsafe {
462 olm_sys::olm_pk_signing_key_from_seed(
463 ptr,
464 pubkey_buffer.as_mut_ptr() as *mut _,
465 pubkey_length,
466 seed.as_ptr() as *const _,
467 seed.len(),
468 )
469 };
470
471 if ret == errors::olm_error() {
472 Err(OlmPkSigning::last_error(ptr))
473 } else {
474 Ok(Self {
475 ptr,
476 _buf: buffer,
477 public_key: String::from_utf8(pubkey_buffer)
478 .expect("Can't conver the public key buffer to a string"),
479 })
480 }
481 }
482
483 fn last_error(ptr: *mut olm_sys::OlmPkSigning) -> OlmPkSigningError {
484 let error = unsafe {
485 let error_raw = olm_sys::olm_pk_signing_last_error(ptr);
486 CStr::from_ptr(error_raw).to_str().unwrap()
487 };
488 error.into()
489 }
490
491 pub fn seed_length() -> usize {
493 unsafe { olm_sys::olm_pk_signing_seed_length() }
494 }
495
496 pub fn generate_seed() -> Vec<u8> {
499 let length = OlmPkSigning::seed_length();
500 let mut buffer = Zeroizing::new(vec![0; length]);
501
502 getrandom(&mut buffer);
503
504 buffer.to_vec()
505 }
506
507 pub fn public_key(&self) -> &str {
527 &self.public_key
528 }
529
530 pub fn sign(&self, message: &str) -> String {
543 let signature_len = unsafe { olm_sys::olm_pk_signature_length() };
544
545 let mut signature = vec![0; signature_len];
546
547 let ret = unsafe {
548 olm_sys::olm_pk_sign(
549 self.ptr,
550 message.as_ptr() as *mut _,
551 message.len(),
552 signature.as_mut_ptr() as *mut _,
553 signature_len,
554 )
555 };
556
557 if ret == errors::olm_error() {
558 errors::handle_fatal_error(Self::last_error(self.ptr));
559 }
560
561 String::from_utf8(signature).expect("Can't conver the signature to a string")
562 }
563}
564
565#[cfg(test)]
566mod test {
567 use crate::errors::OlmPkDecryptionError;
568 use crate::pk::{OlmPkDecryption, OlmPkEncryption, OlmPkSigning, PkMessage};
569 use crate::utility::OlmUtility;
570 use crate::PicklingMode;
571
572 #[test]
573 fn create_pk_sign() {
574 assert!(OlmPkSigning::new(&OlmPkSigning::generate_seed()).is_ok());
575 }
576
577 #[test]
578 fn invalid_seed() {
579 assert!(OlmPkSigning::new(&[]).is_err());
580
581 let lo_seed_len = OlmPkSigning::seed_length() - 1;
582 let hi_seed_len = OlmPkSigning::seed_length() + 1;
583
584 assert!(OlmPkSigning::new(&vec![0; lo_seed_len]).is_err());
585 assert!(OlmPkSigning::new(&vec![0; hi_seed_len]).is_err());
586 }
587
588 #[test]
589 fn seed_random() {
590 let seed_a = OlmPkSigning::generate_seed();
591 let seed_b = OlmPkSigning::generate_seed();
592 assert_ne!(&seed_a[..], &seed_b[..]);
593 }
594
595 #[test]
596 fn sign_a_message() {
597 let message = "It's a secret to everyone".to_string();
598 let sign = OlmPkSigning::new(&OlmPkSigning::generate_seed()).unwrap();
599 let utility = OlmUtility::new();
600
601 let signature = sign.sign(&message);
602 assert!(utility
603 .ed25519_verify(sign.public_key(), &message, signature.clone())
604 .is_ok());
605 assert!(utility
606 .ed25519_verify(sign.public_key(), "Hello world", signature)
607 .is_err());
608 }
609
610 #[test]
611 fn encrypt_a_message() {
612 let message = "It's a secret to everyone".to_string();
613 let decryption = OlmPkDecryption::new();
614 let encryption = OlmPkEncryption::new(decryption.public_key());
615
616 let encrypted_message = encryption.encrypt(&message);
617
618 let plaintext = decryption.decrypt(encrypted_message).unwrap();
619
620 assert_eq!(message, plaintext);
621 }
622
623 #[test]
624 fn pickle() {
625 let message = "It's a secret to everyone".to_string();
626 let decryption = OlmPkDecryption::new();
627 let encryption = OlmPkEncryption::new(decryption.public_key());
628
629 let encrypted_message = encryption.encrypt(&message);
630
631 let pickle = decryption.pickle(PicklingMode::Unencrypted);
632 let decryption = OlmPkDecryption::unpickle(pickle, PicklingMode::Unencrypted).unwrap();
633
634 let plaintext = decryption.decrypt(encrypted_message).unwrap();
635
636 assert_eq!(message, plaintext);
637 }
638
639 #[test]
640 fn invalid_unpickle() {
641 let decryption = OlmPkDecryption::new();
642
643 let pickle = decryption.pickle(PicklingMode::Encrypted {
644 key: Vec::from("wordpass"),
645 });
646 assert!(OlmPkDecryption::unpickle(pickle, PicklingMode::Unencrypted).is_err());
647 }
648
649 #[test]
650 fn invalid_decrypt() {
651 let alice = OlmPkDecryption::new();
652 let malory = OlmPkEncryption::new(OlmPkDecryption::new().public_key());
653
654 let encrypted_message = malory.encrypt("It's a secret to everyone");
655 assert!(alice.decrypt(encrypted_message).is_err());
656 }
657
658 #[test]
659 fn attempt_decrypt_invalid_base64() {
660 let decryption = OlmPkDecryption::new();
661 let message = PkMessage {
662 ciphertext: "1".to_string(),
663 mac: "".to_string(),
664 ephemeral_key: "".to_string(),
665 };
666
667 assert_eq!(
668 Err(OlmPkDecryptionError::InvalidBase64),
669 decryption.decrypt(message)
670 );
671 }
672}