type Result<T> = std::result::Result<T, wapc::errors::Error>;
use std::sync::Arc;
use std::time::Duration;
use crossbeam::channel::{Receiver as SyncReceiver, SendTimeoutError, Sender as SyncSender};
use rusty_pool::ThreadPool;
use tokio::sync::oneshot::Sender as OneshotSender;
use wapc::WapcHost;
use crate::errors::Error;
#[must_use]
pub struct HostPool {
pub name: String,
pool: Option<ThreadPool>,
factory: Arc<dyn Fn() -> WapcHost + Send + Sync + 'static>,
max_threads: usize,
max_wait: Duration,
max_idle: Duration,
tx: SyncSender<WorkerMessage>,
rx: SyncReceiver<WorkerMessage>,
}
impl std::fmt::Debug for HostPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HostPool")
.field("name", &self.name)
.field("tx", &self.tx)
.field("rx", &self.rx)
.finish()
}
}
type WorkerMessage = (
OneshotSender<std::result::Result<Vec<u8>, wapc::errors::Error>>,
String,
Vec<u8>,
);
impl HostPool {
pub fn new<N, F>(
name: N,
factory: F,
min_threads: usize,
max_threads: usize,
max_wait: Duration,
max_idle: Duration,
) -> Self
where
N: AsRef<str>,
F: Fn() -> WapcHost + Send + Sync + 'static,
{
debug!("Creating new wapc host pool with size {}", max_threads);
let arcfn = Arc::new(factory);
let pool = rusty_pool::Builder::new()
.name(name.as_ref().to_owned())
.core_size(min_threads)
.max_size(max_threads)
.keep_alive(Duration::from_millis(0))
.build();
let (tx, rx) = crossbeam::channel::bounded::<WorkerMessage>(1);
let pool = Self {
name: name.as_ref().to_owned(),
factory: arcfn,
pool: Some(pool),
max_threads,
max_wait,
max_idle,
tx,
rx,
};
for _ in 0..min_threads {
pool.spawn(None).unwrap();
}
pool
}
#[must_use]
pub fn num_active_workers(&self) -> usize {
match &self.pool {
Some(pool) => pool.get_current_worker_count(),
None => 0,
}
}
fn spawn(&self, max_idle: Option<Duration>) -> Result<()> {
match &self.pool {
Some(pool) => {
let name = self.name.clone();
let i = pool.get_current_worker_count();
let factory = self.factory.clone();
let rx = self.rx.clone();
pool.execute(move || {
trace!("Host thread {}.{} started...", name, i);
let host = factory();
loop {
let message = match max_idle {
None => rx.recv().map_err(|e| e.to_string()),
Some(duration) => rx.recv_timeout(duration).map_err(|e| e.to_string()),
};
if let Err(e) = message {
debug!("Host thread {}.{} closing: {}", name, i, e);
break;
}
let (tx, op, payload) = message.unwrap();
trace!(
"Host thread {}.{} received call for {} with {} byte payload",
name,
i,
op,
payload.len()
);
let result = host.call(&op, &payload);
if tx.send(result).is_err() {
error!("Host thread {}.{} failed when returning a value...", name, i);
}
}
trace!("Host thread {}.{} stopped.", name, i);
});
Ok(())
}
None => Err(Error::NoPool.into()),
}
}
pub async fn call<T: AsRef<str> + Sync + Send>(&self, op: T, payload: Vec<u8>) -> Result<Vec<u8>> {
let (tx, rx) = tokio::sync::oneshot::channel();
let result = match self
.tx
.send_timeout((tx, op.as_ref().to_owned(), payload), self.max_wait)
{
Ok(_) => Ok(()),
Err(e) => {
let args = match e {
SendTimeoutError::Timeout(args) => {
debug!("Timeout on pool '{}'", self.name);
args
}
SendTimeoutError::Disconnected(args) => {
warn!("Pool worker disconnected on pool '{}'", self.name);
args
}
};
if self.num_active_workers() < self.max_threads {
if let Err(e) = self.spawn(Some(self.max_idle)) {
error!("Error spawning worker for host pool '{}': {}", self.name, e);
};
}
self.tx.send(args)
}
};
if let Err(e) = result {
return Err(wapc::errors::Error::General(e.to_string()));
}
match rx.await {
Ok(res) => res,
Err(e) => Err(wapc::errors::Error::General(e.to_string())),
}
}
pub fn shutdown(&mut self) -> Result<()> {
let pool = self
.pool
.take()
.ok_or_else(|| wapc::errors::Error::from(crate::errors::Error::NoPool))?;
pool.shutdown_join();
Ok(())
}
}
#[must_use]
pub struct HostPoolBuilder {
name: Option<String>,
factory: Option<Box<dyn Fn() -> WapcHost + Send + Sync + 'static>>,
min_threads: usize,
max_threads: usize,
max_wait: Duration,
max_idle: Duration,
}
impl std::fmt::Debug for HostPoolBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HostPoolBuilder")
.field("name", &self.name)
.field("factory", if self.factory.is_some() { &"Some(Fn)" } else { &"None" })
.field("min_threads", &self.min_threads)
.field("max_threads", &self.max_threads)
.field("max_wait", &self.max_wait)
.field("max_idle", &self.max_idle)
.finish()
}
}
impl Default for HostPoolBuilder {
fn default() -> Self {
Self {
name: None,
factory: None,
min_threads: 1,
max_threads: 2,
max_wait: Duration::from_millis(100),
max_idle: Duration::from_secs(5 * 60),
}
}
}
impl HostPoolBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn name<T: AsRef<str>>(mut self, name: T) -> Self {
self.name = Some(name.as_ref().to_owned());
self
}
pub fn factory<F>(mut self, factory: F) -> Self
where
F: Fn() -> WapcHost + Send + Sync + 'static,
{
self.factory = Some(Box::new(factory));
self
}
pub fn min_threads(mut self, min: usize) -> Self {
self.min_threads = min;
self
}
pub fn max_threads(mut self, max: usize) -> Self {
self.max_threads = max;
self
}
pub fn max_idle(mut self, timeout: Duration) -> Self {
self.max_idle = timeout;
self
}
pub fn max_wait(mut self, duration: Duration) -> Self {
self.max_wait = duration;
self
}
pub fn build(mut self) -> HostPool {
#[allow(clippy::expect_used)]
let factory = self
.factory
.take()
.expect("A waPC host pool must have a factory function.");
HostPool::new(
self.name.unwrap_or_else(|| "waPC host pool".to_owned()),
factory,
self.min_threads,
self.max_threads,
self.max_wait,
self.max_idle,
)
}
}
#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
use tokio::join;
use wapc::WebAssemblyEngineProvider;
use super::*;
#[test_log::test(tokio::test)]
async fn test_basic() -> Result<()> {
#[derive(Default)]
struct Test {
host: Option<Arc<wapc::ModuleState>>,
}
impl WebAssemblyEngineProvider for Test {
fn init(
&mut self,
host: Arc<wapc::ModuleState>,
) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.host = Some(host);
Ok(())
}
fn call(
&mut self,
op_length: i32,
msg_length: i32,
) -> std::result::Result<i32, Box<dyn std::error::Error + Send + Sync>> {
println!("op len:{}", op_length);
println!("msg len:{}", msg_length);
std::thread::sleep(Duration::from_millis(100));
let host = self.host.take().unwrap();
host.set_guest_response(b"{}".to_vec());
self.host.replace(host);
Ok(1)
}
fn replace(&mut self, _bytes: &[u8]) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}
let pool = HostPoolBuilder::new()
.name("test")
.factory(move || WapcHost::new(Box::new(Test::default()), None).unwrap())
.min_threads(5)
.max_threads(5)
.build();
let now = Instant::now();
let result = pool.call("test", b"hello world".to_vec()).await.unwrap();
assert_eq!(result, b"{}");
let _res = join!(
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
);
let duration = now.elapsed();
println!("Took {}ms", duration.as_millis());
assert!(duration.as_millis() < 600);
Ok(())
}
#[test_log::test(tokio::test)]
async fn test_elasticity() -> Result<()> {
#[derive(Default)]
struct Test {
host: Option<Arc<wapc::ModuleState>>,
}
impl WebAssemblyEngineProvider for Test {
fn init(
&mut self,
host: Arc<wapc::ModuleState>,
) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.host = Some(host);
Ok(())
}
fn call(&mut self, _: i32, _: i32) -> std::result::Result<i32, Box<dyn std::error::Error + Send + Sync>> {
std::thread::sleep(Duration::from_millis(100));
let host = self.host.take().unwrap();
host.set_guest_response(b"{}".to_vec());
self.host.replace(host);
Ok(1)
}
fn replace(&mut self, _bytes: &[u8]) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}
let pool = HostPoolBuilder::new()
.name("test")
.factory(move || WapcHost::new(Box::new(Test::default()), None).unwrap())
.min_threads(1)
.max_threads(5)
.max_wait(Duration::from_millis(10))
.max_idle(Duration::from_secs(1))
.build();
assert_eq!(pool.num_active_workers(), 1);
let _ = futures::future::join_all(vec![
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
])
.await;
assert_eq!(pool.num_active_workers(), 2);
let _ = futures::future::join_all(vec![
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
pool.call("test", b"hello world".to_vec()),
])
.await;
assert_eq!(pool.num_active_workers(), 5);
std::thread::sleep(Duration::from_millis(1500));
assert_eq!(pool.num_active_workers(), 1);
Ok(())
}
}