use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use regex::Regex;
use tracing::info;
use spvirit_types::{NtScalar, NtScalarArray, ScalarArrayValue, ScalarValue};
use crate::db::{load_db, parse_db};
use crate::handler::PvListMode;
use crate::monitor::MonitorRegistry;
use crate::server::{PvaServerConfig, run_pva_server_with_registry};
use crate::simple_store::{LinkDef, OnPutCallback, ScanCallback, SimplePvStore};
use crate::types::{DbCommonState, OutputMode, RecordData, RecordInstance, RecordType};
pub struct PvaServerBuilder {
records: HashMap<String, RecordInstance>,
on_put: HashMap<String, OnPutCallback>,
scans: Vec<(String, Duration, ScanCallback)>,
links: Vec<LinkDef>,
tcp_port: u16,
udp_port: u16,
listen_ip: Option<IpAddr>,
advertise_ip: Option<IpAddr>,
compute_alarms: bool,
beacon_period_secs: u64,
conn_timeout: Duration,
pvlist_mode: PvListMode,
pvlist_max: usize,
pvlist_allow_pattern: Option<Regex>,
}
impl PvaServerBuilder {
fn new() -> Self {
Self {
records: HashMap::new(),
on_put: HashMap::new(),
scans: Vec::new(),
links: Vec::new(),
tcp_port: 5075,
udp_port: 5076,
listen_ip: None,
advertise_ip: None,
compute_alarms: false,
beacon_period_secs: 15,
conn_timeout: Duration::from_secs(64000),
pvlist_mode: PvListMode::List,
pvlist_max: 1024,
pvlist_allow_pattern: None,
}
}
pub fn ai(mut self, name: impl Into<String>, initial: f64) -> Self {
let name = name.into();
self.records.insert(
name.clone(),
make_scalar_record(&name, RecordType::Ai, ScalarValue::F64(initial)),
);
self
}
pub fn ao(mut self, name: impl Into<String>, initial: f64) -> Self {
let name = name.into();
self.records.insert(
name.clone(),
make_output_record(&name, RecordType::Ao, ScalarValue::F64(initial)),
);
self
}
pub fn bi(mut self, name: impl Into<String>, initial: bool) -> Self {
let name = name.into();
self.records.insert(
name.clone(),
make_scalar_record(&name, RecordType::Bi, ScalarValue::Bool(initial)),
);
self
}
pub fn bo(mut self, name: impl Into<String>, initial: bool) -> Self {
let name = name.into();
self.records.insert(
name.clone(),
make_output_record(&name, RecordType::Bo, ScalarValue::Bool(initial)),
);
self
}
pub fn string_in(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
let name = name.into();
self.records.insert(
name.clone(),
make_scalar_record(
&name,
RecordType::StringIn,
ScalarValue::Str(initial.into()),
),
);
self
}
pub fn string_out(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
let name = name.into();
self.records.insert(
name.clone(),
make_output_record(
&name,
RecordType::StringOut,
ScalarValue::Str(initial.into()),
),
);
self
}
pub fn waveform(mut self, name: impl Into<String>, data: ScalarArrayValue) -> Self {
let name = name.into();
let ftvl = data.type_label().trim_end_matches("[]").to_string();
let nelm = data.len();
self.records.insert(
name.clone(),
RecordInstance {
name: name.clone(),
record_type: RecordType::Waveform,
common: DbCommonState::default(),
data: RecordData::Waveform {
nt: NtScalarArray::from_value(data),
inp: None,
ftvl,
nelm,
nord: nelm,
},
raw_fields: HashMap::new(),
},
);
self
}
pub fn db_file(mut self, path: impl AsRef<str>) -> Self {
match load_db(path.as_ref()) {
Ok(records) => {
self.records.extend(records);
}
Err(e) => {
tracing::error!("Failed to load db file '{}': {}", path.as_ref(), e);
}
}
self
}
pub fn db_string(mut self, content: &str) -> Self {
match parse_db(content) {
Ok(records) => {
self.records.extend(records);
}
Err(e) => {
tracing::error!("Failed to parse db string: {}", e);
}
}
self
}
pub fn on_put<F>(mut self, name: impl Into<String>, callback: F) -> Self
where
F: Fn(&str, &spvirit_codec::spvd_decode::DecodedValue) + Send + Sync + 'static,
{
self.on_put.insert(name.into(), Arc::new(callback));
self
}
pub fn scan<F>(mut self, name: impl Into<String>, period: Duration, callback: F) -> Self
where
F: Fn(&str) -> ScalarValue + Send + Sync + 'static,
{
self.scans.push((name.into(), period, Arc::new(callback)));
self
}
pub fn link<F>(mut self, output: impl Into<String>, inputs: &[&str], compute: F) -> Self
where
F: Fn(&[ScalarValue]) -> ScalarValue + Send + Sync + 'static,
{
self.links.push(LinkDef {
output: output.into(),
inputs: inputs.iter().map(|s| s.to_string()).collect(),
compute: Arc::new(compute),
});
self
}
pub fn port(mut self, port: u16) -> Self {
self.tcp_port = port;
self
}
pub fn udp_port(mut self, port: u16) -> Self {
self.udp_port = port;
self
}
pub fn listen_ip(mut self, ip: IpAddr) -> Self {
self.listen_ip = Some(ip);
self
}
pub fn advertise_ip(mut self, ip: IpAddr) -> Self {
self.advertise_ip = Some(ip);
self
}
pub fn compute_alarms(mut self, enabled: bool) -> Self {
self.compute_alarms = enabled;
self
}
pub fn beacon_period(mut self, secs: u64) -> Self {
self.beacon_period_secs = secs;
self
}
pub fn conn_timeout(mut self, timeout: Duration) -> Self {
self.conn_timeout = timeout;
self
}
pub fn pvlist_mode(mut self, mode: PvListMode) -> Self {
self.pvlist_mode = mode;
self
}
pub fn pvlist_max(mut self, max: usize) -> Self {
self.pvlist_max = max;
self
}
pub fn pvlist_allow_pattern(mut self, pattern: Regex) -> Self {
self.pvlist_allow_pattern = Some(pattern);
self
}
pub fn build(self) -> PvaServer {
let store = Arc::new(SimplePvStore::new(
self.records,
self.on_put,
self.links,
self.compute_alarms,
));
let mut config = PvaServerConfig::default();
config.tcp_port = self.tcp_port;
config.udp_port = self.udp_port;
config.compute_alarms = self.compute_alarms;
if let Some(ip) = self.listen_ip {
config.listen_ip = ip;
}
config.advertise_ip = self.advertise_ip;
config.beacon_period_secs = self.beacon_period_secs;
config.conn_timeout = self.conn_timeout;
config.pvlist_mode = self.pvlist_mode;
config.pvlist_max = self.pvlist_max;
config.pvlist_allow_pattern = self.pvlist_allow_pattern;
PvaServer {
store,
config,
scans: self.scans,
}
}
}
pub struct PvaServer {
store: Arc<SimplePvStore>,
config: PvaServerConfig,
scans: Vec<(String, Duration, ScanCallback)>,
}
impl PvaServer {
pub fn builder() -> PvaServerBuilder {
PvaServerBuilder::new()
}
pub fn store(&self) -> &Arc<SimplePvStore> {
&self.store
}
pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
let registry = Arc::new(MonitorRegistry::new());
self.store.set_registry(registry.clone()).await;
for (name, period, callback) in &self.scans {
let store = self.store.clone();
let name = name.clone();
let period = *period;
let callback = callback.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(period);
loop {
interval.tick().await;
let new_val = callback(&name);
store.set_value(&name, new_val).await;
}
});
}
let pv_count = self.store.pv_names().await.len();
info!(
"PvaServer starting: {} PVs on port {}",
pv_count, self.config.tcp_port
);
run_pva_server_with_registry(self.store, self.config, registry).await
}
}
fn make_scalar_record(name: &str, record_type: RecordType, value: ScalarValue) -> RecordInstance {
let nt = NtScalar::from_value(value);
let data = match record_type {
RecordType::Ai => RecordData::Ai {
nt,
inp: None,
siml: None,
siol: None,
simm: false,
},
RecordType::Bi => RecordData::Bi {
nt,
inp: None,
znam: "Off".to_string(),
onam: "On".to_string(),
siml: None,
siol: None,
simm: false,
},
RecordType::StringIn => RecordData::StringIn {
nt,
inp: None,
siml: None,
siol: None,
simm: false,
},
_ => panic!("make_scalar_record: unsupported type {record_type:?}"),
};
RecordInstance {
name: name.to_string(),
record_type,
common: DbCommonState::default(),
data,
raw_fields: HashMap::new(),
}
}
fn make_output_record(name: &str, record_type: RecordType, value: ScalarValue) -> RecordInstance {
let nt = NtScalar::from_value(value);
let data = match record_type {
RecordType::Ao => RecordData::Ao {
nt,
out: None,
dol: None,
omsl: OutputMode::Supervisory,
drvl: None,
drvh: None,
oroc: None,
siml: None,
siol: None,
simm: false,
},
RecordType::Bo => RecordData::Bo {
nt,
out: None,
dol: None,
omsl: OutputMode::Supervisory,
znam: "Off".to_string(),
onam: "On".to_string(),
siml: None,
siol: None,
simm: false,
},
RecordType::StringOut => RecordData::StringOut {
nt,
out: None,
dol: None,
omsl: OutputMode::Supervisory,
siml: None,
siol: None,
simm: false,
},
_ => panic!("make_output_record: unsupported type {record_type:?}"),
};
RecordInstance {
name: name.to_string(),
record_type,
common: DbCommonState::default(),
data,
raw_fields: HashMap::new(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_creates_records() {
let server = PvaServer::builder()
.ai("T:AI", 1.0)
.ao("T:AO", 2.0)
.bi("T:BI", true)
.bo("T:BO", false)
.string_in("T:SI", "hello")
.string_out("T:SO", "world")
.build();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let names = rt.block_on(server.store.pv_names());
assert_eq!(names.len(), 6);
}
#[test]
fn builder_defaults() {
let server = PvaServer::builder().build();
assert_eq!(server.config.tcp_port, 5075);
assert_eq!(server.config.udp_port, 5076);
assert!(!server.config.compute_alarms);
}
#[test]
fn builder_port_override() {
let server = PvaServer::builder().port(9075).udp_port(9076).build();
assert_eq!(server.config.tcp_port, 9075);
assert_eq!(server.config.udp_port, 9076);
}
#[test]
fn builder_db_string() {
let db = r#"
record(ai, "TEST:VAL") {
field(VAL, "3.14")
}
"#;
let server = PvaServer::builder().db_string(db).build();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
assert!(rt.block_on(server.store.get_value("TEST:VAL")).is_some());
}
#[test]
fn builder_waveform() {
let data = ScalarArrayValue::F64(vec![1.0, 2.0, 3.0]);
let server = PvaServer::builder().waveform("T:WF", data).build();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let names = rt.block_on(server.store.pv_names());
assert!(names.contains(&"T:WF".to_string()));
}
#[test]
fn builder_scan_callback() {
let server = PvaServer::builder()
.ai("SCAN:V", 0.0)
.scan("SCAN:V", Duration::from_secs(1), |_name| {
ScalarValue::F64(42.0)
})
.build();
assert_eq!(server.scans.len(), 1);
}
#[test]
fn builder_on_put_callback() {
let server = PvaServer::builder()
.ao("PUT:V", 0.0)
.on_put("PUT:V", |_name, _val| {})
.build();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
assert!(rt.block_on(server.store.get_value("PUT:V")).is_some());
}
#[test]
fn store_runtime_get_set() {
let server = PvaServer::builder().ao("RT:V", 0.0).build();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let store = server.store().clone();
rt.block_on(async {
assert_eq!(store.get_value("RT:V").await, Some(ScalarValue::F64(0.0)));
store.set_value("RT:V", ScalarValue::F64(99.0)).await;
assert_eq!(store.get_value("RT:V").await, Some(ScalarValue::F64(99.0)));
});
}
#[test]
fn link_propagates_on_set_value() {
let server = PvaServer::builder()
.ao("INPUT:A", 1.0)
.ao("INPUT:B", 2.0)
.ai("CALC:SUM", 0.0)
.link("CALC:SUM", &["INPUT:A", "INPUT:B"], |values| {
let a = match &values[0] {
ScalarValue::F64(v) => *v,
_ => 0.0,
};
let b = match &values[1] {
ScalarValue::F64(v) => *v,
_ => 0.0,
};
ScalarValue::F64(a + b)
})
.build();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let store = server.store().clone();
rt.block_on(async {
store.set_value("INPUT:A", ScalarValue::F64(10.0)).await;
assert_eq!(
store.get_value("CALC:SUM").await,
Some(ScalarValue::F64(12.0))
);
store.set_value("INPUT:B", ScalarValue::F64(5.0)).await;
assert_eq!(
store.get_value("CALC:SUM").await,
Some(ScalarValue::F64(15.0))
);
});
}
}