use std::collections::HashMap;
use std::sync::Arc;
use epics_base_rs::server::database::PvDatabase;
use epics_base_rs::types::EpicsValue;
use epics_ca_rs::client::CaClient;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use crate::error::{BridgeError, BridgeResult};
use super::cache::{PvCache, PvState};
pub struct UpstreamManager {
client: Arc<CaClient>,
cache: Arc<RwLock<PvCache>>,
shadow_db: Arc<PvDatabase>,
tasks: HashMap<String, JoinHandle<()>>,
}
impl UpstreamManager {
pub async fn new(
cache: Arc<RwLock<PvCache>>,
shadow_db: Arc<PvDatabase>,
) -> BridgeResult<Self> {
let client = CaClient::new()
.await
.map_err(|e| BridgeError::PutRejected(format!("CaClient init: {e}")))?;
Ok(Self {
client: Arc::new(client),
cache,
shadow_db,
tasks: HashMap::new(),
})
}
pub fn subscription_count(&self) -> usize {
self.tasks.len()
}
pub fn is_subscribed(&self, name: &str) -> bool {
self.tasks.contains_key(name)
}
pub async fn ensure_subscribed(&mut self, upstream_name: &str) -> BridgeResult<()> {
if self.tasks.contains_key(upstream_name) {
return Ok(());
}
{
let mut cache = self.cache.write().await;
let entry = cache.get_or_create(upstream_name);
entry.write().await.set_state(PvState::Connecting);
}
self.shadow_db
.add_pv(upstream_name, EpicsValue::Double(0.0))
.await;
let channel = self.client.create_channel(upstream_name);
let mut monitor = channel
.subscribe()
.await
.map_err(|e| BridgeError::PutRejected(format!("subscribe failed: {e}")))?;
let cache_clone = self.cache.clone();
let db_clone = self.shadow_db.clone();
let name = upstream_name.to_string();
let task = tokio::spawn(async move {
while let Some(result) = monitor.recv().await {
let snapshot = match result {
Ok(s) => s,
Err(_) => continue,
};
if let Some(entry_arc) = cache_clone.read().await.get(&name) {
let mut entry = entry_arc.write().await;
if entry.state == PvState::Connecting {
entry.set_state(PvState::Inactive);
}
entry.update(snapshot.clone());
}
let _ = db_clone
.put_pv_and_post(&name, snapshot.value.clone())
.await;
}
if let Some(entry_arc) = cache_clone.read().await.get(&name) {
entry_arc.write().await.set_state(PvState::Disconnect);
}
});
self.tasks.insert(upstream_name.to_string(), task);
Ok(())
}
pub async fn unsubscribe(&mut self, upstream_name: &str) {
if let Some(task) = self.tasks.remove(upstream_name) {
task.abort();
}
}
pub async fn put(&self, upstream_name: &str, value: &EpicsValue) -> BridgeResult<()> {
let channel = self.client.create_channel(upstream_name);
channel
.put(value)
.await
.map_err(|e| BridgeError::PutRejected(format!("upstream put: {e}")))
}
pub async fn get(&self, upstream_name: &str) -> BridgeResult<EpicsValue> {
let channel = self.client.create_channel(upstream_name);
let (_dbf, value) = channel
.get()
.await
.map_err(|e| BridgeError::PutRejected(format!("upstream get: {e}")))?;
Ok(value)
}
pub async fn sweep_orphaned(&mut self) {
let cache = self.cache.read().await;
let live_names: Vec<String> = self
.tasks
.keys()
.filter(|name| cache.get(name).is_none())
.cloned()
.collect();
drop(cache);
for name in live_names {
self.unsubscribe(&name).await;
}
}
pub async fn shutdown(&mut self) {
for (_name, task) in self.tasks.drain() {
task.abort();
}
self.client.shutdown().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn manager_construct() {
let cache = Arc::new(RwLock::new(PvCache::new()));
let db = Arc::new(PvDatabase::new());
let mgr = UpstreamManager::new(cache, db).await;
assert!(mgr.is_ok());
let mgr = mgr.unwrap();
assert_eq!(mgr.subscription_count(), 0);
assert!(!mgr.is_subscribed("ANY"));
}
#[test]
fn _entry_imports() {
let _ = super::super::cache::GwPvEntry::new_connecting("X");
}
}