use crate::provider_core::ProviderCore;
use ipc_queue::Identified;
use std::cell::RefCell;
use std::mem;
use std::os::fortanix_sgx::usercalls::alloc::{User, UserSafe};
use std::os::fortanix_sgx::usercalls::raw::{Usercall, UsercallNrs};
pub trait BatchDroppable: private::BatchDroppable {}
impl<T: private::BatchDroppable> BatchDroppable for T {}
pub fn batch_drop<T: BatchDroppable>(t: T) {
t.batch_drop();
}
mod private {
use super::*;
const BATCH_SIZE: usize = 8;
struct BatchDropProvider {
core: ProviderCore,
deferred: Vec<Identified<Usercall>>,
}
impl BatchDropProvider {
pub fn new() -> Self {
Self {
core: ProviderCore::new(None),
deferred: Vec::with_capacity(BATCH_SIZE),
}
}
fn make_progress(&self, deferred: &[Identified<Usercall>]) -> usize {
if deferred.is_empty() {
return 0;
}
let sent = self.core.try_send_multiple_usercalls(deferred);
if sent == 0 {
self.core.send_usercall(deferred[0]);
return 1;
}
sent
}
fn maybe_send_usercall(&mut self, u: Usercall) {
self.deferred.push(self.core.assign_id(u));
if self.deferred.len() < BATCH_SIZE {
return;
}
let sent = self.make_progress(&self.deferred);
self.deferred.drain(..sent);
}
pub fn free<T: UserSafe + ?Sized>(&mut self, buf: User<T>) {
let ptr = buf.into_raw();
let size = unsafe { mem::size_of_val(&mut *ptr) };
let alignment = T::align_of();
let ptr = ptr as *mut u8;
let u = Usercall(UsercallNrs::free as _, ptr as _, size as _, alignment as _, 0);
self.maybe_send_usercall(u);
}
}
impl Drop for BatchDropProvider {
fn drop(&mut self) {
let mut sent = 0;
while sent < self.deferred.len() {
sent += self.make_progress(&self.deferred[sent..]);
}
}
}
std::thread_local! {
static PROVIDER: RefCell<BatchDropProvider> = RefCell::new(BatchDropProvider::new());
}
pub trait BatchDroppable {
fn batch_drop(self);
}
impl<T: UserSafe + ?Sized> BatchDroppable for User<T> {
fn batch_drop(self) {
PROVIDER.with(|p| p.borrow_mut().free(self));
}
}
}
#[cfg(test)]
mod tests {
use super::batch_drop;
use std::os::fortanix_sgx::usercalls::alloc::User;
use std::thread;
#[test]
fn basic() {
for _ in 0..100 {
let bytes = rand::random::<usize>() % 256;
batch_drop(User::<[u8]>::uninitialized(bytes));
}
}
#[test]
fn multiple_threads() {
const THREADS: usize = 16;
let mut handles = Vec::with_capacity(THREADS);
for _ in 0..THREADS {
handles.push(thread::spawn(move || {
for _ in 0..1000 {
let bytes = rand::random::<usize>() % 256;
batch_drop(User::<[u8]>::uninitialized(bytes));
}
}));
}
for h in handles {
h.join().unwrap();
}
}
}