ockam_ffi/
vault.rs

1use crate::vault_types::{FfiSecretAttributes, SecretKeyHandle};
2use crate::{check_buffer, FfiError, FfiOckamError};
3use crate::{FfiVaultFatPointer, FfiVaultType};
4use core::{future::Future, result::Result as StdResult, slice};
5use futures::future::join_all;
6use lazy_static::lazy_static;
7use ockam_core::compat::collections::BTreeMap;
8use ockam_core::compat::sync::Arc;
9use ockam_core::{Error, Result};
10use ockam_vault::{AsymmetricVault, KeyId, PublicKey, Secret, SecretAttributes, SymmetricVault};
11use ockam_vault::{EphemeralSecretsStore, SecretsStoreReader, Vault};
12use tokio::{runtime::Runtime, sync::RwLock, task};
13
14#[derive(Default)]
15struct SecretsMapping {
16    mapping: BTreeMap<u64, KeyId>,
17    last_index: u64,
18}
19
20impl SecretsMapping {
21    fn insert(&mut self, key_id: KeyId) -> u64 {
22        self.last_index += 1;
23
24        self.mapping.insert(self.last_index, key_id);
25
26        self.last_index
27    }
28
29    fn get(&self, index: u64) -> Result<KeyId> {
30        Ok(self
31            .mapping
32            .get(&index)
33            .cloned()
34            .ok_or(FfiError::EntryNotFound)?)
35    }
36
37    fn take(&mut self, index: u64) -> Result<KeyId> {
38        Ok(self.mapping.remove(&index).ok_or(FfiError::EntryNotFound)?)
39    }
40}
41
42#[derive(Clone, Default)]
43struct VaultEntry {
44    vault: Vault,
45    secrets_mapping: Arc<RwLock<SecretsMapping>>,
46}
47
48impl VaultEntry {
49    async fn insert(&self, key_id: KeyId) -> u64 {
50        self.secrets_mapping.write().await.insert(key_id)
51    }
52
53    async fn get(&self, index: u64) -> Result<KeyId> {
54        self.secrets_mapping.read().await.get(index)
55    }
56
57    async fn take(&self, index: u64) -> Result<KeyId> {
58        self.secrets_mapping.write().await.take(index)
59    }
60}
61
62lazy_static! {
63    static ref SOFTWARE_VAULTS: RwLock<Vec<VaultEntry>> = RwLock::new(vec![]);
64    static ref RUNTIME: Arc<Runtime> = Arc::new(Runtime::new().unwrap());
65}
66
67fn get_runtime() -> Arc<Runtime> {
68    RUNTIME.clone()
69}
70
71fn block_future<F>(f: F) -> <F as Future>::Output
72where
73    F: Future,
74{
75    let rt = get_runtime();
76    task::block_in_place(move || {
77        let local = task::LocalSet::new();
78        local.block_on(&rt, f)
79    })
80}
81
82async fn get_vault_entry(context: FfiVaultFatPointer) -> Result<VaultEntry> {
83    match context.vault_type() {
84        FfiVaultType::Software => {
85            let item = SOFTWARE_VAULTS
86                .read()
87                .await
88                .get(context.handle() as usize)
89                .ok_or(FfiError::VaultNotFound)?
90                .clone();
91
92            Ok(item)
93        }
94    }
95}
96
97/// Create and return a default Ockam Vault.
98#[no_mangle]
99pub extern "C" fn ockam_vault_default_init(context: &mut FfiVaultFatPointer) -> FfiOckamError {
100    handle_panics(|| {
101        // TODO: handle logging
102        let handle = block_future(async move {
103            let mut write_lock = SOFTWARE_VAULTS.write().await;
104            write_lock.push(Default::default());
105            write_lock.len() - 1
106        });
107
108        *context = FfiVaultFatPointer::new(handle as u64, FfiVaultType::Software);
109
110        Ok(())
111    })
112}
113
114/// Compute the SHA-256 hash on `input` and put the result in `digest`.
115/// `digest` must be 32 bytes in length.
116#[no_mangle]
117pub extern "C" fn ockam_vault_sha256(
118    context: FfiVaultFatPointer,
119    input: *const u8,
120    input_length: u32,
121    digest: *mut u8,
122) -> FfiOckamError {
123    handle_panics(|| {
124        check_buffer!(input);
125        check_buffer!(digest);
126
127        let input = unsafe { core::slice::from_raw_parts(input, input_length as usize) };
128
129        let res = block_future(async move {
130            let entry = get_vault_entry(context).await?;
131            Ok::<[u8; 32], Error>(entry.vault.compute_sha256(input))
132        })?;
133
134        unsafe {
135            std::ptr::copy_nonoverlapping(res.as_ptr(), digest, res.len());
136        }
137        Ok(())
138    })
139}
140
141/// Generate a secret key with the specific attributes.
142/// Returns a handle for the secret.
143#[no_mangle]
144pub extern "C" fn ockam_vault_secret_generate(
145    context: FfiVaultFatPointer,
146    secret: &mut SecretKeyHandle,
147    attributes: FfiSecretAttributes,
148) -> FfiOckamError {
149    handle_panics(|| {
150        *secret = block_future(async move {
151            let entry = get_vault_entry(context).await?;
152            let atts = attributes.try_into()?;
153            let key_id = entry.vault.create_ephemeral_secret(atts).await?;
154
155            let index = entry.insert(key_id).await;
156
157            Ok::<u64, Error>(index)
158        })?;
159        Ok(())
160    })
161}
162
163/// Import a secret key with the specific handle and attributes.
164#[no_mangle]
165pub extern "C" fn ockam_vault_secret_import(
166    context: FfiVaultFatPointer,
167    secret: &mut SecretKeyHandle,
168    attributes: FfiSecretAttributes,
169    input: *mut u8,
170    input_length: u32,
171) -> FfiOckamError {
172    handle_panics(|| {
173        check_buffer!(input, input_length);
174        *secret = block_future(async move {
175            let entry = get_vault_entry(context).await?;
176            let atts = attributes.try_into()?;
177
178            let secret_data = unsafe { core::slice::from_raw_parts(input, input_length as usize) };
179
180            let secret = Secret::new(secret_data.to_vec());
181            let key_id = entry.vault.import_ephemeral_secret(secret, atts).await?;
182
183            let index = entry.insert(key_id).await;
184
185            Ok::<u64, Error>(index)
186        })?;
187        Ok(())
188    })
189}
190
191/// Export a secret key with the specific handle to the `output_buffer`.
192#[no_mangle]
193pub extern "C" fn ockam_vault_secret_export(
194    context: FfiVaultFatPointer,
195    secret: SecretKeyHandle,
196    output_buffer: *mut u8,
197    output_buffer_size: u32,
198    output_buffer_length: &mut u32,
199) -> FfiOckamError {
200    *output_buffer_length = 0;
201    handle_panics(|| {
202        block_future(async move {
203            let entry = get_vault_entry(context).await?;
204            let key_id = entry.get(secret).await?;
205            let key = entry
206                .vault
207                .get_ephemeral_secret(&key_id, "secret from ffi")
208                .await?;
209            let key = key.secret().as_ref();
210            if output_buffer_size < key.len() as u32 {
211                return Err(FfiError::BufferTooSmall.into());
212            }
213            *output_buffer_length = key.len() as u32;
214
215            unsafe {
216                std::ptr::copy_nonoverlapping(key.as_ptr(), output_buffer, key.len());
217            };
218            Ok::<(), Error>(())
219        })?;
220        Ok(())
221    })
222}
223
224/// Get the public key, given a secret key, and copy it to the output buffer.
225#[no_mangle]
226pub extern "C" fn ockam_vault_secret_publickey_get(
227    context: FfiVaultFatPointer,
228    secret: SecretKeyHandle,
229    output_buffer: *mut u8,
230    output_buffer_size: u32,
231    output_buffer_length: &mut u32,
232) -> FfiOckamError {
233    *output_buffer_length = 0;
234    handle_panics(|| {
235        block_future(async move {
236            let entry = get_vault_entry(context).await?;
237            let key_id = entry.get(secret).await?;
238            let key = entry.vault.get_public_key(&key_id).await?;
239            if output_buffer_size < key.data().len() as u32 {
240                return Err(FfiError::BufferTooSmall.into());
241            }
242            *output_buffer_length = key.data().len() as u32;
243
244            unsafe {
245                std::ptr::copy_nonoverlapping(key.data().as_ptr(), output_buffer, key.data().len());
246            };
247            Ok::<(), Error>(())
248        })?;
249        Ok(())
250    })
251}
252
253/// Retrieve the attributes for a specified secret.
254#[no_mangle]
255pub extern "C" fn ockam_vault_secret_attributes_get(
256    context: FfiVaultFatPointer,
257    secret: SecretKeyHandle,
258    attributes: &mut FfiSecretAttributes,
259) -> FfiOckamError {
260    handle_panics(|| {
261        *attributes = block_future(async move {
262            let entry = get_vault_entry(context).await?;
263            let key_id = entry.get(secret).await?;
264            let atts = entry.vault.get_secret_attributes(&key_id).await?;
265            Ok::<FfiSecretAttributes, Error>(atts.into())
266        })?;
267        Ok(())
268    })
269}
270
271/// Delete an ockam vault secret.
272#[no_mangle]
273pub extern "C" fn ockam_vault_secret_destroy(
274    context: FfiVaultFatPointer,
275    secret: SecretKeyHandle,
276) -> FfiOckamError {
277    match block_future(async move {
278        let entry = get_vault_entry(context).await?;
279        let key_id = entry.take(secret).await?;
280        entry.vault.delete_ephemeral_secret(key_id).await?;
281        Ok::<(), Error>(())
282    }) {
283        Ok(_) => FfiOckamError::none(),
284        Err(err) => err.into(),
285    }
286}
287
288/// Perform an ECDH operation on the supplied Ockam Vault `secret` and `peer_publickey`. The result
289/// is an Ockam Vault secret of unknown type.
290#[no_mangle]
291pub extern "C" fn ockam_vault_ecdh(
292    context: FfiVaultFatPointer,
293    secret: SecretKeyHandle,
294    peer_publickey: *const u8,
295    peer_publickey_length: u32,
296    shared_secret: &mut SecretKeyHandle,
297) -> FfiOckamError {
298    handle_panics(|| {
299        check_buffer!(peer_publickey, peer_publickey_length);
300
301        let peer_publickey =
302            unsafe { core::slice::from_raw_parts(peer_publickey, peer_publickey_length as usize) };
303
304        *shared_secret = block_future(async move {
305            let entry = get_vault_entry(context).await?;
306            let key_id = entry.get(secret).await?;
307            let atts = entry.vault.get_secret_attributes(&key_id).await?;
308            let pubkey = PublicKey::new(peer_publickey.to_vec(), atts.secret_type());
309            let shared_ctx = entry.vault.ec_diffie_hellman(&key_id, &pubkey).await?;
310            let index = entry.insert(shared_ctx).await;
311            Ok::<u64, Error>(index)
312        })?;
313        Ok(())
314    })
315}
316
317/// Perform an HMAC-SHA256 based key derivation function on the supplied salt and input key
318/// material.
319#[no_mangle]
320pub extern "C" fn ockam_vault_hkdf_sha256(
321    context: FfiVaultFatPointer,
322    salt: SecretKeyHandle,
323    input_key_material: *const SecretKeyHandle,
324    derived_outputs_attributes: *const FfiSecretAttributes,
325    derived_outputs_count: u8,
326    derived_outputs: *mut SecretKeyHandle,
327) -> FfiOckamError {
328    handle_panics(|| {
329        let derived_outputs_count = derived_outputs_count as usize;
330
331        block_future(async move {
332            let entry = get_vault_entry(context).await?;
333            let salt_key_id = entry.get(salt).await?;
334            let ikm_key_id = if input_key_material.is_null() {
335                None
336            } else {
337                let ctx = unsafe { entry.get(*input_key_material).await? };
338                Some(ctx)
339            };
340            let ikm_key_id = ikm_key_id.as_ref();
341
342            let array: &[FfiSecretAttributes] =
343                unsafe { slice::from_raw_parts(derived_outputs_attributes, derived_outputs_count) };
344
345            let mut output_attributes = Vec::<SecretAttributes>::with_capacity(array.len());
346            for x in array.iter() {
347                output_attributes.push(SecretAttributes::try_from(*x)?);
348            }
349
350            // TODO: Hardcoded to be empty for now because any changes
351            // to the C layer requires an API change.
352            // This change was necessary to implement Enrollment since the info string is not
353            // left blank for that protocol, but is blank for the XX key exchange pattern.
354            // If we agree to change the API, then this wouldn't be hardcoded but received
355            // from a parameter in the C API. Elixir and other consumers would be expected
356            // to pass the appropriate flag. The other option is to not expose the vault
357            // directly since it may confuse users about what to pass here and
358            // I don't like the idea of yelling at consumers through comments.
359            // Instead the vault could be encapsulated in channels and key exchanges.
360            // Either way, I don't want to change the API until this decision is finalized.
361            let hkdf_output = entry
362                .vault
363                .hkdf_sha256(&salt_key_id, b"", ikm_key_id, output_attributes)
364                .await?;
365
366            let hkdf_output: Vec<SecretKeyHandle> =
367                join_all(hkdf_output.into_iter().map(|x| entry.insert(x))).await;
368
369            unsafe {
370                std::ptr::copy_nonoverlapping(
371                    hkdf_output.as_ptr(),
372                    derived_outputs,
373                    derived_outputs_count,
374                )
375            };
376            Ok::<(), Error>(())
377        })?;
378        Ok(())
379    })
380}
381
382///   Encrypt a payload using AES-GCM.
383#[no_mangle]
384pub extern "C" fn ockam_vault_aead_aes_gcm_encrypt(
385    context: FfiVaultFatPointer,
386    secret: SecretKeyHandle,
387    nonce: u64,
388    additional_data: *const u8,
389    additional_data_length: u32,
390    plaintext: *const u8,
391    plaintext_length: u32,
392    ciphertext_and_tag: &mut u8,
393    ciphertext_and_tag_size: u32,
394    ciphertext_and_tag_length: &mut u32,
395) -> FfiOckamError {
396    *ciphertext_and_tag_length = 0;
397    handle_panics(|| {
398        check_buffer!(additional_data);
399        check_buffer!(plaintext);
400
401        let additional_data = unsafe {
402            core::slice::from_raw_parts(additional_data, additional_data_length as usize)
403        };
404
405        let plaintext =
406            unsafe { core::slice::from_raw_parts(plaintext, plaintext_length as usize) };
407
408        block_future(async move {
409            let entry = get_vault_entry(context).await?;
410            let key_id = entry.get(secret).await?;
411            let mut nonce_vec = vec![0; 12 - 8];
412            nonce_vec.extend_from_slice(&nonce.to_be_bytes());
413            let ciphertext = entry
414                .vault
415                .aead_aes_gcm_encrypt(&key_id, plaintext, &nonce_vec, additional_data)
416                .await?;
417
418            if ciphertext_and_tag_size < ciphertext.len() as u32 {
419                return Err(FfiError::BufferTooSmall.into());
420            }
421            *ciphertext_and_tag_length = ciphertext.len() as u32;
422
423            unsafe {
424                std::ptr::copy_nonoverlapping(
425                    ciphertext.as_ptr(),
426                    ciphertext_and_tag,
427                    ciphertext.len(),
428                )
429            };
430            Ok::<(), Error>(())
431        })?;
432        Ok(())
433    })
434}
435
436/// Decrypt a payload using AES-GCM.
437#[no_mangle]
438pub extern "C" fn ockam_vault_aead_aes_gcm_decrypt(
439    context: FfiVaultFatPointer,
440    secret: SecretKeyHandle,
441    nonce: u64,
442    additional_data: *const u8,
443    additional_data_length: u32,
444    ciphertext_and_tag: *const u8,
445    ciphertext_and_tag_length: u32,
446    plaintext: &mut u8,
447    plaintext_size: u32,
448    plaintext_length: &mut u32,
449) -> FfiOckamError {
450    *plaintext_length = 0;
451    handle_panics(|| {
452        check_buffer!(ciphertext_and_tag, ciphertext_and_tag_length);
453        check_buffer!(additional_data);
454
455        let additional_data = unsafe {
456            core::slice::from_raw_parts(additional_data, additional_data_length as usize)
457        };
458
459        let ciphertext_and_tag = unsafe {
460            core::slice::from_raw_parts(ciphertext_and_tag, ciphertext_and_tag_length as usize)
461        };
462
463        block_future(async move {
464            let entry = get_vault_entry(context).await?;
465            let key_id = entry.get(secret).await?;
466            let mut nonce_vec = vec![0; 12 - 8];
467            nonce_vec.extend_from_slice(&nonce.to_be_bytes());
468            let plain = entry
469                .vault
470                .aead_aes_gcm_decrypt(&key_id, ciphertext_and_tag, &nonce_vec, additional_data)
471                .await?;
472            if plaintext_size < plain.len() as u32 {
473                return Err(FfiError::BufferTooSmall.into());
474            }
475            *plaintext_length = plain.len() as u32;
476
477            unsafe { std::ptr::copy_nonoverlapping(plain.as_ptr(), plaintext, plain.len()) };
478            Ok::<(), Error>(())
479        })?;
480        Ok(())
481    })
482}
483
484/// De-initialize an Ockam Vault.
485#[no_mangle]
486pub extern "C" fn ockam_vault_deinit(context: FfiVaultFatPointer) -> FfiOckamError {
487    handle_panics(|| {
488        block_future(async move {
489            match context.vault_type() {
490                FfiVaultType::Software => {
491                    let handle = context.handle() as usize;
492                    let mut v = SOFTWARE_VAULTS.write().await;
493                    if handle < v.len() {
494                        v.remove(handle);
495                        Ok(())
496                    } else {
497                        Err(FfiError::VaultNotFound)
498                    }
499                }
500            }
501        })?;
502        Ok(())
503    })
504}
505
506fn handle_panics<F>(f: F) -> FfiOckamError
507where
508    F: FnOnce() -> StdResult<(), FfiOckamError>,
509{
510    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
511    match result {
512        // No error.
513        Ok(Ok(())) => FfiOckamError::none(),
514        // Failed with a specific ockam error:
515        Ok(Err(e)) => e,
516        // Panicked
517        Err(e) => {
518            // Force an abort if either:
519            //
520            // - `e` panics during its `Drop` impl.
521            // - `FfiOckamError::from(FfiError)` panics.
522            //
523            // Both of these are extremely unlikely, but possible.
524            let panic_guard = AbortOnDrop;
525            drop(e);
526            let ret = FfiOckamError::from(FfiError::UnexpectedPanic);
527            core::mem::forget(panic_guard);
528            ret
529        }
530    }
531}
532
533/// Aborts on drop, used to guard against panics in a section of code.
534///
535/// Correct usage should `mem::forget` this struct after the non-panicking
536/// section.
537struct AbortOnDrop;
538impl Drop for AbortOnDrop {
539    fn drop(&mut self) {
540        eprintln!("Panic from error drop, aborting!");
541        std::process::abort();
542    }
543}