use crate::vault_types::{FfiSecretAttributes, SecretKeyHandle};
use crate::{check_buffer, FfiError, FfiOckamError};
use crate::{FfiVaultFatPointer, FfiVaultType};
use core::{future::Future, result::Result as StdResult, slice};
use futures::future::join_all;
use lazy_static::lazy_static;
use ockam_core::compat::collections::BTreeMap;
use ockam_core::compat::sync::Arc;
use ockam_core::vault::{
AsymmetricVault, Hasher, KeyId, PublicKey, Secret, SecretAttributes, SecretKey, SecretVault,
SymmetricVault,
};
use ockam_core::{Error, Result};
use ockam_vault::Vault;
use tokio::{runtime::Runtime, sync::RwLock, task};
#[derive(Default)]
struct SecretsMapping {
mapping: BTreeMap<u64, KeyId>,
last_index: u64,
}
impl SecretsMapping {
fn insert(&mut self, key_id: KeyId) -> u64 {
self.last_index += 1;
self.mapping.insert(self.last_index, key_id);
self.last_index
}
fn get(&self, index: u64) -> Result<KeyId> {
Ok(self
.mapping
.get(&index)
.cloned()
.ok_or(FfiError::EntryNotFound)?)
}
fn take(&mut self, index: u64) -> Result<KeyId> {
Ok(self.mapping.remove(&index).ok_or(FfiError::EntryNotFound)?)
}
}
#[derive(Clone, Default)]
struct VaultEntry {
vault: Vault,
secrets_mapping: Arc<RwLock<SecretsMapping>>,
}
impl VaultEntry {
async fn insert(&self, key_id: KeyId) -> u64 {
self.secrets_mapping.write().await.insert(key_id)
}
async fn get(&self, index: u64) -> Result<KeyId> {
self.secrets_mapping.read().await.get(index)
}
async fn take(&self, index: u64) -> Result<KeyId> {
self.secrets_mapping.write().await.take(index)
}
}
lazy_static! {
static ref SOFTWARE_VAULTS: RwLock<Vec<VaultEntry>> = RwLock::new(vec![]);
static ref RUNTIME: Arc<Runtime> = Arc::new(Runtime::new().unwrap());
}
fn get_runtime() -> Arc<Runtime> {
RUNTIME.clone()
}
fn block_future<F>(f: F) -> <F as Future>::Output
where
F: Future,
{
let rt = get_runtime();
task::block_in_place(move || {
let local = task::LocalSet::new();
local.block_on(&rt, f)
})
}
async fn get_vault_entry(context: FfiVaultFatPointer) -> Result<VaultEntry> {
match context.vault_type() {
FfiVaultType::Software => {
let item = SOFTWARE_VAULTS
.read()
.await
.get(context.handle() as usize)
.ok_or(FfiError::VaultNotFound)?
.clone();
Ok(item)
}
}
}
#[no_mangle]
pub extern "C" fn ockam_vault_default_init(context: &mut FfiVaultFatPointer) -> FfiOckamError {
handle_panics(|| {
let handle = block_future(async move {
let mut write_lock = SOFTWARE_VAULTS.write().await;
write_lock.push(Default::default());
write_lock.len() - 1
});
*context = FfiVaultFatPointer::new(handle as u64, FfiVaultType::Software);
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_sha256(
context: FfiVaultFatPointer,
input: *const u8,
input_length: u32,
digest: *mut u8,
) -> FfiOckamError {
handle_panics(|| {
check_buffer!(input);
check_buffer!(digest);
let input = unsafe { core::slice::from_raw_parts(input, input_length as usize) };
let res = block_future(async move {
let entry = get_vault_entry(context).await?;
entry.vault.sha256(input).await
})?;
unsafe {
std::ptr::copy_nonoverlapping(res.as_ptr(), digest, res.len());
}
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_secret_generate(
context: FfiVaultFatPointer,
secret: &mut SecretKeyHandle,
attributes: FfiSecretAttributes,
) -> FfiOckamError {
handle_panics(|| {
*secret = block_future(async move {
let entry = get_vault_entry(context).await?;
let atts = attributes.try_into()?;
let key_id = entry.vault.secret_generate(atts).await?;
let index = entry.insert(key_id).await;
Ok::<u64, Error>(index)
})?;
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_secret_import(
context: FfiVaultFatPointer,
secret: &mut SecretKeyHandle,
attributes: FfiSecretAttributes,
input: *mut u8,
input_length: u32,
) -> FfiOckamError {
handle_panics(|| {
check_buffer!(input, input_length);
*secret = block_future(async move {
let entry = get_vault_entry(context).await?;
let atts = attributes.try_into()?;
let secret_data = unsafe { core::slice::from_raw_parts(input, input_length as usize) };
let secret = Secret::Key(SecretKey::new(secret_data.to_vec()));
let key_id = entry.vault.secret_import(secret, atts).await?;
let index = entry.insert(key_id).await;
Ok::<u64, Error>(index)
})?;
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_secret_export(
context: FfiVaultFatPointer,
secret: SecretKeyHandle,
output_buffer: *mut u8,
output_buffer_size: u32,
output_buffer_length: &mut u32,
) -> FfiOckamError {
*output_buffer_length = 0;
handle_panics(|| {
block_future(async move {
let entry = get_vault_entry(context).await?;
let key_id = entry.get(secret).await?;
let key = entry.vault.secret_export(&key_id).await?;
if output_buffer_size < key.try_as_key()?.as_ref().len() as u32 {
return Err(FfiError::BufferTooSmall.into());
}
*output_buffer_length = key.try_as_key()?.as_ref().len() as u32;
unsafe {
std::ptr::copy_nonoverlapping(
key.try_as_key()?.as_ref().as_ptr(),
output_buffer,
key.try_as_key()?.as_ref().len(),
);
};
Ok::<(), Error>(())
})?;
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_secret_publickey_get(
context: FfiVaultFatPointer,
secret: SecretKeyHandle,
output_buffer: *mut u8,
output_buffer_size: u32,
output_buffer_length: &mut u32,
) -> FfiOckamError {
*output_buffer_length = 0;
handle_panics(|| {
block_future(async move {
let entry = get_vault_entry(context).await?;
let key_id = entry.get(secret).await?;
let key = entry.vault.secret_public_key_get(&key_id).await?;
if output_buffer_size < key.data().len() as u32 {
return Err(FfiError::BufferTooSmall.into());
}
*output_buffer_length = key.data().len() as u32;
unsafe {
std::ptr::copy_nonoverlapping(key.data().as_ptr(), output_buffer, key.data().len());
};
Ok::<(), Error>(())
})?;
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_secret_attributes_get(
context: FfiVaultFatPointer,
secret: SecretKeyHandle,
attributes: &mut FfiSecretAttributes,
) -> FfiOckamError {
handle_panics(|| {
*attributes = block_future(async move {
let entry = get_vault_entry(context).await?;
let key_id = entry.get(secret).await?;
let atts = entry.vault.secret_attributes_get(&key_id).await?;
Ok::<FfiSecretAttributes, Error>(atts.into())
})?;
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_secret_destroy(
context: FfiVaultFatPointer,
secret: SecretKeyHandle,
) -> FfiOckamError {
match block_future(async move {
let entry = get_vault_entry(context).await?;
let key_id = entry.take(secret).await?;
entry.vault.secret_destroy(key_id).await?;
Ok::<(), Error>(())
}) {
Ok(_) => FfiOckamError::none(),
Err(err) => err.into(),
}
}
#[no_mangle]
pub extern "C" fn ockam_vault_ecdh(
context: FfiVaultFatPointer,
secret: SecretKeyHandle,
peer_publickey: *const u8,
peer_publickey_length: u32,
shared_secret: &mut SecretKeyHandle,
) -> FfiOckamError {
handle_panics(|| {
check_buffer!(peer_publickey, peer_publickey_length);
let peer_publickey =
unsafe { core::slice::from_raw_parts(peer_publickey, peer_publickey_length as usize) };
*shared_secret = block_future(async move {
let entry = get_vault_entry(context).await?;
let key_id = entry.get(secret).await?;
let atts = entry.vault.secret_attributes_get(&key_id).await?;
let pubkey = PublicKey::new(peer_publickey.to_vec(), atts.stype());
let shared_ctx = entry.vault.ec_diffie_hellman(&key_id, &pubkey).await?;
let index = entry.insert(shared_ctx).await;
Ok::<u64, Error>(index)
})?;
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_hkdf_sha256(
context: FfiVaultFatPointer,
salt: SecretKeyHandle,
input_key_material: *const SecretKeyHandle,
derived_outputs_attributes: *const FfiSecretAttributes,
derived_outputs_count: u8,
derived_outputs: *mut SecretKeyHandle,
) -> FfiOckamError {
handle_panics(|| {
let derived_outputs_count = derived_outputs_count as usize;
block_future(async move {
let entry = get_vault_entry(context).await?;
let salt_key_id = entry.get(salt).await?;
let ikm_key_id = if input_key_material.is_null() {
None
} else {
let ctx = unsafe { entry.get(*input_key_material).await? };
Some(ctx)
};
let ikm_key_id = ikm_key_id.as_ref();
let array: &[FfiSecretAttributes] =
unsafe { slice::from_raw_parts(derived_outputs_attributes, derived_outputs_count) };
let mut output_attributes = Vec::<SecretAttributes>::with_capacity(array.len());
for x in array.iter() {
output_attributes.push(SecretAttributes::try_from(*x)?);
}
let hkdf_output = entry
.vault
.hkdf_sha256(&salt_key_id, b"", ikm_key_id, output_attributes)
.await?;
let hkdf_output: Vec<SecretKeyHandle> =
join_all(hkdf_output.into_iter().map(|x| entry.insert(x))).await;
unsafe {
std::ptr::copy_nonoverlapping(
hkdf_output.as_ptr(),
derived_outputs,
derived_outputs_count,
)
};
Ok::<(), Error>(())
})?;
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_aead_aes_gcm_encrypt(
context: FfiVaultFatPointer,
secret: SecretKeyHandle,
nonce: u16,
additional_data: *const u8,
additional_data_length: u32,
plaintext: *const u8,
plaintext_length: u32,
ciphertext_and_tag: &mut u8,
ciphertext_and_tag_size: u32,
ciphertext_and_tag_length: &mut u32,
) -> FfiOckamError {
*ciphertext_and_tag_length = 0;
handle_panics(|| {
check_buffer!(additional_data);
check_buffer!(plaintext);
let additional_data = unsafe {
core::slice::from_raw_parts(additional_data, additional_data_length as usize)
};
let plaintext =
unsafe { core::slice::from_raw_parts(plaintext, plaintext_length as usize) };
block_future(async move {
let entry = get_vault_entry(context).await?;
let key_id = entry.get(secret).await?;
let mut nonce_vec = vec![0; 12 - 2];
nonce_vec.extend_from_slice(&nonce.to_be_bytes());
let ciphertext = entry
.vault
.aead_aes_gcm_encrypt(&key_id, plaintext, &nonce_vec, additional_data)
.await?;
if ciphertext_and_tag_size < ciphertext.len() as u32 {
return Err(FfiError::BufferTooSmall.into());
}
*ciphertext_and_tag_length = ciphertext.len() as u32;
unsafe {
std::ptr::copy_nonoverlapping(
ciphertext.as_ptr(),
ciphertext_and_tag,
ciphertext.len(),
)
};
Ok::<(), Error>(())
})?;
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_aead_aes_gcm_decrypt(
context: FfiVaultFatPointer,
secret: SecretKeyHandle,
nonce: u16,
additional_data: *const u8,
additional_data_length: u32,
ciphertext_and_tag: *const u8,
ciphertext_and_tag_length: u32,
plaintext: &mut u8,
plaintext_size: u32,
plaintext_length: &mut u32,
) -> FfiOckamError {
*plaintext_length = 0;
handle_panics(|| {
check_buffer!(ciphertext_and_tag, ciphertext_and_tag_length);
check_buffer!(additional_data);
let additional_data = unsafe {
core::slice::from_raw_parts(additional_data, additional_data_length as usize)
};
let ciphertext_and_tag = unsafe {
core::slice::from_raw_parts(ciphertext_and_tag, ciphertext_and_tag_length as usize)
};
block_future(async move {
let entry = get_vault_entry(context).await?;
let key_id = entry.get(secret).await?;
let mut nonce_vec = vec![0; 12 - 2];
nonce_vec.extend_from_slice(&nonce.to_be_bytes());
let plain = entry
.vault
.aead_aes_gcm_decrypt(&key_id, ciphertext_and_tag, &nonce_vec, additional_data)
.await?;
if plaintext_size < plain.len() as u32 {
return Err(FfiError::BufferTooSmall.into());
}
*plaintext_length = plain.len() as u32;
unsafe { std::ptr::copy_nonoverlapping(plain.as_ptr(), plaintext, plain.len()) };
Ok::<(), Error>(())
})?;
Ok(())
})
}
#[no_mangle]
pub extern "C" fn ockam_vault_deinit(context: FfiVaultFatPointer) -> FfiOckamError {
handle_panics(|| {
block_future(async move {
match context.vault_type() {
FfiVaultType::Software => {
let handle = context.handle() as usize;
let mut v = SOFTWARE_VAULTS.write().await;
if handle < v.len() {
v.remove(handle);
Ok(())
} else {
Err(FfiError::VaultNotFound)
}
}
}
})?;
Ok(())
})
}
fn handle_panics<F>(f: F) -> FfiOckamError
where
F: FnOnce() -> StdResult<(), FfiOckamError>,
{
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
match result {
Ok(Ok(())) => FfiOckamError::none(),
Ok(Err(e)) => e,
Err(e) => {
let panic_guard = AbortOnDrop;
drop(e);
let ret = FfiOckamError::from(FfiError::UnexpectedPanic);
core::mem::forget(panic_guard);
ret
}
}
}
struct AbortOnDrop;
impl Drop for AbortOnDrop {
fn drop(&mut self) {
eprintln!("Panic from error drop, aborting!");
std::process::abort();
}
}