use std::time::Duration;
use epics_base_rs::error::{CaError, CaResult};
use epics_base_rs::types::{DbFieldType, EpicsValue};
use super::CaChannel;
type GetFuture = std::pin::Pin<
Box<dyn std::future::Future<Output = CaResult<(DbFieldType, EpicsValue)>> + Send>,
>;
type PutFuture = std::pin::Pin<Box<dyn std::future::Future<Output = CaResult<()>> + Send>>;
#[derive(Default)]
pub struct SyncGroup {
gets: Vec<GetFuture>,
puts: Vec<PutFuture>,
}
#[derive(Debug)]
pub struct SyncGroupResults {
pub gets: Vec<CaResult<(DbFieldType, EpicsValue)>>,
pub puts: Vec<CaResult<()>>,
}
impl SyncGroup {
pub fn new() -> Self {
Self::default()
}
pub fn get(&mut self, ch: &CaChannel) {
let ch = ch.clone();
self.gets.push(Box::pin(async move { ch.get().await }));
}
pub fn put(&mut self, ch: &CaChannel, value: EpicsValue) {
let ch = ch.clone();
self.puts
.push(Box::pin(async move { ch.put(&value).await }));
}
pub async fn block(self, timeout: Duration) -> CaResult<SyncGroupResults> {
let SyncGroup { gets, puts } = self;
let get_join = futures_util::future::join_all(gets);
let put_join = futures_util::future::join_all(puts);
let combined = async { tokio::join!(get_join, put_join) };
let (gets_res, puts_res) = tokio::time::timeout(timeout, combined)
.await
.map_err(|_| CaError::Timeout)?;
Ok(SyncGroupResults {
gets: gets_res,
puts: puts_res,
})
}
pub fn len(&self) -> usize {
self.gets.len() + self.puts.len()
}
pub fn is_empty(&self) -> bool {
self.gets.is_empty() && self.puts.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_group_blocks_immediately() {
let g = SyncGroup::new();
assert!(g.is_empty());
assert_eq!(g.len(), 0);
let rt = tokio::runtime::Runtime::new().unwrap();
let res = rt.block_on(async { g.block(Duration::from_millis(50)).await });
let r = res.expect("empty group should never time out");
assert_eq!(r.gets.len(), 0);
assert_eq!(r.puts.len(), 0);
}
}