1type Result<T> = std::result::Result<T, wapc::errors::Error>;
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use crossbeam::channel::{Receiver as SyncReceiver, SendTimeoutError, Sender as SyncSender};
7use rusty_pool::ThreadPool;
8use tokio::sync::oneshot::Sender as OneshotSender;
9use wapc::WapcHost;
10
11use crate::errors::Error;
12
13#[must_use]
16pub struct HostPool {
17 pub name: String,
19 pool: Option<ThreadPool>,
20 factory: Arc<dyn Fn() -> WapcHost + Send + Sync + 'static>,
21 max_threads: usize,
22 max_wait: Duration,
23 max_idle: Duration,
24 tx: SyncSender<WorkerMessage>,
25 rx: SyncReceiver<WorkerMessage>,
26}
27
28impl std::fmt::Debug for HostPool {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("HostPool")
31 .field("name", &self.name)
32 .field("tx", &self.tx)
33 .field("rx", &self.rx)
34 .finish()
35 }
36}
37
38type WorkerMessage = (
39 OneshotSender<std::result::Result<Vec<u8>, wapc::errors::Error>>,
40 String,
41 Vec<u8>,
42);
43
44impl HostPool {
45 pub fn new<N, F>(
47 name: N,
48 factory: F,
49 min_threads: usize,
50 max_threads: usize,
51 max_wait: Duration,
52 max_idle: Duration,
53 ) -> Self
54 where
55 N: AsRef<str>,
56 F: Fn() -> WapcHost + Send + Sync + 'static,
57 {
58 debug!("Creating new wapc host pool with size {}", max_threads);
59 let arcfn = Arc::new(factory);
60 let pool = rusty_pool::Builder::new()
61 .name(name.as_ref().to_owned())
62 .core_size(min_threads)
63 .max_size(max_threads)
64 .keep_alive(Duration::from_millis(0))
65 .build();
66
67 let (tx, rx) = crossbeam::channel::bounded::<WorkerMessage>(1);
68
69 let pool = Self {
70 name: name.as_ref().to_owned(),
71 factory: arcfn,
72 pool: Some(pool),
73 max_threads,
74 max_wait,
75 max_idle,
76 tx,
77 rx,
78 };
79
80 for _ in 0..min_threads {
81 pool.spawn(None).unwrap();
82 }
83
84 pool
85 }
86
87 #[must_use]
89 pub fn num_active_workers(&self) -> usize {
90 self.pool.as_ref().map_or(0, |pool| pool.get_current_worker_count())
91 }
92
93 fn spawn(&self, max_idle: Option<Duration>) -> Result<()> {
94 self.pool.as_ref().map_or_else(
95 || Err(Error::NoPool.into()),
96 |pool| {
97 let name = self.name.clone();
98 let i = pool.get_current_worker_count();
99 let factory = self.factory.clone();
100 let rx = self.rx.clone();
101 pool.execute(move || {
102 trace!("Host thread {}.{} started...", name, i);
103 let host = factory();
104 loop {
105 let message = max_idle.map_or_else(
106 || rx.recv().map_err(|e| e.to_string()),
107 |duration| rx.recv_timeout(duration).map_err(|e| e.to_string()),
108 );
109 if let Err(e) = message {
110 debug!("Host thread {}.{} closing: {}", name, i, e);
111 break;
112 }
113 let (tx, op, payload) = message.unwrap();
114 trace!(
115 "Host thread {}.{} received call for {} with {} byte payload",
116 name,
117 i,
118 op,
119 payload.len()
120 );
121 let result = host.call(&op, &payload);
122 if tx.send(result).is_err() {
123 error!("Host thread {}.{} failed when returning a value...", name, i);
124 }
125 }
126
127 trace!("Host thread {}.{} stopped.", name, i);
128 });
129 Ok(())
130 },
131 )
132 }
133
134 pub async fn call<T: AsRef<str> + Sync + Send>(&self, op: T, payload: Vec<u8>) -> Result<Vec<u8>> {
136 let (tx, rx) = tokio::sync::oneshot::channel();
137 let result = match self
139 .tx
140 .send_timeout((tx, op.as_ref().to_owned(), payload), self.max_wait)
141 {
142 Ok(_) => Ok(()),
143 Err(e) => {
144 let args = match e {
146 SendTimeoutError::Timeout(args) => {
147 debug!("Timeout on pool '{}'", self.name);
148 args
149 }
150 SendTimeoutError::Disconnected(args) => {
151 warn!("Pool worker disconnected on pool '{}'", self.name);
152 args
153 }
154 };
155 if self.num_active_workers() < self.max_threads {
157 if let Err(e) = self.spawn(Some(self.max_idle)) {
158 error!("Error spawning worker for host pool '{}': {}", self.name, e);
159 };
160 }
161 self.tx.send(args)
163 }
164 };
165 if let Err(e) = result {
166 return Err(wapc::errors::Error::General(e.to_string()));
167 }
168 match rx.await {
169 Ok(res) => res,
170 Err(e) => Err(wapc::errors::Error::General(e.to_string())),
171 }
172 }
173
174 pub fn shutdown(&mut self) -> Result<()> {
176 let pool = self
177 .pool
178 .take()
179 .ok_or_else(|| wapc::errors::Error::from(crate::errors::Error::NoPool))?;
180
181 pool.shutdown_join();
182 Ok(())
183 }
184}
185
186#[must_use]
187pub struct HostPoolBuilder {
189 name: Option<String>,
190 factory: Option<Box<dyn Fn() -> WapcHost + Send + Sync + 'static>>,
191 min_threads: usize,
192 max_threads: usize,
193 max_wait: Duration,
194 max_idle: Duration,
195}
196
197impl std::fmt::Debug for HostPoolBuilder {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 f.debug_struct("HostPoolBuilder")
200 .field("name", &self.name)
201 .field("factory", if self.factory.is_some() { &"Some(Fn)" } else { &"None" })
202 .field("min_threads", &self.min_threads)
203 .field("max_threads", &self.max_threads)
204 .field("max_wait", &self.max_wait)
205 .field("max_idle", &self.max_idle)
206 .finish()
207 }
208}
209
210impl Default for HostPoolBuilder {
211 fn default() -> Self {
212 Self {
213 name: None,
214 factory: None,
215 min_threads: 1,
216 max_threads: 2,
217 max_wait: Duration::from_millis(100),
218 max_idle: Duration::from_secs(5 * 60),
219 }
220 }
221}
222
223impl HostPoolBuilder {
224 pub fn new() -> Self {
232 Self::default()
233 }
234
235 pub fn name<T: AsRef<str>>(mut self, name: T) -> Self {
243 self.name = Some(name.as_ref().to_owned());
244 self
245 }
246
247 pub fn factory<F>(mut self, factory: F) -> Self
263 where
264 F: Fn() -> WapcHost + Send + Sync + 'static,
265 {
266 self.factory = Some(Box::new(factory));
267 self
268 }
269
270 pub fn min_threads(mut self, min: usize) -> Self {
278 self.min_threads = min;
279 self
280 }
281
282 pub fn max_threads(mut self, max: usize) -> Self {
290 self.max_threads = max;
291 self
292 }
293
294 pub fn max_idle(mut self, timeout: Duration) -> Self {
303 self.max_idle = timeout;
304 self
305 }
306
307 pub fn max_wait(mut self, duration: Duration) -> Self {
316 self.max_wait = duration;
317 self
318 }
319
320 pub fn build(mut self) -> HostPool {
336 #[allow(clippy::expect_used)]
337 let factory = self
338 .factory
339 .take()
340 .expect("A waPC host pool must have a factory function.");
341 HostPool::new(
342 self.name.unwrap_or_else(|| "waPC host pool".to_owned()),
343 factory,
344 self.min_threads,
345 self.max_threads,
346 self.max_wait,
347 self.max_idle,
348 )
349 }
350}
351
352#[cfg(test)]
353mod tests {
354
355 use std::time::{Duration, Instant};
356
357 use tokio::join;
358 use wapc::WebAssemblyEngineProvider;
359
360 use super::*;
361
362 #[test_log::test(tokio::test)]
363 async fn test_basic() -> Result<()> {
364 #[derive(Default)]
365 struct Test {
366 host: Option<Arc<wapc::ModuleState>>,
367 }
368 impl WebAssemblyEngineProvider for Test {
369 fn init(
370 &mut self,
371 host: Arc<wapc::ModuleState>,
372 ) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
373 self.host = Some(host);
374 Ok(())
375 }
376
377 fn call(
378 &mut self,
379 op_length: i32,
380 msg_length: i32,
381 ) -> std::result::Result<i32, Box<dyn std::error::Error + Send + Sync>> {
382 println!("op len:{}", op_length);
383 println!("msg len:{}", msg_length);
384 std::thread::sleep(Duration::from_millis(100));
385 let host = self.host.take().unwrap();
386 host.set_guest_response(b"{}".to_vec());
387 self.host.replace(host);
388 Ok(1)
389 }
390
391 fn replace(&mut self, _bytes: &[u8]) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
392 Ok(())
393 }
394 }
395 let pool = HostPoolBuilder::new()
396 .name("test")
397 .factory(move || WapcHost::new(Box::new(Test::default()), None).unwrap())
398 .min_threads(5)
399 .max_threads(5)
400 .build();
401
402 let now = Instant::now();
403 let result = pool.call("test", b"hello world".to_vec()).await.unwrap();
404 assert_eq!(result, b"{}");
405 let _res = join!(
406 pool.call("test", b"hello world".to_vec()),
407 pool.call("test", b"hello world".to_vec()),
408 pool.call("test", b"hello world".to_vec()),
409 pool.call("test", b"hello world".to_vec()),
410 pool.call("test", b"hello world".to_vec()),
411 pool.call("test", b"hello world".to_vec()),
412 pool.call("test", b"hello world".to_vec()),
413 pool.call("test", b"hello world".to_vec()),
414 );
415 let duration = now.elapsed();
416 println!("Took {}ms", duration.as_millis());
417 assert!(duration.as_millis() < 600);
418
419 Ok(())
420 }
421
422 #[test_log::test(tokio::test)]
423 async fn test_elasticity() -> Result<()> {
424 #[derive(Default)]
425 struct Test {
426 host: Option<Arc<wapc::ModuleState>>,
427 }
428 impl WebAssemblyEngineProvider for Test {
429 fn init(
430 &mut self,
431 host: Arc<wapc::ModuleState>,
432 ) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
433 self.host = Some(host);
434 Ok(())
435 }
436
437 fn call(&mut self, _: i32, _: i32) -> std::result::Result<i32, Box<dyn std::error::Error + Send + Sync>> {
438 std::thread::sleep(Duration::from_millis(100));
439 let host = self.host.take().unwrap();
440 host.set_guest_response(b"{}".to_vec());
441 self.host.replace(host);
442 Ok(1)
443 }
444
445 fn replace(&mut self, _bytes: &[u8]) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
446 Ok(())
447 }
448 }
449 let pool = HostPoolBuilder::new()
450 .name("test")
451 .factory(move || WapcHost::new(Box::new(Test::default()), None).unwrap())
452 .min_threads(1)
453 .max_threads(5)
454 .max_wait(Duration::from_millis(10))
455 .max_idle(Duration::from_secs(1))
456 .build();
457 assert_eq!(pool.num_active_workers(), 1);
458 let _ = futures::future::join_all(vec![
459 pool.call("test", b"hello world".to_vec()),
460 pool.call("test", b"hello world".to_vec()),
461 pool.call("test", b"hello world".to_vec()),
462 ])
463 .await;
464 assert_eq!(pool.num_active_workers(), 2);
465 let _ = futures::future::join_all(vec![
466 pool.call("test", b"hello world".to_vec()),
467 pool.call("test", b"hello world".to_vec()),
468 pool.call("test", b"hello world".to_vec()),
469 pool.call("test", b"hello world".to_vec()),
470 pool.call("test", b"hello world".to_vec()),
471 pool.call("test", b"hello world".to_vec()),
472 pool.call("test", b"hello world".to_vec()),
473 pool.call("test", b"hello world".to_vec()),
474 pool.call("test", b"hello world".to_vec()),
475 ])
476 .await;
477 assert_eq!(pool.num_active_workers(), 5);
478 std::thread::sleep(Duration::from_millis(1500));
479 assert_eq!(pool.num_active_workers(), 1);
480
481 Ok(())
482 }
483}