autd3_driver/firmware/v12/transmission/
sender.rs

1use std::time::{Duration, Instant};
2
3use crate::{
4    datagram::FirmwareVersionType,
5    error::AUTDDriverError,
6    firmware::{
7        driver::{Driver, Operation, OperationHandler, SenderOption, TimerStrategy},
8        v12::{V12, operation::OperationGenerator},
9    },
10};
11
12use autd3_core::{
13    datagram::{Datagram, DeviceFilter},
14    geometry::Geometry,
15    link::{Link, MsgId, RxMessage, TxMessage},
16    sleep::Sleep,
17};
18
19/// A struct to send the [`Datagram`] to the devices.
20pub struct Sender<'a, L: Link, S: Sleep, T: TimerStrategy<S>> {
21    pub(crate) msg_id: &'a mut MsgId,
22    pub(crate) link: &'a mut L,
23    pub(crate) geometry: &'a Geometry,
24    pub(crate) sent_flags: &'a mut [bool],
25    pub(crate) rx: &'a mut [RxMessage],
26    pub(crate) option: SenderOption,
27    pub(crate) timer_strategy: T,
28    pub(crate) _phantom: std::marker::PhantomData<S>,
29}
30
31impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
32    /// Send the [`Datagram`] to the devices.
33    pub fn send<D: Datagram>(&mut self, s: D) -> Result<(), AUTDDriverError>
34    where
35        AUTDDriverError: From<D::Error>,
36        D::G: OperationGenerator,
37        AUTDDriverError: From<<<D::G as OperationGenerator>::O1 as Operation>::Error>
38            + From<<<D::G as OperationGenerator>::O2 as Operation>::Error>,
39    {
40        let timeout = self.option.timeout.unwrap_or(s.option().timeout);
41        let parallel_threshold = s.option().parallel_threshold;
42        let strict = self.option.strict;
43
44        let mut g = s.operation_generator(
45            self.geometry,
46            &DeviceFilter::all_enabled(),
47            &V12.firmware_limits(),
48        )?;
49        let mut operations = self
50            .geometry
51            .iter()
52            .map(|dev| g.generate(dev))
53            .collect::<Vec<_>>();
54
55        operations
56            .iter()
57            .zip(self.sent_flags.iter_mut())
58            .for_each(|(op, flag)| {
59                *flag = op.is_some();
60            });
61
62        let num_enabled = self.sent_flags.iter().filter(|x| **x).count();
63        let parallel = self
64            .option
65            .parallel
66            .is_parallel(num_enabled, parallel_threshold);
67
68        self.link.ensure_is_open()?;
69        self.link.update(self.geometry)?;
70
71        let mut send_timing = self.timer_strategy.initial();
72        loop {
73            let mut tx = self.link.alloc_tx_buffer()?;
74
75            self.msg_id.increment();
76            OperationHandler::pack(
77                *self.msg_id,
78                &mut operations,
79                self.geometry,
80                &mut tx,
81                parallel,
82            )?;
83
84            self.send_receive(tx, timeout, strict)?;
85
86            if OperationHandler::is_done(&operations) {
87                return Ok(());
88            }
89
90            send_timing = self
91                .timer_strategy
92                .sleep(send_timing, self.option.send_interval);
93        }
94    }
95}
96
97impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
98    fn send_receive(
99        &mut self,
100        tx: Vec<TxMessage>,
101        timeout: Duration,
102        strict: bool,
103    ) -> Result<(), AUTDDriverError> {
104        self.link.ensure_is_open()?;
105        self.link.send(tx)?;
106        Self::wait_msg_processed(
107            self.link,
108            self.msg_id,
109            self.rx,
110            self.sent_flags,
111            &self.option,
112            &self.timer_strategy,
113            timeout,
114            strict,
115        )
116    }
117
118    #[allow(clippy::too_many_arguments)]
119    pub(crate) fn wait_msg_processed(
120        link: &mut L,
121        msg_id: &mut MsgId,
122        rx: &mut [RxMessage],
123        sent_flags: &mut [bool],
124        option: &SenderOption,
125        timer_strategy: &T,
126        timeout: Duration,
127        strict: bool,
128    ) -> Result<(), AUTDDriverError> {
129        let start = Instant::now();
130        let mut receive_timing = timer_strategy.initial();
131        loop {
132            link.ensure_is_open()?;
133            link.receive(rx)?;
134
135            if crate::firmware::v12::cpu::check_if_msg_is_processed(*msg_id, rx)
136                .zip(sent_flags.iter())
137                .filter_map(|(r, sent)| sent.then_some(r))
138                .all(std::convert::identity)
139            {
140                break;
141            }
142
143            if start.elapsed() > timeout {
144                return if !strict && timeout == Duration::ZERO {
145                    Ok(())
146                } else {
147                    Err(AUTDDriverError::ConfirmResponseFailed)
148                };
149            }
150
151            receive_timing = timer_strategy.sleep(receive_timing, option.receive_interval);
152        }
153
154        rx.iter().try_fold((), |_, r| {
155            crate::firmware::v12::cpu::check_firmware_err(r.ack())
156        })
157    }
158
159    pub(crate) fn fetch_firminfo(
160        &mut self,
161        ty: FirmwareVersionType,
162    ) -> Result<Vec<u8>, AUTDDriverError> {
163        self.send(ty).map_err(|_| {
164            AUTDDriverError::ReadFirmwareVersionFailed(
165                crate::firmware::v12::cpu::check_if_msg_is_processed(*self.msg_id, self.rx)
166                    .collect(),
167            )
168        })?;
169        Ok(self.rx.iter().map(|rx| rx.data()).collect())
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    use crate::firmware::driver::{FixedSchedule, ParallelMode};
178    use autd3_core::{
179        link::{Ack, LinkError, TxBufferPoolSync},
180        sleep::{Sleep, SpinSleeper, SpinWaitSleeper, StdSleeper},
181    };
182
183    #[derive(Default)]
184    struct MockLink {
185        pub is_open: bool,
186        pub send_cnt: usize,
187        pub recv_cnt: usize,
188        pub down: bool,
189        pub buffer_pool: TxBufferPoolSync,
190    }
191
192    impl Link for MockLink {
193        fn open(&mut self, geometry: &Geometry) -> Result<(), LinkError> {
194            self.is_open = true;
195            self.buffer_pool.init(geometry);
196            Ok(())
197        }
198
199        fn close(&mut self) -> Result<(), LinkError> {
200            self.is_open = false;
201            Ok(())
202        }
203
204        fn alloc_tx_buffer(&mut self) -> Result<Vec<TxMessage>, LinkError> {
205            Ok(self.buffer_pool.borrow())
206        }
207
208        fn send(&mut self, tx: Vec<TxMessage>) -> Result<(), LinkError> {
209            if !self.down {
210                self.send_cnt += 1;
211            }
212            self.buffer_pool.return_buffer(tx);
213            Ok(())
214        }
215
216        fn receive(&mut self, rx: &mut [RxMessage]) -> Result<(), LinkError> {
217            if self.recv_cnt > 10 {
218                return Err(LinkError::new("too many"));
219            }
220
221            if !self.down {
222                self.recv_cnt += 1;
223            }
224            rx.iter_mut().for_each(|r| {
225                *r = RxMessage::new(r.data(), Ack::new().with_msg_id(self.recv_cnt as u8))
226            });
227
228            Ok(())
229        }
230
231        fn is_open(&self) -> bool {
232            self.is_open
233        }
234    }
235
236    #[test]
237    fn test_close() -> anyhow::Result<()> {
238        let mut link = MockLink::default();
239        link.open(&Geometry::new(Vec::new()))?;
240
241        assert!(link.is_open());
242
243        link.close()?;
244
245        assert!(!link.is_open());
246
247        Ok(())
248    }
249
250    #[rstest::rstest]
251    #[case(StdSleeper)]
252    #[case(SpinSleeper::default())]
253    #[case(SpinWaitSleeper)]
254    #[test]
255    fn test_send_receive(#[case] sleeper: impl Sleep) {
256        let mut link = MockLink::default();
257        let mut geometry = crate::autd3_device::tests::create_geometry(1);
258        let mut sent_flags = vec![false; 1];
259        let mut rx = Vec::new();
260        let mut msg_id = MsgId::new(0);
261
262        assert!(link.open(&geometry).is_ok());
263        let mut sender = Sender {
264            msg_id: &mut msg_id,
265            link: &mut link,
266            geometry: &mut geometry,
267            sent_flags: &mut sent_flags,
268            rx: &mut rx,
269            option: SenderOption {
270                send_interval: Duration::from_millis(1),
271                receive_interval: Duration::from_millis(1),
272                timeout: None,
273                parallel: ParallelMode::Auto,
274                strict: true,
275            },
276            timer_strategy: FixedSchedule(sleeper),
277            _phantom: std::marker::PhantomData,
278        };
279
280        let tx = sender.link.alloc_tx_buffer().unwrap();
281        assert_eq!(Ok(()), sender.send_receive(tx, Duration::ZERO, true));
282
283        let tx = sender.link.alloc_tx_buffer().unwrap();
284        assert_eq!(
285            Ok(()),
286            sender.send_receive(tx, Duration::from_millis(1), true)
287        );
288
289        sender.link.is_open = false;
290        let tx = sender.link.alloc_tx_buffer().unwrap();
291        assert_eq!(
292            Err(AUTDDriverError::Link(LinkError::closed())),
293            sender.send_receive(tx, Duration::ZERO, true),
294        );
295    }
296
297    #[rstest::rstest]
298    #[case(StdSleeper)]
299    #[case(SpinSleeper::default())]
300    #[case(SpinWaitSleeper)]
301    #[test]
302    fn test_wait_msg_processed(#[case] sleeper: impl Sleep) {
303        let mut link = MockLink::default();
304        let mut geometry = crate::autd3_device::tests::create_geometry(1);
305        let mut sent_flags = vec![true; 1];
306        let mut rx = vec![RxMessage::new(0, Ack::new())];
307        let mut msg_id = MsgId::new(1);
308
309        assert!(link.open(&geometry).is_ok());
310        let sender = Sender {
311            msg_id: &mut msg_id,
312            link: &mut link,
313            geometry: &mut geometry,
314            sent_flags: &mut sent_flags,
315            rx: &mut rx,
316            option: SenderOption {
317                send_interval: Duration::from_millis(1),
318                receive_interval: Duration::from_millis(1),
319                timeout: None,
320                parallel: ParallelMode::Auto,
321                strict: true,
322            },
323            timer_strategy: FixedSchedule(sleeper),
324            _phantom: std::marker::PhantomData,
325        };
326
327        assert_eq!(
328            Ok(()),
329            Sender::wait_msg_processed(
330                sender.link,
331                sender.msg_id,
332                sender.rx,
333                sender.sent_flags,
334                &sender.option,
335                &sender.timer_strategy,
336                Duration::from_millis(10),
337                true
338            )
339        );
340
341        sender.link.recv_cnt = 0;
342        sender.link.is_open = false;
343        assert_eq!(
344            Err(AUTDDriverError::Link(LinkError::closed())),
345            Sender::wait_msg_processed(
346                sender.link,
347                sender.msg_id,
348                sender.rx,
349                sender.sent_flags,
350                &sender.option,
351                &sender.timer_strategy,
352                Duration::from_millis(10),
353                true
354            )
355        );
356
357        sender.link.recv_cnt = 0;
358        sender.link.is_open = true;
359        sender.link.down = true;
360        assert_eq!(
361            Err(AUTDDriverError::ConfirmResponseFailed),
362            Sender::wait_msg_processed(
363                sender.link,
364                sender.msg_id,
365                sender.rx,
366                sender.sent_flags,
367                &sender.option,
368                &sender.timer_strategy,
369                Duration::from_millis(10),
370                true
371            )
372        );
373
374        sender.link.recv_cnt = 0;
375        sender.link.is_open = true;
376        sender.link.down = true;
377        assert_eq!(
378            Err(AUTDDriverError::ConfirmResponseFailed),
379            Sender::wait_msg_processed(
380                sender.link,
381                sender.msg_id,
382                sender.rx,
383                sender.sent_flags,
384                &sender.option,
385                &sender.timer_strategy,
386                Duration::ZERO,
387                true
388            )
389        );
390
391        sender.link.recv_cnt = 0;
392        sender.link.is_open = true;
393        sender.link.down = true;
394        assert_eq!(
395            Ok(()),
396            Sender::wait_msg_processed(
397                sender.link,
398                sender.msg_id,
399                sender.rx,
400                sender.sent_flags,
401                &sender.option,
402                &sender.timer_strategy,
403                Duration::ZERO,
404                false
405            )
406        );
407
408        sender.link.down = false;
409        sender.link.recv_cnt = 0;
410        *sender.msg_id = MsgId::new(20);
411        assert_eq!(
412            Err(AUTDDriverError::Link(LinkError::new("too many"))),
413            Sender::wait_msg_processed(
414                sender.link,
415                sender.msg_id,
416                sender.rx,
417                sender.sent_flags,
418                &sender.option,
419                &sender.timer_strategy,
420                Duration::from_secs(10),
421                true
422            )
423        );
424    }
425}