1use std::path::Path;
2use std::sync::Arc;
3
4use asyn_rs::error::AsynResult;
5use asyn_rs::port::{PortDriverBase, PortFlags};
6
7use crate::ndarray::NDArray;
8use crate::ndarray_pool::NDArrayPool;
9use crate::params::ndarray_driver::NDArrayDriverParams;
10use crate::plugin::channel::{NDArrayOutput, NDArraySender, QueuedArrayCounter};
11
12pub struct NDArrayDriverBase {
14 pub port_base: PortDriverBase,
15 pub params: NDArrayDriverParams,
16 pub pool: Arc<NDArrayPool>,
17 pub array_output: NDArrayOutput,
18 pub queued_counter: Arc<QueuedArrayCounter>,
19}
20
21impl NDArrayDriverBase {
22 pub fn new(port_name: &str, max_memory: usize) -> AsynResult<Self> {
23 let mut port_base = PortDriverBase::new(
24 port_name,
25 1,
26 PortFlags {
27 can_block: true,
28 ..Default::default()
29 },
30 );
31
32 let params = NDArrayDriverParams::create(&mut port_base)?;
33
34 port_base.set_int32_param(params.array_callbacks, 0, 1)?;
35 port_base.set_float64_param(params.pool_max_memory, 0, max_memory as f64 / 1_048_576.0)?;
36
37 let pool = Arc::new(NDArrayPool::new(max_memory));
38
39 Ok(Self {
40 port_base,
41 params,
42 pool,
43 array_output: NDArrayOutput::new(),
44 queued_counter: Arc::new(QueuedArrayCounter::new()),
45 })
46 }
47
48 pub fn connect_downstream(&mut self, mut sender: NDArraySender) {
50 sender.set_queued_counter(self.queued_counter.clone());
51 self.array_output.add(sender);
52 }
53
54 pub fn num_plugins(&self) -> usize {
56 self.array_output.num_senders()
57 }
58
59 pub fn publish_array(&mut self, array: Arc<NDArray>) -> AsynResult<()> {
61 let counter = self.port_base.get_int32_param(self.params.array_counter, 0)? + 1;
62 self.port_base
63 .set_int32_param(self.params.array_counter, 0, counter)?;
64
65 let info = array.info();
66 self.port_base
67 .set_int32_param(self.params.array_size_x, 0, info.x_size as i32)?;
68 self.port_base
69 .set_int32_param(self.params.array_size_y, 0, info.y_size as i32)?;
70 self.port_base
71 .set_int32_param(self.params.array_size_z, 0, info.color_size as i32)?;
72 self.port_base
73 .set_int32_param(self.params.array_size, 0, info.total_bytes as i32)?;
74 self.port_base
75 .set_int32_param(self.params.unique_id, 0, array.unique_id)?;
76
77 self.port_base.set_float64_param(
79 self.params.pool_used_memory,
80 0,
81 self.pool.allocated_bytes() as f64 / 1_048_576.0,
82 )?;
83 self.port_base.set_int32_param(
84 self.params.pool_free_buffers,
85 0,
86 self.pool.num_free_buffers() as i32,
87 )?;
88 self.port_base.set_int32_param(
89 self.params.pool_alloc_buffers,
90 0,
91 self.pool.num_alloc_buffers() as i32,
92 )?;
93
94 let callbacks_enabled =
95 self.port_base
96 .get_int32_param(self.params.array_callbacks, 0)?
97 != 0;
98
99 if callbacks_enabled {
100 self.port_base.set_generic_pointer_param(
101 self.params.ndarray_data,
102 0,
103 array.clone() as Arc<dyn std::any::Any + Send + Sync>,
104 )?;
105
106 self.array_output.publish(array);
107 }
108
109 self.port_base.call_param_callbacks(0)?;
110
111 Ok(())
112 }
113
114 pub fn create_file_name(&mut self) -> AsynResult<String> {
116 let path = self.port_base.get_string_param(self.params.file_path, 0)?;
117 let name = self.port_base.get_string_param(self.params.file_name, 0)?;
118 let number = self.port_base.get_int32_param(self.params.file_number, 0)?;
119 let template = self.port_base.get_string_param(self.params.file_template, 0)?;
120
121 let full = if template.is_empty() {
122 format!("{}{}{:04}", path, name, number)
123 } else {
124 template
126 .replace("%s%s", &format!("{}{}", path, name))
127 .replace("%d", &number.to_string())
128 };
129
130 self.port_base
131 .set_string_param(self.params.full_file_name, 0, full.clone())?;
132
133 Ok(full)
134 }
135
136 pub fn check_path(&mut self) -> AsynResult<bool> {
138 let path = self.port_base.get_string_param(self.params.file_path, 0)?;
139 let exists = Path::new(&path).is_dir();
140 self.port_base
141 .set_int32_param(self.params.file_path_exists, 0, exists as i32)?;
142 Ok(exists)
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::plugin::channel::ndarray_channel;
150
151 #[test]
152 fn test_new_sets_callbacks_enabled() {
153 let drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
154 assert_eq!(
155 drv.port_base.get_int32_param(drv.params.array_callbacks, 0).unwrap(),
156 1,
157 );
158 }
159
160 #[test]
161 fn test_publish_array() {
162 let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
163 let arr = drv.pool.alloc(
164 vec![crate::ndarray::NDDimension::new(64), crate::ndarray::NDDimension::new(64)],
165 crate::ndarray::NDDataType::UInt8,
166 ).unwrap();
167 drv.publish_array(Arc::new(arr)).unwrap();
168 assert_eq!(
169 drv.port_base.get_int32_param(drv.params.array_counter, 0).unwrap(),
170 1,
171 );
172 }
173
174 #[test]
175 fn test_publish_updates_size_info() {
176 let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
177 let arr = drv.pool.alloc(
178 vec![crate::ndarray::NDDimension::new(320), crate::ndarray::NDDimension::new(240)],
179 crate::ndarray::NDDataType::UInt16,
180 ).unwrap();
181 drv.publish_array(Arc::new(arr)).unwrap();
182 assert_eq!(
183 drv.port_base.get_int32_param(drv.params.array_size_x, 0).unwrap(),
184 320,
185 );
186 assert_eq!(
187 drv.port_base.get_int32_param(drv.params.array_size_y, 0).unwrap(),
188 240,
189 );
190 }
191
192 #[test]
193 fn test_create_file_name_default() {
194 let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
195 drv.port_base.set_string_param(drv.params.file_path, 0, "/tmp/".into()).unwrap();
196 drv.port_base.set_string_param(drv.params.file_name, 0, "test_".into()).unwrap();
197 drv.port_base.set_int32_param(drv.params.file_number, 0, 42).unwrap();
198 drv.port_base.set_string_param(drv.params.file_template, 0, "".into()).unwrap();
199
200 let name = drv.create_file_name().unwrap();
201 assert_eq!(name, "/tmp/test_0042");
202 }
203
204 #[test]
205 fn test_check_path_exists() {
206 let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
207 drv.port_base.set_string_param(drv.params.file_path, 0, "/tmp".into()).unwrap();
208 assert!(drv.check_path().unwrap());
209 }
210
211 #[test]
212 fn test_check_path_not_exists() {
213 let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
214 drv.port_base.set_string_param(drv.params.file_path, 0, "/nonexistent_path_xyz".into()).unwrap();
215 assert!(!drv.check_path().unwrap());
216 }
217
218 #[test]
219 fn test_connect_downstream() {
220 let mut drv = NDArrayDriverBase::new("TEST", 1_000_000).unwrap();
221 let (sender, mut receiver) = ndarray_channel("DOWNSTREAM", 10);
222 drv.connect_downstream(sender);
223 assert_eq!(drv.num_plugins(), 1);
224
225 let arr = drv.pool.alloc(
226 vec![crate::ndarray::NDDimension::new(8)],
227 crate::ndarray::NDDataType::UInt8,
228 ).unwrap();
229 let id = arr.unique_id;
230 drv.publish_array(Arc::new(arr)).unwrap();
231
232 let received = receiver.blocking_recv().unwrap();
233 assert_eq!(received.unique_id, id);
234 }
235}