olm_rs/
pk.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module wraps around all functions following the pattern `olm_pk_*`.
16
17use std::ffi::CStr;
18
19use zeroize::Zeroizing;
20
21use crate::errors::{self, OlmPkDecryptionError, OlmPkEncryptionError, OlmPkSigningError};
22use crate::{getrandom, ByteBuf, PicklingMode};
23
24/// A PK encrypted message.
25pub struct PkMessage {
26    pub ciphertext: String,
27    pub mac: String,
28    pub ephemeral_key: String,
29}
30
31impl PkMessage {
32    /// Create a new PK encrypted message.
33    ///
34    /// # Arguments
35    ///
36    /// * `ephemeral_key` - the public part of the ephemeral key used (together
37    /// with the recipient's key) to generate a symmetric encryption key.
38    ///
39    /// * `mac` - Message Authentication Code of the encrypted message
40    ///
41    /// * `ciphertext` - The cipher text of the encrypted message
42    pub fn new(ephemeral_key: String, mac: String, ciphertext: String) -> Self {
43        PkMessage {
44            ciphertext,
45            mac,
46            ephemeral_key,
47        }
48    }
49}
50
51/// The encryption part of a PK encrypted channel.
52pub struct OlmPkEncryption {
53    ptr: *mut olm_sys::OlmPkEncryption,
54    _buf: ByteBuf,
55}
56
57impl Drop for OlmPkEncryption {
58    fn drop(&mut self) {
59        unsafe {
60            olm_sys::olm_clear_pk_encryption(self.ptr);
61        }
62    }
63}
64
65impl Default for OlmPkDecryption {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl OlmPkEncryption {
72    /// Create a new PK encryption object.
73    ///
74    /// # Arguments
75    ///
76    /// * `recipient_key` - a public key that will be used for encryption, the
77    ///     public key will be provided by the matching decryption object.
78    pub fn new(recipient_key: &str) -> Self {
79        let mut buf = ByteBuf::new(unsafe { olm_sys::olm_pk_encryption_size() });
80        let ptr = unsafe { olm_sys::olm_pk_encryption(buf.as_mut_void_ptr()) };
81
82        unsafe {
83            olm_sys::olm_pk_encryption_set_recipient_key(
84                ptr,
85                recipient_key.as_ptr() as *mut _,
86                recipient_key.len(),
87            );
88        }
89
90        Self { ptr, _buf: buf }
91    }
92
93    fn last_error(ptr: *mut olm_sys::OlmPkEncryption) -> OlmPkEncryptionError {
94        let error = unsafe {
95            let error_raw = olm_sys::olm_pk_encryption_last_error(ptr);
96            CStr::from_ptr(error_raw).to_str().unwrap()
97        };
98        error.into()
99    }
100
101    /// Encrypt a plaintext message.
102    ///
103    /// Returns the encrypted PkMessage.
104    ///
105    /// # Arguments
106    ///
107    /// * `plaintext` - A string that will be encrypted using the PkEncryption
108    ///     object.
109    ///
110    /// # Panics
111    /// * `InputBufferTooSmall` if the ciphertext, ephemeral key, or  mac
112    /// buffers are too small.
113    /// * `OutputBufferTooSmall` if the random buffer is too small.
114    /// * on malformed UTF-8 coding of the ciphertext provided by libolm
115    pub fn encrypt(&self, plaintext: &str) -> PkMessage {
116        let random_length = unsafe { olm_sys::olm_pk_encrypt_random_length(self.ptr) };
117
118        let mut random_buf = Zeroizing::new(vec![0; random_length]);
119        getrandom(&mut random_buf);
120
121        let ciphertext_length =
122            unsafe { olm_sys::olm_pk_ciphertext_length(self.ptr, plaintext.len()) };
123
124        let mac_length = unsafe { olm_sys::olm_pk_mac_length(self.ptr) };
125
126        let ephemeral_key_size = unsafe { olm_sys::olm_pk_key_length() };
127
128        let mut ciphertext = vec![0; ciphertext_length];
129        let mut mac = vec![0; mac_length];
130        let mut ephemeral_key = vec![0; ephemeral_key_size];
131
132        let ret = unsafe {
133            olm_sys::olm_pk_encrypt(
134                self.ptr,
135                plaintext.as_ptr() as *const _,
136                plaintext.len(),
137                ciphertext.as_mut_ptr() as *mut _,
138                ciphertext.len(),
139                mac.as_mut_ptr() as *mut _,
140                mac.len(),
141                ephemeral_key.as_mut_ptr() as *mut _,
142                ephemeral_key.len(),
143                random_buf.as_ptr() as *mut _,
144                random_buf.len(),
145            )
146        };
147
148        if ret == errors::olm_error() {
149            errors::handle_fatal_error(OlmPkEncryption::last_error(self.ptr));
150        }
151
152        let ciphertext = unsafe { String::from_utf8_unchecked(ciphertext) };
153        let mac = unsafe { String::from_utf8_unchecked(mac) };
154        let ephemeral_key = unsafe { String::from_utf8_unchecked(ephemeral_key) };
155
156        PkMessage {
157            ciphertext,
158            mac,
159            ephemeral_key,
160        }
161    }
162}
163
164/// The decryption part of a PK encrypted channel.
165pub struct OlmPkDecryption {
166    ptr: *mut olm_sys::OlmPkDecryption,
167    _buf: ByteBuf,
168    public_key: String,
169}
170
171impl Drop for OlmPkDecryption {
172    fn drop(&mut self) {
173        unsafe {
174            olm_sys::olm_clear_pk_decryption(self.ptr);
175        }
176    }
177}
178
179impl OlmPkDecryption {
180    /// Create a new PK decryption object initializing the private key to a
181    /// random value.
182    ///
183    /// # Panics
184    /// * `NOT_ENOUGH_RANDOM` if there's not enough random data provided when
185    /// creating the OlmPkDecryption object.
186    /// * on malformed UTF-8 coding of the public key that is generated by
187    /// libolm.
188    pub fn new() -> Self {
189        let random_len = Self::private_key_length();
190        let mut random_buf = Zeroizing::new(vec![0; random_len]);
191        getrandom(&mut random_buf);
192
193        Self::from_bytes(&random_buf)
194            .expect("Can't create a PK decryption object from a valid random key")
195    }
196
197    /// Get the number of bytes a private key needs to have.
198    pub fn private_key_length() -> usize {
199        unsafe { olm_sys::olm_pk_private_key_length() }
200    }
201
202    /// Create a new PK decryption object from the given private key.
203    ///
204    /// # Arguments
205    ///
206    /// * `bytes` - An array of random bytes, the number of bytes this method
207    /// expects can be checked using the [`OlmPkDecryption::private_key_length`]
208    /// method.
209    ///
210    /// **Warning**: The caller needs to ensure that the passed in bytes are
211    /// cryptographically sound.
212    ///
213    /// # Panics
214    /// * on malformed UTF-8 coding of the public key that is generated by
215    /// libolm.
216    pub fn from_bytes(bytes: &[u8]) -> Result<Self, OlmPkDecryptionError> {
217        let (ptr, buf) = OlmPkDecryption::init();
218
219        let key_length = unsafe { olm_sys::olm_pk_key_length() };
220        let mut key_buffer = vec![0; key_length];
221
222        let ret = unsafe {
223            olm_sys::olm_pk_key_from_private(
224                ptr,
225                key_buffer.as_mut_ptr() as *mut _,
226                key_buffer.len(),
227                bytes.as_ptr() as *const _,
228                bytes.len(),
229            )
230        };
231
232        if ret == errors::olm_error() {
233            Err(Self::last_error(ptr))
234        } else {
235            let public_key = String::from_utf8(key_buffer)
236                .expect("Can't convert the public key buffer to a string");
237
238            Ok(Self {
239                ptr,
240                _buf: buf,
241                public_key,
242            })
243        }
244    }
245
246    fn init() -> (*mut olm_sys::OlmPkDecryption, ByteBuf) {
247        let mut buf = ByteBuf::new(unsafe { olm_sys::olm_pk_decryption_size() });
248        let ptr = unsafe { olm_sys::olm_pk_decryption(buf.as_mut_void_ptr() as *mut _) };
249
250        (ptr, buf)
251    }
252
253    fn last_error(ptr: *mut olm_sys::OlmPkDecryption) -> OlmPkDecryptionError {
254        let error = unsafe {
255            let error_raw = olm_sys::olm_pk_decryption_last_error(ptr);
256            CStr::from_ptr(error_raw).to_str().unwrap()
257        };
258        error.into()
259    }
260
261    /// Store a PkDecryption object.
262    ///
263    /// Stores a [`OlmPkDecryption`] object as a base64 string. Encrypts the object
264    /// using the supplied passphrase. Returns a byte object containing the
265    /// base64 encoded string of the pickled session.
266    ///
267    /// # Arguments
268    ///
269    /// * `mode` - The pickle mode that should be used to store the decryption
270    /// object.
271    ///
272    /// # Panics
273    /// * `OUTPUT_BUFFER_TOO_SMALL` for OlmSession's pickled buffer
274    /// * on malformed UTF-8 coding of the pickling provided by libolm
275    pub fn pickle(&self, mode: PicklingMode) -> String {
276        let mut pickled_buf: Vec<u8> =
277            vec![0; unsafe { olm_sys::olm_pickle_pk_decryption_length(self.ptr) }];
278
279        let pickle_error = {
280            let key = Zeroizing::new(crate::convert_pickling_mode_to_key(mode));
281
282            unsafe {
283                olm_sys::olm_pickle_pk_decryption(
284                    self.ptr,
285                    key.as_ptr() as *const _,
286                    key.len(),
287                    pickled_buf.as_mut_ptr() as *mut _,
288                    pickled_buf.len(),
289                )
290            }
291        };
292
293        let pickled_result =
294            String::from_utf8(pickled_buf).expect("Pickle string is not valid utf-8");
295
296        if pickle_error == errors::olm_error() {
297            errors::handle_fatal_error(Self::last_error(self.ptr));
298        }
299
300        pickled_result
301    }
302
303    /// Restore a previously stored OlmPkDecryption object.
304    ///
305    /// Creates a [`OlmPkDecryption`] object from a pickled base64 string. Decrypts
306    /// the pickled object using the supplied passphrase.
307    ///
308    /// # Arguments
309    ///
310    /// * `mode` - The pickle mode that should be used to store the decryption
311    /// object.
312    ///
313    /// # C-API equivalent
314    /// `olm_unpickle_pk_decryption`
315    ///
316    /// # Errors
317    ///
318    /// * `BadAccountKey` if the key doesn't match the one the account was encrypted with
319    /// * `InvalidBase64` if decoding the supplied `pickled` string slice fails
320    ///
321    /// # Panics
322    ///
323    /// * on malformed UTF-8 coding of the public key that is generated by
324    /// libolm.
325    pub fn unpickle(mut pickle: String, mode: PicklingMode) -> Result<Self, OlmPkDecryptionError> {
326        let (ptr, buf) = OlmPkDecryption::init();
327
328        let pubkey_length = unsafe { olm_sys::olm_pk_signing_public_key_length() };
329        let mut pubkey_buffer = vec![0; pubkey_length];
330
331        let unpickle_error = {
332            let key = Zeroizing::new(crate::convert_pickling_mode_to_key(mode));
333
334            unsafe {
335                olm_sys::olm_unpickle_pk_decryption(
336                    ptr,
337                    key.as_ptr() as *const _,
338                    key.len(),
339                    pickle.as_mut_ptr() as *mut _,
340                    pickle.len(),
341                    pubkey_buffer.as_mut_ptr() as *mut _,
342                    pubkey_buffer.len(),
343                )
344            }
345        };
346
347        let public_key = String::from_utf8(pubkey_buffer)
348            .expect("Can't conver the public key buffer to a string");
349
350        if unpickle_error == errors::olm_error() {
351            Err(Self::last_error(ptr))
352        } else {
353            Ok(Self {
354                ptr,
355                _buf: buf,
356                public_key,
357            })
358        }
359    }
360
361    /// Decrypts a PK message using this decryption object.
362    ///
363    /// Decoding is lossy, meaing if the decrypted plaintext contains invalid
364    /// UTF-8 symbols, they will be returned as `U+FFFD` (�).
365    ///
366    /// # Arguments
367    ///
368    /// * `message` - The encrypted PkMessage that should be decrypted.
369    ///
370    /// # C-API equivalent
371    /// `olm_pk_decrypt`
372    ///
373    /// # Errors
374    /// * `InvalidBase64` on invalid base64 coding for supplied arguments
375    /// * `BadMessageVersion` on unsupported protocol version
376    /// * `BadMessageFormat` on failing to decode the message
377    /// * `BadMessageMac` on invalid message MAC
378    ///
379    /// # Panics
380    ///
381    /// * `OutputBufferTooSmall` on plaintext output buffer
382    ///
383    pub fn decrypt(&self, mut message: PkMessage) -> Result<String, OlmPkDecryptionError> {
384        let max_plaintext = {
385            let ret =
386                unsafe { olm_sys::olm_pk_max_plaintext_length(self.ptr, message.ciphertext.len()) };
387
388            if ret == errors::olm_error() {
389                return Err(OlmPkDecryptionError::InvalidBase64);
390            }
391
392            ret
393        };
394
395        let mut plaintext = vec![0; max_plaintext];
396
397        let plaintext_len = unsafe {
398            olm_sys::olm_pk_decrypt(
399                self.ptr,
400                message.ephemeral_key.as_ptr() as *const _,
401                message.ephemeral_key.len(),
402                message.mac.as_ptr() as *const _,
403                message.mac.len(),
404                message.ciphertext.as_mut_ptr() as *mut _,
405                message.ciphertext.len(),
406                plaintext.as_mut_ptr() as *mut _,
407                max_plaintext,
408            )
409        };
410
411        if plaintext_len == errors::olm_error() {
412            Err(Self::last_error(self.ptr))
413        } else {
414            plaintext.truncate(plaintext_len);
415            Ok(String::from_utf8_lossy(&plaintext).to_string())
416        }
417    }
418
419    /// Get the public key of the decryption object.
420    ///
421    /// This can be used to initialize a encryption object to encrypt messages
422    /// for this decryption object.
423    pub fn public_key(&self) -> &str {
424        &self.public_key
425    }
426}
427
428/// Signs messages using public key cryptography.
429pub struct OlmPkSigning {
430    ptr: *mut olm_sys::OlmPkSigning,
431    _buf: ByteBuf,
432    public_key: String,
433}
434
435impl Drop for OlmPkSigning {
436    fn drop(&mut self) {
437        unsafe { olm_sys::olm_clear_pk_signing(self.ptr) };
438    }
439}
440
441impl OlmPkSigning {
442    /// Create a new signing object.
443    ///
444    /// # Arguments
445    ///
446    /// * `seed` - the seed to use as the private key for signing. The seed must
447    ///     have the same length as the seeds generated by
448    ///     [`OlmPkSigning::generate_seed()`]. The correct length can be checked
449    ///     using [`OlmPkSigning::seed_length()`] as well.
450    pub fn new(seed: &[u8]) -> Result<Self, OlmPkSigningError> {
451        if seed.len() != OlmPkSigning::seed_length() {
452            return Err(OlmPkSigningError::InvalidSeed);
453        }
454
455        let mut buffer = ByteBuf::new(unsafe { olm_sys::olm_pk_signing_size() });
456
457        let ptr = unsafe { olm_sys::olm_pk_signing(buffer.as_mut_void_ptr() as *mut _) };
458        let pubkey_length = unsafe { olm_sys::olm_pk_signing_public_key_length() };
459        let mut pubkey_buffer = vec![0; pubkey_length];
460
461        let ret = unsafe {
462            olm_sys::olm_pk_signing_key_from_seed(
463                ptr,
464                pubkey_buffer.as_mut_ptr() as *mut _,
465                pubkey_length,
466                seed.as_ptr() as *const _,
467                seed.len(),
468            )
469        };
470
471        if ret == errors::olm_error() {
472            Err(OlmPkSigning::last_error(ptr))
473        } else {
474            Ok(Self {
475                ptr,
476                _buf: buffer,
477                public_key: String::from_utf8(pubkey_buffer)
478                    .expect("Can't conver the public key buffer to a string"),
479            })
480        }
481    }
482
483    fn last_error(ptr: *mut olm_sys::OlmPkSigning) -> OlmPkSigningError {
484        let error = unsafe {
485            let error_raw = olm_sys::olm_pk_signing_last_error(ptr);
486            CStr::from_ptr(error_raw).to_str().unwrap()
487        };
488        error.into()
489    }
490
491    /// Get the required seed length.
492    pub fn seed_length() -> usize {
493        unsafe { olm_sys::olm_pk_signing_seed_length() }
494    }
495
496    /// Generate a random seed that can be used to initialize a [`OlmPkSigning`]
497    /// object.
498    pub fn generate_seed() -> Vec<u8> {
499        let length = OlmPkSigning::seed_length();
500        let mut buffer = Zeroizing::new(vec![0; length]);
501
502        getrandom(&mut buffer);
503
504        buffer.to_vec()
505    }
506
507    /// Get the public key of the the [`OlmPkSigning`] object.
508    ///
509    /// This can be used to check the signature of a messsage that has been
510    /// signed by this object.
511    ///
512    /// # Example
513    ///
514    /// ```
515    /// # use olm_rs::pk::OlmPkSigning;
516    /// # use olm_rs::utility::OlmUtility;
517    /// let message = "It's a secret to everyone".to_string();
518    ///
519    /// let sign = OlmPkSigning::new(&OlmPkSigning::generate_seed()).unwrap();
520    /// let utility = OlmUtility::new();
521    ///
522    /// let signature = sign.sign(&message);
523    ///
524    /// utility.ed25519_verify(sign.public_key(), &message, signature).unwrap();
525    /// ```
526    pub fn public_key(&self) -> &str {
527        &self.public_key
528    }
529
530    /// Sign a message using this object.
531    ///
532    /// # Arguments
533    ///
534    /// * `message` - The message that should be signed with the private key of
535    ///     this object.
536    ///
537    /// # Panics
538    ///
539    /// * `OUTPUT_BUFFER_TOO_SMALL` for the signature buffer that is provided to
540    /// libolm.
541    /// * on malformed UTF-8 coding of the signature provided by libolm.
542    pub fn sign(&self, message: &str) -> String {
543        let signature_len = unsafe { olm_sys::olm_pk_signature_length() };
544
545        let mut signature = vec![0; signature_len];
546
547        let ret = unsafe {
548            olm_sys::olm_pk_sign(
549                self.ptr,
550                message.as_ptr() as *mut _,
551                message.len(),
552                signature.as_mut_ptr() as *mut _,
553                signature_len,
554            )
555        };
556
557        if ret == errors::olm_error() {
558            errors::handle_fatal_error(Self::last_error(self.ptr));
559        }
560
561        String::from_utf8(signature).expect("Can't conver the signature to a string")
562    }
563}
564
565#[cfg(test)]
566mod test {
567    use crate::errors::OlmPkDecryptionError;
568    use crate::pk::{OlmPkDecryption, OlmPkEncryption, OlmPkSigning, PkMessage};
569    use crate::utility::OlmUtility;
570    use crate::PicklingMode;
571
572    #[test]
573    fn create_pk_sign() {
574        assert!(OlmPkSigning::new(&OlmPkSigning::generate_seed()).is_ok());
575    }
576
577    #[test]
578    fn invalid_seed() {
579        assert!(OlmPkSigning::new(&[]).is_err());
580
581        let lo_seed_len = OlmPkSigning::seed_length() - 1;
582        let hi_seed_len = OlmPkSigning::seed_length() + 1;
583
584        assert!(OlmPkSigning::new(&vec![0; lo_seed_len]).is_err());
585        assert!(OlmPkSigning::new(&vec![0; hi_seed_len]).is_err());
586    }
587
588    #[test]
589    fn seed_random() {
590        let seed_a = OlmPkSigning::generate_seed();
591        let seed_b = OlmPkSigning::generate_seed();
592        assert_ne!(&seed_a[..], &seed_b[..]);
593    }
594
595    #[test]
596    fn sign_a_message() {
597        let message = "It's a secret to everyone".to_string();
598        let sign = OlmPkSigning::new(&OlmPkSigning::generate_seed()).unwrap();
599        let utility = OlmUtility::new();
600
601        let signature = sign.sign(&message);
602        assert!(utility
603            .ed25519_verify(sign.public_key(), &message, signature.clone())
604            .is_ok());
605        assert!(utility
606            .ed25519_verify(sign.public_key(), "Hello world", signature)
607            .is_err());
608    }
609
610    #[test]
611    fn encrypt_a_message() {
612        let message = "It's a secret to everyone".to_string();
613        let decryption = OlmPkDecryption::new();
614        let encryption = OlmPkEncryption::new(decryption.public_key());
615
616        let encrypted_message = encryption.encrypt(&message);
617
618        let plaintext = decryption.decrypt(encrypted_message).unwrap();
619
620        assert_eq!(message, plaintext);
621    }
622
623    #[test]
624    fn pickle() {
625        let message = "It's a secret to everyone".to_string();
626        let decryption = OlmPkDecryption::new();
627        let encryption = OlmPkEncryption::new(decryption.public_key());
628
629        let encrypted_message = encryption.encrypt(&message);
630
631        let pickle = decryption.pickle(PicklingMode::Unencrypted);
632        let decryption = OlmPkDecryption::unpickle(pickle, PicklingMode::Unencrypted).unwrap();
633
634        let plaintext = decryption.decrypt(encrypted_message).unwrap();
635
636        assert_eq!(message, plaintext);
637    }
638
639    #[test]
640    fn invalid_unpickle() {
641        let decryption = OlmPkDecryption::new();
642
643        let pickle = decryption.pickle(PicklingMode::Encrypted {
644            key: Vec::from("wordpass"),
645        });
646        assert!(OlmPkDecryption::unpickle(pickle, PicklingMode::Unencrypted).is_err());
647    }
648
649    #[test]
650    fn invalid_decrypt() {
651        let alice = OlmPkDecryption::new();
652        let malory = OlmPkEncryption::new(OlmPkDecryption::new().public_key());
653
654        let encrypted_message = malory.encrypt("It's a secret to everyone");
655        assert!(alice.decrypt(encrypted_message).is_err());
656    }
657
658    #[test]
659    fn attempt_decrypt_invalid_base64() {
660        let decryption = OlmPkDecryption::new();
661        let message = PkMessage {
662            ciphertext: "1".to_string(),
663            mac: "".to_string(),
664            ephemeral_key: "".to_string(),
665        };
666
667        assert_eq!(
668            Err(OlmPkDecryptionError::InvalidBase64),
669            decryption.decrypt(message)
670        );
671    }
672}