1use std::time::{Duration, Instant};
2
3use crate::{
4 datagram::FirmwareVersionType,
5 error::AUTDDriverError,
6 firmware::{
7 driver::{Driver, Operation, OperationHandler, SenderOption, TimerStrategy},
8 v10::{V10, operation::OperationGenerator},
9 },
10};
11
12use autd3_core::{
13 datagram::{Datagram, DeviceFilter},
14 geometry::Geometry,
15 link::{Link, MsgId, RxMessage, TxMessage},
16 sleep::Sleep,
17};
18
19pub struct Sender<'a, L, S, T> {
21 pub(crate) msg_id: &'a mut MsgId,
22 pub(crate) link: &'a mut L,
23 pub(crate) geometry: &'a Geometry,
24 pub(crate) sent_flags: &'a mut [bool],
25 pub(crate) rx: &'a mut [RxMessage],
26 pub(crate) option: SenderOption,
27 pub(crate) timer_strategy: T,
28 pub(crate) _phantom: std::marker::PhantomData<S>,
29}
30
31impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
32 pub fn send<D: Datagram>(&mut self, s: D) -> Result<(), AUTDDriverError>
34 where
35 AUTDDriverError: From<D::Error>,
36 D::G: OperationGenerator,
37 AUTDDriverError: From<<<D::G as OperationGenerator>::O1 as Operation>::Error>
38 + From<<<D::G as OperationGenerator>::O2 as Operation>::Error>,
39 {
40 let timeout = self.option.timeout.unwrap_or(s.option().timeout);
41 let parallel_threshold = s.option().parallel_threshold;
42 let strict = self.option.strict;
43
44 let mut g = s.operation_generator(
45 self.geometry,
46 &DeviceFilter::all_enabled(),
47 &V10.firmware_limits(),
48 )?;
49 let mut operations = self
50 .geometry
51 .iter()
52 .map(|dev| g.generate(dev))
53 .collect::<Vec<_>>();
54
55 operations
56 .iter()
57 .zip(self.sent_flags.iter_mut())
58 .for_each(|(op, flag)| {
59 *flag = op.is_some();
60 });
61
62 let num_enabled = self.sent_flags.iter().filter(|x| **x).count();
63 let parallel = self
64 .option
65 .parallel
66 .is_parallel(num_enabled, parallel_threshold);
67
68 self.link.ensure_is_open()?;
69 self.link.update(self.geometry)?;
70
71 let mut send_timing = self.timer_strategy.initial();
72 loop {
73 let mut tx = self.link.alloc_tx_buffer()?;
74
75 self.msg_id.increment();
76 OperationHandler::pack(
77 *self.msg_id,
78 &mut operations,
79 self.geometry,
80 &mut tx,
81 parallel,
82 )?;
83
84 self.send_receive(tx, timeout, strict)?;
85
86 if OperationHandler::is_done(&operations) {
87 return Ok(());
88 }
89
90 send_timing = self
91 .timer_strategy
92 .sleep(send_timing, self.option.send_interval);
93 }
94 }
95}
96
97impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
98 fn send_receive(
99 &mut self,
100 tx: Vec<TxMessage>,
101 timeout: Duration,
102 strict: bool,
103 ) -> Result<(), AUTDDriverError> {
104 self.link.ensure_is_open()?;
105 self.link.send(tx)?;
106 Self::wait_msg_processed(
107 self.link,
108 self.msg_id,
109 self.rx,
110 self.sent_flags,
111 &self.option,
112 &self.timer_strategy,
113 timeout,
114 strict,
115 )
116 }
117
118 #[allow(clippy::too_many_arguments)]
119 pub(crate) fn wait_msg_processed(
120 link: &mut L,
121 msg_id: &mut MsgId,
122 rx: &mut [RxMessage],
123 sent_flags: &mut [bool],
124 option: &SenderOption,
125 timer_strategy: &T,
126 timeout: Duration,
127 strict: bool,
128 ) -> Result<(), AUTDDriverError> {
129 let start = Instant::now();
130 let mut receive_timing = timer_strategy.initial();
131 loop {
132 link.ensure_is_open()?;
133 link.receive(rx)?;
134
135 if crate::firmware::v10::cpu::check_if_msg_is_processed(*msg_id, rx)
136 .zip(sent_flags.iter())
137 .filter_map(|(r, sent)| sent.then_some(r))
138 .all(std::convert::identity)
139 {
140 return Ok(());
141 }
142
143 if start.elapsed() > timeout {
144 break;
145 }
146
147 receive_timing = timer_strategy.sleep(receive_timing, option.receive_interval);
148 }
149
150 rx.iter()
151 .try_fold((), |_, r| {
152 crate::firmware::v10::cpu::check_firmware_err(r.ack())
153 })
154 .and_then(|_| {
155 if !strict && timeout == Duration::ZERO {
156 Ok(())
157 } else {
158 Err(AUTDDriverError::ConfirmResponseFailed)
159 }
160 })
161 }
162
163 pub(crate) fn fetch_firminfo(
164 &mut self,
165 ty: FirmwareVersionType,
166 ) -> Result<Vec<u8>, AUTDDriverError> {
167 self.send(ty).map_err(|_| {
168 AUTDDriverError::ReadFirmwareVersionFailed(
169 crate::firmware::v10::cpu::check_if_msg_is_processed(*self.msg_id, self.rx)
170 .collect(),
171 )
172 })?;
173 Ok(self.rx.iter().map(|rx| rx.data()).collect())
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 use crate::firmware::driver::{FixedSchedule, ParallelMode};
182 use autd3_core::{
183 link::{Ack, LinkError, TxBufferPoolSync},
184 sleep::{Sleep, SpinSleeper, SpinWaitSleeper, StdSleeper},
185 };
186
187 #[derive(Default)]
188 struct MockLink {
189 pub is_open: bool,
190 pub send_cnt: usize,
191 pub recv_cnt: usize,
192 pub down: bool,
193 pub buffer_pool: TxBufferPoolSync,
194 }
195
196 impl Link for MockLink {
197 fn open(&mut self, geometry: &Geometry) -> Result<(), LinkError> {
198 self.is_open = true;
199 self.buffer_pool.init(geometry);
200 Ok(())
201 }
202
203 fn close(&mut self) -> Result<(), LinkError> {
204 self.is_open = false;
205 Ok(())
206 }
207
208 fn alloc_tx_buffer(&mut self) -> Result<Vec<TxMessage>, LinkError> {
209 Ok(self.buffer_pool.borrow())
210 }
211
212 fn send(&mut self, tx: Vec<TxMessage>) -> Result<(), LinkError> {
213 if !self.down {
214 self.send_cnt += 1;
215 }
216 self.buffer_pool.return_buffer(tx);
217 Ok(())
218 }
219
220 fn receive(&mut self, rx: &mut [RxMessage]) -> Result<(), LinkError> {
221 if self.recv_cnt > 10 {
222 return Err(LinkError::new("too many"));
223 }
224
225 if !self.down {
226 self.recv_cnt += 1;
227 }
228 rx.iter_mut().for_each(|r| {
229 *r = RxMessage::new(r.data(), Ack::new().with_msg_id(self.recv_cnt as u8))
230 });
231
232 Ok(())
233 }
234
235 fn is_open(&self) -> bool {
236 self.is_open
237 }
238 }
239
240 #[test]
241 fn test_close() -> anyhow::Result<()> {
242 let mut link = MockLink::default();
243 link.open(&Geometry::new(Vec::new()))?;
244
245 assert!(link.is_open());
246
247 link.close()?;
248
249 assert!(!link.is_open());
250
251 Ok(())
252 }
253
254 #[rstest::rstest]
255 #[case(StdSleeper)]
256 #[case(SpinSleeper::default())]
257 #[case(SpinWaitSleeper)]
258 #[test]
259 fn test_send_receive(#[case] sleeper: impl Sleep) {
260 let mut link = MockLink::default();
261 let mut geometry = crate::autd3_device::tests::create_geometry(1);
262 let mut sent_flags = vec![false; 1];
263 let mut rx = Vec::new();
264 let mut msg_id = MsgId::new(0);
265
266 assert!(link.open(&geometry).is_ok());
267 let mut sender = Sender {
268 msg_id: &mut msg_id,
269 link: &mut link,
270 geometry: &mut geometry,
271 sent_flags: &mut sent_flags,
272 rx: &mut rx,
273 option: SenderOption {
274 send_interval: Duration::from_millis(1),
275 receive_interval: Duration::from_millis(1),
276 timeout: None,
277 parallel: ParallelMode::Auto,
278 strict: true,
279 },
280 timer_strategy: FixedSchedule(sleeper),
281 _phantom: std::marker::PhantomData,
282 };
283
284 let tx = sender.link.alloc_tx_buffer().unwrap();
285 assert_eq!(Ok(()), sender.send_receive(tx, Duration::ZERO, true));
286
287 let tx = sender.link.alloc_tx_buffer().unwrap();
288 assert_eq!(
289 Ok(()),
290 sender.send_receive(tx, Duration::from_millis(1), true)
291 );
292
293 sender.link.is_open = false;
294 let tx = sender.link.alloc_tx_buffer().unwrap();
295 assert_eq!(
296 Err(AUTDDriverError::Link(LinkError::closed())),
297 sender.send_receive(tx, Duration::ZERO, true),
298 );
299 }
300
301 #[rstest::rstest]
302 #[case(StdSleeper)]
303 #[case(SpinSleeper::default())]
304 #[case(SpinWaitSleeper)]
305 #[test]
306 fn test_wait_msg_processed<S: Sleep>(#[case] sleeper: S) {
307 let mut link = MockLink::default();
308 let mut geometry = crate::autd3_device::tests::create_geometry(1);
309 let mut sent_flags = vec![true; 1];
310 let mut rx = vec![RxMessage::new(0, Ack::new())];
311 let mut msg_id = MsgId::new(1);
312
313 assert!(link.open(&geometry).is_ok());
314 let sender = Sender {
315 msg_id: &mut msg_id,
316 link: &mut link,
317 geometry: &mut geometry,
318 sent_flags: &mut sent_flags,
319 rx: &mut rx,
320 option: SenderOption {
321 send_interval: Duration::from_millis(1),
322 receive_interval: Duration::from_millis(1),
323 timeout: None,
324 parallel: ParallelMode::Auto,
325 strict: true,
326 },
327 timer_strategy: FixedSchedule(sleeper),
328 _phantom: std::marker::PhantomData::<S>,
329 };
330
331 assert_eq!(
332 Ok(()),
333 Sender::wait_msg_processed(
334 sender.link,
335 sender.msg_id,
336 sender.rx,
337 sender.sent_flags,
338 &sender.option,
339 &sender.timer_strategy,
340 Duration::from_millis(10),
341 true
342 )
343 );
344
345 sender.link.recv_cnt = 0;
346 sender.link.is_open = false;
347 assert_eq!(
348 Err(AUTDDriverError::Link(LinkError::closed())),
349 Sender::wait_msg_processed(
350 sender.link,
351 sender.msg_id,
352 sender.rx,
353 sender.sent_flags,
354 &sender.option,
355 &sender.timer_strategy,
356 Duration::from_millis(10),
357 true
358 )
359 );
360
361 sender.link.recv_cnt = 0;
362 sender.link.is_open = true;
363 sender.link.down = true;
364 assert_eq!(
365 Err(AUTDDriverError::ConfirmResponseFailed),
366 Sender::wait_msg_processed(
367 sender.link,
368 sender.msg_id,
369 sender.rx,
370 sender.sent_flags,
371 &sender.option,
372 &sender.timer_strategy,
373 Duration::ZERO,
374 true
375 )
376 );
377
378 sender.link.recv_cnt = 0;
379 sender.link.is_open = true;
380 sender.link.down = true;
381 assert_eq!(
382 Err(AUTDDriverError::ConfirmResponseFailed),
383 Sender::wait_msg_processed(
384 sender.link,
385 sender.msg_id,
386 sender.rx,
387 sender.sent_flags,
388 &sender.option,
389 &sender.timer_strategy,
390 Duration::from_secs(10),
391 true
392 )
393 );
394
395 sender.link.recv_cnt = 0;
396 sender.link.is_open = true;
397 sender.link.down = true;
398 assert_eq!(
399 Ok(()),
400 Sender::wait_msg_processed(
401 sender.link,
402 sender.msg_id,
403 sender.rx,
404 sender.sent_flags,
405 &sender.option,
406 &sender.timer_strategy,
407 Duration::ZERO,
408 false
409 )
410 );
411
412 sender.link.down = false;
413 sender.link.recv_cnt = 0;
414 *sender.msg_id = MsgId::new(20);
415 assert_eq!(
416 Err(AUTDDriverError::Link(LinkError::new("too many"))),
417 Sender::wait_msg_processed(
418 sender.link,
419 sender.msg_id,
420 sender.rx,
421 sender.sent_flags,
422 &sender.option,
423 &sender.timer_strategy,
424 Duration::from_secs(10),
425 true
426 )
427 );
428 }
429}