autd3_driver/firmware/v10/transmission/
sender.rs

1use std::time::{Duration, Instant};
2
3use crate::{
4    datagram::FirmwareVersionType,
5    error::AUTDDriverError,
6    firmware::{
7        driver::{Driver, Operation, OperationHandler, SenderOption, TimerStrategy},
8        v10::{V10, operation::OperationGenerator},
9    },
10};
11
12use autd3_core::{
13    datagram::{Datagram, DeviceFilter},
14    environment::Environment,
15    geometry::Geometry,
16    link::{Link, MsgId, RxMessage, TxMessage},
17    sleep::Sleep,
18};
19
20/// A struct to send the [`Datagram`] to the devices.
21pub struct Sender<'a, L, S, T> {
22    pub(crate) msg_id: &'a mut MsgId,
23    pub(crate) link: &'a mut L,
24    pub(crate) geometry: &'a Geometry,
25    pub(crate) sent_flags: &'a mut [bool],
26    pub(crate) rx: &'a mut [RxMessage],
27    pub(crate) env: &'a Environment,
28    pub(crate) option: SenderOption,
29    pub(crate) timer_strategy: T,
30    pub(crate) _phantom: std::marker::PhantomData<S>,
31}
32
33impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
34    /// Send the [`Datagram`] to the devices.
35    pub fn send<D: Datagram>(&mut self, s: D) -> Result<(), AUTDDriverError>
36    where
37        AUTDDriverError: From<D::Error>,
38        D::G: OperationGenerator,
39        AUTDDriverError: From<<<D::G as OperationGenerator>::O1 as Operation>::Error>
40            + From<<<D::G as OperationGenerator>::O2 as Operation>::Error>,
41    {
42        let timeout = self.option.timeout.unwrap_or(s.option().timeout);
43        let parallel_threshold = s.option().parallel_threshold;
44
45        let mut g = s.operation_generator(
46            self.geometry,
47            self.env,
48            &DeviceFilter::all_enabled(),
49            &V10.firmware_limits(),
50        )?;
51        let mut operations = self
52            .geometry
53            .iter()
54            .map(|dev| g.generate(dev))
55            .collect::<Vec<_>>();
56
57        self.send_impl(timeout, parallel_threshold, &mut operations)
58    }
59}
60
61impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
62    pub(crate) fn send_impl<O1, O2>(
63        &mut self,
64        timeout: Duration,
65        parallel_threshold: usize,
66        operations: &mut [Option<(O1, O2)>],
67    ) -> Result<(), AUTDDriverError>
68    where
69        O1: Operation,
70        O2: Operation,
71        AUTDDriverError: From<O1::Error> + From<O2::Error>,
72    {
73        let strict = self.option.strict;
74
75        operations
76            .iter()
77            .zip(self.sent_flags.iter_mut())
78            .for_each(|(op, flag)| {
79                *flag = op.is_some();
80            });
81
82        let num_enabled = self.sent_flags.iter().filter(|x| **x).count();
83        let parallel = self
84            .option
85            .parallel
86            .is_parallel(num_enabled, parallel_threshold);
87
88        self.link.ensure_is_open()?;
89        self.link.update(self.geometry)?;
90
91        let mut send_timing = self.timer_strategy.initial();
92        loop {
93            let mut tx = self.link.alloc_tx_buffer()?;
94
95            self.msg_id.increment();
96            OperationHandler::pack(*self.msg_id, operations, self.geometry, &mut tx, parallel)?;
97
98            self.send_receive(tx, timeout, strict)?;
99
100            if OperationHandler::is_done(operations) {
101                return Ok(());
102            }
103
104            send_timing = self
105                .timer_strategy
106                .sleep(send_timing, self.option.send_interval);
107        }
108    }
109
110    fn send_receive(
111        &mut self,
112        tx: Vec<TxMessage>,
113        timeout: Duration,
114        strict: bool,
115    ) -> Result<(), AUTDDriverError> {
116        self.link.ensure_is_open()?;
117        self.link.send(tx)?;
118        self.wait_msg_processed(timeout, strict)
119    }
120
121    fn wait_msg_processed(
122        &mut self,
123        timeout: Duration,
124        strict: bool,
125    ) -> Result<(), AUTDDriverError> {
126        let start = Instant::now();
127        let mut receive_timing = self.timer_strategy.initial();
128        loop {
129            self.link.ensure_is_open()?;
130            self.link.receive(self.rx)?;
131
132            if crate::firmware::v10::cpu::check_if_msg_is_processed(*self.msg_id, self.rx)
133                .zip(self.sent_flags.iter())
134                .filter_map(|(r, sent)| sent.then_some(r))
135                .all(std::convert::identity)
136            {
137                return Ok(());
138            }
139
140            if start.elapsed() > timeout {
141                break;
142            }
143
144            receive_timing = self
145                .timer_strategy
146                .sleep(receive_timing, self.option.receive_interval);
147        }
148
149        self.rx
150            .iter()
151            .try_fold((), |_, r| {
152                crate::firmware::v10::cpu::check_firmware_err(r.ack())
153            })
154            .and_then(|_| {
155                if !strict && timeout == Duration::ZERO {
156                    Ok(())
157                } else {
158                    Err(AUTDDriverError::ConfirmResponseFailed)
159                }
160            })
161    }
162
163    pub(crate) fn fetch_firminfo(
164        &mut self,
165        ty: FirmwareVersionType,
166    ) -> Result<Vec<u8>, AUTDDriverError> {
167        self.send(ty).map_err(|_| {
168            AUTDDriverError::ReadFirmwareVersionFailed(
169                crate::firmware::v10::cpu::check_if_msg_is_processed(*self.msg_id, self.rx)
170                    .collect(),
171            )
172        })?;
173        Ok(self.rx.iter().map(|rx| rx.data()).collect())
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    use crate::firmware::driver::{FixedSchedule, ParallelMode};
182    use autd3_core::{
183        link::{Ack, LinkError, TxBufferPoolSync},
184        sleep::{Sleep, SpinSleeper, SpinWaitSleeper, StdSleeper},
185    };
186
187    #[derive(Default)]
188    struct MockLink {
189        pub is_open: bool,
190        pub send_cnt: usize,
191        pub recv_cnt: usize,
192        pub down: bool,
193        pub buffer_pool: TxBufferPoolSync,
194    }
195
196    impl Link for MockLink {
197        fn open(&mut self, geometry: &Geometry) -> Result<(), LinkError> {
198            self.is_open = true;
199            self.buffer_pool.init(geometry);
200            Ok(())
201        }
202
203        fn close(&mut self) -> Result<(), LinkError> {
204            self.is_open = false;
205            Ok(())
206        }
207
208        fn alloc_tx_buffer(&mut self) -> Result<Vec<TxMessage>, LinkError> {
209            Ok(self.buffer_pool.borrow())
210        }
211
212        fn send(&mut self, tx: Vec<TxMessage>) -> Result<(), LinkError> {
213            if !self.down {
214                self.send_cnt += 1;
215            }
216            self.buffer_pool.return_buffer(tx);
217            Ok(())
218        }
219
220        fn receive(&mut self, rx: &mut [RxMessage]) -> Result<(), LinkError> {
221            if self.recv_cnt > 10 {
222                return Err(LinkError::new("too many"));
223            }
224
225            if !self.down {
226                self.recv_cnt += 1;
227            }
228            rx.iter_mut().for_each(|r| {
229                *r = RxMessage::new(r.data(), Ack::new().with_msg_id(self.recv_cnt as u8))
230            });
231
232            Ok(())
233        }
234
235        fn is_open(&self) -> bool {
236            self.is_open
237        }
238    }
239
240    #[test]
241    fn test_close() -> anyhow::Result<()> {
242        let mut link = MockLink::default();
243        link.open(&Geometry::new(Vec::new()))?;
244
245        assert!(link.is_open());
246
247        link.close()?;
248
249        assert!(!link.is_open());
250
251        Ok(())
252    }
253
254    #[rstest::rstest]
255    #[case(StdSleeper)]
256    #[case(SpinSleeper::default())]
257    #[case(SpinWaitSleeper)]
258    #[test]
259    fn test_send_receive(#[case] sleeper: impl Sleep) {
260        let mut link = MockLink::default();
261        let mut geometry = crate::autd3_device::tests::create_geometry(1);
262        let mut sent_flags = vec![false; 1];
263        let mut rx = Vec::new();
264        let mut msg_id = MsgId::new(0);
265
266        assert!(link.open(&geometry).is_ok());
267        let mut sender = Sender {
268            msg_id: &mut msg_id,
269            link: &mut link,
270            geometry: &mut geometry,
271            sent_flags: &mut sent_flags,
272            rx: &mut rx,
273            env: &Environment::new(),
274            option: SenderOption {
275                send_interval: Duration::from_millis(1),
276                receive_interval: Duration::from_millis(1),
277                timeout: None,
278                parallel: ParallelMode::Auto,
279                strict: true,
280            },
281            timer_strategy: FixedSchedule(sleeper),
282            _phantom: std::marker::PhantomData,
283        };
284
285        let tx = sender.link.alloc_tx_buffer().unwrap();
286        assert_eq!(Ok(()), sender.send_receive(tx, Duration::ZERO, true));
287
288        let tx = sender.link.alloc_tx_buffer().unwrap();
289        assert_eq!(
290            Ok(()),
291            sender.send_receive(tx, Duration::from_millis(1), true)
292        );
293
294        sender.link.is_open = false;
295        let tx = sender.link.alloc_tx_buffer().unwrap();
296        assert_eq!(
297            Err(AUTDDriverError::Link(LinkError::closed())),
298            sender.send_receive(tx, Duration::ZERO, true),
299        );
300    }
301
302    #[rstest::rstest]
303    #[case(StdSleeper)]
304    #[case(SpinSleeper::default())]
305    #[case(SpinWaitSleeper)]
306    #[test]
307    fn test_wait_msg_processed<S: Sleep>(#[case] sleeper: S) {
308        let mut link = MockLink::default();
309        let mut geometry = crate::autd3_device::tests::create_geometry(1);
310        let mut sent_flags = vec![true; 1];
311        let mut rx = vec![RxMessage::new(0, Ack::new())];
312        let mut msg_id = MsgId::new(1);
313
314        assert!(link.open(&geometry).is_ok());
315        let mut sender = Sender {
316            msg_id: &mut msg_id,
317            link: &mut link,
318            geometry: &mut geometry,
319            sent_flags: &mut sent_flags,
320            rx: &mut rx,
321            env: &Environment::new(),
322            option: SenderOption {
323                send_interval: Duration::from_millis(1),
324                receive_interval: Duration::from_millis(1),
325                timeout: None,
326                parallel: ParallelMode::Auto,
327                strict: true,
328            },
329            timer_strategy: FixedSchedule(sleeper),
330            _phantom: std::marker::PhantomData::<S>,
331        };
332
333        assert_eq!(
334            Ok(()),
335            sender.wait_msg_processed(Duration::from_millis(10), true)
336        );
337
338        sender.link.recv_cnt = 0;
339        sender.link.is_open = false;
340        assert_eq!(
341            Err(AUTDDriverError::Link(LinkError::closed())),
342            sender.wait_msg_processed(Duration::from_millis(10), true)
343        );
344
345        sender.link.recv_cnt = 0;
346        sender.link.is_open = true;
347        sender.link.down = true;
348        assert_eq!(
349            Err(AUTDDriverError::ConfirmResponseFailed),
350            sender.wait_msg_processed(Duration::ZERO, true)
351        );
352
353        sender.link.recv_cnt = 0;
354        sender.link.is_open = true;
355        sender.link.down = true;
356        assert_eq!(
357            Err(AUTDDriverError::ConfirmResponseFailed),
358            sender.wait_msg_processed(Duration::from_secs(1), true)
359        );
360
361        sender.link.recv_cnt = 0;
362        sender.link.is_open = true;
363        sender.link.down = true;
364        assert_eq!(Ok(()), sender.wait_msg_processed(Duration::ZERO, false));
365
366        sender.link.down = false;
367        sender.link.recv_cnt = 0;
368        *sender.msg_id = MsgId::new(20);
369        assert_eq!(
370            Err(AUTDDriverError::Link(LinkError::new("too many"))),
371            sender.wait_msg_processed(Duration::from_secs(1), true)
372        );
373    }
374}