use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use epics_base_rs::server::snapshot::Snapshot;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PvState {
Dead,
Connecting,
Inactive,
Active,
Disconnect,
}
impl PvState {
pub fn is_existent(self) -> bool {
matches!(self, Self::Inactive | Self::Active)
}
}
#[derive(Debug)]
pub struct GwPvEntry {
pub name: String,
pub state: PvState,
pub cached: Option<Snapshot>,
pub subscribers: Vec<u32>,
pub state_since: Instant,
pub event_count: u64,
pub total_alive: Duration,
pub total_dead: Duration,
}
impl GwPvEntry {
pub fn new_connecting(name: impl Into<String>) -> Self {
Self {
name: name.into(),
state: PvState::Connecting,
cached: None,
subscribers: Vec::new(),
state_since: Instant::now(),
event_count: 0,
total_alive: Duration::ZERO,
total_dead: Duration::ZERO,
}
}
pub fn set_state(&mut self, new: PvState) {
if self.state != new {
let elapsed = self.state_since.elapsed();
if self.state.is_existent() {
self.total_alive = self.total_alive.saturating_add(elapsed);
} else {
self.total_dead = self.total_dead.saturating_add(elapsed);
}
self.state = new;
self.state_since = Instant::now();
}
}
pub fn add_subscriber(&mut self, sid: u32) {
if !self.subscribers.contains(&sid) {
self.subscribers.push(sid);
}
if self.state == PvState::Inactive && !self.subscribers.is_empty() {
self.set_state(PvState::Active);
}
}
pub fn remove_subscriber(&mut self, sid: u32) {
self.subscribers.retain(|s| *s != sid);
if self.state == PvState::Active && self.subscribers.is_empty() {
self.set_state(PvState::Inactive);
}
}
pub fn subscriber_count(&self) -> usize {
self.subscribers.len()
}
pub fn update(&mut self, snap: Snapshot) {
self.cached = Some(snap);
self.event_count += 1;
}
pub fn time_in_state(&self) -> Duration {
self.state_since.elapsed()
}
}
#[derive(Debug, Clone, Copy)]
pub struct CacheTimeouts {
pub connect_timeout: Duration,
pub inactive_timeout: Duration,
pub dead_timeout: Duration,
pub disconnect_timeout: Duration,
}
impl Default for CacheTimeouts {
fn default() -> Self {
Self {
connect_timeout: Duration::from_secs(1),
inactive_timeout: Duration::from_secs(60 * 60 * 2),
dead_timeout: Duration::from_secs(60 * 2),
disconnect_timeout: Duration::from_secs(60 * 60 * 2),
}
}
}
#[derive(Debug, Default)]
pub struct PvCache {
entries: HashMap<String, Arc<RwLock<GwPvEntry>>>,
}
impl PvCache {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn get(&self, name: &str) -> Option<Arc<RwLock<GwPvEntry>>> {
self.entries.get(name).cloned()
}
pub fn insert(&mut self, entry: GwPvEntry) -> Arc<RwLock<GwPvEntry>> {
let name = entry.name.clone();
let arc = Arc::new(RwLock::new(entry));
self.entries.insert(name, arc.clone());
arc
}
pub fn get_or_create(&mut self, name: &str) -> Arc<RwLock<GwPvEntry>> {
if let Some(arc) = self.entries.get(name) {
return arc.clone();
}
self.insert(GwPvEntry::new_connecting(name.to_string()))
}
pub fn remove(&mut self, name: &str) -> Option<Arc<RwLock<GwPvEntry>>> {
self.entries.remove(name)
}
pub fn names(&self) -> Vec<String> {
self.entries.keys().cloned().collect()
}
pub async fn count_by_state(&self, state: PvState) -> usize {
let entries: Vec<Arc<RwLock<GwPvEntry>>> = self.entries.values().cloned().collect();
let mut count = 0;
for entry in entries {
if entry.read().await.state == state {
count += 1;
}
}
count
}
pub async fn count_states(&self) -> (usize, usize, usize, usize, usize) {
let entries: Vec<Arc<RwLock<GwPvEntry>>> = self.entries.values().cloned().collect();
let mut connecting = 0;
let mut active = 0;
let mut inactive = 0;
let mut dead = 0;
let mut disconnect = 0;
for entry in entries {
match entry.read().await.state {
PvState::Connecting => connecting += 1,
PvState::Active => active += 1,
PvState::Inactive => inactive += 1,
PvState::Dead => dead += 1,
PvState::Disconnect => disconnect += 1,
}
}
(connecting, active, inactive, dead, disconnect)
}
pub async fn cleanup(&mut self, timeouts: &CacheTimeouts) -> Vec<String> {
let mut to_remove = Vec::new();
let mut to_demote: Vec<String> = Vec::new();
for (name, entry) in &self.entries {
let entry_guard = entry.read().await;
let elapsed = entry_guard.time_in_state();
match entry_guard.state {
PvState::Connecting => {
if elapsed > timeouts.connect_timeout {
to_demote.push(name.clone());
}
}
PvState::Inactive => {
if elapsed > timeouts.inactive_timeout {
to_remove.push(name.clone());
}
}
PvState::Dead => {
if elapsed > timeouts.dead_timeout {
to_remove.push(name.clone());
}
}
PvState::Disconnect => {
if elapsed > timeouts.disconnect_timeout {
to_remove.push(name.clone());
}
}
PvState::Active => { }
}
}
for name in &to_demote {
if let Some(arc) = self.entries.get(name) {
arc.write().await.set_state(PvState::Dead);
}
}
for name in &to_remove {
self.entries.remove(name);
}
to_remove
}
}
#[cfg(test)]
mod tests {
use super::*;
use epics_base_rs::types::EpicsValue;
use std::time::SystemTime;
fn dummy_snapshot(v: f64) -> Snapshot {
Snapshot::new(EpicsValue::Double(v), 0, 0, SystemTime::now())
}
#[test]
fn pv_state_is_existent() {
assert!(PvState::Inactive.is_existent());
assert!(PvState::Active.is_existent());
assert!(!PvState::Dead.is_existent());
assert!(!PvState::Connecting.is_existent());
assert!(!PvState::Disconnect.is_existent());
}
#[test]
fn entry_subscriber_lifecycle() {
let mut e = GwPvEntry::new_connecting("TEMP");
assert_eq!(e.state, PvState::Connecting);
assert_eq!(e.subscriber_count(), 0);
e.set_state(PvState::Inactive);
assert_eq!(e.state, PvState::Inactive);
e.add_subscriber(1);
assert_eq!(e.state, PvState::Active);
assert_eq!(e.subscriber_count(), 1);
e.add_subscriber(2);
assert_eq!(e.state, PvState::Active);
assert_eq!(e.subscriber_count(), 2);
e.add_subscriber(2);
assert_eq!(e.subscriber_count(), 2);
e.remove_subscriber(1);
assert_eq!(e.state, PvState::Active);
assert_eq!(e.subscriber_count(), 1);
e.remove_subscriber(2);
assert_eq!(e.state, PvState::Inactive);
assert_eq!(e.subscriber_count(), 0);
}
#[test]
fn entry_update_increments_event_count() {
let mut e = GwPvEntry::new_connecting("TEMP");
assert_eq!(e.event_count, 0);
assert!(e.cached.is_none());
e.update(dummy_snapshot(1.0));
assert_eq!(e.event_count, 1);
assert!(e.cached.is_some());
e.update(dummy_snapshot(2.0));
assert_eq!(e.event_count, 2);
}
#[tokio::test]
async fn cache_get_or_create() {
let mut cache = PvCache::new();
assert!(cache.is_empty());
let arc1 = cache.get_or_create("TEMP");
assert_eq!(cache.len(), 1);
assert_eq!(arc1.read().await.state, PvState::Connecting);
let arc2 = cache.get_or_create("TEMP");
assert!(Arc::ptr_eq(&arc1, &arc2));
assert_eq!(cache.len(), 1);
cache.get_or_create("PRESSURE");
assert_eq!(cache.len(), 2);
}
#[tokio::test]
async fn cache_count_by_state() {
let mut cache = PvCache::new();
let a = cache.insert(GwPvEntry::new_connecting("A"));
let b = cache.insert(GwPvEntry::new_connecting("B"));
let _c = cache.insert(GwPvEntry::new_connecting("C"));
a.write().await.set_state(PvState::Active);
b.write().await.set_state(PvState::Inactive);
assert_eq!(cache.count_by_state(PvState::Connecting).await, 1);
assert_eq!(cache.count_by_state(PvState::Inactive).await, 1);
assert_eq!(cache.count_by_state(PvState::Active).await, 1);
assert_eq!(cache.count_by_state(PvState::Dead).await, 0);
}
#[tokio::test]
async fn cache_cleanup_removes_expired() {
let mut cache = PvCache::new();
let dead = cache.insert(GwPvEntry::new_connecting("DEAD"));
let active = cache.insert(GwPvEntry::new_connecting("ALIVE"));
{
let mut e = dead.write().await;
e.state = PvState::Dead;
e.state_since = Instant::now() - Duration::from_secs(60 * 60);
}
{
let mut e = active.write().await;
e.state = PvState::Active;
}
let timeouts = CacheTimeouts::default();
let removed = cache.cleanup(&timeouts).await;
assert_eq!(removed, vec!["DEAD".to_string()]);
assert!(cache.get("DEAD").is_none());
assert!(cache.get("ALIVE").is_some());
}
#[tokio::test]
async fn cache_cleanup_demotes_connecting_to_dead() {
let mut cache = PvCache::new();
let stuck = cache.insert(GwPvEntry::new_connecting("STUCK"));
{
let mut e = stuck.write().await;
e.state_since = Instant::now() - Duration::from_secs(5);
}
let timeouts = CacheTimeouts::default();
let removed = cache.cleanup(&timeouts).await;
assert!(removed.is_empty());
assert!(cache.get("STUCK").is_some());
assert_eq!(stuck.read().await.state, PvState::Dead);
{
let mut e = stuck.write().await;
e.state_since = Instant::now() - Duration::from_secs(60 * 5);
}
let removed = cache.cleanup(&timeouts).await;
assert_eq!(removed, vec!["STUCK".to_string()]);
assert!(cache.get("STUCK").is_none());
}
}