1use std::time::{Duration, Instant};
2
3use crate::{
4 datagram::FirmwareVersionType,
5 error::AUTDDriverError,
6 firmware::{
7 driver::{Driver, Operation, OperationHandler, SenderOption, TimerStrategy},
8 v12::{V12, operation::OperationGenerator},
9 },
10};
11
12use autd3_core::{
13 datagram::{Datagram, DeviceFilter},
14 geometry::Geometry,
15 link::{Link, MsgId, RxMessage, TxMessage},
16 sleep::Sleep,
17};
18
19pub struct Sender<'a, L: Link, S: Sleep, T: TimerStrategy<S>> {
21 pub(crate) msg_id: &'a mut MsgId,
22 pub(crate) link: &'a mut L,
23 pub(crate) geometry: &'a Geometry,
24 pub(crate) sent_flags: &'a mut [bool],
25 pub(crate) rx: &'a mut [RxMessage],
26 pub(crate) option: SenderOption,
27 pub(crate) timer_strategy: T,
28 pub(crate) _phantom: std::marker::PhantomData<S>,
29}
30
31impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
32 pub fn send<D: Datagram>(&mut self, s: D) -> Result<(), AUTDDriverError>
34 where
35 AUTDDriverError: From<D::Error>,
36 D::G: OperationGenerator,
37 AUTDDriverError: From<<<D::G as OperationGenerator>::O1 as Operation>::Error>
38 + From<<<D::G as OperationGenerator>::O2 as Operation>::Error>,
39 {
40 let timeout = self.option.timeout.unwrap_or(s.option().timeout);
41 let parallel_threshold = s.option().parallel_threshold;
42 let strict = self.option.strict;
43
44 let mut g = s.operation_generator(
45 self.geometry,
46 &DeviceFilter::all_enabled(),
47 &V12.firmware_limits(),
48 )?;
49 let mut operations = self
50 .geometry
51 .iter()
52 .map(|dev| g.generate(dev))
53 .collect::<Vec<_>>();
54
55 operations
56 .iter()
57 .zip(self.sent_flags.iter_mut())
58 .for_each(|(op, flag)| {
59 *flag = op.is_some();
60 });
61
62 let num_enabled = self.sent_flags.iter().filter(|x| **x).count();
63 let parallel = self
64 .option
65 .parallel
66 .is_parallel(num_enabled, parallel_threshold);
67
68 self.link.ensure_is_open()?;
69 self.link.update(self.geometry)?;
70
71 let mut send_timing = self.timer_strategy.initial();
72 loop {
73 let mut tx = self.link.alloc_tx_buffer()?;
74
75 self.msg_id.increment();
76 OperationHandler::pack(
77 *self.msg_id,
78 &mut operations,
79 self.geometry,
80 &mut tx,
81 parallel,
82 )?;
83
84 self.send_receive(tx, timeout, strict)?;
85
86 if OperationHandler::is_done(&operations) {
87 return Ok(());
88 }
89
90 send_timing = self
91 .timer_strategy
92 .sleep(send_timing, self.option.send_interval);
93 }
94 }
95}
96
97impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
98 fn send_receive(
99 &mut self,
100 tx: Vec<TxMessage>,
101 timeout: Duration,
102 strict: bool,
103 ) -> Result<(), AUTDDriverError> {
104 self.link.ensure_is_open()?;
105 self.link.send(tx)?;
106 Self::wait_msg_processed(
107 self.link,
108 self.msg_id,
109 self.rx,
110 self.sent_flags,
111 &self.option,
112 &self.timer_strategy,
113 timeout,
114 strict,
115 )
116 }
117
118 #[allow(clippy::too_many_arguments)]
119 pub(crate) fn wait_msg_processed(
120 link: &mut L,
121 msg_id: &mut MsgId,
122 rx: &mut [RxMessage],
123 sent_flags: &mut [bool],
124 option: &SenderOption,
125 timer_strategy: &T,
126 timeout: Duration,
127 strict: bool,
128 ) -> Result<(), AUTDDriverError> {
129 let start = Instant::now();
130 let mut receive_timing = timer_strategy.initial();
131 loop {
132 link.ensure_is_open()?;
133 link.receive(rx)?;
134
135 if crate::firmware::v12::cpu::check_if_msg_is_processed(*msg_id, rx)
136 .zip(sent_flags.iter())
137 .filter_map(|(r, sent)| sent.then_some(r))
138 .all(std::convert::identity)
139 {
140 break;
141 }
142
143 if start.elapsed() > timeout {
144 return if !strict && timeout == Duration::ZERO {
145 Ok(())
146 } else {
147 Err(AUTDDriverError::ConfirmResponseFailed)
148 };
149 }
150
151 receive_timing = timer_strategy.sleep(receive_timing, option.receive_interval);
152 }
153
154 rx.iter().try_fold((), |_, r| {
155 crate::firmware::v12::cpu::check_firmware_err(r.ack())
156 })
157 }
158
159 pub(crate) fn fetch_firminfo(
160 &mut self,
161 ty: FirmwareVersionType,
162 ) -> Result<Vec<u8>, AUTDDriverError> {
163 self.send(ty).map_err(|_| {
164 AUTDDriverError::ReadFirmwareVersionFailed(
165 crate::firmware::v12::cpu::check_if_msg_is_processed(*self.msg_id, self.rx)
166 .collect(),
167 )
168 })?;
169 Ok(self.rx.iter().map(|rx| rx.data()).collect())
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 use crate::firmware::driver::{FixedSchedule, ParallelMode};
178 use autd3_core::{
179 link::{Ack, LinkError, TxBufferPoolSync},
180 sleep::{Sleep, SpinSleeper, SpinWaitSleeper, StdSleeper},
181 };
182
183 #[derive(Default)]
184 struct MockLink {
185 pub is_open: bool,
186 pub send_cnt: usize,
187 pub recv_cnt: usize,
188 pub down: bool,
189 pub buffer_pool: TxBufferPoolSync,
190 }
191
192 impl Link for MockLink {
193 fn open(&mut self, geometry: &Geometry) -> Result<(), LinkError> {
194 self.is_open = true;
195 self.buffer_pool.init(geometry);
196 Ok(())
197 }
198
199 fn close(&mut self) -> Result<(), LinkError> {
200 self.is_open = false;
201 Ok(())
202 }
203
204 fn alloc_tx_buffer(&mut self) -> Result<Vec<TxMessage>, LinkError> {
205 Ok(self.buffer_pool.borrow())
206 }
207
208 fn send(&mut self, tx: Vec<TxMessage>) -> Result<(), LinkError> {
209 if !self.down {
210 self.send_cnt += 1;
211 }
212 self.buffer_pool.return_buffer(tx);
213 Ok(())
214 }
215
216 fn receive(&mut self, rx: &mut [RxMessage]) -> Result<(), LinkError> {
217 if self.recv_cnt > 10 {
218 return Err(LinkError::new("too many"));
219 }
220
221 if !self.down {
222 self.recv_cnt += 1;
223 }
224 rx.iter_mut().for_each(|r| {
225 *r = RxMessage::new(r.data(), Ack::new().with_msg_id(self.recv_cnt as u8))
226 });
227
228 Ok(())
229 }
230
231 fn is_open(&self) -> bool {
232 self.is_open
233 }
234 }
235
236 #[test]
237 fn test_close() -> anyhow::Result<()> {
238 let mut link = MockLink::default();
239 link.open(&Geometry::new(Vec::new()))?;
240
241 assert!(link.is_open());
242
243 link.close()?;
244
245 assert!(!link.is_open());
246
247 Ok(())
248 }
249
250 #[rstest::rstest]
251 #[case(StdSleeper)]
252 #[case(SpinSleeper::default())]
253 #[case(SpinWaitSleeper)]
254 #[test]
255 fn test_send_receive(#[case] sleeper: impl Sleep) {
256 let mut link = MockLink::default();
257 let mut geometry = crate::autd3_device::tests::create_geometry(1);
258 let mut sent_flags = vec![false; 1];
259 let mut rx = Vec::new();
260 let mut msg_id = MsgId::new(0);
261
262 assert!(link.open(&geometry).is_ok());
263 let mut sender = Sender {
264 msg_id: &mut msg_id,
265 link: &mut link,
266 geometry: &mut geometry,
267 sent_flags: &mut sent_flags,
268 rx: &mut rx,
269 option: SenderOption {
270 send_interval: Duration::from_millis(1),
271 receive_interval: Duration::from_millis(1),
272 timeout: None,
273 parallel: ParallelMode::Auto,
274 strict: true,
275 },
276 timer_strategy: FixedSchedule(sleeper),
277 _phantom: std::marker::PhantomData,
278 };
279
280 let tx = sender.link.alloc_tx_buffer().unwrap();
281 assert_eq!(Ok(()), sender.send_receive(tx, Duration::ZERO, true));
282
283 let tx = sender.link.alloc_tx_buffer().unwrap();
284 assert_eq!(
285 Ok(()),
286 sender.send_receive(tx, Duration::from_millis(1), true)
287 );
288
289 sender.link.is_open = false;
290 let tx = sender.link.alloc_tx_buffer().unwrap();
291 assert_eq!(
292 Err(AUTDDriverError::Link(LinkError::closed())),
293 sender.send_receive(tx, Duration::ZERO, true),
294 );
295 }
296
297 #[rstest::rstest]
298 #[case(StdSleeper)]
299 #[case(SpinSleeper::default())]
300 #[case(SpinWaitSleeper)]
301 #[test]
302 fn test_wait_msg_processed(#[case] sleeper: impl Sleep) {
303 let mut link = MockLink::default();
304 let mut geometry = crate::autd3_device::tests::create_geometry(1);
305 let mut sent_flags = vec![true; 1];
306 let mut rx = vec![RxMessage::new(0, Ack::new())];
307 let mut msg_id = MsgId::new(1);
308
309 assert!(link.open(&geometry).is_ok());
310 let sender = Sender {
311 msg_id: &mut msg_id,
312 link: &mut link,
313 geometry: &mut geometry,
314 sent_flags: &mut sent_flags,
315 rx: &mut rx,
316 option: SenderOption {
317 send_interval: Duration::from_millis(1),
318 receive_interval: Duration::from_millis(1),
319 timeout: None,
320 parallel: ParallelMode::Auto,
321 strict: true,
322 },
323 timer_strategy: FixedSchedule(sleeper),
324 _phantom: std::marker::PhantomData,
325 };
326
327 assert_eq!(
328 Ok(()),
329 Sender::wait_msg_processed(
330 sender.link,
331 sender.msg_id,
332 sender.rx,
333 sender.sent_flags,
334 &sender.option,
335 &sender.timer_strategy,
336 Duration::from_millis(10),
337 true
338 )
339 );
340
341 sender.link.recv_cnt = 0;
342 sender.link.is_open = false;
343 assert_eq!(
344 Err(AUTDDriverError::Link(LinkError::closed())),
345 Sender::wait_msg_processed(
346 sender.link,
347 sender.msg_id,
348 sender.rx,
349 sender.sent_flags,
350 &sender.option,
351 &sender.timer_strategy,
352 Duration::from_millis(10),
353 true
354 )
355 );
356
357 sender.link.recv_cnt = 0;
358 sender.link.is_open = true;
359 sender.link.down = true;
360 assert_eq!(
361 Err(AUTDDriverError::ConfirmResponseFailed),
362 Sender::wait_msg_processed(
363 sender.link,
364 sender.msg_id,
365 sender.rx,
366 sender.sent_flags,
367 &sender.option,
368 &sender.timer_strategy,
369 Duration::from_millis(10),
370 true
371 )
372 );
373
374 sender.link.recv_cnt = 0;
375 sender.link.is_open = true;
376 sender.link.down = true;
377 assert_eq!(
378 Err(AUTDDriverError::ConfirmResponseFailed),
379 Sender::wait_msg_processed(
380 sender.link,
381 sender.msg_id,
382 sender.rx,
383 sender.sent_flags,
384 &sender.option,
385 &sender.timer_strategy,
386 Duration::ZERO,
387 true
388 )
389 );
390
391 sender.link.recv_cnt = 0;
392 sender.link.is_open = true;
393 sender.link.down = true;
394 assert_eq!(
395 Ok(()),
396 Sender::wait_msg_processed(
397 sender.link,
398 sender.msg_id,
399 sender.rx,
400 sender.sent_flags,
401 &sender.option,
402 &sender.timer_strategy,
403 Duration::ZERO,
404 false
405 )
406 );
407
408 sender.link.down = false;
409 sender.link.recv_cnt = 0;
410 *sender.msg_id = MsgId::new(20);
411 assert_eq!(
412 Err(AUTDDriverError::Link(LinkError::new("too many"))),
413 Sender::wait_msg_processed(
414 sender.link,
415 sender.msg_id,
416 sender.rx,
417 sender.sent_flags,
418 &sender.option,
419 &sender.timer_strategy,
420 Duration::from_secs(10),
421 true
422 )
423 );
424 }
425}