Skip to main content

ad_core/driver/
ndarray_driver.rs

1use std::path::Path;
2use std::sync::Arc;
3
4use asyn_rs::error::AsynResult;
5use asyn_rs::port::{PortDriverBase, PortFlags};
6
7use crate::ndarray::NDArray;
8use crate::ndarray_pool::NDArrayPool;
9use crate::params::ndarray_driver::NDArrayDriverParams;
10use crate::plugin::channel::{NDArrayOutput, NDArraySender, QueuedArrayCounter};
11
12/// Base state for asynNDArrayDriver (file handling, attribute mgmt, pool).
13pub struct NDArrayDriverBase {
14    pub port_base: PortDriverBase,
15    pub params: NDArrayDriverParams,
16    pub pool: Arc<NDArrayPool>,
17    pub array_output: NDArrayOutput,
18    pub queued_counter: Arc<QueuedArrayCounter>,
19}
20
21impl NDArrayDriverBase {
22    pub fn new(port_name: &str, max_memory: usize) -> AsynResult<Self> {
23        let mut port_base = PortDriverBase::new(
24            port_name,
25            1,
26            PortFlags {
27                can_block: true,
28                ..Default::default()
29            },
30        );
31
32        let params = NDArrayDriverParams::create(&mut port_base)?;
33
34        port_base.set_int32_param(params.array_callbacks, 0, 1)?;
35        port_base.set_float64_param(params.pool_max_memory, 0, max_memory as f64 / 1_048_576.0)?;
36
37        let pool = Arc::new(NDArrayPool::new(max_memory));
38
39        Ok(Self {
40            port_base,
41            params,
42            pool,
43            array_output: NDArrayOutput::new(),
44            queued_counter: Arc::new(QueuedArrayCounter::new()),
45        })
46    }
47
48    /// Connect a downstream channel-based receiver.
49    pub fn connect_downstream(&mut self, mut sender: NDArraySender) {
50        sender.set_queued_counter(self.queued_counter.clone());
51        self.array_output.add(sender);
52    }
53
54    /// Number of connected downstream channels.
55    pub fn num_plugins(&self) -> usize {
56        self.array_output.num_senders()
57    }
58
59    /// Publish an array: update counters, push to plugins and channel outputs.
60    pub fn publish_array(&mut self, array: Arc<NDArray>) -> AsynResult<()> {
61        let counter = self.port_base.get_int32_param(self.params.array_counter, 0)? + 1;
62        self.port_base
63            .set_int32_param(self.params.array_counter, 0, counter)?;
64
65        let info = array.info();
66        self.port_base
67            .set_int32_param(self.params.array_size_x, 0, info.x_size as i32)?;
68        self.port_base
69            .set_int32_param(self.params.array_size_y, 0, info.y_size as i32)?;
70        self.port_base
71            .set_int32_param(self.params.array_size_z, 0, info.color_size as i32)?;
72        self.port_base
73            .set_int32_param(self.params.array_size, 0, info.total_bytes as i32)?;
74        self.port_base
75            .set_int32_param(self.params.unique_id, 0, array.unique_id)?;
76
77        // Update pool stats
78        self.port_base.set_float64_param(
79            self.params.pool_used_memory,
80            0,
81            self.pool.allocated_bytes() as f64 / 1_048_576.0,
82        )?;
83        self.port_base.set_int32_param(
84            self.params.pool_free_buffers,
85            0,
86            self.pool.num_free_buffers() as i32,
87        )?;
88        self.port_base.set_int32_param(
89            self.params.pool_alloc_buffers,
90            0,
91            self.pool.num_alloc_buffers() as i32,
92        )?;
93
94        let callbacks_enabled =
95            self.port_base
96                .get_int32_param(self.params.array_callbacks, 0)?
97                != 0;
98
99        if callbacks_enabled {
100            self.port_base.set_generic_pointer_param(
101                self.params.ndarray_data,
102                0,
103                array.clone() as Arc<dyn std::any::Any + Send + Sync>,
104            )?;
105
106            self.array_output.publish(array);
107        }
108
109        self.port_base.call_param_callbacks(0)?;
110
111        Ok(())
112    }
113
114    /// Construct a file path from template, path, name, and number.
115    pub fn create_file_name(&mut self) -> AsynResult<String> {
116        let path = self.port_base.get_string_param(self.params.file_path, 0)?;
117        let name = self.port_base.get_string_param(self.params.file_name, 0)?;
118        let number = self.port_base.get_int32_param(self.params.file_number, 0)?;
119        let template = self.port_base.get_string_param(self.params.file_template, 0)?;
120
121        let full = if template.is_empty() {
122            format!("{}{}{:04}", path, name, number)
123        } else {
124            // Simple template: replace %s with path+name, %d with number
125            template
126                .replace("%s%s", &format!("{}{}", path, name))
127                .replace("%d", &number.to_string())
128        };
129
130        self.port_base
131            .set_string_param(self.params.full_file_name, 0, full.clone())?;
132
133        Ok(full)
134    }
135
136    /// Check if the file path directory exists.
137    pub fn check_path(&mut self) -> AsynResult<bool> {
138        let path = self.port_base.get_string_param(self.params.file_path, 0)?;
139        let exists = Path::new(&path).is_dir();
140        self.port_base
141            .set_int32_param(self.params.file_path_exists, 0, exists as i32)?;
142        Ok(exists)
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::plugin::channel::ndarray_channel;
150
151    #[test]
152    fn test_new_sets_callbacks_enabled() {
153        let drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
154        assert_eq!(
155            drv.port_base.get_int32_param(drv.params.array_callbacks, 0).unwrap(),
156            1,
157        );
158    }
159
160    #[test]
161    fn test_publish_array() {
162        let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
163        let arr = drv.pool.alloc(
164            vec![crate::ndarray::NDDimension::new(64), crate::ndarray::NDDimension::new(64)],
165            crate::ndarray::NDDataType::UInt8,
166        ).unwrap();
167        drv.publish_array(Arc::new(arr)).unwrap();
168        assert_eq!(
169            drv.port_base.get_int32_param(drv.params.array_counter, 0).unwrap(),
170            1,
171        );
172    }
173
174    #[test]
175    fn test_publish_updates_size_info() {
176        let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
177        let arr = drv.pool.alloc(
178            vec![crate::ndarray::NDDimension::new(320), crate::ndarray::NDDimension::new(240)],
179            crate::ndarray::NDDataType::UInt16,
180        ).unwrap();
181        drv.publish_array(Arc::new(arr)).unwrap();
182        assert_eq!(
183            drv.port_base.get_int32_param(drv.params.array_size_x, 0).unwrap(),
184            320,
185        );
186        assert_eq!(
187            drv.port_base.get_int32_param(drv.params.array_size_y, 0).unwrap(),
188            240,
189        );
190    }
191
192    #[test]
193    fn test_create_file_name_default() {
194        let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
195        drv.port_base.set_string_param(drv.params.file_path, 0, "/tmp/".into()).unwrap();
196        drv.port_base.set_string_param(drv.params.file_name, 0, "test_".into()).unwrap();
197        drv.port_base.set_int32_param(drv.params.file_number, 0, 42).unwrap();
198        drv.port_base.set_string_param(drv.params.file_template, 0, "".into()).unwrap();
199
200        let name = drv.create_file_name().unwrap();
201        assert_eq!(name, "/tmp/test_0042");
202    }
203
204    #[test]
205    fn test_check_path_exists() {
206        let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
207        drv.port_base.set_string_param(drv.params.file_path, 0, "/tmp".into()).unwrap();
208        assert!(drv.check_path().unwrap());
209    }
210
211    #[test]
212    fn test_check_path_not_exists() {
213        let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
214        drv.port_base.set_string_param(drv.params.file_path, 0, "/nonexistent_path_xyz".into()).unwrap();
215        assert!(!drv.check_path().unwrap());
216    }
217
218    #[test]
219    fn test_connect_downstream() {
220        let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
221        let (sender, mut receiver) = ndarray_channel("DOWNSTREAM", 10);
222        drv.connect_downstream(sender);
223        assert_eq!(drv.num_plugins(), 1);
224
225        let arr = drv.pool.alloc(
226            vec![crate::ndarray::NDDimension::new(8)],
227            crate::ndarray::NDDataType::UInt8,
228        ).unwrap();
229        let id = arr.unique_id;
230        drv.publish_array(Arc::new(arr)).unwrap();
231
232        let received = receiver.blocking_recv().unwrap();
233        assert_eq!(received.unique_id, id);
234    }
235}