Skip to main content

ad_core/driver/
ad_driver.rs

1use std::sync::Arc;
2
3use asyn_rs::error::AsynResult;
4use asyn_rs::port::{PortDriverBase, PortFlags};
5
6use crate::color::NDColorMode;
7use crate::ndarray::NDArray;
8use crate::ndarray_pool::NDArrayPool;
9use crate::params::ad_driver::ADDriverParams;
10use crate::plugin::channel::{NDArrayOutput, NDArraySender, QueuedArrayCounter};
11
12use super::{ADStatus, ImageMode, ShutterMode};
13
14/// Base state for ADDriver (extends NDArrayDriver with detector-specific params).
15pub struct ADDriverBase {
16    pub port_base: PortDriverBase,
17    pub params: ADDriverParams,
18    pub pool: Arc<NDArrayPool>,
19    pub array_output: NDArrayOutput,
20    pub queued_counter: Arc<QueuedArrayCounter>,
21}
22
23impl ADDriverBase {
24    pub fn new(
25        port_name: &str,
26        max_size_x: i32,
27        max_size_y: i32,
28        max_memory: usize,
29    ) -> AsynResult<Self> {
30        let mut port_base = PortDriverBase::new(
31            port_name,
32            1,
33            PortFlags {
34                can_block: true,
35                ..Default::default()
36            },
37        );
38
39        let params = ADDriverParams::create(&mut port_base)?;
40
41        // Set initial values
42        // Identity strings
43        port_base.set_string_param(params.base.port_name_self, 0, port_name.into())?;
44        port_base.set_string_param(params.base.ad_core_version, 0, env!("CARGO_PKG_VERSION").into())?;
45        port_base.set_string_param(params.base.driver_version, 0, "0.0.0".into())?;
46        port_base.set_string_param(params.base.codec, 0, String::new())?;
47
48        port_base.set_int32_param(params.max_size_x, 0, max_size_x)?;
49        port_base.set_int32_param(params.max_size_y, 0, max_size_y)?;
50        port_base.set_int32_param(params.size_x, 0, max_size_x)?;
51        port_base.set_int32_param(params.size_y, 0, max_size_y)?;
52        port_base.set_int32_param(params.bin_x, 0, 1)?;
53        port_base.set_int32_param(params.bin_y, 0, 1)?;
54        port_base.set_int32_param(params.image_mode, 0, ImageMode::Single as i32)?;
55        port_base.set_int32_param(params.num_images, 0, 1)?;
56        port_base.set_int32_param(params.num_exposures, 0, 1)?;
57        port_base.set_float64_param(params.acquire_time, 0, 1.0)?;
58        port_base.set_float64_param(params.acquire_period, 0, 1.0)?;
59        port_base.set_int32_param(params.status, 0, ADStatus::Idle as i32)?;
60        port_base.set_string_param(params.status_message, 0, "Idle".into())?;
61        port_base.set_int32_param(params.base.data_type, 0, 1)?; // UInt8
62        port_base.set_int32_param(params.base.color_mode, 0, NDColorMode::Mono as i32)?;
63        port_base.set_int32_param(params.base.array_callbacks, 0, 1)?;
64        port_base.set_float64_param(params.base.pool_max_memory, 0, max_memory as f64 / 1_048_576.0)?;
65        // Initial array size based on detector dimensions and data type (UInt8)
66        port_base.set_int32_param(params.base.array_size_x, 0, max_size_x)?;
67        port_base.set_int32_param(params.base.array_size_y, 0, max_size_y)?;
68        port_base.set_int32_param(params.base.array_size_z, 0, 0)?;
69        let initial_array_bytes = max_size_x as i64 * max_size_y as i64; // UInt8 = 1 byte/element
70        port_base.set_int32_param(params.base.array_size, 0, initial_array_bytes as i32)?;
71
72        port_base.set_float64_param(params.gain, 0, 1.0)?;
73        port_base.set_int32_param(params.shutter_mode, 0, ShutterMode::None as i32)?;
74        port_base.set_float64_param(params.temperature, 0, 25.0)?;
75        port_base.set_float64_param(params.temperature_actual, 0, 25.0)?;
76
77        let pool = Arc::new(NDArrayPool::new(max_memory));
78
79        Ok(Self {
80            port_base,
81            params,
82            pool,
83            array_output: NDArrayOutput::new(),
84            queued_counter: Arc::new(QueuedArrayCounter::new()),
85        })
86    }
87
88    /// Connect a downstream channel-based receiver.
89    pub fn connect_downstream(&mut self, mut sender: NDArraySender) {
90        sender.set_queued_counter(self.queued_counter.clone());
91        self.array_output.add(sender);
92    }
93
94    /// Publish an array: update counters, push to plugins and channel outputs, fire callbacks.
95    pub fn publish_array(&mut self, array: Arc<NDArray>) -> AsynResult<()> {
96        let counter = self.port_base.get_int32_param(self.params.base.array_counter, 0)? + 1;
97        self.port_base
98            .set_int32_param(self.params.base.array_counter, 0, counter)?;
99
100        let info = array.info();
101        self.port_base
102            .set_int32_param(self.params.base.array_size_x, 0, info.x_size as i32)?;
103        self.port_base
104            .set_int32_param(self.params.base.array_size_y, 0, info.y_size as i32)?;
105        self.port_base
106            .set_int32_param(self.params.base.array_size_z, 0, info.color_size as i32)?;
107        self.port_base
108            .set_int32_param(self.params.base.array_size, 0, info.total_bytes as i32)?;
109
110        self.port_base.set_float64_param(
111            self.params.base.pool_used_memory,
112            0,
113            self.pool.allocated_bytes() as f64 / 1_048_576.0,
114        )?;
115
116        let callbacks_enabled =
117            self.port_base
118                .get_int32_param(self.params.base.array_callbacks, 0)?
119                != 0;
120
121        if callbacks_enabled {
122            self.port_base.set_generic_pointer_param(
123                self.params.base.ndarray_data,
124                0,
125                array.clone() as Arc<dyn std::any::Any + Send + Sync>,
126            )?;
127
128            self.array_output.publish(array);
129        }
130
131        self.port_base.call_param_callbacks(0)?;
132
133        Ok(())
134    }
135
136    /// Set shutter state (open/close). In C++ this dispatches based on shutter mode.
137    pub fn set_shutter(&mut self, open: bool) -> AsynResult<()> {
138        let mode = ShutterMode::from_i32(
139            self.port_base.get_int32_param(self.params.shutter_mode, 0)?,
140        );
141
142        match mode {
143            ShutterMode::None => {}
144            ShutterMode::DetectorOnly | ShutterMode::EpicsAndDetector => {
145                self.port_base.set_int32_param(
146                    self.params.shutter_control,
147                    0,
148                    if open { 1 } else { 0 },
149                )?;
150            }
151            ShutterMode::EpicsOnly => {
152                self.port_base.set_int32_param(
153                    self.params.shutter_control_epics,
154                    0,
155                    if open { 1 } else { 0 },
156                )?;
157            }
158        }
159
160        self.port_base.set_int32_param(
161            self.params.shutter_status,
162            0,
163            if open { 1 } else { 0 },
164        )?;
165
166        Ok(())
167    }
168}
169
170/// Trait for areaDetector drivers.
171pub trait ADDriver: asyn_rs::port::PortDriver {
172    fn ad_base(&self) -> &ADDriverBase;
173    fn ad_base_mut(&mut self) -> &mut ADDriverBase;
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_new_sets_initial_params() {
182        let ad = ADDriverBase::new("TEST", 1024, 768, 50_000_000).unwrap();
183        assert_eq!(
184            ad.port_base.get_int32_param(ad.params.max_size_x, 0).unwrap(),
185            1024
186        );
187        assert_eq!(
188            ad.port_base.get_int32_param(ad.params.max_size_y, 0).unwrap(),
189            768
190        );
191        assert_eq!(
192            ad.port_base.get_int32_param(ad.params.size_x, 0).unwrap(),
193            1024
194        );
195        assert_eq!(
196            ad.port_base.get_int32_param(ad.params.size_y, 0).unwrap(),
197            768
198        );
199        assert_eq!(
200            ad.port_base.get_int32_param(ad.params.status, 0).unwrap(),
201            ADStatus::Idle as i32
202        );
203    }
204
205    #[test]
206    fn test_publish_array_increments_counter() {
207        let mut ad = ADDriverBase::new("TEST", 256, 256, 50_000_000).unwrap();
208        let arr = ad
209            .pool
210            .alloc(
211                vec![
212                    crate::ndarray::NDDimension::new(256),
213                    crate::ndarray::NDDimension::new(256),
214                ],
215                crate::ndarray::NDDataType::UInt8,
216            )
217            .unwrap();
218        ad.publish_array(Arc::new(arr)).unwrap();
219        assert_eq!(
220            ad.port_base.get_int32_param(ad.params.base.array_counter, 0).unwrap(),
221            1
222        );
223    }
224
225    #[test]
226    fn test_publish_array_skips_output_when_callbacks_disabled() {
227        use crate::plugin::channel::ndarray_channel;
228
229        let mut ad = ADDriverBase::new("TEST", 64, 64, 1_000_000).unwrap();
230        let (sender, _receiver) = ndarray_channel("DOWNSTREAM", 10);
231        ad.connect_downstream(sender);
232
233        ad.port_base
234            .set_int32_param(ad.params.base.array_callbacks, 0, 0)
235            .unwrap();
236
237        let arr = ad
238            .pool
239            .alloc(
240                vec![
241                    crate::ndarray::NDDimension::new(64),
242                    crate::ndarray::NDDimension::new(64),
243                ],
244                crate::ndarray::NDDataType::UInt8,
245            )
246            .unwrap();
247        ad.publish_array(Arc::new(arr)).unwrap();
248
249        // Counter still increments, but generic pointer should NOT be updated to an NDArray
250        assert_eq!(
251            ad.port_base.get_int32_param(ad.params.base.array_counter, 0).unwrap(),
252            1
253        );
254        // Generic pointer should still be the default (unit type), not an NDArray
255        let gp = ad.port_base.get_generic_pointer_param(ad.params.base.ndarray_data, 0).unwrap();
256        assert!(gp.downcast_ref::<NDArray>().is_none());
257    }
258
259    #[test]
260    fn test_publish_sets_generic_pointer() {
261        let mut ad = ADDriverBase::new("TEST", 8, 8, 1_000_000).unwrap();
262        let arr = ad
263            .pool
264            .alloc(
265                vec![
266                    crate::ndarray::NDDimension::new(8),
267                    crate::ndarray::NDDimension::new(8),
268                ],
269                crate::ndarray::NDDataType::UInt8,
270            )
271            .unwrap();
272        let id = arr.unique_id;
273        ad.publish_array(Arc::new(arr)).unwrap();
274
275        let gp = ad
276            .port_base
277            .get_generic_pointer_param(ad.params.base.ndarray_data, 0)
278            .unwrap();
279        let recovered = gp.downcast_ref::<NDArray>().unwrap();
280        assert_eq!(recovered.unique_id, id);
281    }
282
283    #[test]
284    fn test_shutter_control_detector_mode() {
285        let mut ad = ADDriverBase::new("TEST", 8, 8, 1_000_000).unwrap();
286        ad.port_base
287            .set_int32_param(ad.params.shutter_mode, 0, ShutterMode::DetectorOnly as i32)
288            .unwrap();
289
290        ad.set_shutter(true).unwrap();
291        assert_eq!(
292            ad.port_base.get_int32_param(ad.params.shutter_control, 0).unwrap(),
293            1
294        );
295        assert_eq!(
296            ad.port_base.get_int32_param(ad.params.shutter_status, 0).unwrap(),
297            1
298        );
299
300        ad.set_shutter(false).unwrap();
301        assert_eq!(
302            ad.port_base.get_int32_param(ad.params.shutter_control, 0).unwrap(),
303            0
304        );
305        assert_eq!(
306            ad.port_base.get_int32_param(ad.params.shutter_status, 0).unwrap(),
307            0
308        );
309    }
310
311    #[test]
312    fn test_shutter_control_epics_mode() {
313        let mut ad = ADDriverBase::new("TEST", 8, 8, 1_000_000).unwrap();
314        ad.port_base
315            .set_int32_param(ad.params.shutter_mode, 0, ShutterMode::EpicsOnly as i32)
316            .unwrap();
317
318        ad.set_shutter(true).unwrap();
319        assert_eq!(
320            ad.port_base.get_int32_param(ad.params.shutter_control_epics, 0).unwrap(),
321            1
322        );
323    }
324
325    #[test]
326    fn test_gain_and_temperature() {
327        let ad = ADDriverBase::new("TEST", 8, 8, 1_000_000).unwrap();
328        assert_eq!(
329            ad.port_base.get_float64_param(ad.params.gain, 0).unwrap(),
330            1.0
331        );
332        assert_eq!(
333            ad.port_base.get_float64_param(ad.params.temperature, 0).unwrap(),
334            25.0
335        );
336    }
337
338    #[test]
339    fn test_connect_downstream() {
340        use crate::plugin::channel::ndarray_channel;
341
342        let mut ad = ADDriverBase::new("TEST", 8, 8, 1_000_000).unwrap();
343        let (sender, mut receiver) = ndarray_channel("DOWNSTREAM", 10);
344        ad.connect_downstream(sender);
345
346        let arr = ad
347            .pool
348            .alloc(
349                vec![
350                    crate::ndarray::NDDimension::new(8),
351                    crate::ndarray::NDDimension::new(8),
352                ],
353                crate::ndarray::NDDataType::UInt8,
354            )
355            .unwrap();
356        let id = arr.unique_id;
357        ad.publish_array(Arc::new(arr)).unwrap();
358
359        let received = receiver.blocking_recv().unwrap();
360        assert_eq!(received.unique_id, id);
361    }
362}