1use std::sync::Arc;
2
3use asyn_rs::error::AsynResult;
4use asyn_rs::port::{PortDriverBase, PortFlags};
5
6use crate::color::NDColorMode;
7use crate::ndarray::NDArray;
8use crate::ndarray_pool::NDArrayPool;
9use crate::params::ad_driver::ADDriverParams;
10use crate::plugin::channel::{NDArrayOutput, NDArraySender, QueuedArrayCounter};
11
12use super::{ADStatus, ImageMode, ShutterMode};
13
14pub struct ADDriverBase {
16 pub port_base: PortDriverBase,
17 pub params: ADDriverParams,
18 pub pool: Arc<NDArrayPool>,
19 pub array_output: NDArrayOutput,
20 pub queued_counter: Arc<QueuedArrayCounter>,
21}
22
23impl ADDriverBase {
24 pub fn new(
25 port_name: &str,
26 max_size_x: i32,
27 max_size_y: i32,
28 max_memory: usize,
29 ) -> AsynResult<Self> {
30 let mut port_base = PortDriverBase::new(
31 port_name,
32 1,
33 PortFlags {
34 can_block: true,
35 ..Default::default()
36 },
37 );
38
39 let params = ADDriverParams::create(&mut port_base)?;
40
41 port_base.set_string_param(params.base.port_name_self, 0, port_name.into())?;
44 port_base.set_string_param(params.base.ad_core_version, 0, env!("CARGO_PKG_VERSION").into())?;
45 port_base.set_string_param(params.base.driver_version, 0, "0.0.0".into())?;
46 port_base.set_string_param(params.base.codec, 0, String::new())?;
47
48 port_base.set_int32_param(params.max_size_x, 0, max_size_x)?;
49 port_base.set_int32_param(params.max_size_y, 0, max_size_y)?;
50 port_base.set_int32_param(params.size_x, 0, max_size_x)?;
51 port_base.set_int32_param(params.size_y, 0, max_size_y)?;
52 port_base.set_int32_param(params.bin_x, 0, 1)?;
53 port_base.set_int32_param(params.bin_y, 0, 1)?;
54 port_base.set_int32_param(params.image_mode, 0, ImageMode::Single as i32)?;
55 port_base.set_int32_param(params.num_images, 0, 1)?;
56 port_base.set_int32_param(params.num_exposures, 0, 1)?;
57 port_base.set_float64_param(params.acquire_time, 0, 1.0)?;
58 port_base.set_float64_param(params.acquire_period, 0, 1.0)?;
59 port_base.set_int32_param(params.status, 0, ADStatus::Idle as i32)?;
60 port_base.set_string_param(params.status_message, 0, "Idle".into())?;
61 port_base.set_int32_param(params.base.data_type, 0, 1)?; port_base.set_int32_param(params.base.color_mode, 0, NDColorMode::Mono as i32)?;
63 port_base.set_int32_param(params.base.array_callbacks, 0, 1)?;
64 port_base.set_float64_param(params.base.pool_max_memory, 0, max_memory as f64 / 1_048_576.0)?;
65 port_base.set_int32_param(params.base.array_size_x, 0, max_size_x)?;
67 port_base.set_int32_param(params.base.array_size_y, 0, max_size_y)?;
68 port_base.set_int32_param(params.base.array_size_z, 0, 0)?;
69 let initial_array_bytes = max_size_x as i64 * max_size_y as i64; port_base.set_int32_param(params.base.array_size, 0, initial_array_bytes as i32)?;
71
72 port_base.set_float64_param(params.gain, 0, 1.0)?;
73 port_base.set_int32_param(params.shutter_mode, 0, ShutterMode::None as i32)?;
74 port_base.set_float64_param(params.temperature, 0, 25.0)?;
75 port_base.set_float64_param(params.temperature_actual, 0, 25.0)?;
76
77 let pool = Arc::new(NDArrayPool::new(max_memory));
78
79 Ok(Self {
80 port_base,
81 params,
82 pool,
83 array_output: NDArrayOutput::new(),
84 queued_counter: Arc::new(QueuedArrayCounter::new()),
85 })
86 }
87
88 pub fn connect_downstream(&mut self, mut sender: NDArraySender) {
90 sender.set_queued_counter(self.queued_counter.clone());
91 self.array_output.add(sender);
92 }
93
94 pub fn publish_array(&mut self, array: Arc<NDArray>) -> AsynResult<()> {
96 let counter = self.port_base.get_int32_param(self.params.base.array_counter, 0)? + 1;
97 self.port_base
98 .set_int32_param(self.params.base.array_counter, 0, counter)?;
99
100 let info = array.info();
101 self.port_base
102 .set_int32_param(self.params.base.array_size_x, 0, info.x_size as i32)?;
103 self.port_base
104 .set_int32_param(self.params.base.array_size_y, 0, info.y_size as i32)?;
105 self.port_base
106 .set_int32_param(self.params.base.array_size_z, 0, info.color_size as i32)?;
107 self.port_base
108 .set_int32_param(self.params.base.array_size, 0, info.total_bytes as i32)?;
109
110 self.port_base.set_float64_param(
111 self.params.base.pool_used_memory,
112 0,
113 self.pool.allocated_bytes() as f64 / 1_048_576.0,
114 )?;
115
116 let callbacks_enabled =
117 self.port_base
118 .get_int32_param(self.params.base.array_callbacks, 0)?
119 != 0;
120
121 if callbacks_enabled {
122 self.port_base.set_generic_pointer_param(
123 self.params.base.ndarray_data,
124 0,
125 array.clone() as Arc<dyn std::any::Any + Send + Sync>,
126 )?;
127
128 self.array_output.publish(array);
129 }
130
131 self.port_base.call_param_callbacks(0)?;
132
133 Ok(())
134 }
135
136 pub fn set_shutter(&mut self, open: bool) -> AsynResult<()> {
138 let mode = ShutterMode::from_i32(
139 self.port_base.get_int32_param(self.params.shutter_mode, 0)?,
140 );
141
142 match mode {
143 ShutterMode::None => {}
144 ShutterMode::DetectorOnly | ShutterMode::EpicsAndDetector => {
145 self.port_base.set_int32_param(
146 self.params.shutter_control,
147 0,
148 if open { 1 } else { 0 },
149 )?;
150 }
151 ShutterMode::EpicsOnly => {
152 self.port_base.set_int32_param(
153 self.params.shutter_control_epics,
154 0,
155 if open { 1 } else { 0 },
156 )?;
157 }
158 }
159
160 self.port_base.set_int32_param(
161 self.params.shutter_status,
162 0,
163 if open { 1 } else { 0 },
164 )?;
165
166 Ok(())
167 }
168}
169
170pub trait ADDriver: asyn_rs::port::PortDriver {
172 fn ad_base(&self) -> &ADDriverBase;
173 fn ad_base_mut(&mut self) -> &mut ADDriverBase;
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_new_sets_initial_params() {
182 let ad = ADDriverBase::new("TEST", 1024, 768, 50_000_000).unwrap();
183 assert_eq!(
184 ad.port_base.get_int32_param(ad.params.max_size_x, 0).unwrap(),
185 1024
186 );
187 assert_eq!(
188 ad.port_base.get_int32_param(ad.params.max_size_y, 0).unwrap(),
189 768
190 );
191 assert_eq!(
192 ad.port_base.get_int32_param(ad.params.size_x, 0).unwrap(),
193 1024
194 );
195 assert_eq!(
196 ad.port_base.get_int32_param(ad.params.size_y, 0).unwrap(),
197 768
198 );
199 assert_eq!(
200 ad.port_base.get_int32_param(ad.params.status, 0).unwrap(),
201 ADStatus::Idle as i32
202 );
203 }
204
205 #[test]
206 fn test_publish_array_increments_counter() {
207 let mut ad = ADDriverBase::new("TEST", 256, 256, 50_000_000).unwrap();
208 let arr = ad
209 .pool
210 .alloc(
211 vec![
212 crate::ndarray::NDDimension::new(256),
213 crate::ndarray::NDDimension::new(256),
214 ],
215 crate::ndarray::NDDataType::UInt8,
216 )
217 .unwrap();
218 ad.publish_array(Arc::new(arr)).unwrap();
219 assert_eq!(
220 ad.port_base.get_int32_param(ad.params.base.array_counter, 0).unwrap(),
221 1
222 );
223 }
224
225 #[test]
226 fn test_publish_array_skips_output_when_callbacks_disabled() {
227 use crate::plugin::channel::ndarray_channel;
228
229 let mut ad = ADDriverBase::new("TEST", 64, 64, 1_000_000).unwrap();
230 let (sender, _receiver) = ndarray_channel("DOWNSTREAM", 10);
231 ad.connect_downstream(sender);
232
233 ad.port_base
234 .set_int32_param(ad.params.base.array_callbacks, 0, 0)
235 .unwrap();
236
237 let arr = ad
238 .pool
239 .alloc(
240 vec![
241 crate::ndarray::NDDimension::new(64),
242 crate::ndarray::NDDimension::new(64),
243 ],
244 crate::ndarray::NDDataType::UInt8,
245 )
246 .unwrap();
247 ad.publish_array(Arc::new(arr)).unwrap();
248
249 assert_eq!(
251 ad.port_base.get_int32_param(ad.params.base.array_counter, 0).unwrap(),
252 1
253 );
254 let gp = ad.port_base.get_generic_pointer_param(ad.params.base.ndarray_data, 0).unwrap();
256 assert!(gp.downcast_ref::<NDArray>().is_none());
257 }
258
259 #[test]
260 fn test_publish_sets_generic_pointer() {
261 let mut ad = ADDriverBase::new("TEST", 8, 8, 1_000_000).unwrap();
262 let arr = ad
263 .pool
264 .alloc(
265 vec![
266 crate::ndarray::NDDimension::new(8),
267 crate::ndarray::NDDimension::new(8),
268 ],
269 crate::ndarray::NDDataType::UInt8,
270 )
271 .unwrap();
272 let id = arr.unique_id;
273 ad.publish_array(Arc::new(arr)).unwrap();
274
275 let gp = ad
276 .port_base
277 .get_generic_pointer_param(ad.params.base.ndarray_data, 0)
278 .unwrap();
279 let recovered = gp.downcast_ref::<NDArray>().unwrap();
280 assert_eq!(recovered.unique_id, id);
281 }
282
283 #[test]
284 fn test_shutter_control_detector_mode() {
285 let mut ad = ADDriverBase::new("TEST", 8, 8, 1_000_000).unwrap();
286 ad.port_base
287 .set_int32_param(ad.params.shutter_mode, 0, ShutterMode::DetectorOnly as i32)
288 .unwrap();
289
290 ad.set_shutter(true).unwrap();
291 assert_eq!(
292 ad.port_base.get_int32_param(ad.params.shutter_control, 0).unwrap(),
293 1
294 );
295 assert_eq!(
296 ad.port_base.get_int32_param(ad.params.shutter_status, 0).unwrap(),
297 1
298 );
299
300 ad.set_shutter(false).unwrap();
301 assert_eq!(
302 ad.port_base.get_int32_param(ad.params.shutter_control, 0).unwrap(),
303 0
304 );
305 assert_eq!(
306 ad.port_base.get_int32_param(ad.params.shutter_status, 0).unwrap(),
307 0
308 );
309 }
310
311 #[test]
312 fn test_shutter_control_epics_mode() {
313 let mut ad = ADDriverBase::new("TEST", 8, 8, 1_000_000).unwrap();
314 ad.port_base
315 .set_int32_param(ad.params.shutter_mode, 0, ShutterMode::EpicsOnly as i32)
316 .unwrap();
317
318 ad.set_shutter(true).unwrap();
319 assert_eq!(
320 ad.port_base.get_int32_param(ad.params.shutter_control_epics, 0).unwrap(),
321 1
322 );
323 }
324
325 #[test]
326 fn test_gain_and_temperature() {
327 let ad = ADDriverBase::new("TEST", 8, 8, 1_000_000).unwrap();
328 assert_eq!(
329 ad.port_base.get_float64_param(ad.params.gain, 0).unwrap(),
330 1.0
331 );
332 assert_eq!(
333 ad.port_base.get_float64_param(ad.params.temperature, 0).unwrap(),
334 25.0
335 );
336 }
337
338 #[test]
339 fn test_connect_downstream() {
340 use crate::plugin::channel::ndarray_channel;
341
342 let mut ad = ADDriverBase::new("TEST", 8, 8, 1_000_000).unwrap();
343 let (sender, mut receiver) = ndarray_channel("DOWNSTREAM", 10);
344 ad.connect_downstream(sender);
345
346 let arr = ad
347 .pool
348 .alloc(
349 vec![
350 crate::ndarray::NDDimension::new(8),
351 crate::ndarray::NDDimension::new(8),
352 ],
353 crate::ndarray::NDDataType::UInt8,
354 )
355 .unwrap();
356 let id = arr.unique_id;
357 ad.publish_array(Arc::new(arr)).unwrap();
358
359 let received = receiver.blocking_recv().unwrap();
360 assert_eq!(received.unique_id, id);
361 }
362}