Skip to main content

spvirit_client/
pva_client.rs

1//! High-level PVAccess client — one-liner get, put, monitor, info.
2//!
3//! # Example
4//!
5//! ```rust,ignore
6//! use spvirit_client::PvaClient;
7//!
8//! let client = PvaClient::builder().build();
9//! let result = client.pvget("MY:PV").await?;
10//! client.pvput("MY:PV", 42.0).await?;
11//! ```
12
13use std::net::SocketAddr;
14use std::ops::ControlFlow;
15use std::sync::atomic::{AtomicU32, Ordering};
16use std::time::Duration;
17
18use serde_json::Value;
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use tokio::net::tcp::OwnedWriteHalf;
21use tokio::task::JoinHandle;
22use tokio::time::{interval, Instant};
23
24use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand};
25use spvirit_codec::spvd_decode::{DecodedValue, PvdDecoder, StructureDesc};
26use spvirit_codec::spvd_encode::encode_pv_request;
27use spvirit_codec::spvirit_encode::{
28    encode_control_message, encode_get_field_request, encode_monitor_request, encode_put_request,
29};
30
31use crate::client::{ensure_status_ok, establish_channel, pvget as low_level_pvget, ChannelConn};
32use crate::put_encode::encode_put_payload;
33use crate::search::resolve_pv_server;
34use crate::transport::{read_packet, read_until};
35use crate::types::{PvGetError, PvGetResult, PvOptions};
36
37/// PVA protocol version used in headers.
38const PVA_VERSION: u8 = 2;
39/// QoS / subcommand flag: INIT.
40const QOS_INIT: u8 = 0x08;
41
42static NEXT_IOID: AtomicU32 = AtomicU32::new(1);
43fn alloc_ioid() -> u32 {
44    NEXT_IOID.fetch_add(1, Ordering::Relaxed)
45}
46
47// ─── PvaClientBuilder ────────────────────────────────────────────────────────
48
49/// Builder for [`PvaClient`].
50///
51/// ```rust,ignore
52/// let client = PvaClient::builder()
53///     .timeout(Duration::from_secs(10))
54///     .port(5075)
55///     .build();
56/// ```
57pub struct PvaClientBuilder {
58    udp_port: u16,
59    tcp_port: u16,
60    timeout: Duration,
61    no_broadcast: bool,
62    name_servers: Vec<SocketAddr>,
63    authnz_user: Option<String>,
64    authnz_host: Option<String>,
65}
66
67impl PvaClientBuilder {
68    fn new() -> Self {
69        Self {
70            udp_port: 5076,
71            tcp_port: 5075,
72            timeout: Duration::from_secs(5),
73            no_broadcast: false,
74            name_servers: Vec::new(),
75            authnz_user: None,
76            authnz_host: None,
77        }
78    }
79
80    /// Set the TCP port (default 5075).
81    pub fn port(mut self, port: u16) -> Self {
82        self.tcp_port = port;
83        self
84    }
85
86    /// Set the UDP search port (default 5076).
87    pub fn udp_port(mut self, port: u16) -> Self {
88        self.udp_port = port;
89        self
90    }
91
92    /// Set the operation timeout (default 5 s).
93    pub fn timeout(mut self, timeout: Duration) -> Self {
94        self.timeout = timeout;
95        self
96    }
97
98    /// Disable UDP broadcast search (use name servers only).
99    pub fn no_broadcast(mut self) -> Self {
100        self.no_broadcast = true;
101        self
102    }
103
104    /// Add a PVA name-server address for TCP search.
105    pub fn name_server(mut self, addr: SocketAddr) -> Self {
106        self.name_servers.push(addr);
107        self
108    }
109
110    /// Override the authentication user.
111    pub fn authnz_user(mut self, user: impl Into<String>) -> Self {
112        self.authnz_user = Some(user.into());
113        self
114    }
115
116    /// Override the authentication host.
117    pub fn authnz_host(mut self, host: impl Into<String>) -> Self {
118        self.authnz_host = Some(host.into());
119        self
120    }
121
122    /// Build the [`PvaClient`].
123    pub fn build(self) -> PvaClient {
124        PvaClient {
125            udp_port: self.udp_port,
126            tcp_port: self.tcp_port,
127            timeout: self.timeout,
128            no_broadcast: self.no_broadcast,
129            name_servers: self.name_servers,
130            authnz_user: self.authnz_user,
131            authnz_host: self.authnz_host,
132        }
133    }
134}
135
136// ─── PvaClient ───────────────────────────────────────────────────────────────
137
138/// High-level PVAccess client.
139///
140/// Provides `pvget`, `pvput`, `pvmonitor`, and `pvinfo` methods that hide
141/// the underlying protocol handshake.
142///
143/// ```rust,ignore
144/// let client = PvaClient::builder().build();
145/// let val = client.pvget("MY:PV").await?;
146/// ```
147#[derive(Clone, Debug)]
148pub struct PvaClient {
149    udp_port: u16,
150    tcp_port: u16,
151    timeout: Duration,
152    no_broadcast: bool,
153    name_servers: Vec<SocketAddr>,
154    authnz_user: Option<String>,
155    authnz_host: Option<String>,
156}
157
158impl PvaClient {
159    /// Create a builder for configuring a [`PvaClient`].
160    pub fn builder() -> PvaClientBuilder {
161        PvaClientBuilder::new()
162    }
163
164    /// Build [`PvOptions`] for a given PV name, inheriting client-level settings.
165    fn opts(&self, pv_name: &str) -> PvOptions {
166        let mut o = PvOptions::new(pv_name.to_string());
167        o.udp_port = self.udp_port;
168        o.tcp_port = self.tcp_port;
169        o.timeout = self.timeout;
170        o.no_broadcast = self.no_broadcast;
171        o.name_servers.clone_from(&self.name_servers);
172        o.authnz_user.clone_from(&self.authnz_user);
173        o.authnz_host.clone_from(&self.authnz_host);
174        o
175    }
176
177    /// Resolve a PV server and establish a channel, returning the raw connection.
178    async fn open_channel(&self, pv_name: &str) -> Result<ChannelConn, PvGetError> {
179        let opts = self.opts(pv_name);
180        let target = resolve_pv_server(&opts).await?;
181        establish_channel(target, &opts).await
182    }
183
184    // ─── pvget ───────────────────────────────────────────────────────────
185
186    /// Fetch the current value of a PV.
187    pub async fn pvget(&self, pv_name: &str) -> Result<PvGetResult, PvGetError> {
188        let opts = self.opts(pv_name);
189        low_level_pvget(&opts).await
190    }
191
192    // ─── pvput ───────────────────────────────────────────────────────────
193
194    /// Write a value to a PV.
195    ///
196    /// Accepts anything convertible to `serde_json::Value`:
197    /// ```rust,ignore
198    /// client.pvput("MY:PV", 42.0).await?;
199    /// client.pvput("MY:PV", "hello").await?;
200    /// client.pvput("MY:PV", serde_json::json!({"value": 1.5})).await?;
201    /// ```
202    pub async fn pvput(
203        &self,
204        pv_name: &str,
205        value: impl Into<Value>,
206    ) -> Result<(), PvGetError> {
207        let json_val = value.into();
208        let ChannelConn {
209            mut stream,
210            sid,
211            version: _,
212            is_be,
213        } = self.open_channel(pv_name).await?;
214
215        let ioid = alloc_ioid();
216
217        // PUT INIT — send pvRequest for "field(value)"
218        let pv_request = encode_pv_request(&["value"], is_be);
219        let init = encode_put_request(sid, ioid, QOS_INIT, &pv_request, PVA_VERSION, is_be);
220        stream.write_all(&init).await?;
221
222        // Read INIT response — extract introspection
223        let init_bytes = read_until(&mut stream, self.timeout, |cmd| {
224            matches!(cmd, PvaPacketCommand::Op(op) if op.command == 11 && (op.subcmd & 0x08) != 0)
225        })
226        .await?;
227
228        let desc = decode_init_introspection(&init_bytes, "PUT")?;
229
230        // Encode and send the value
231        let payload = encode_put_payload(&desc, &json_val, is_be)
232            .map_err(|e| PvGetError::Protocol(format!("put encode: {e}")))?;
233        let req = encode_put_request(sid, ioid, 0x00, &payload, PVA_VERSION, is_be);
234        stream.write_all(&req).await?;
235
236        // Read PUT response — verify status
237        let resp_bytes = read_until(&mut stream, self.timeout, |cmd| {
238            matches!(cmd, PvaPacketCommand::Op(op) if op.command == 11 && op.subcmd == 0x00)
239        })
240        .await?;
241        ensure_status_ok(&resp_bytes, is_be, "PUT")?;
242
243        Ok(())
244    }
245
246    // ─── open_put_channel ────────────────────────────────────────────────
247
248    /// Open a persistent channel for high-rate PUT streaming.
249    ///
250    /// Resolves the PV, establishes a channel, and completes the PUT INIT
251    /// handshake. The returned [`PvaChannel`] is ready for immediate
252    /// [`put`](PvaChannel::put) calls.
253    pub async fn open_put_channel(&self, pv_name: &str) -> Result<PvaChannel, PvGetError> {
254        let ChannelConn {
255            mut stream,
256            sid,
257            version,
258            is_be,
259        } = self.open_channel(pv_name).await?;
260
261        let ioid = alloc_ioid();
262
263        // PUT INIT
264        let pv_request = encode_pv_request(&["value"], is_be);
265        let init = encode_put_request(sid, ioid, QOS_INIT, &pv_request, PVA_VERSION, is_be);
266        stream.write_all(&init).await?;
267
268        let init_bytes = read_until(&mut stream, self.timeout, |cmd| {
269            matches!(cmd, PvaPacketCommand::Op(op) if op.command == 11 && (op.subcmd & 0x08) != 0)
270        })
271        .await?;
272
273        let desc = decode_init_introspection(&init_bytes, "PUT")?;
274
275        // Split stream; background reader logs PUT errors
276        let (mut reader, writer) = stream.into_split();
277        let reader_is_be = is_be;
278        let reader_handle = tokio::spawn(async move {
279            loop {
280                let mut header = [0u8; 8];
281                if reader.read_exact(&mut header).await.is_err() {
282                    break;
283                }
284                let hdr = spvirit_codec::epics_decode::PvaHeader::new(&header);
285                let len = if hdr.flags.is_control {
286                    0usize
287                } else {
288                    hdr.payload_length as usize
289                };
290                let mut payload = vec![0u8; len];
291                if len > 0 && reader.read_exact(&mut payload).await.is_err() {
292                    break;
293                }
294                if hdr.command == 11 && !hdr.flags.is_control && len >= 5 {
295                    if let Some(st) =
296                        spvirit_codec::epics_decode::decode_status(&payload[5..], reader_is_be).0
297                    {
298                        if st.code != 0 {
299                            let msg = st.message.unwrap_or_else(|| format!("code={}", st.code));
300                            eprintln!("PvaChannel put error: {msg}");
301                        }
302                    }
303                }
304            }
305        });
306
307        Ok(PvaChannel {
308            writer,
309            sid,
310            ioid,
311            version,
312            is_be,
313            put_desc: desc,
314            echo_token: 1,
315            last_echo: Instant::now(),
316            _reader_handle: reader_handle,
317        })
318    }
319
320    // ─── pvmonitor ───────────────────────────────────────────────────────
321
322    /// Subscribe to a PV and receive live updates via a callback.
323    ///
324    /// The callback returns [`ControlFlow::Continue`] to keep listening or
325    /// [`ControlFlow::Break`] to stop the subscription.
326    ///
327    /// ```rust,ignore
328    /// use std::ops::ControlFlow;
329    ///
330    /// client.pvmonitor("MY:PV", |value| {
331    ///     println!("{value:?}");
332    ///     ControlFlow::Continue(())
333    /// }).await?;
334    /// ```
335    pub async fn pvmonitor<F>(
336        &self,
337        pv_name: &str,
338        mut callback: F,
339    ) -> Result<(), PvGetError>
340    where
341        F: FnMut(&DecodedValue) -> ControlFlow<()>,
342    {
343        let ChannelConn {
344            mut stream,
345            sid,
346            version: _,
347            is_be,
348        } = self.open_channel(pv_name).await?;
349
350        let ioid = alloc_ioid();
351        let decoder = PvdDecoder::new(is_be);
352
353        // MONITOR INIT — request value + alarm + timeStamp
354        let pv_request = encode_pv_request(&["value", "alarm", "timeStamp"], is_be);
355        let init = encode_monitor_request(sid, ioid, QOS_INIT, &pv_request, PVA_VERSION, is_be);
356        stream.write_all(&init).await?;
357
358        // Read INIT response — extract introspection
359        let init_bytes = read_until(&mut stream, self.timeout, |cmd| {
360            matches!(cmd, PvaPacketCommand::Op(op) if op.command == 13 && (op.subcmd & 0x08) != 0)
361        })
362        .await?;
363
364        let field_desc = decode_init_introspection(&init_bytes, "MONITOR")?;
365
366        // Start subscription (non-pipeline: START 0x04 + GET 0x40 = 0x44)
367        let start = encode_monitor_request(sid, ioid, 0x44, &[], PVA_VERSION, is_be);
368        stream.write_all(&start).await?;
369
370        // Event loop — with echo keepalive and timeout resilience
371        let mut echo_interval = interval(Duration::from_secs(10));
372        let mut echo_token: u32 = 1;
373
374        loop {
375            tokio::select! {
376                _ = echo_interval.tick() => {
377                    let msg = encode_control_message(false, is_be, PVA_VERSION, 3, echo_token);
378                    echo_token = echo_token.wrapping_add(1);
379                    let _ = stream.write_all(&msg).await;
380                }
381                res = read_packet(&mut stream, self.timeout) => {
382                    let bytes = match res {
383                        Ok(b) => b,
384                        Err(PvGetError::Timeout(_)) => continue,
385                        Err(e) => return Err(e),
386                    };
387                    let mut pkt = PvaPacket::new(&bytes);
388                    if let Some(PvaPacketCommand::Op(op)) = pkt.decode_payload() {
389                        if op.command == 13 && op.ioid == ioid && op.subcmd == 0x00 {
390                            let payload = &bytes[8..]; // skip header
391                            let pos = 5; // skip ioid(4) + subcmd(1)
392                            if let Some((decoded, _)) =
393                                decoder.decode_structure_with_bitset(&payload[pos..], &field_desc)
394                            {
395                                if callback(&decoded).is_break() {
396                                    return Ok(());
397                                }
398                            }
399                        }
400                    }
401                }
402            }
403        }
404    }
405
406    // ─── pvinfo ──────────────────────────────────────────────────────────
407
408    /// Retrieve the field/structure description (introspection) for a PV.
409    pub async fn pvinfo(&self, pv_name: &str) -> Result<StructureDesc, PvGetError> {
410        let ChannelConn {
411            mut stream,
412            sid,
413            version: _,
414            is_be,
415        } = self.open_channel(pv_name).await?;
416
417        let ioid = alloc_ioid();
418        let msg = encode_get_field_request(sid, ioid, None, PVA_VERSION, is_be);
419        stream.write_all(&msg).await?;
420
421        let resp_bytes = read_until(&mut stream, self.timeout, |cmd| {
422            matches!(cmd, PvaPacketCommand::Op(op) if op.command == 17)
423        })
424        .await?;
425
426        decode_init_introspection(&resp_bytes, "GET_FIELD")
427    }
428
429    // ─── pvlist ──────────────────────────────────────────────────────────
430
431    /// List PV names served by a specific server (via `__pvlist` GET).
432    pub async fn pvlist(&self, server_addr: SocketAddr) -> Result<Vec<String>, PvGetError> {
433        let opts = self.opts("__pvlist");
434        crate::pvlist::pvlist(&opts, server_addr).await
435    }
436
437    /// List PV names with automatic fallback through all strategies.
438    ///
439    /// Tries: `__pvlist` → GET_FIELD (opt-in) → Server RPC → Server GET.
440    pub async fn pvlist_with_fallback(
441        &self,
442        server_addr: SocketAddr,
443    ) -> Result<(Vec<String>, crate::pvlist::PvListSource), PvGetError> {
444        let opts = self.opts("__pvlist");
445        crate::pvlist::pvlist_with_fallback(&opts, server_addr).await
446    }
447}
448
449// ─── PvaChannel ──────────────────────────────────────────────────────────────
450
451/// A persistent PVA channel for high-rate streaming PUT operations.
452///
453/// Created via [`PvaClient::open_put_channel`], this keeps the TCP connection
454/// open and reuses the PUT introspection for repeated writes without
455/// per-operation handshake overhead.
456///
457/// # Example
458///
459/// ```rust,ignore
460/// let client = PvaClient::builder().build();
461/// let mut channel = client.open_put_channel("MY:PV").await?;
462/// for value in 0..100 {
463///     channel.put(value as f64).await?;
464/// }
465/// ```
466pub struct PvaChannel {
467    writer: OwnedWriteHalf,
468    sid: u32,
469    ioid: u32,
470    version: u8,
471    is_be: bool,
472    put_desc: StructureDesc,
473    echo_token: u32,
474    last_echo: Instant,
475    _reader_handle: JoinHandle<()>,
476}
477
478impl PvaChannel {
479    /// Write a value over the persistent channel.
480    ///
481    /// Automatically sends echo keepalive pings when more than 10 seconds
482    /// have elapsed since the last one.
483    pub async fn put(&mut self, value: impl Into<Value>) -> Result<(), PvGetError> {
484        // Echo keepalive
485        if self.last_echo.elapsed() >= Duration::from_secs(10) {
486            let msg = encode_control_message(false, self.is_be, self.version, 3, self.echo_token);
487            self.echo_token = self.echo_token.wrapping_add(1);
488            let _ = self.writer.write_all(&msg).await;
489            self.last_echo = Instant::now();
490        }
491
492        let json_val = value.into();
493        let payload = encode_put_payload(&self.put_desc, &json_val, self.is_be)
494            .map_err(|e| PvGetError::Protocol(format!("put encode: {e}")))?;
495        let req = encode_put_request(self.sid, self.ioid, 0x00, &payload, self.version, self.is_be);
496        self.writer.write_all(&req).await?;
497        Ok(())
498    }
499
500    /// Returns the PUT introspection for this channel.
501    pub fn introspection(&self) -> &StructureDesc {
502        &self.put_desc
503    }
504}
505
506impl Drop for PvaChannel {
507    fn drop(&mut self) {
508        self._reader_handle.abort();
509    }
510}
511
512// ─── Standalone convenience functions ────────────────────────────────────────
513
514/// Write a value to a PV (one-shot).
515///
516/// ```rust,ignore
517/// use spvirit_client::{pvput, PvOptions};
518///
519/// pvput(&PvOptions::new("MY:PV".into()), 42.0).await?;
520/// ```
521pub async fn pvput(opts: &PvOptions, value: impl Into<Value>) -> Result<(), PvGetError> {
522    let client = client_from_opts(opts);
523    client.pvput(&opts.pv_name, value).await
524}
525
526/// Subscribe to a PV and receive live updates (one-shot).
527///
528/// The callback returns [`ControlFlow::Continue`] to keep listening or
529/// [`ControlFlow::Break`] to stop.
530pub async fn pvmonitor<F>(opts: &PvOptions, callback: F) -> Result<(), PvGetError>
531where
532    F: FnMut(&DecodedValue) -> ControlFlow<()>,
533{
534    let client = client_from_opts(opts);
535    client.pvmonitor(&opts.pv_name, callback).await
536}
537
538/// Retrieve the field/structure description for a PV (one-shot).
539pub async fn pvinfo(opts: &PvOptions) -> Result<StructureDesc, PvGetError> {
540    let client = client_from_opts(opts);
541    client.pvinfo(&opts.pv_name).await
542}
543
544// ─── Internal helpers ────────────────────────────────────────────────────────
545
546/// Build a PvaClient inheriting configuration from PvOptions.
547pub fn client_from_opts(opts: &PvOptions) -> PvaClient {
548    let mut b = PvaClient::builder()
549        .port(opts.tcp_port)
550        .udp_port(opts.udp_port)
551        .timeout(opts.timeout);
552    if opts.no_broadcast {
553        b = b.no_broadcast();
554    }
555    for ns in &opts.name_servers {
556        b = b.name_server(*ns);
557    }
558    if let Some(ref u) = opts.authnz_user {
559        b = b.authnz_user(u.clone());
560    }
561    if let Some(ref h) = opts.authnz_host {
562        b = b.authnz_host(h.clone());
563    }
564    b.build()
565}
566
567/// Decode an INIT response to extract the introspection StructureDesc.
568fn decode_init_introspection(raw: &[u8], label: &str) -> Result<StructureDesc, PvGetError> {
569    let mut pkt = PvaPacket::new(raw);
570    let cmd = pkt.decode_payload().ok_or_else(|| {
571        PvGetError::Decode(format!("{label} init response decode failed"))
572    })?;
573
574    match cmd {
575        PvaPacketCommand::Op(op) => {
576            if let Some(ref st) = op.status {
577                if st.is_error() {
578                    let msg = st
579                        .message
580                        .clone()
581                        .unwrap_or_else(|| format!("code={}", st.code));
582                    return Err(PvGetError::Protocol(format!("{label} init error: {msg}")));
583                }
584            }
585            op.introspection
586                .ok_or_else(|| PvGetError::Decode(format!("missing {label} introspection")))
587        }
588        _ => Err(PvGetError::Protocol(format!(
589            "unexpected {label} init response"
590        ))),
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597
598    #[test]
599    fn builder_defaults() {
600        let c = PvaClient::builder().build();
601        assert_eq!(c.tcp_port, 5075);
602        assert_eq!(c.udp_port, 5076);
603        assert_eq!(c.timeout, Duration::from_secs(5));
604        assert!(!c.no_broadcast);
605        assert!(c.name_servers.is_empty());
606    }
607
608    #[test]
609    fn builder_overrides() {
610        let c = PvaClient::builder()
611            .port(9075)
612            .udp_port(9076)
613            .timeout(Duration::from_secs(10))
614            .no_broadcast()
615            .name_server("127.0.0.1:5075".parse().unwrap())
616            .authnz_user("testuser")
617            .authnz_host("testhost")
618            .build();
619        assert_eq!(c.tcp_port, 9075);
620        assert_eq!(c.udp_port, 9076);
621        assert_eq!(c.timeout, Duration::from_secs(10));
622        assert!(c.no_broadcast);
623        assert_eq!(c.name_servers.len(), 1);
624        assert_eq!(c.authnz_user.as_deref(), Some("testuser"));
625        assert_eq!(c.authnz_host.as_deref(), Some("testhost"));
626    }
627
628    #[test]
629    fn opts_inherits_client_config() {
630        let c = PvaClient::builder()
631            .port(9075)
632            .udp_port(9076)
633            .timeout(Duration::from_secs(10))
634            .no_broadcast()
635            .build();
636        let o = c.opts("TEST:PV");
637        assert_eq!(o.pv_name, "TEST:PV");
638        assert_eq!(o.tcp_port, 9075);
639        assert_eq!(o.udp_port, 9076);
640        assert_eq!(o.timeout, Duration::from_secs(10));
641        assert!(o.no_broadcast);
642    }
643
644    #[test]
645    fn client_from_opts_roundtrip() {
646        let mut opts = PvOptions::new("X:Y".into());
647        opts.tcp_port = 8075;
648        opts.udp_port = 8076;
649        opts.timeout = Duration::from_secs(3);
650        opts.no_broadcast = true;
651        let c = client_from_opts(&opts);
652        assert_eq!(c.tcp_port, 8075);
653        assert_eq!(c.udp_port, 8076);
654        assert!(c.no_broadcast);
655    }
656
657    #[test]
658    fn pv_get_options_alias_works() {
659        // PvGetOptions is a type alias for PvOptions — verify it compiles and works
660        let opts: crate::types::PvGetOptions = PvOptions::new("ALIAS:TEST".into());
661        assert_eq!(opts.pv_name, "ALIAS:TEST");
662    }
663}