autd3_driver/firmware/v10/transmission/
sender.rs

1use std::time::{Duration, Instant};
2
3use crate::{
4    datagram::FirmwareVersionType,
5    error::AUTDDriverError,
6    firmware::{
7        driver::{Driver, Operation, OperationHandler, SenderOption, TimerStrategy},
8        v10::{V10, operation::OperationGenerator},
9    },
10};
11
12use autd3_core::{
13    datagram::{Datagram, DeviceFilter},
14    geometry::Geometry,
15    link::{Link, MsgId, RxMessage, TxMessage},
16    sleep::Sleep,
17};
18
19/// A struct to send the [`Datagram`] to the devices.
20pub struct Sender<'a, L, S, T> {
21    pub(crate) msg_id: &'a mut MsgId,
22    pub(crate) link: &'a mut L,
23    pub(crate) geometry: &'a Geometry,
24    pub(crate) sent_flags: &'a mut [bool],
25    pub(crate) rx: &'a mut [RxMessage],
26    pub(crate) option: SenderOption,
27    pub(crate) timer_strategy: T,
28    pub(crate) _phantom: std::marker::PhantomData<S>,
29}
30
31impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
32    /// Send the [`Datagram`] to the devices.
33    pub fn send<D: Datagram>(&mut self, s: D) -> Result<(), AUTDDriverError>
34    where
35        AUTDDriverError: From<D::Error>,
36        D::G: OperationGenerator,
37        AUTDDriverError: From<<<D::G as OperationGenerator>::O1 as Operation>::Error>
38            + From<<<D::G as OperationGenerator>::O2 as Operation>::Error>,
39    {
40        let timeout = self.option.timeout.unwrap_or(s.option().timeout);
41        let parallel_threshold = s.option().parallel_threshold;
42        let strict = self.option.strict;
43
44        let mut g = s.operation_generator(
45            self.geometry,
46            &DeviceFilter::all_enabled(),
47            &V10.firmware_limits(),
48        )?;
49        let mut operations = self
50            .geometry
51            .iter()
52            .map(|dev| g.generate(dev))
53            .collect::<Vec<_>>();
54
55        operations
56            .iter()
57            .zip(self.sent_flags.iter_mut())
58            .for_each(|(op, flag)| {
59                *flag = op.is_some();
60            });
61
62        let num_enabled = self.sent_flags.iter().filter(|x| **x).count();
63        let parallel = self
64            .option
65            .parallel
66            .is_parallel(num_enabled, parallel_threshold);
67
68        self.link.ensure_is_open()?;
69        self.link.update(self.geometry)?;
70
71        let mut send_timing = self.timer_strategy.initial();
72        loop {
73            let mut tx = self.link.alloc_tx_buffer()?;
74
75            self.msg_id.increment();
76            OperationHandler::pack(
77                *self.msg_id,
78                &mut operations,
79                self.geometry,
80                &mut tx,
81                parallel,
82            )?;
83
84            self.send_receive(tx, timeout, strict)?;
85
86            if OperationHandler::is_done(&operations) {
87                return Ok(());
88            }
89
90            send_timing = self
91                .timer_strategy
92                .sleep(send_timing, self.option.send_interval);
93        }
94    }
95}
96
97impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
98    fn send_receive(
99        &mut self,
100        tx: Vec<TxMessage>,
101        timeout: Duration,
102        strict: bool,
103    ) -> Result<(), AUTDDriverError> {
104        self.link.ensure_is_open()?;
105        self.link.send(tx)?;
106        Self::wait_msg_processed(
107            self.link,
108            self.msg_id,
109            self.rx,
110            self.sent_flags,
111            &self.option,
112            &self.timer_strategy,
113            timeout,
114            strict,
115        )
116    }
117
118    #[allow(clippy::too_many_arguments)]
119    pub(crate) fn wait_msg_processed(
120        link: &mut L,
121        msg_id: &mut MsgId,
122        rx: &mut [RxMessage],
123        sent_flags: &mut [bool],
124        option: &SenderOption,
125        timer_strategy: &T,
126        timeout: Duration,
127        strict: bool,
128    ) -> Result<(), AUTDDriverError> {
129        let start = Instant::now();
130        let mut receive_timing = timer_strategy.initial();
131        loop {
132            link.ensure_is_open()?;
133            link.receive(rx)?;
134
135            if crate::firmware::v10::cpu::check_if_msg_is_processed(*msg_id, rx)
136                .zip(sent_flags.iter())
137                .filter_map(|(r, sent)| sent.then_some(r))
138                .all(std::convert::identity)
139            {
140                return Ok(());
141            }
142
143            if start.elapsed() > timeout {
144                break;
145            }
146
147            receive_timing = timer_strategy.sleep(receive_timing, option.receive_interval);
148        }
149
150        rx.iter()
151            .try_fold((), |_, r| {
152                crate::firmware::v10::cpu::check_firmware_err(r.ack())
153            })
154            .and_then(|_| {
155                if !strict && timeout == Duration::ZERO {
156                    Ok(())
157                } else {
158                    Err(AUTDDriverError::ConfirmResponseFailed)
159                }
160            })
161    }
162
163    pub(crate) fn fetch_firminfo(
164        &mut self,
165        ty: FirmwareVersionType,
166    ) -> Result<Vec<u8>, AUTDDriverError> {
167        self.send(ty).map_err(|_| {
168            AUTDDriverError::ReadFirmwareVersionFailed(
169                crate::firmware::v10::cpu::check_if_msg_is_processed(*self.msg_id, self.rx)
170                    .collect(),
171            )
172        })?;
173        Ok(self.rx.iter().map(|rx| rx.data()).collect())
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    use crate::firmware::driver::{FixedSchedule, ParallelMode};
182    use autd3_core::{
183        link::{Ack, LinkError, TxBufferPoolSync},
184        sleep::{Sleep, SpinSleeper, SpinWaitSleeper, StdSleeper},
185    };
186
187    #[derive(Default)]
188    struct MockLink {
189        pub is_open: bool,
190        pub send_cnt: usize,
191        pub recv_cnt: usize,
192        pub down: bool,
193        pub buffer_pool: TxBufferPoolSync,
194    }
195
196    impl Link for MockLink {
197        fn open(&mut self, geometry: &Geometry) -> Result<(), LinkError> {
198            self.is_open = true;
199            self.buffer_pool.init(geometry);
200            Ok(())
201        }
202
203        fn close(&mut self) -> Result<(), LinkError> {
204            self.is_open = false;
205            Ok(())
206        }
207
208        fn alloc_tx_buffer(&mut self) -> Result<Vec<TxMessage>, LinkError> {
209            Ok(self.buffer_pool.borrow())
210        }
211
212        fn send(&mut self, tx: Vec<TxMessage>) -> Result<(), LinkError> {
213            if !self.down {
214                self.send_cnt += 1;
215            }
216            self.buffer_pool.return_buffer(tx);
217            Ok(())
218        }
219
220        fn receive(&mut self, rx: &mut [RxMessage]) -> Result<(), LinkError> {
221            if self.recv_cnt > 10 {
222                return Err(LinkError::new("too many"));
223            }
224
225            if !self.down {
226                self.recv_cnt += 1;
227            }
228            rx.iter_mut().for_each(|r| {
229                *r = RxMessage::new(r.data(), Ack::new().with_msg_id(self.recv_cnt as u8))
230            });
231
232            Ok(())
233        }
234
235        fn is_open(&self) -> bool {
236            self.is_open
237        }
238    }
239
240    #[test]
241    fn test_close() -> anyhow::Result<()> {
242        let mut link = MockLink::default();
243        link.open(&Geometry::new(Vec::new()))?;
244
245        assert!(link.is_open());
246
247        link.close()?;
248
249        assert!(!link.is_open());
250
251        Ok(())
252    }
253
254    #[rstest::rstest]
255    #[case(StdSleeper)]
256    #[case(SpinSleeper::default())]
257    #[case(SpinWaitSleeper)]
258    #[test]
259    fn test_send_receive(#[case] sleeper: impl Sleep) {
260        let mut link = MockLink::default();
261        let mut geometry = crate::autd3_device::tests::create_geometry(1);
262        let mut sent_flags = vec![false; 1];
263        let mut rx = Vec::new();
264        let mut msg_id = MsgId::new(0);
265
266        assert!(link.open(&geometry).is_ok());
267        let mut sender = Sender {
268            msg_id: &mut msg_id,
269            link: &mut link,
270            geometry: &mut geometry,
271            sent_flags: &mut sent_flags,
272            rx: &mut rx,
273            option: SenderOption {
274                send_interval: Duration::from_millis(1),
275                receive_interval: Duration::from_millis(1),
276                timeout: None,
277                parallel: ParallelMode::Auto,
278                strict: true,
279            },
280            timer_strategy: FixedSchedule(sleeper),
281            _phantom: std::marker::PhantomData,
282        };
283
284        let tx = sender.link.alloc_tx_buffer().unwrap();
285        assert_eq!(Ok(()), sender.send_receive(tx, Duration::ZERO, true));
286
287        let tx = sender.link.alloc_tx_buffer().unwrap();
288        assert_eq!(
289            Ok(()),
290            sender.send_receive(tx, Duration::from_millis(1), true)
291        );
292
293        sender.link.is_open = false;
294        let tx = sender.link.alloc_tx_buffer().unwrap();
295        assert_eq!(
296            Err(AUTDDriverError::Link(LinkError::closed())),
297            sender.send_receive(tx, Duration::ZERO, true),
298        );
299    }
300
301    #[rstest::rstest]
302    #[case(StdSleeper)]
303    #[case(SpinSleeper::default())]
304    #[case(SpinWaitSleeper)]
305    #[test]
306    fn test_wait_msg_processed<S: Sleep>(#[case] sleeper: S) {
307        let mut link = MockLink::default();
308        let mut geometry = crate::autd3_device::tests::create_geometry(1);
309        let mut sent_flags = vec![true; 1];
310        let mut rx = vec![RxMessage::new(0, Ack::new())];
311        let mut msg_id = MsgId::new(1);
312
313        assert!(link.open(&geometry).is_ok());
314        let sender = Sender {
315            msg_id: &mut msg_id,
316            link: &mut link,
317            geometry: &mut geometry,
318            sent_flags: &mut sent_flags,
319            rx: &mut rx,
320            option: SenderOption {
321                send_interval: Duration::from_millis(1),
322                receive_interval: Duration::from_millis(1),
323                timeout: None,
324                parallel: ParallelMode::Auto,
325                strict: true,
326            },
327            timer_strategy: FixedSchedule(sleeper),
328            _phantom: std::marker::PhantomData::<S>,
329        };
330
331        assert_eq!(
332            Ok(()),
333            Sender::wait_msg_processed(
334                sender.link,
335                sender.msg_id,
336                sender.rx,
337                sender.sent_flags,
338                &sender.option,
339                &sender.timer_strategy,
340                Duration::from_millis(10),
341                true
342            )
343        );
344
345        sender.link.recv_cnt = 0;
346        sender.link.is_open = false;
347        assert_eq!(
348            Err(AUTDDriverError::Link(LinkError::closed())),
349            Sender::wait_msg_processed(
350                sender.link,
351                sender.msg_id,
352                sender.rx,
353                sender.sent_flags,
354                &sender.option,
355                &sender.timer_strategy,
356                Duration::from_millis(10),
357                true
358            )
359        );
360
361        sender.link.recv_cnt = 0;
362        sender.link.is_open = true;
363        sender.link.down = true;
364        assert_eq!(
365            Err(AUTDDriverError::ConfirmResponseFailed),
366            Sender::wait_msg_processed(
367                sender.link,
368                sender.msg_id,
369                sender.rx,
370                sender.sent_flags,
371                &sender.option,
372                &sender.timer_strategy,
373                Duration::ZERO,
374                true
375            )
376        );
377
378        sender.link.recv_cnt = 0;
379        sender.link.is_open = true;
380        sender.link.down = true;
381        assert_eq!(
382            Err(AUTDDriverError::ConfirmResponseFailed),
383            Sender::wait_msg_processed(
384                sender.link,
385                sender.msg_id,
386                sender.rx,
387                sender.sent_flags,
388                &sender.option,
389                &sender.timer_strategy,
390                Duration::from_secs(10),
391                true
392            )
393        );
394
395        sender.link.recv_cnt = 0;
396        sender.link.is_open = true;
397        sender.link.down = true;
398        assert_eq!(
399            Ok(()),
400            Sender::wait_msg_processed(
401                sender.link,
402                sender.msg_id,
403                sender.rx,
404                sender.sent_flags,
405                &sender.option,
406                &sender.timer_strategy,
407                Duration::ZERO,
408                false
409            )
410        );
411
412        sender.link.down = false;
413        sender.link.recv_cnt = 0;
414        *sender.msg_id = MsgId::new(20);
415        assert_eq!(
416            Err(AUTDDriverError::Link(LinkError::new("too many"))),
417            Sender::wait_msg_processed(
418                sender.link,
419                sender.msg_id,
420                sender.rx,
421                sender.sent_flags,
422                &sender.option,
423                &sender.timer_strategy,
424                Duration::from_secs(10),
425                true
426            )
427        );
428    }
429}