Skip to main content

spvirit_server/
pva_server.rs

1//! High-level PVAccess server — builder pattern for typed records.
2//!
3//! # Example
4//!
5//! ```rust,ignore
6//! use spvirit_server::PvaServer;
7//!
8//! let server = PvaServer::builder()
9//!     .ai("SIM:TEMPERATURE", 22.5)
10//!     .ao("SIM:SETPOINT", 25.0)
11//!     .bo("SIM:ENABLE", false)
12//!     .build();
13//!
14//! server.run().await?;
15//! ```
16
17use std::collections::HashMap;
18use std::net::IpAddr;
19use std::sync::Arc;
20use std::time::Duration;
21
22use regex::Regex;
23use tracing::info;
24
25use spvirit_types::{NtScalar, NtScalarArray, ScalarArrayValue, ScalarValue};
26
27use crate::db::{load_db, parse_db};
28use crate::handler::PvListMode;
29use crate::monitor::MonitorRegistry;
30use crate::server::{run_pva_server_with_registry, PvaServerConfig};
31use crate::simple_store::{LinkDef, OnPutCallback, ScanCallback, SimplePvStore};
32use crate::types::{
33    DbCommonState, OutputMode, RecordData, RecordInstance, RecordType,
34};
35
36// ─── PvaServerBuilder ────────────────────────────────────────────────────
37
38/// Builder for [`PvaServer`].
39///
40/// ```rust,ignore
41/// let server = PvaServer::builder()
42///     .ai("TEMP:READBACK", 22.5)
43///     .ao("TEMP:SETPOINT", 25.0)
44///     .bo("HEATER:ON", false)
45///     .port(5075)
46///     .build();
47/// ```
48pub struct PvaServerBuilder {
49    records: HashMap<String, RecordInstance>,
50    on_put: HashMap<String, OnPutCallback>,
51    scans: Vec<(String, Duration, ScanCallback)>,
52    links: Vec<LinkDef>,
53    tcp_port: u16,
54    udp_port: u16,
55    listen_ip: Option<IpAddr>,
56    advertise_ip: Option<IpAddr>,
57    compute_alarms: bool,
58    beacon_period_secs: u64,
59    conn_timeout: Duration,
60    pvlist_mode: PvListMode,
61    pvlist_max: usize,
62    pvlist_allow_pattern: Option<Regex>,
63}
64
65impl PvaServerBuilder {
66    fn new() -> Self {
67        Self {
68            records: HashMap::new(),
69            on_put: HashMap::new(),
70            scans: Vec::new(),
71            links: Vec::new(),
72            tcp_port: 5075,
73            udp_port: 5076,
74            listen_ip: None,
75            advertise_ip: None,
76            compute_alarms: false,
77            beacon_period_secs: 15,
78            conn_timeout: Duration::from_secs(64000),
79            pvlist_mode: PvListMode::List,
80            pvlist_max: 1024,
81            pvlist_allow_pattern: None,
82        }
83    }
84
85    // ─── Typed record constructors ───────────────────────────────────
86
87    /// Add an `ai` (analog input, read-only) record.
88    pub fn ai(mut self, name: impl Into<String>, initial: f64) -> Self {
89        let name = name.into();
90        self.records.insert(
91            name.clone(),
92            make_scalar_record(&name, RecordType::Ai, ScalarValue::F64(initial)),
93        );
94        self
95    }
96
97    /// Add an `ao` (analog output, writable) record.
98    pub fn ao(mut self, name: impl Into<String>, initial: f64) -> Self {
99        let name = name.into();
100        self.records.insert(
101            name.clone(),
102            make_output_record(&name, RecordType::Ao, ScalarValue::F64(initial)),
103        );
104        self
105    }
106
107    /// Add a `bi` (binary input, read-only) record.
108    pub fn bi(mut self, name: impl Into<String>, initial: bool) -> Self {
109        let name = name.into();
110        self.records.insert(
111            name.clone(),
112            make_scalar_record(&name, RecordType::Bi, ScalarValue::Bool(initial)),
113        );
114        self
115    }
116
117    /// Add a `bo` (binary output, writable) record.
118    pub fn bo(mut self, name: impl Into<String>, initial: bool) -> Self {
119        let name = name.into();
120        self.records.insert(
121            name.clone(),
122            make_output_record(&name, RecordType::Bo, ScalarValue::Bool(initial)),
123        );
124        self
125    }
126
127    /// Add a `stringin` (string input, read-only) record.
128    pub fn string_in(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
129        let name = name.into();
130        self.records.insert(
131            name.clone(),
132            make_scalar_record(&name, RecordType::StringIn, ScalarValue::Str(initial.into())),
133        );
134        self
135    }
136
137    /// Add a `stringout` (string output, writable) record.
138    pub fn string_out(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
139        let name = name.into();
140        self.records.insert(
141            name.clone(),
142            make_output_record(&name, RecordType::StringOut, ScalarValue::Str(initial.into())),
143        );
144        self
145    }
146
147    /// Add a `waveform` record (array) with the given initial data.
148    pub fn waveform(
149        mut self,
150        name: impl Into<String>,
151        data: ScalarArrayValue,
152    ) -> Self {
153        let name = name.into();
154        let ftvl = data.type_label().trim_end_matches("[]").to_string();
155        let nelm = data.len();
156        self.records.insert(
157            name.clone(),
158            RecordInstance {
159                name: name.clone(),
160                record_type: RecordType::Waveform,
161                common: DbCommonState::default(),
162                data: RecordData::Waveform {
163                    nt: NtScalarArray::from_value(data),
164                    inp: None,
165                    ftvl,
166                    nelm,
167                    nord: nelm,
168                },
169                raw_fields: HashMap::new(),
170            },
171        );
172        self
173    }
174
175    // ─── .db file loading ────────────────────────────────────────────
176
177    /// Load records from an EPICS `.db` file.
178    pub fn db_file(mut self, path: impl AsRef<str>) -> Self {
179        match load_db(path.as_ref()) {
180            Ok(records) => {
181                self.records.extend(records);
182            }
183            Err(e) => {
184                tracing::error!("Failed to load db file '{}': {}", path.as_ref(), e);
185            }
186        }
187        self
188    }
189
190    /// Parse records from an EPICS `.db` string.
191    pub fn db_string(mut self, content: &str) -> Self {
192        match parse_db(content) {
193            Ok(records) => {
194                self.records.extend(records);
195            }
196            Err(e) => {
197                tracing::error!("Failed to parse db string: {}", e);
198            }
199        }
200        self
201    }
202
203    // ─── Callbacks ───────────────────────────────────────────────────
204
205    /// Register a callback invoked when a PUT is applied to the named PV.
206    pub fn on_put<F>(mut self, name: impl Into<String>, callback: F) -> Self
207    where
208        F: Fn(&str, &spvirit_codec::spvd_decode::DecodedValue) + Send + Sync + 'static,
209    {
210        self.on_put.insert(name.into(), Arc::new(callback));
211        self
212    }
213
214    /// Register a periodic scan callback that produces a new value for a PV.
215    pub fn scan<F>(
216        mut self,
217        name: impl Into<String>,
218        period: Duration,
219        callback: F,
220    ) -> Self
221    where
222        F: Fn(&str) -> ScalarValue + Send + Sync + 'static,
223    {
224        self.scans
225            .push((name.into(), period, Arc::new(callback)));
226        self
227    }
228
229    /// Link an output PV to one or more input PVs.
230    ///
231    /// Whenever any input PV changes (via `set_value`, protocol PUT, or
232    /// another link), the `compute` callback is invoked with the current
233    /// values of **all** inputs (in order) and the result is written to
234    /// the output PV.
235    ///
236    /// ```rust,ignore
237    /// .link("CALC:SUM", &["INPUT:A", "INPUT:B"], |values| {
238    ///     let a = values[0].as_f64().unwrap_or(0.0);
239    ///     let b = values[1].as_f64().unwrap_or(0.0);
240    ///     ScalarValue::F64(a + b)
241    /// })
242    /// ```
243    pub fn link<F>(
244        mut self,
245        output: impl Into<String>,
246        inputs: &[&str],
247        compute: F,
248    ) -> Self
249    where
250        F: Fn(&[ScalarValue]) -> ScalarValue + Send + Sync + 'static,
251    {
252        self.links.push(LinkDef {
253            output: output.into(),
254            inputs: inputs.iter().map(|s| s.to_string()).collect(),
255            compute: Arc::new(compute),
256        });
257        self
258    }
259
260    // ─── Configuration ───────────────────────────────────────────────
261
262    /// Set the TCP port (default 5075).
263    pub fn port(mut self, port: u16) -> Self {
264        self.tcp_port = port;
265        self
266    }
267
268    /// Set the UDP search port (default 5076).
269    pub fn udp_port(mut self, port: u16) -> Self {
270        self.udp_port = port;
271        self
272    }
273
274    /// Set the IP address to listen on.
275    pub fn listen_ip(mut self, ip: IpAddr) -> Self {
276        self.listen_ip = Some(ip);
277        self
278    }
279
280    /// Set the IP address to advertise in search responses.
281    pub fn advertise_ip(mut self, ip: IpAddr) -> Self {
282        self.advertise_ip = Some(ip);
283        self
284    }
285
286    /// Enable alarm computation from limits.
287    pub fn compute_alarms(mut self, enabled: bool) -> Self {
288        self.compute_alarms = enabled;
289        self
290    }
291
292    /// Set the beacon broadcast period in seconds (default 15).
293    pub fn beacon_period(mut self, secs: u64) -> Self {
294        self.beacon_period_secs = secs;
295        self
296    }
297
298    /// Set the idle connection timeout (default ~18 hours).
299    pub fn conn_timeout(mut self, timeout: Duration) -> Self {
300        self.conn_timeout = timeout;
301        self
302    }
303
304    /// Set the PV list mode (default [`PvListMode::List`]).
305    pub fn pvlist_mode(mut self, mode: PvListMode) -> Self {
306        self.pvlist_mode = mode;
307        self
308    }
309
310    /// Set the maximum number of PV names in pvlist responses (default 1024).
311    pub fn pvlist_max(mut self, max: usize) -> Self {
312        self.pvlist_max = max;
313        self
314    }
315
316    /// Set a regex filter for PV names exposed by pvlist.
317    pub fn pvlist_allow_pattern(mut self, pattern: Regex) -> Self {
318        self.pvlist_allow_pattern = Some(pattern);
319        self
320    }
321
322    /// Build the [`PvaServer`].
323    pub fn build(self) -> PvaServer {
324        let store = Arc::new(SimplePvStore::new(
325            self.records,
326            self.on_put,
327            self.links,
328            self.compute_alarms,
329        ));
330
331        let mut config = PvaServerConfig::default();
332        config.tcp_port = self.tcp_port;
333        config.udp_port = self.udp_port;
334        config.compute_alarms = self.compute_alarms;
335        if let Some(ip) = self.listen_ip {
336            config.listen_ip = ip;
337        }
338        config.advertise_ip = self.advertise_ip;
339        config.beacon_period_secs = self.beacon_period_secs;
340        config.conn_timeout = self.conn_timeout;
341        config.pvlist_mode = self.pvlist_mode;
342        config.pvlist_max = self.pvlist_max;
343        config.pvlist_allow_pattern = self.pvlist_allow_pattern;
344
345        PvaServer {
346            store,
347            config,
348            scans: self.scans,
349        }
350    }
351}
352
353// ─── PvaServer ───────────────────────────────────────────────────────────
354
355/// High-level PVAccess server.
356///
357/// Built via [`PvaServer::builder()`] with typed record constructors,
358/// `.db_file()` loading, `.on_put()` / `.scan()` callbacks, and a
359/// simple `.run()` to start serving.
360///
361/// ```rust,ignore
362/// let server = PvaServer::builder()
363///     .ai("SIM:TEMP", 22.5)
364///     .ao("SIM:SP", 25.0)
365///     .build();
366///
367/// // Read/write PVs from another task:
368/// let store = server.store();
369/// store.set_value("SIM:TEMP", ScalarValue::F64(23.1)).await;
370///
371/// server.run().await?;
372/// ```
373pub struct PvaServer {
374    store: Arc<SimplePvStore>,
375    config: PvaServerConfig,
376    scans: Vec<(String, Duration, ScanCallback)>,
377}
378
379impl PvaServer {
380    /// Create a builder for configuring a [`PvaServer`].
381    pub fn builder() -> PvaServerBuilder {
382        PvaServerBuilder::new()
383    }
384
385    /// Get a reference to the underlying store for runtime get/put.
386    pub fn store(&self) -> &Arc<SimplePvStore> {
387        &self.store
388    }
389
390    /// Start the PVA server (UDP search + TCP handler + beacon + scan tasks).
391    ///
392    /// This blocks until the server is shut down or an error occurs.
393    pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
394        // Create the monitor registry early so scan tasks can notify
395        // PVAccess monitor clients when values change.
396        let registry = Arc::new(MonitorRegistry::new());
397        self.store.set_registry(registry.clone()).await;
398
399        // Spawn scan tasks.
400        for (name, period, callback) in &self.scans {
401            let store = self.store.clone();
402            let name = name.clone();
403            let period = *period;
404            let callback = callback.clone();
405            tokio::spawn(async move {
406                let mut interval = tokio::time::interval(period);
407                loop {
408                    interval.tick().await;
409                    let new_val = callback(&name);
410                    store.set_value(&name, new_val).await;
411                }
412            });
413        }
414
415        let pv_count = self.store.pv_names().await.len();
416        info!(
417            "PvaServer starting: {} PVs on port {}",
418            pv_count, self.config.tcp_port
419        );
420
421        run_pva_server_with_registry(self.store, self.config, registry).await
422    }
423}
424
425// ─── Record construction helpers ─────────────────────────────────────────
426
427fn make_scalar_record(
428    name: &str,
429    record_type: RecordType,
430    value: ScalarValue,
431) -> RecordInstance {
432    let nt = NtScalar::from_value(value);
433    let data = match record_type {
434        RecordType::Ai => RecordData::Ai {
435            nt,
436            inp: None,
437            siml: None,
438            siol: None,
439            simm: false,
440        },
441        RecordType::Bi => RecordData::Bi {
442            nt,
443            inp: None,
444            znam: "Off".to_string(),
445            onam: "On".to_string(),
446            siml: None,
447            siol: None,
448            simm: false,
449        },
450        RecordType::StringIn => RecordData::StringIn {
451            nt,
452            inp: None,
453            siml: None,
454            siol: None,
455            simm: false,
456        },
457        _ => panic!("make_scalar_record: unsupported type {record_type:?}"),
458    };
459    RecordInstance {
460        name: name.to_string(),
461        record_type,
462        common: DbCommonState::default(),
463        data,
464        raw_fields: HashMap::new(),
465    }
466}
467
468fn make_output_record(
469    name: &str,
470    record_type: RecordType,
471    value: ScalarValue,
472) -> RecordInstance {
473    let nt = NtScalar::from_value(value);
474    let data = match record_type {
475        RecordType::Ao => RecordData::Ao {
476            nt,
477            out: None,
478            dol: None,
479            omsl: OutputMode::Supervisory,
480            drvl: None,
481            drvh: None,
482            oroc: None,
483            siml: None,
484            siol: None,
485            simm: false,
486        },
487        RecordType::Bo => RecordData::Bo {
488            nt,
489            out: None,
490            dol: None,
491            omsl: OutputMode::Supervisory,
492            znam: "Off".to_string(),
493            onam: "On".to_string(),
494            siml: None,
495            siol: None,
496            simm: false,
497        },
498        RecordType::StringOut => RecordData::StringOut {
499            nt,
500            out: None,
501            dol: None,
502            omsl: OutputMode::Supervisory,
503            siml: None,
504            siol: None,
505            simm: false,
506        },
507        _ => panic!("make_output_record: unsupported type {record_type:?}"),
508    };
509    RecordInstance {
510        name: name.to_string(),
511        record_type,
512        common: DbCommonState::default(),
513        data,
514        raw_fields: HashMap::new(),
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn builder_creates_records() {
524        let server = PvaServer::builder()
525            .ai("T:AI", 1.0)
526            .ao("T:AO", 2.0)
527            .bi("T:BI", true)
528            .bo("T:BO", false)
529            .string_in("T:SI", "hello")
530            .string_out("T:SO", "world")
531            .build();
532
533        let rt = tokio::runtime::Builder::new_current_thread()
534            .enable_all()
535            .build()
536            .unwrap();
537        let names = rt.block_on(server.store.pv_names());
538        assert_eq!(names.len(), 6);
539    }
540
541    #[test]
542    fn builder_defaults() {
543        let server = PvaServer::builder().build();
544        assert_eq!(server.config.tcp_port, 5075);
545        assert_eq!(server.config.udp_port, 5076);
546        assert!(!server.config.compute_alarms);
547    }
548
549    #[test]
550    fn builder_port_override() {
551        let server = PvaServer::builder().port(9075).udp_port(9076).build();
552        assert_eq!(server.config.tcp_port, 9075);
553        assert_eq!(server.config.udp_port, 9076);
554    }
555
556    #[test]
557    fn builder_db_string() {
558        let db = r#"
559            record(ai, "TEST:VAL") {
560                field(VAL, "3.14")
561            }
562        "#;
563        let server = PvaServer::builder().db_string(db).build();
564        let rt = tokio::runtime::Builder::new_current_thread()
565            .enable_all()
566            .build()
567            .unwrap();
568        assert!(rt.block_on(server.store.get_value("TEST:VAL")).is_some());
569    }
570
571    #[test]
572    fn builder_waveform() {
573        let data = ScalarArrayValue::F64(vec![1.0, 2.0, 3.0]);
574        let server = PvaServer::builder()
575            .waveform("T:WF", data)
576            .build();
577        let rt = tokio::runtime::Builder::new_current_thread()
578            .enable_all()
579            .build()
580            .unwrap();
581        let names = rt.block_on(server.store.pv_names());
582        assert!(names.contains(&"T:WF".to_string()));
583    }
584
585    #[test]
586    fn builder_scan_callback() {
587        let server = PvaServer::builder()
588            .ai("SCAN:V", 0.0)
589            .scan("SCAN:V", Duration::from_secs(1), |_name| {
590                ScalarValue::F64(42.0)
591            })
592            .build();
593        assert_eq!(server.scans.len(), 1);
594    }
595
596    #[test]
597    fn builder_on_put_callback() {
598        let server = PvaServer::builder()
599            .ao("PUT:V", 0.0)
600            .on_put("PUT:V", |_name, _val| {})
601            .build();
602        // on_put is stored in the SimplePvStore, not directly inspectable,
603        // but the server built without panic.
604        let rt = tokio::runtime::Builder::new_current_thread()
605            .enable_all()
606            .build()
607            .unwrap();
608        assert!(rt.block_on(server.store.get_value("PUT:V")).is_some());
609    }
610
611    #[test]
612    fn store_runtime_get_set() {
613        let server = PvaServer::builder()
614            .ao("RT:V", 0.0)
615            .build();
616        let rt = tokio::runtime::Builder::new_current_thread()
617            .enable_all()
618            .build()
619            .unwrap();
620        let store = server.store().clone();
621        rt.block_on(async {
622            assert_eq!(
623                store.get_value("RT:V").await,
624                Some(ScalarValue::F64(0.0))
625            );
626            store.set_value("RT:V", ScalarValue::F64(99.0)).await;
627            assert_eq!(
628                store.get_value("RT:V").await,
629                Some(ScalarValue::F64(99.0))
630            );
631        });
632    }
633
634    #[test]
635    fn link_propagates_on_set_value() {
636        let server = PvaServer::builder()
637            .ao("INPUT:A", 1.0)
638            .ao("INPUT:B", 2.0)
639            .ai("CALC:SUM", 0.0)
640            .link("CALC:SUM", &["INPUT:A", "INPUT:B"], |values| {
641                let a = match &values[0] { ScalarValue::F64(v) => *v, _ => 0.0 };
642                let b = match &values[1] { ScalarValue::F64(v) => *v, _ => 0.0 };
643                ScalarValue::F64(a + b)
644            })
645            .build();
646
647        let rt = tokio::runtime::Builder::new_current_thread()
648            .enable_all()
649            .build()
650            .unwrap();
651        let store = server.store().clone();
652        rt.block_on(async {
653            // Writing INPUT:A should recompute CALC:SUM = 10 + 2.
654            store.set_value("INPUT:A", ScalarValue::F64(10.0)).await;
655            assert_eq!(
656                store.get_value("CALC:SUM").await,
657                Some(ScalarValue::F64(12.0))
658            );
659
660            // Writing INPUT:B should recompute CALC:SUM = 10 + 5.
661            store.set_value("INPUT:B", ScalarValue::F64(5.0)).await;
662            assert_eq!(
663                store.get_value("CALC:SUM").await,
664                Some(ScalarValue::F64(15.0))
665            );
666        });
667    }
668}