olm_rs/
sas.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module wraps around all functions following the pattern `olm_sas_*`.
16//!
17//! # Example
18//!
19//! ```
20//! # use olm_rs::sas::OlmSas;
21//! let mut alice = OlmSas::new();
22//! let mut bob = OlmSas::new();
23//!
24//! alice.set_their_public_key(bob.public_key()).unwrap();
25//! bob.set_their_public_key(alice.public_key()).unwrap();
26//!
27//! assert_eq!(
28//!     alice.generate_bytes("", 5).unwrap(),
29//!     bob.generate_bytes("", 5).unwrap()
30//! );
31//!
32//! ```
33
34use std::ffi::CStr;
35
36use zeroize::Zeroizing;
37
38use crate::errors::{self, OlmSasError};
39use crate::getrandom;
40use crate::ByteBuf;
41
42pub struct OlmSas {
43    sas_ptr: *mut olm_sys::OlmSAS,
44    _sas_buf: ByteBuf,
45    public_key_set: bool,
46}
47
48impl Drop for OlmSas {
49    fn drop(&mut self) {
50        unsafe {
51            olm_sys::olm_clear_sas(self.sas_ptr);
52        }
53    }
54}
55
56impl Default for OlmSas {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl OlmSas {
63    pub fn new() -> Self {
64        // allocate buffer for OlmAccount to be written into
65        let mut sas_buf = ByteBuf::new(unsafe { olm_sys::olm_sas_size() });
66        let ptr = unsafe { olm_sys::olm_sas(sas_buf.as_mut_void_ptr()) };
67
68        let random_len = unsafe { olm_sys::olm_create_sas_random_length(ptr) };
69        let mut random_buf: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0; random_len]);
70        getrandom(&mut random_buf);
71
72        let ret =
73            unsafe { olm_sys::olm_create_sas(ptr, random_buf.as_mut_ptr() as *mut _, random_len) };
74
75        if ret == errors::olm_error() {
76            errors::handle_fatal_error(Self::last_error(ptr));
77        }
78
79        Self {
80            sas_ptr: ptr,
81            _sas_buf: sas_buf,
82            public_key_set: false,
83        }
84    }
85
86    /// Get the public key for the SAS object.
87    ///
88    /// This returns the public key of the SAS object that can then be shared
89    /// with another user to perform the authentication process.
90    pub fn public_key(&self) -> String {
91        let pubkey_length = unsafe { olm_sys::olm_sas_pubkey_length(self.sas_ptr) };
92
93        let mut buffer: Vec<u8> = vec![0; pubkey_length];
94
95        let ret = unsafe {
96            olm_sys::olm_sas_get_pubkey(self.sas_ptr, buffer.as_mut_ptr() as *mut _, pubkey_length)
97        };
98
99        if ret == errors::olm_error() {
100            errors::handle_fatal_error(Self::last_error(self.sas_ptr));
101        }
102
103        unsafe { String::from_utf8_unchecked(buffer) }
104    }
105
106    /// Returns the last error that occurred for an OlmSas object.
107    /// Since error codes are encoded as CStrings by libolm,
108    /// OlmSasError::Unknown is returned on an unknown error code.
109    fn last_error(sas_ptr: *mut olm_sys::OlmSAS) -> OlmSasError {
110        let error = unsafe {
111            let error_raw = olm_sys::olm_sas_last_error(sas_ptr);
112            CStr::from_ptr(error_raw).to_str().unwrap()
113        };
114
115        match error {
116            "NOT_ENOUGH_RANDOM" => OlmSasError::NotEnoughRandom,
117            "OUTPUT_BUFFER_TOO_SMALL" => OlmSasError::OutputBufferTooSmall,
118            "INPUT_BUFFER_TOO_SMALL" => OlmSasError::OutputBufferTooSmall,
119            _ => OlmSasError::Unknown,
120        }
121    }
122
123    /// Set the public key of the other user.
124    ///
125    /// This sets the public key of the other user, it needs to be set before
126    /// bytes can be generated for the authentication string and a MAC can be
127    /// calculated.
128    ///
129    /// Returns an error if the public key was too short or invalid.
130    ///
131    /// # Arguments
132    ///
133    /// * `public_key` - The public key of the other user.
134    pub fn set_their_public_key(&mut self, public_key: String) -> Result<(), OlmSasError> {
135        let ret = unsafe {
136            olm_sys::olm_sas_set_their_key(
137                self.sas_ptr,
138                public_key.as_ptr() as *mut _,
139                public_key.len(),
140            )
141        };
142
143        if ret == errors::olm_error() {
144            Err(Self::last_error(self.sas_ptr))
145        } else {
146            self.public_key_set = true;
147            Ok(())
148        }
149    }
150
151    /// Generate bytes to use for the short authentication string.
152    ///
153    /// Note the other public key needs to be set for this method to work.
154    /// Returns an error if it isn't set.
155    ///
156    /// # Arguments
157    ///
158    /// * `extra_info` - Extra information to mix in when generating the
159    ///     bytes.
160    ///
161    /// * `length` - The number of bytes to generate.
162    pub fn generate_bytes(&self, extra_info: &str, length: usize) -> Result<Vec<u8>, OlmSasError> {
163        if !self.public_key_set {
164            return Err(OlmSasError::OtherPublicKeyUnset);
165        } else if length < 1 {
166            return Err(OlmSasError::InvalidLength);
167        }
168
169        let mut out_buffer = vec![0; length];
170
171        let ret = unsafe {
172            olm_sys::olm_sas_generate_bytes(
173                self.sas_ptr,
174                extra_info.as_ptr() as *mut _,
175                extra_info.len(),
176                out_buffer.as_mut_ptr() as *mut _,
177                length,
178            )
179        };
180
181        if ret == errors::olm_error() {
182            Err(Self::last_error(self.sas_ptr))
183        } else {
184            Ok(out_buffer)
185        }
186    }
187
188    /// Generate a message authentication code based on the shared secret.
189    ///
190    /// Note the other public key needs to be set for this method to work.
191    /// Returns an error if it isn't set.
192    ///
193    /// # Arguments
194    ///
195    /// * `message` - The message to produce the authentication code for.
196    ///
197    /// * `extra_info` - Extra information to mix in when generating the MAC.
198    pub fn calculate_mac(&self, message: &str, extra_info: &str) -> Result<String, OlmSasError> {
199        if !self.public_key_set {
200            return Err(OlmSasError::OtherPublicKeyUnset);
201        }
202
203        let mac_length = unsafe { olm_sys::olm_sas_mac_length(self.sas_ptr) };
204        let mut mac_buffer = vec![0; mac_length];
205
206        let ret = unsafe {
207            olm_sys::olm_sas_calculate_mac(
208                self.sas_ptr,
209                message.as_ptr() as *mut _,
210                message.len(),
211                extra_info.as_ptr() as *mut _,
212                extra_info.len(),
213                mac_buffer.as_mut_ptr() as *mut _,
214                mac_length,
215            )
216        };
217
218        if ret == errors::olm_error() {
219            Err(Self::last_error(self.sas_ptr))
220        } else {
221            Ok(unsafe { String::from_utf8_unchecked(mac_buffer) })
222        }
223    }
224}
225
226#[cfg(test)]
227mod test {
228    use crate::sas::OlmSas;
229
230    #[test]
231    fn test_creation() {
232        let alice = OlmSas::new();
233        assert!(!alice.public_key().is_empty());
234    }
235
236    #[test]
237    fn test_set_pubkey() {
238        let mut alice = OlmSas::new();
239
240        assert!(alice.set_their_public_key(alice.public_key()).is_ok());
241        assert!(alice.set_their_public_key("".to_string()).is_err());
242    }
243
244    #[test]
245    fn test_generate_bytes() {
246        let mut alice = OlmSas::new();
247        let mut bob = OlmSas::new();
248
249        assert!(alice.generate_bytes("", 5).is_err());
250
251        assert!(alice.set_their_public_key(bob.public_key()).is_ok());
252        assert!(bob.set_their_public_key(alice.public_key()).is_ok());
253
254        assert_eq!(
255            alice.generate_bytes("", 5).unwrap(),
256            bob.generate_bytes("", 5).unwrap()
257        );
258        assert_ne!(
259            alice.generate_bytes("fake", 5).unwrap(),
260            bob.generate_bytes("", 5).unwrap()
261        );
262    }
263
264    #[test]
265    fn test_calculate_mac() {
266        let mut alice = OlmSas::new();
267        let mut bob = OlmSas::new();
268
269        let message = "It's a secret to everyone".to_string();
270
271        assert!(alice.calculate_mac(&message, "").is_err());
272
273        assert!(alice.set_their_public_key(bob.public_key()).is_ok());
274        assert!(bob.set_their_public_key(alice.public_key()).is_ok());
275
276        assert_eq!(
277            alice.calculate_mac(&message, "").unwrap(),
278            bob.calculate_mac(&message, "").unwrap()
279        );
280        assert_ne!(
281            alice.calculate_mac("fake", "").unwrap(),
282            bob.calculate_mac(&message, "").unwrap()
283        );
284    }
285}