1use std::ffi::CStr;
35
36use zeroize::Zeroizing;
37
38use crate::errors::{self, OlmSasError};
39use crate::getrandom;
40use crate::ByteBuf;
41
42pub struct OlmSas {
43 sas_ptr: *mut olm_sys::OlmSAS,
44 _sas_buf: ByteBuf,
45 public_key_set: bool,
46}
47
48impl Drop for OlmSas {
49 fn drop(&mut self) {
50 unsafe {
51 olm_sys::olm_clear_sas(self.sas_ptr);
52 }
53 }
54}
55
56impl Default for OlmSas {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl OlmSas {
63 pub fn new() -> Self {
64 let mut sas_buf = ByteBuf::new(unsafe { olm_sys::olm_sas_size() });
66 let ptr = unsafe { olm_sys::olm_sas(sas_buf.as_mut_void_ptr()) };
67
68 let random_len = unsafe { olm_sys::olm_create_sas_random_length(ptr) };
69 let mut random_buf: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0; random_len]);
70 getrandom(&mut random_buf);
71
72 let ret =
73 unsafe { olm_sys::olm_create_sas(ptr, random_buf.as_mut_ptr() as *mut _, random_len) };
74
75 if ret == errors::olm_error() {
76 errors::handle_fatal_error(Self::last_error(ptr));
77 }
78
79 Self {
80 sas_ptr: ptr,
81 _sas_buf: sas_buf,
82 public_key_set: false,
83 }
84 }
85
86 pub fn public_key(&self) -> String {
91 let pubkey_length = unsafe { olm_sys::olm_sas_pubkey_length(self.sas_ptr) };
92
93 let mut buffer: Vec<u8> = vec![0; pubkey_length];
94
95 let ret = unsafe {
96 olm_sys::olm_sas_get_pubkey(self.sas_ptr, buffer.as_mut_ptr() as *mut _, pubkey_length)
97 };
98
99 if ret == errors::olm_error() {
100 errors::handle_fatal_error(Self::last_error(self.sas_ptr));
101 }
102
103 unsafe { String::from_utf8_unchecked(buffer) }
104 }
105
106 fn last_error(sas_ptr: *mut olm_sys::OlmSAS) -> OlmSasError {
110 let error = unsafe {
111 let error_raw = olm_sys::olm_sas_last_error(sas_ptr);
112 CStr::from_ptr(error_raw).to_str().unwrap()
113 };
114
115 match error {
116 "NOT_ENOUGH_RANDOM" => OlmSasError::NotEnoughRandom,
117 "OUTPUT_BUFFER_TOO_SMALL" => OlmSasError::OutputBufferTooSmall,
118 "INPUT_BUFFER_TOO_SMALL" => OlmSasError::OutputBufferTooSmall,
119 _ => OlmSasError::Unknown,
120 }
121 }
122
123 pub fn set_their_public_key(&mut self, public_key: String) -> Result<(), OlmSasError> {
135 let ret = unsafe {
136 olm_sys::olm_sas_set_their_key(
137 self.sas_ptr,
138 public_key.as_ptr() as *mut _,
139 public_key.len(),
140 )
141 };
142
143 if ret == errors::olm_error() {
144 Err(Self::last_error(self.sas_ptr))
145 } else {
146 self.public_key_set = true;
147 Ok(())
148 }
149 }
150
151 pub fn generate_bytes(&self, extra_info: &str, length: usize) -> Result<Vec<u8>, OlmSasError> {
163 if !self.public_key_set {
164 return Err(OlmSasError::OtherPublicKeyUnset);
165 } else if length < 1 {
166 return Err(OlmSasError::InvalidLength);
167 }
168
169 let mut out_buffer = vec![0; length];
170
171 let ret = unsafe {
172 olm_sys::olm_sas_generate_bytes(
173 self.sas_ptr,
174 extra_info.as_ptr() as *mut _,
175 extra_info.len(),
176 out_buffer.as_mut_ptr() as *mut _,
177 length,
178 )
179 };
180
181 if ret == errors::olm_error() {
182 Err(Self::last_error(self.sas_ptr))
183 } else {
184 Ok(out_buffer)
185 }
186 }
187
188 pub fn calculate_mac(&self, message: &str, extra_info: &str) -> Result<String, OlmSasError> {
199 if !self.public_key_set {
200 return Err(OlmSasError::OtherPublicKeyUnset);
201 }
202
203 let mac_length = unsafe { olm_sys::olm_sas_mac_length(self.sas_ptr) };
204 let mut mac_buffer = vec![0; mac_length];
205
206 let ret = unsafe {
207 olm_sys::olm_sas_calculate_mac(
208 self.sas_ptr,
209 message.as_ptr() as *mut _,
210 message.len(),
211 extra_info.as_ptr() as *mut _,
212 extra_info.len(),
213 mac_buffer.as_mut_ptr() as *mut _,
214 mac_length,
215 )
216 };
217
218 if ret == errors::olm_error() {
219 Err(Self::last_error(self.sas_ptr))
220 } else {
221 Ok(unsafe { String::from_utf8_unchecked(mac_buffer) })
222 }
223 }
224}
225
226#[cfg(test)]
227mod test {
228 use crate::sas::OlmSas;
229
230 #[test]
231 fn test_creation() {
232 let alice = OlmSas::new();
233 assert!(!alice.public_key().is_empty());
234 }
235
236 #[test]
237 fn test_set_pubkey() {
238 let mut alice = OlmSas::new();
239
240 assert!(alice.set_their_public_key(alice.public_key()).is_ok());
241 assert!(alice.set_their_public_key("".to_string()).is_err());
242 }
243
244 #[test]
245 fn test_generate_bytes() {
246 let mut alice = OlmSas::new();
247 let mut bob = OlmSas::new();
248
249 assert!(alice.generate_bytes("", 5).is_err());
250
251 assert!(alice.set_their_public_key(bob.public_key()).is_ok());
252 assert!(bob.set_their_public_key(alice.public_key()).is_ok());
253
254 assert_eq!(
255 alice.generate_bytes("", 5).unwrap(),
256 bob.generate_bytes("", 5).unwrap()
257 );
258 assert_ne!(
259 alice.generate_bytes("fake", 5).unwrap(),
260 bob.generate_bytes("", 5).unwrap()
261 );
262 }
263
264 #[test]
265 fn test_calculate_mac() {
266 let mut alice = OlmSas::new();
267 let mut bob = OlmSas::new();
268
269 let message = "It's a secret to everyone".to_string();
270
271 assert!(alice.calculate_mac(&message, "").is_err());
272
273 assert!(alice.set_their_public_key(bob.public_key()).is_ok());
274 assert!(bob.set_their_public_key(alice.public_key()).is_ok());
275
276 assert_eq!(
277 alice.calculate_mac(&message, "").unwrap(),
278 bob.calculate_mac(&message, "").unwrap()
279 );
280 assert_ne!(
281 alice.calculate_mac("fake", "").unwrap(),
282 bob.calculate_mac(&message, "").unwrap()
283 );
284 }
285}