Skip to main content

ad_plugins_rs/
time_series.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use asyn_rs::param::ParamType;
5use asyn_rs::port::{PortDriver, PortDriverBase, PortFlags};
6use asyn_rs::runtime::config::RuntimeConfig;
7use asyn_rs::runtime::port::{PortRuntimeHandle, create_port_runtime};
8use asyn_rs::user::AsynUser;
9use parking_lot::Mutex;
10
11// ===== Stats-specific channel definitions =====
12
13/// Number of stats channels in the time series.
14pub const NUM_STATS_TS_CHANNELS: usize = 23;
15
16/// Channel names for the 23 NDStats time series channels.
17pub const STATS_TS_CHANNEL_NAMES: [&str; NUM_STATS_TS_CHANNELS] = [
18    "TSMinValue",
19    "TSMinX",
20    "TSMinY",
21    "TSMaxValue",
22    "TSMaxX",
23    "TSMaxY",
24    "TSMeanValue",
25    "TSSigma",
26    "TSTotal",
27    "TSNet",
28    "TSCentroidTotal",
29    "TSCentroidX",
30    "TSCentroidY",
31    "TSSigmaX",
32    "TSSigmaY",
33    "TSSigmaXY",
34    "TSSkewX",
35    "TSSkewY",
36    "TSKurtosisX",
37    "TSKurtosisY",
38    "TSEccentricity",
39    "TSOrientation",
40    "TSTimestamp",
41];
42
43// ===== Generic time series data =====
44
45/// Shared data pushed from a plugin processor to a TS port driver.
46/// `values` length must match the channel count configured on the driver.
47pub struct TimeSeriesData {
48    pub values: Vec<f64>,
49}
50
51/// Sender from plugin -> TS port.
52pub type TimeSeriesSender = tokio::sync::mpsc::Sender<TimeSeriesData>;
53/// Receiver in TS port background thread.
54pub type TimeSeriesReceiver = tokio::sync::mpsc::Receiver<TimeSeriesData>;
55
56/// Registry for pending TS receivers, keyed by upstream plugin port name.
57/// NDStatsConfigure etc. store receivers here; NDTimeSeriesConfigure picks them up.
58pub struct TsReceiverRegistry {
59    inner: std::sync::Mutex<std::collections::HashMap<String, (TimeSeriesReceiver, Vec<String>)>>,
60}
61
62impl TsReceiverRegistry {
63    pub fn new() -> Self {
64        Self {
65            inner: std::sync::Mutex::new(std::collections::HashMap::new()),
66        }
67    }
68
69    /// Store a receiver and its channel names for a given upstream port.
70    pub fn store(
71        &self,
72        upstream_port: &str,
73        receiver: TimeSeriesReceiver,
74        channel_names: Vec<String>,
75    ) {
76        let mut map = self.inner.lock().unwrap();
77        map.insert(upstream_port.to_string(), (receiver, channel_names));
78    }
79
80    /// Take a receiver for the given upstream port (returns None if not found or already taken).
81    pub fn take(&self, upstream_port: &str) -> Option<(TimeSeriesReceiver, Vec<String>)> {
82        let mut map = self.inner.lock().unwrap();
83        map.remove(upstream_port)
84    }
85}
86
87impl Default for TsReceiverRegistry {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93/// Accumulation mode for time series.
94///
95/// Mirrors C++ `TSAcquireMode`: `OneShot` == `TSAcquireModeFixed` (acquisition
96/// stops once `num_points` output points are collected); `RingBuffer` ==
97/// `TSAcquireModeCircular` (the buffer wraps and acquisition continues).
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum TimeSeriesMode {
100    /// C++ `TSAcquireModeFixed`: stop at `num_points`.
101    OneShot,
102    /// C++ `TSAcquireModeCircular`: wrap and keep acquiring.
103    RingBuffer,
104}
105
106/// Time-series accumulator: stores scalar/1D values from successive arrays.
107pub struct TimeSeries {
108    pub num_points: usize,
109    pub mode: TimeSeriesMode,
110    buffer: Vec<f64>,
111    write_pos: usize,
112    count: usize,
113}
114
115impl TimeSeries {
116    pub fn new(num_points: usize, mode: TimeSeriesMode) -> Self {
117        Self {
118            num_points,
119            mode,
120            buffer: vec![0.0; num_points],
121            write_pos: 0,
122            count: 0,
123        }
124    }
125
126    /// Add a value (e.g., mean of an array) to the time series.
127    pub fn add_value(&mut self, value: f64) {
128        match self.mode {
129            TimeSeriesMode::OneShot => {
130                if self.write_pos < self.num_points {
131                    self.buffer[self.write_pos] = value;
132                    self.write_pos += 1;
133                    self.count = self.write_pos;
134                }
135            }
136            TimeSeriesMode::RingBuffer => {
137                self.buffer[self.write_pos % self.num_points] = value;
138                self.write_pos += 1;
139                self.count = self.count.max(self.write_pos.min(self.num_points));
140            }
141        }
142    }
143
144    /// Get the accumulated values in order.
145    pub fn values(&self) -> Vec<f64> {
146        match self.mode {
147            TimeSeriesMode::OneShot => self.buffer[..self.count].to_vec(),
148            TimeSeriesMode::RingBuffer => {
149                if self.write_pos <= self.num_points {
150                    self.buffer[..self.count].to_vec()
151                } else {
152                    let start = self.write_pos % self.num_points;
153                    let mut result = Vec::with_capacity(self.num_points);
154                    result.extend_from_slice(&self.buffer[start..]);
155                    result.extend_from_slice(&self.buffer[..start]);
156                    result
157                }
158            }
159        }
160    }
161
162    pub fn count(&self) -> usize {
163        self.count
164    }
165
166    pub fn reset(&mut self) {
167        self.buffer.fill(0.0);
168        self.write_pos = 0;
169        self.count = 0;
170    }
171
172    /// Resize the buffer. Resets all data.
173    pub fn resize(&mut self, num_points: usize) {
174        self.num_points = num_points;
175        self.buffer = vec![0.0; num_points];
176        self.write_pos = 0;
177        self.count = 0;
178    }
179
180    /// Change the accumulation mode. Resets all data.
181    pub fn set_mode(&mut self, mode: TimeSeriesMode) {
182        self.mode = mode;
183        self.reset();
184    }
185}
186
187// ===== Time Series Port Driver =====
188
189/// Param indices for the TS port.
190pub struct TSParams {
191    pub ts_acquire: usize,
192    pub ts_read: usize,
193    pub ts_num_points: usize,
194    pub ts_current_point: usize,
195    pub ts_time_per_point: usize,
196    pub ts_averaging_time: usize,
197    pub ts_num_average: usize,
198    pub ts_elapsed_time: usize,
199    pub ts_acquire_mode: usize,
200    pub ts_time_axis: usize,
201    /// Per-channel waveform param indices (length = num_channels).
202    pub ts_channels: Vec<usize>,
203    /// Channel names (kept for registry building).
204    pub channel_names: Vec<String>,
205    /// Generic time series waveform (for NDTimeSeries.template).
206    pub ts_time_series: usize,
207    /// Timestamp waveform (for NDTimeSeries.template).
208    pub ts_timestamp: usize,
209}
210
211/// Shared state between the data ingestion thread and the TS port driver.
212pub struct SharedTsState {
213    pub buffers: Vec<TimeSeries>,
214    pub acquiring: bool,
215    pub start_time: Option<Instant>,
216    pub num_points: usize,
217    pub mode: TimeSeriesMode,
218    /// Number of input samples averaged into one output time point
219    /// (C++ `numAverage_`). `1` means each input sample is one output point.
220    pub num_average: usize,
221    /// Running per-channel sum of the input samples for the in-progress
222    /// output point (C++ `averageStore_`).
223    average_store: Vec<f64>,
224    /// Number of input samples accumulated into `average_store` so far
225    /// (C++ `numAveraged_`).
226    num_averaged: usize,
227}
228
229impl SharedTsState {
230    fn new(num_channels: usize, num_points: usize) -> Self {
231        let buffers = (0..num_channels)
232            .map(|_| TimeSeries::new(num_points, TimeSeriesMode::OneShot))
233            .collect();
234        Self {
235            buffers,
236            acquiring: false,
237            start_time: None,
238            num_points,
239            mode: TimeSeriesMode::OneShot,
240            num_average: 1,
241            average_store: vec![0.0; num_channels],
242            num_averaged: 0,
243        }
244    }
245
246    /// Reset the running average accumulator (C++ resets `numAveraged_` and
247    /// `averageStore_` on start/erase, resize, mode change).
248    fn reset_average(&mut self) {
249        for v in &mut self.average_store {
250            *v = 0.0;
251        }
252        self.num_averaged = 0;
253    }
254
255    /// Accumulate one input sample vector and, once `num_average` samples
256    /// have been collected, push the per-channel average into the buffers.
257    ///
258    /// Port of the inner loop of C++ `doAddToTimeSeriesT`: each call is one
259    /// input time point; `num_average` of them produce one output point.
260    /// Returns `true` when an output point was emitted this call.
261    fn accumulate(&mut self, values: &[f64]) -> bool {
262        let n = values.len().min(self.average_store.len());
263        for i in 0..n {
264            self.average_store[i] += values[i];
265        }
266        self.num_averaged += 1;
267        if self.num_averaged < self.num_average.max(1) {
268            return false;
269        }
270        let divisor = self.num_averaged as f64;
271        let nb = n.min(self.buffers.len());
272        for i in 0..nb {
273            self.buffers[i].add_value(self.average_store[i] / divisor);
274        }
275        self.reset_average();
276        true
277    }
278}
279
280/// TS port driver: standalone asyn PortDriver for time series waveforms.
281///
282/// Generic over the number of channels — Stats uses 23, ROIStat uses
283/// a different set, and NDTimeSeries standalone can use any count.
284pub struct TimeSeriesPortDriver {
285    base: PortDriverBase,
286    params: TSParams,
287    shared: Arc<Mutex<SharedTsState>>,
288    num_channels: usize,
289    time_per_point: f64,
290}
291
292impl TimeSeriesPortDriver {
293    fn new(
294        port_name: &str,
295        channel_names: &[&str],
296        num_points: usize,
297        shared: Arc<Mutex<SharedTsState>>,
298    ) -> Self {
299        let num_channels = channel_names.len();
300        let mut base = PortDriverBase::new(
301            port_name,
302            1,
303            PortFlags {
304                multi_device: false,
305                can_block: false,
306                destructible: true,
307            },
308        );
309
310        // NDPluginBase params (NDTimeSeries.template includes NDPluginBase.template)
311        let _ = ad_core_rs::params::ndarray_driver::NDArrayDriverParams::create(&mut base);
312        let _ = ad_core_rs::plugin::params::PluginBaseParams::create(&mut base);
313
314        // Register control params
315        let ts_acquire = base.create_param("TS_ACQUIRE", ParamType::Int32).unwrap();
316        let _ = base.set_int32_param(ts_acquire, 0, 0);
317        let ts_read = base.create_param("TS_READ", ParamType::Int32).unwrap();
318        let ts_num_points = base
319            .create_param("TS_NUM_POINTS", ParamType::Int32)
320            .unwrap();
321        let _ = base.set_int32_param(ts_num_points, 0, num_points as i32);
322        let ts_current_point = base
323            .create_param("TS_CURRENT_POINT", ParamType::Int32)
324            .unwrap();
325        let _ = base.set_int32_param(ts_current_point, 0, 0);
326        let ts_time_per_point = base
327            .create_param("TS_TIME_PER_POINT", ParamType::Float64)
328            .unwrap();
329        let ts_averaging_time = base
330            .create_param("TS_AVERAGING_TIME", ParamType::Float64)
331            .unwrap();
332        let ts_num_average = base
333            .create_param("TS_NUM_AVERAGE", ParamType::Int32)
334            .unwrap();
335        let _ = base.set_int32_param(ts_num_average, 0, 1);
336        let ts_elapsed_time = base
337            .create_param("TS_ELAPSED_TIME", ParamType::Float64)
338            .unwrap();
339        let ts_acquire_mode = base
340            .create_param("TS_ACQUIRE_MODE", ParamType::Int32)
341            .unwrap();
342        let _ = base.set_int32_param(ts_acquire_mode, 0, 0);
343        let ts_time_axis = base
344            .create_param("TS_TIME_AXIS", ParamType::Float64Array)
345            .unwrap();
346
347        // Initialize time axis (scaled by time_per_point, default 1.0)
348        let time_per_point = 1.0;
349        let time_axis: Vec<f64> = (0..num_points).map(|i| i as f64 * time_per_point).collect();
350        let _ = base.params.set_float64_array(ts_time_axis, 0, time_axis);
351
352        // Channel waveform params — one Float64Array per channel
353        let mut ts_channels = Vec::with_capacity(num_channels);
354        for name in channel_names {
355            let param_name = format!("TS_CHAN_{name}");
356            let idx = base
357                .create_param(&param_name, ParamType::Float64Array)
358                .unwrap();
359            let _ = base.params.set_float64_array(idx, 0, vec![0.0; num_points]);
360            ts_channels.push(idx);
361        }
362
363        // Generic time series and timestamp waveform params
364        let ts_time_series = base
365            .create_param("TS_TIME_SERIES", ParamType::Float64Array)
366            .unwrap();
367        let ts_timestamp = base
368            .create_param("TS_TIMESTAMP", ParamType::Float64Array)
369            .unwrap();
370
371        let params = TSParams {
372            ts_acquire,
373            ts_read,
374            ts_num_points,
375            ts_current_point,
376            ts_time_per_point,
377            ts_averaging_time,
378            ts_num_average,
379            ts_elapsed_time,
380            ts_acquire_mode,
381            ts_time_axis,
382            ts_channels,
383            channel_names: channel_names.iter().map(|s| s.to_string()).collect(),
384            ts_time_series,
385            ts_timestamp,
386        };
387
388        Self {
389            base,
390            params,
391            shared,
392            num_channels,
393            time_per_point,
394        }
395    }
396
397    /// Build the time axis for the current mode.
398    ///
399    /// Fixed (OneShot) mode uses an ascending axis `i * time_per_point`;
400    /// Circular (RingBuffer) mode uses a signed axis ending at 0, so the
401    /// most recent point is t=0 and older points are negative — C++
402    /// `createAxisArray`: `timeAxis_[i] = -(numTimePoints-1-i)*timePerPoint`.
403    fn build_time_axis(&self, num_points: usize, mode: TimeSeriesMode) -> Vec<f64> {
404        (0..num_points)
405            .map(|i| match mode {
406                TimeSeriesMode::OneShot => i as f64 * self.time_per_point,
407                TimeSeriesMode::RingBuffer => {
408                    -((num_points.saturating_sub(1) - i) as f64) * self.time_per_point
409                }
410            })
411            .collect()
412    }
413
414    /// Recompute and publish the time axis param for the current mode.
415    fn refresh_time_axis(&mut self) {
416        let (num_points, mode) = {
417            let s = self.shared.lock();
418            (s.num_points, s.mode)
419        };
420        let axis = self.build_time_axis(num_points, mode);
421        let _ = self
422            .base
423            .params
424            .set_float64_array(self.params.ts_time_axis, 0, axis);
425    }
426
427    /// Copy buffer data to Float64Array params and call callbacks.
428    fn update_waveform_params(&mut self) {
429        let state = self.shared.lock();
430        let num_points = state.num_points;
431
432        // Update per-channel waveform params
433        for (i, buf) in state.buffers.iter().enumerate() {
434            let mut values = buf.values();
435            values.resize(num_points, 0.0);
436            let _ = self
437                .base
438                .params
439                .set_float64_array(self.params.ts_channels[i], 0, values);
440        }
441
442        // Update current point
443        let current_point = state.buffers[0].count();
444        let _ = self
445            .base
446            .set_int32_param(self.params.ts_current_point, 0, current_point as i32);
447
448        // Update elapsed time
449        if let Some(start) = state.start_time {
450            let elapsed = start.elapsed().as_secs_f64();
451            let _ = self
452                .base
453                .set_float64_param(self.params.ts_elapsed_time, 0, elapsed);
454        }
455
456        // Update acquire status (may have auto-stopped)
457        let acquiring = state.acquiring;
458        drop(state);
459
460        let _ = self
461            .base
462            .set_int32_param(self.params.ts_acquire, 0, if acquiring { 1 } else { 0 });
463
464        // Notify listeners
465        let _ = self.base.call_param_callbacks(0);
466    }
467}
468
469impl PortDriver for TimeSeriesPortDriver {
470    fn base(&self) -> &PortDriverBase {
471        &self.base
472    }
473
474    fn base_mut(&mut self) -> &mut PortDriverBase {
475        &mut self.base
476    }
477
478    fn write_int32(&mut self, user: &mut AsynUser, value: i32) -> asyn_rs::error::AsynResult<()> {
479        let reason = user.reason;
480
481        if reason == self.params.ts_acquire {
482            let mut state = self.shared.lock();
483            if value != 0 {
484                // Start acquiring
485                if !state.acquiring {
486                    // If buffers are empty, this is Erase/Start
487                    if state.buffers[0].count() == 0 {
488                        for buf in state.buffers.iter_mut() {
489                            buf.reset();
490                        }
491                    }
492                    // Start always resets the running average accumulator
493                    // (C++ doTimeSeriesCallbacks / start path).
494                    state.reset_average();
495                    state.acquiring = true;
496                    state.start_time = Some(Instant::now());
497                }
498            } else {
499                // Stop
500                state.acquiring = false;
501            }
502            drop(state);
503            self.base.set_int32_param(reason, 0, value)?;
504            self.base.call_param_callbacks(0)?;
505        } else if reason == self.params.ts_read {
506            // Trigger waveform update
507            self.update_waveform_params();
508        } else if reason == self.params.ts_num_points {
509            let new_size = value.max(1) as usize;
510            {
511                let mut state = self.shared.lock();
512                state.num_points = new_size;
513                for buf in state.buffers.iter_mut() {
514                    buf.resize(new_size);
515                }
516                state.reset_average();
517                state.acquiring = false;
518            }
519
520            // Rebuild the time axis for the current mode.
521            self.refresh_time_axis();
522
523            // Re-initialize channel waveforms
524            for i in 0..self.num_channels {
525                let _ = self.base.params.set_float64_array(
526                    self.params.ts_channels[i],
527                    0,
528                    vec![0.0; new_size],
529                );
530            }
531
532            self.base.set_int32_param(reason, 0, value)?;
533            self.base
534                .set_int32_param(self.params.ts_current_point, 0, 0)?;
535            self.base.set_int32_param(self.params.ts_acquire, 0, 0)?;
536            self.base.call_param_callbacks(0)?;
537        } else if reason == self.params.ts_num_average {
538            // numAverage: input samples averaged per output time point
539            // (C++ P_TSNumAverage). Resets the running accumulator.
540            let n = value.max(1) as usize;
541            {
542                let mut state = self.shared.lock();
543                state.num_average = n;
544                state.reset_average();
545            }
546            self.base.set_int32_param(reason, 0, n as i32)?;
547            self.base.call_param_callbacks(0)?;
548        } else if reason == self.params.ts_acquire_mode {
549            // 0 == TSAcquireModeFixed (OneShot), 1 == TSAcquireModeCircular.
550            let mode = if value == 0 {
551                TimeSeriesMode::OneShot
552            } else {
553                TimeSeriesMode::RingBuffer
554            };
555            {
556                let mut state = self.shared.lock();
557                state.mode = mode;
558                for buf in state.buffers.iter_mut() {
559                    buf.set_mode(mode);
560                }
561                state.reset_average();
562                state.acquiring = false;
563            }
564            // Circular mode flips the time axis to a signed (ending-at-0) one.
565            self.refresh_time_axis();
566
567            self.base.set_int32_param(reason, 0, value)?;
568            self.base.set_int32_param(self.params.ts_acquire, 0, 0)?;
569            self.base.call_param_callbacks(0)?;
570        } else {
571            // Default: store in param cache
572            self.base.set_int32_param(reason, user.addr, value)?;
573            self.base.call_param_callbacks(user.addr)?;
574        }
575
576        Ok(())
577    }
578
579    fn write_float64(&mut self, user: &mut AsynUser, value: f64) -> asyn_rs::error::AsynResult<()> {
580        let reason = user.reason;
581        if reason == self.params.ts_time_per_point {
582            self.time_per_point = value;
583            self.base.set_float64_param(reason, user.addr, value)?;
584            // Rebuild the time axis with the new scaling for the current mode.
585            self.refresh_time_axis();
586            self.base.call_param_callbacks(user.addr)?;
587        } else {
588            self.base.set_float64_param(reason, user.addr, value)?;
589            self.base.call_param_callbacks(user.addr)?;
590        }
591        Ok(())
592    }
593
594    fn read_float64_array(
595        &mut self,
596        user: &AsynUser,
597        buf: &mut [f64],
598    ) -> asyn_rs::error::AsynResult<usize> {
599        let data = self.base.params.get_float64_array(user.reason, user.addr)?;
600        let n = data.len().min(buf.len());
601        buf[..n].copy_from_slice(&data[..n]);
602        Ok(n)
603    }
604}
605
606/// Background thread that receives data from a plugin and accumulates into
607/// shared buffers.
608///
609/// Each received `TimeSeriesData` is one input time point. `num_average`
610/// consecutive input points are averaged into one output time point
611/// (C++ `doAddToTimeSeriesT`). In `OneShot`/Fixed mode acquisition auto-stops
612/// once `num_points` output points exist; in `RingBuffer`/Circular mode the
613/// `TimeSeries` ring buffer wraps and acquisition continues.
614fn ts_data_thread(shared: Arc<Mutex<SharedTsState>>, mut data_rx: TimeSeriesReceiver) {
615    while let Some(data) = data_rx.blocking_recv() {
616        let mut state = shared.lock();
617        if !state.acquiring {
618            continue;
619        }
620        let emitted = state.accumulate(&data.values);
621        // Auto-stop for Fixed (OneShot) mode once num_points output points
622        // have been collected. Only check when an output point was emitted.
623        if emitted
624            && state.mode == TimeSeriesMode::OneShot
625            && state.buffers[0].count() >= state.num_points
626        {
627            state.acquiring = false;
628        }
629    }
630}
631
632/// Create a TS port runtime.
633///
634/// `channel_names` defines the number and names of time series channels.
635/// Returns the port runtime handle, the TS params (for building a registry),
636/// and thread join handles for the actor and data ingestion threads.
637pub fn create_ts_port_runtime(
638    port_name: &str,
639    channel_names: &[&str],
640    num_points: usize,
641    data_rx: TimeSeriesReceiver,
642) -> (
643    PortRuntimeHandle,
644    TSParams,
645    std::thread::JoinHandle<()>,
646    std::thread::JoinHandle<()>,
647) {
648    let num_channels = channel_names.len();
649    let shared = Arc::new(Mutex::new(SharedTsState::new(num_channels, num_points)));
650
651    let driver = TimeSeriesPortDriver::new(port_name, channel_names, num_points, shared.clone());
652
653    // Capture params before the driver is moved into the actor
654    let ts_params = TSParams {
655        ts_acquire: driver.params.ts_acquire,
656        ts_read: driver.params.ts_read,
657        ts_num_points: driver.params.ts_num_points,
658        ts_current_point: driver.params.ts_current_point,
659        ts_time_per_point: driver.params.ts_time_per_point,
660        ts_averaging_time: driver.params.ts_averaging_time,
661        ts_num_average: driver.params.ts_num_average,
662        ts_elapsed_time: driver.params.ts_elapsed_time,
663        ts_acquire_mode: driver.params.ts_acquire_mode,
664        ts_time_axis: driver.params.ts_time_axis,
665        ts_channels: driver.params.ts_channels.clone(),
666        channel_names: driver.params.channel_names.clone(),
667        ts_time_series: driver.params.ts_time_series,
668        ts_timestamp: driver.params.ts_timestamp,
669    };
670
671    let (runtime_handle, actor_jh) = create_port_runtime(driver, RuntimeConfig::default());
672
673    // Spawn data ingestion thread
674    let data_jh = std::thread::Builder::new()
675        .name(format!("ts-data-{port_name}"))
676        .spawn(move || {
677            ts_data_thread(shared, data_rx);
678        })
679        .expect("failed to spawn TS data thread");
680
681    (runtime_handle, ts_params, actor_jh, data_jh)
682}
683
684#[cfg(test)]
685mod tests {
686    use super::*;
687
688    #[test]
689    fn test_one_shot() {
690        let mut ts = TimeSeries::new(5, TimeSeriesMode::OneShot);
691        for i in 0..5 {
692            ts.add_value(i as f64);
693        }
694        assert_eq!(ts.count(), 5);
695        assert_eq!(ts.values(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
696
697        // Adding beyond capacity is a no-op
698        ts.add_value(99.0);
699        assert_eq!(ts.count(), 5);
700    }
701
702    #[test]
703    fn test_ring_buffer() {
704        let mut ts = TimeSeries::new(4, TimeSeriesMode::RingBuffer);
705        for i in 0..6 {
706            ts.add_value(i as f64);
707        }
708        assert_eq!(ts.count(), 4);
709        // Should contain [2, 3, 4, 5] in order
710        assert_eq!(ts.values(), vec![2.0, 3.0, 4.0, 5.0]);
711    }
712
713    #[test]
714    fn test_ring_buffer_partial() {
715        let mut ts = TimeSeries::new(4, TimeSeriesMode::RingBuffer);
716        ts.add_value(10.0);
717        ts.add_value(20.0);
718        assert_eq!(ts.count(), 2);
719        assert_eq!(ts.values(), vec![10.0, 20.0]);
720    }
721
722    #[test]
723    fn test_reset() {
724        let mut ts = TimeSeries::new(3, TimeSeriesMode::OneShot);
725        ts.add_value(1.0);
726        ts.add_value(2.0);
727        ts.reset();
728        assert_eq!(ts.count(), 0);
729        assert!(ts.values().is_empty());
730    }
731
732    #[test]
733    fn test_resize() {
734        let mut ts = TimeSeries::new(5, TimeSeriesMode::OneShot);
735        ts.add_value(1.0);
736        ts.add_value(2.0);
737        ts.resize(3);
738        assert_eq!(ts.num_points, 3);
739        assert_eq!(ts.count(), 0);
740        assert!(ts.values().is_empty());
741    }
742
743    #[test]
744    fn test_set_mode() {
745        let mut ts = TimeSeries::new(5, TimeSeriesMode::OneShot);
746        ts.add_value(1.0);
747        ts.set_mode(TimeSeriesMode::RingBuffer);
748        assert_eq!(ts.mode, TimeSeriesMode::RingBuffer);
749        assert_eq!(ts.count(), 0);
750    }
751
752    // --- TS port driver tests (using a small channel set for simplicity) ---
753
754    const TEST_CHANNELS: [&str; 3] = ["ChA", "ChB", "ChC"];
755
756    #[test]
757    fn test_shared_ts_state_init() {
758        let state = SharedTsState::new(3, 100);
759        assert_eq!(state.buffers.len(), 3);
760        assert_eq!(state.num_points, 100);
761        assert!(!state.acquiring);
762        assert_eq!(state.mode, TimeSeriesMode::OneShot);
763    }
764
765    #[test]
766    fn test_ts_port_driver_create() {
767        let shared = Arc::new(Mutex::new(SharedTsState::new(3, 100)));
768        let driver = TimeSeriesPortDriver::new("TEST_TS", &TEST_CHANNELS, 100, shared);
769        assert_eq!(driver.base().port_name, "TEST_TS");
770        assert_eq!(driver.num_channels, 3);
771        assert!(!driver.base().flags.multi_device);
772    }
773
774    #[test]
775    fn test_ts_port_driver_write_acquire() {
776        let shared = Arc::new(Mutex::new(SharedTsState::new(3, 100)));
777        let mut driver = TimeSeriesPortDriver::new("TEST_TS", &TEST_CHANNELS, 100, shared.clone());
778
779        // Start acquiring
780        let mut user = AsynUser::new(driver.params.ts_acquire);
781        driver.write_int32(&mut user, 1).unwrap();
782        assert!(shared.lock().acquiring);
783
784        // Stop acquiring
785        driver.write_int32(&mut user, 0).unwrap();
786        assert!(!shared.lock().acquiring);
787    }
788
789    #[test]
790    fn test_ts_port_driver_write_num_points() {
791        let shared = Arc::new(Mutex::new(SharedTsState::new(3, 100)));
792        let mut driver = TimeSeriesPortDriver::new("TEST_TS", &TEST_CHANNELS, 100, shared.clone());
793
794        let mut user = AsynUser::new(driver.params.ts_num_points);
795        driver.write_int32(&mut user, 50).unwrap();
796
797        let state = shared.lock();
798        assert_eq!(state.num_points, 50);
799        for buf in &state.buffers {
800            assert_eq!(buf.num_points, 50);
801        }
802    }
803
804    #[test]
805    fn test_ts_port_driver_write_mode() {
806        let shared = Arc::new(Mutex::new(SharedTsState::new(3, 100)));
807        let mut driver = TimeSeriesPortDriver::new("TEST_TS", &TEST_CHANNELS, 100, shared.clone());
808
809        let mut user = AsynUser::new(driver.params.ts_acquire_mode);
810        driver.write_int32(&mut user, 1).unwrap();
811
812        let state = shared.lock();
813        assert_eq!(state.mode, TimeSeriesMode::RingBuffer);
814        for buf in &state.buffers {
815            assert_eq!(buf.mode, TimeSeriesMode::RingBuffer);
816        }
817    }
818
819    #[test]
820    fn test_ts_port_driver_update_waveforms() {
821        let shared = Arc::new(Mutex::new(SharedTsState::new(3, 10)));
822        let mut driver = TimeSeriesPortDriver::new("TEST_TS", &TEST_CHANNELS, 10, shared.clone());
823
824        // Add some data
825        {
826            let mut state = shared.lock();
827            state.acquiring = true;
828            state.start_time = Some(Instant::now());
829            for buf in state.buffers.iter_mut() {
830                buf.add_value(42.0);
831                buf.add_value(43.0);
832            }
833        }
834
835        // Trigger update
836        driver.update_waveform_params();
837
838        // Check current point was updated
839        let cp = driver
840            .base
841            .get_int32_param(driver.params.ts_current_point, 0)
842            .unwrap();
843        assert_eq!(cp, 2);
844
845        // Check waveform data was written
846        let data = driver
847            .base
848            .params
849            .get_float64_array(driver.params.ts_channels[0], 0)
850            .unwrap();
851        assert_eq!(data[0], 42.0);
852        assert_eq!(data[1], 43.0);
853    }
854
855    #[test]
856    fn test_ts_port_driver_read_array() {
857        let shared = Arc::new(Mutex::new(SharedTsState::new(3, 5)));
858        let mut driver = TimeSeriesPortDriver::new("TEST_TS", &TEST_CHANNELS, 5, shared);
859
860        let user = AsynUser::new(driver.params.ts_time_axis);
861        let mut buf = vec![0.0; 5];
862        let n = driver.read_float64_array(&user, &mut buf).unwrap();
863        assert_eq!(n, 5);
864        assert_eq!(buf, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
865    }
866
867    #[test]
868    fn test_ts_data_ingestion_oneshot() {
869        let shared = Arc::new(Mutex::new(SharedTsState::new(3, 3)));
870        let (tx, rx) = tokio::sync::mpsc::channel(16);
871
872        // Start acquiring
873        shared.lock().acquiring = true;
874
875        let shared_clone = shared.clone();
876        let jh = std::thread::spawn(move || ts_data_thread(shared_clone, rx));
877
878        // Send data
879        tx.blocking_send(TimeSeriesData {
880            values: vec![1.0, 10.0, 100.0],
881        })
882        .unwrap();
883        tx.blocking_send(TimeSeriesData {
884            values: vec![2.0, 20.0, 200.0],
885        })
886        .unwrap();
887        tx.blocking_send(TimeSeriesData {
888            values: vec![3.0, 30.0, 300.0],
889        })
890        .unwrap();
891        tx.blocking_send(TimeSeriesData {
892            values: vec![4.0, 40.0, 400.0],
893        })
894        .unwrap(); // beyond capacity
895
896        // Close channel and wait for thread
897        drop(tx);
898        jh.join().unwrap();
899
900        let state = shared.lock();
901        assert_eq!(state.buffers[0].count(), 3);
902        assert_eq!(state.buffers[0].values(), vec![1.0, 2.0, 3.0]);
903        assert_eq!(state.buffers[1].values(), vec![10.0, 20.0, 30.0]);
904        assert_eq!(state.buffers[2].values(), vec![100.0, 200.0, 300.0]);
905        assert!(!state.acquiring); // auto-stopped
906    }
907
908    #[test]
909    fn test_ts_data_ingestion_not_acquiring() {
910        let shared = Arc::new(Mutex::new(SharedTsState::new(3, 10)));
911        let (tx, rx) = tokio::sync::mpsc::channel(16);
912
913        // Not acquiring (default)
914        let shared_clone = shared.clone();
915        let jh = std::thread::spawn(move || ts_data_thread(shared_clone, rx));
916
917        tx.blocking_send(TimeSeriesData {
918            values: vec![1.0, 2.0, 3.0],
919        })
920        .unwrap();
921
922        drop(tx);
923        jh.join().unwrap();
924
925        let state = shared.lock();
926        assert_eq!(state.buffers[0].count(), 0);
927    }
928
929    #[test]
930    fn test_num_average_averages_input_samples() {
931        // numAverage = 3: every 3 input samples produce one averaged output
932        // point. Channel A inputs 0,1,2 -> mean 1; 3,4,5 -> mean 4.
933        let mut state = SharedTsState::new(1, 10);
934        state.num_average = 3;
935        assert!(!state.accumulate(&[0.0]));
936        assert!(!state.accumulate(&[1.0]));
937        assert!(state.accumulate(&[2.0])); // emits mean of 0,1,2 = 1
938        assert!(!state.accumulate(&[3.0]));
939        assert!(!state.accumulate(&[4.0]));
940        assert!(state.accumulate(&[5.0])); // emits mean of 3,4,5 = 4
941        let vals = state.buffers[0].values();
942        assert_eq!(vals.len(), 2);
943        assert!((vals[0] - 1.0).abs() < 1e-10);
944        assert!((vals[1] - 4.0).abs() < 1e-10);
945    }
946
947    #[test]
948    fn test_num_average_one_is_passthrough() {
949        // numAverage = 1: each input sample is one output point unchanged.
950        let mut state = SharedTsState::new(2, 10);
951        state.num_average = 1;
952        assert!(state.accumulate(&[5.0, 50.0]));
953        assert!(state.accumulate(&[6.0, 60.0]));
954        assert_eq!(state.buffers[0].values(), vec![5.0, 6.0]);
955        assert_eq!(state.buffers[1].values(), vec![50.0, 60.0]);
956    }
957
958    #[test]
959    fn test_num_average_drives_ingestion_thread() {
960        // The data thread must average numAverage=2 samples per output point.
961        let shared = Arc::new(Mutex::new(SharedTsState::new(1, 5)));
962        {
963            let mut s = shared.lock();
964            s.num_average = 2;
965            s.acquiring = true;
966        }
967        let (tx, rx) = tokio::sync::mpsc::channel(16);
968        let shared_clone = shared.clone();
969        let jh = std::thread::spawn(move || ts_data_thread(shared_clone, rx));
970
971        for v in [10.0, 20.0, 30.0, 40.0] {
972            tx.blocking_send(TimeSeriesData { values: vec![v] })
973                .unwrap();
974        }
975        drop(tx);
976        jh.join().unwrap();
977
978        // 4 input samples / numAverage 2 -> 2 output points: 15, 35.
979        let state = shared.lock();
980        let vals = state.buffers[0].values();
981        assert_eq!(vals.len(), 2);
982        assert!((vals[0] - 15.0).abs() < 1e-10);
983        assert!((vals[1] - 35.0).abs() < 1e-10);
984    }
985
986    #[test]
987    fn test_fixed_mode_stops_at_num_points() {
988        // Fixed (OneShot) mode: acquisition auto-stops once num_points output
989        // points are collected, even with more input pending.
990        let shared = Arc::new(Mutex::new(SharedTsState::new(1, 3)));
991        {
992            let mut s = shared.lock();
993            s.num_average = 1;
994            s.mode = TimeSeriesMode::OneShot;
995            s.acquiring = true;
996        }
997        let (tx, rx) = tokio::sync::mpsc::channel(16);
998        let shared_clone = shared.clone();
999        let jh = std::thread::spawn(move || ts_data_thread(shared_clone, rx));
1000        for v in [1.0, 2.0, 3.0, 4.0, 5.0] {
1001            tx.blocking_send(TimeSeriesData { values: vec![v] })
1002                .unwrap();
1003        }
1004        drop(tx);
1005        jh.join().unwrap();
1006
1007        let state = shared.lock();
1008        assert!(!state.acquiring, "Fixed mode must auto-stop");
1009        assert_eq!(state.buffers[0].count(), 3);
1010        assert_eq!(state.buffers[0].values(), vec![1.0, 2.0, 3.0]);
1011    }
1012
1013    #[test]
1014    fn test_circular_mode_wraps_and_keeps_acquiring() {
1015        // Circular (RingBuffer) mode: the buffer wraps; acquisition does not
1016        // auto-stop.
1017        let shared = Arc::new(Mutex::new(SharedTsState::new(1, 3)));
1018        {
1019            let mut s = shared.lock();
1020            s.num_average = 1;
1021            s.mode = TimeSeriesMode::RingBuffer;
1022            for buf in s.buffers.iter_mut() {
1023                buf.set_mode(TimeSeriesMode::RingBuffer);
1024            }
1025            s.acquiring = true;
1026        }
1027        let (tx, rx) = tokio::sync::mpsc::channel(16);
1028        let shared_clone = shared.clone();
1029        let jh = std::thread::spawn(move || ts_data_thread(shared_clone, rx));
1030        for v in [1.0, 2.0, 3.0, 4.0, 5.0] {
1031            tx.blocking_send(TimeSeriesData { values: vec![v] })
1032                .unwrap();
1033        }
1034        drop(tx);
1035        jh.join().unwrap();
1036
1037        let state = shared.lock();
1038        assert!(state.acquiring, "Circular mode must keep acquiring");
1039        // Ring buffer of 3 keeps the last 3 points: 3,4,5.
1040        assert_eq!(state.buffers[0].values(), vec![3.0, 4.0, 5.0]);
1041    }
1042
1043    #[test]
1044    fn test_acquire_mode_param_drives_behavior_and_axis() {
1045        // Writing TS_ACQUIRE_MODE must switch the buffer mode AND flip the
1046        // time axis from ascending (Fixed) to signed-ending-at-0 (Circular).
1047        let shared = Arc::new(Mutex::new(SharedTsState::new(1, 4)));
1048        let mut driver = TimeSeriesPortDriver::new("TEST_TS_MODE", &["Ch0"], 4, shared.clone());
1049
1050        // Fixed mode axis: 0, 1, 2, 3.
1051        let axis = driver
1052            .base
1053            .params
1054            .get_float64_array(driver.params.ts_time_axis, 0)
1055            .unwrap();
1056        assert_eq!(&*axis, &[0.0, 1.0, 2.0, 3.0]);
1057
1058        // Switch to Circular mode.
1059        let mut user = AsynUser::new(driver.params.ts_acquire_mode);
1060        driver.write_int32(&mut user, 1).unwrap();
1061        assert_eq!(shared.lock().mode, TimeSeriesMode::RingBuffer);
1062
1063        // Circular axis: -3, -2, -1, 0 (most recent point is t=0).
1064        let axis = driver
1065            .base
1066            .params
1067            .get_float64_array(driver.params.ts_time_axis, 0)
1068            .unwrap();
1069        assert_eq!(&*axis, &[-3.0, -2.0, -1.0, 0.0]);
1070    }
1071
1072    #[test]
1073    fn test_num_average_param_drives_state() {
1074        // Writing TS_NUM_AVERAGE must update SharedTsState::num_average.
1075        let shared = Arc::new(Mutex::new(SharedTsState::new(1, 10)));
1076        let mut driver = TimeSeriesPortDriver::new("TEST_TS_NAVG", &["Ch0"], 10, shared.clone());
1077        let mut user = AsynUser::new(driver.params.ts_num_average);
1078        driver.write_int32(&mut user, 5).unwrap();
1079        assert_eq!(shared.lock().num_average, 5);
1080        // A value of 0 is clamped to 1.
1081        driver.write_int32(&mut user, 0).unwrap();
1082        assert_eq!(shared.lock().num_average, 1);
1083    }
1084
1085    #[test]
1086    fn test_create_ts_port_runtime() {
1087        let (_tx, rx) = tokio::sync::mpsc::channel(16);
1088        let (handle, params, _actor_jh, _data_jh) =
1089            create_ts_port_runtime("TEST_TS_RT", &TEST_CHANNELS, 100, rx);
1090        assert_eq!(handle.port_name(), "TEST_TS_RT");
1091        assert_eq!(params.ts_channels.len(), 3);
1092        handle.shutdown();
1093    }
1094}