use std::time::Duration;
use tokio::task::JoinHandle;
use epics_base_rs::error::{CaError, CaResult};
use epics_base_rs::runtime::task::spawn;
use epics_base_rs::types::{DbFieldType, EpicsValue};
use super::CaChannel;
type GetOutput = CaResult<(DbFieldType, EpicsValue)>;
#[derive(Clone, Copy)]
enum OpKind {
Get,
Put,
}
enum Outcome {
Get(GetOutput),
Put(CaResult<()>),
}
struct SyncOp {
kind: OpKind,
handle: JoinHandle<Outcome>,
done: Option<Outcome>,
}
impl SyncOp {
fn is_complete(&self) -> bool {
self.done.is_some() || self.handle.is_finished()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SyncGroupStatus {
Done,
InProgress,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SyncGroupStat {
pub outstanding: usize,
pub completed: usize,
}
#[derive(Default)]
pub struct SyncGroup {
ops: Vec<SyncOp>,
}
#[derive(Debug)]
pub struct SyncGroupResults {
pub gets: Vec<GetOutput>,
pub puts: Vec<CaResult<()>>,
}
impl SyncGroup {
pub fn new() -> Self {
Self::default()
}
pub fn get(&mut self, ch: &CaChannel) {
let ch = ch.clone();
let handle = spawn(async move { Outcome::Get(ch.get().await) });
self.ops.push(SyncOp {
kind: OpKind::Get,
handle,
done: None,
});
}
pub fn put(&mut self, ch: &CaChannel, value: EpicsValue) {
let ch = ch.clone();
let handle = spawn(async move { Outcome::Put(ch.put(&value).await) });
self.ops.push(SyncOp {
kind: OpKind::Put,
handle,
done: None,
});
}
pub async fn block(&mut self, timeout: Duration) -> CaResult<SyncGroupResults> {
let collect = async {
for op in self.ops.iter_mut() {
if op.done.is_none() {
let kind = op.kind;
let outcome = match (&mut op.handle).await {
Ok(o) => o,
Err(_) => match kind {
OpKind::Get => Outcome::Get(Err(CaError::Disconnected)),
OpKind::Put => Outcome::Put(Err(CaError::Disconnected)),
},
};
op.done = Some(outcome);
}
}
};
match tokio::time::timeout(timeout, collect).await {
Ok(()) => {
let mut gets = Vec::new();
let mut puts = Vec::new();
for op in std::mem::take(&mut self.ops) {
match op.done.expect("collect filled every op on success") {
Outcome::Get(r) => gets.push(r),
Outcome::Put(r) => puts.push(r),
}
}
Ok(SyncGroupResults { gets, puts })
}
Err(_) => {
self.reset();
Err(CaError::Timeout)
}
}
}
pub fn test(&self) -> SyncGroupStatus {
if self.ops.iter().all(SyncOp::is_complete) {
SyncGroupStatus::Done
} else {
SyncGroupStatus::InProgress
}
}
pub fn reset(&mut self) {
for op in &self.ops {
op.handle.abort();
}
self.ops.clear();
}
pub fn stat(&self) -> SyncGroupStat {
let completed = self.ops.iter().filter(|op| op.is_complete()).count();
SyncGroupStat {
outstanding: self.ops.len() - completed,
completed,
}
}
pub fn len(&self) -> usize {
self.ops.len()
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
impl SyncGroup {
fn push_delayed_get(&mut self, ms: u64, val: i32) {
let handle = spawn(async move {
tokio::time::sleep(Duration::from_millis(ms)).await;
Outcome::Get(Ok((DbFieldType::Long, EpicsValue::Long(val))))
});
self.ops.push(SyncOp {
kind: OpKind::Get,
handle,
done: None,
});
}
}
#[tokio::test]
async fn empty_group_blocks_immediately() {
let mut g = SyncGroup::new();
assert!(g.is_empty());
assert_eq!(g.test(), SyncGroupStatus::Done);
let r = g
.block(Duration::from_millis(50))
.await
.expect("empty group never times out");
assert!(r.gets.is_empty() && r.puts.is_empty());
}
#[tokio::test]
async fn reusable_block_waits_only_for_current_batch() {
let mut g = SyncGroup::new();
g.push_delayed_get(10, 1);
g.push_delayed_get(20, 2);
let r1 = g.block(Duration::from_secs(2)).await.unwrap();
assert_eq!(r1.gets.len(), 2, "first batch");
assert!(g.is_empty(), "successful block clears the batch");
g.push_delayed_get(10, 3);
let r2 = g.block(Duration::from_secs(2)).await.unwrap();
assert_eq!(r2.gets.len(), 1, "second block waits only for the new op");
}
#[tokio::test]
async fn test_reports_in_progress_then_done() {
let mut g = SyncGroup::new();
g.push_delayed_get(80, 1);
assert_eq!(g.test(), SyncGroupStatus::InProgress);
assert_eq!(g.stat().outstanding, 1);
tokio::time::sleep(Duration::from_millis(150)).await;
assert_eq!(g.test(), SyncGroupStatus::Done);
assert_eq!(g.stat().completed, 1);
let r = g.block(Duration::from_secs(1)).await.unwrap();
assert_eq!(r.gets.len(), 1);
}
#[tokio::test]
async fn block_timeout_discards_batch_like_c() {
let mut g = SyncGroup::new();
g.push_delayed_get(60_000, 1); let r = g.block(Duration::from_millis(20)).await;
assert!(matches!(r, Err(CaError::Timeout)), "block times out");
assert!(g.is_empty(), "timeout empties the batch");
assert_eq!(g.test(), SyncGroupStatus::Done, "test() reports IODONE");
assert_eq!(
g.stat(),
SyncGroupStat {
outstanding: 0,
completed: 0
},
"no outstanding ops after a timed-out block"
);
g.push_delayed_get(10, 7);
let r2 = g
.block(Duration::from_secs(2))
.await
.expect("fresh batch completes");
assert_eq!(r2.gets.len(), 1, "only the new op is awaited");
assert!(
matches!(r2.gets[0], Ok((_, EpicsValue::Long(7)))),
"the new op's result, not the discarded one"
);
}
}