1use std::time::{Duration, Instant};
2
3use crate::{
4 datagram::FirmwareVersionType,
5 error::AUTDDriverError,
6 firmware::{
7 driver::{Driver, Operation, OperationHandler, SenderOption, TimerStrategy},
8 v10::{V10, operation::OperationGenerator},
9 },
10};
11
12use autd3_core::{
13 datagram::{Datagram, DeviceFilter},
14 environment::Environment,
15 geometry::Geometry,
16 link::{Link, MsgId, RxMessage, TxMessage},
17 sleep::Sleep,
18};
19
20pub struct Sender<'a, L, S, T> {
22 pub(crate) msg_id: &'a mut MsgId,
23 pub(crate) link: &'a mut L,
24 pub(crate) geometry: &'a Geometry,
25 pub(crate) sent_flags: &'a mut [bool],
26 pub(crate) rx: &'a mut [RxMessage],
27 pub(crate) env: &'a Environment,
28 pub(crate) option: SenderOption,
29 pub(crate) timer_strategy: T,
30 pub(crate) _phantom: std::marker::PhantomData<S>,
31}
32
33impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
34 pub fn send<D: Datagram>(&mut self, s: D) -> Result<(), AUTDDriverError>
36 where
37 AUTDDriverError: From<D::Error>,
38 D::G: OperationGenerator,
39 AUTDDriverError: From<<<D::G as OperationGenerator>::O1 as Operation>::Error>
40 + From<<<D::G as OperationGenerator>::O2 as Operation>::Error>,
41 {
42 let timeout = self.option.timeout.unwrap_or(s.option().timeout);
43 let parallel_threshold = s.option().parallel_threshold;
44
45 let mut g = s.operation_generator(
46 self.geometry,
47 self.env,
48 &DeviceFilter::all_enabled(),
49 &V10.firmware_limits(),
50 )?;
51 let mut operations = self
52 .geometry
53 .iter()
54 .map(|dev| g.generate(dev))
55 .collect::<Vec<_>>();
56
57 self.send_impl(timeout, parallel_threshold, &mut operations)
58 }
59}
60
61impl<'a, L: Link, S: Sleep, T: TimerStrategy<S>> Sender<'a, L, S, T> {
62 pub(crate) fn send_impl<O1, O2>(
63 &mut self,
64 timeout: Duration,
65 parallel_threshold: usize,
66 operations: &mut [Option<(O1, O2)>],
67 ) -> Result<(), AUTDDriverError>
68 where
69 O1: Operation,
70 O2: Operation,
71 AUTDDriverError: From<O1::Error> + From<O2::Error>,
72 {
73 let strict = self.option.strict;
74
75 operations
76 .iter()
77 .zip(self.sent_flags.iter_mut())
78 .for_each(|(op, flag)| {
79 *flag = op.is_some();
80 });
81
82 let num_enabled = self.sent_flags.iter().filter(|x| **x).count();
83 let parallel = self
84 .option
85 .parallel
86 .is_parallel(num_enabled, parallel_threshold);
87
88 self.link.ensure_is_open()?;
89 self.link.update(self.geometry)?;
90
91 let mut send_timing = self.timer_strategy.initial();
92 loop {
93 let mut tx = self.link.alloc_tx_buffer()?;
94
95 self.msg_id.increment();
96 OperationHandler::pack(*self.msg_id, operations, self.geometry, &mut tx, parallel)?;
97
98 self.send_receive(tx, timeout, strict)?;
99
100 if OperationHandler::is_done(operations) {
101 return Ok(());
102 }
103
104 send_timing = self
105 .timer_strategy
106 .sleep(send_timing, self.option.send_interval);
107 }
108 }
109
110 fn send_receive(
111 &mut self,
112 tx: Vec<TxMessage>,
113 timeout: Duration,
114 strict: bool,
115 ) -> Result<(), AUTDDriverError> {
116 self.link.ensure_is_open()?;
117 self.link.send(tx)?;
118 self.wait_msg_processed(timeout, strict)
119 }
120
121 fn wait_msg_processed(
122 &mut self,
123 timeout: Duration,
124 strict: bool,
125 ) -> Result<(), AUTDDriverError> {
126 let start = Instant::now();
127 let mut receive_timing = self.timer_strategy.initial();
128 loop {
129 self.link.ensure_is_open()?;
130 self.link.receive(self.rx)?;
131
132 if crate::firmware::v10::cpu::check_if_msg_is_processed(*self.msg_id, self.rx)
133 .zip(self.sent_flags.iter())
134 .filter_map(|(r, sent)| sent.then_some(r))
135 .all(std::convert::identity)
136 {
137 return Ok(());
138 }
139
140 if start.elapsed() > timeout {
141 break;
142 }
143
144 receive_timing = self
145 .timer_strategy
146 .sleep(receive_timing, self.option.receive_interval);
147 }
148
149 self.rx
150 .iter()
151 .try_fold((), |_, r| {
152 crate::firmware::v10::cpu::check_firmware_err(r.ack())
153 })
154 .and_then(|_| {
155 if !strict && timeout == Duration::ZERO {
156 Ok(())
157 } else {
158 Err(AUTDDriverError::ConfirmResponseFailed)
159 }
160 })
161 }
162
163 pub(crate) fn fetch_firminfo(
164 &mut self,
165 ty: FirmwareVersionType,
166 ) -> Result<Vec<u8>, AUTDDriverError> {
167 self.send(ty).map_err(|_| {
168 AUTDDriverError::ReadFirmwareVersionFailed(
169 crate::firmware::v10::cpu::check_if_msg_is_processed(*self.msg_id, self.rx)
170 .collect(),
171 )
172 })?;
173 Ok(self.rx.iter().map(|rx| rx.data()).collect())
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 use crate::firmware::driver::{FixedSchedule, ParallelMode};
182 use autd3_core::{
183 link::{Ack, LinkError, TxBufferPoolSync},
184 sleep::{Sleep, SpinSleeper, SpinWaitSleeper, StdSleeper},
185 };
186
187 #[derive(Default)]
188 struct MockLink {
189 pub is_open: bool,
190 pub send_cnt: usize,
191 pub recv_cnt: usize,
192 pub down: bool,
193 pub buffer_pool: TxBufferPoolSync,
194 }
195
196 impl Link for MockLink {
197 fn open(&mut self, geometry: &Geometry) -> Result<(), LinkError> {
198 self.is_open = true;
199 self.buffer_pool.init(geometry);
200 Ok(())
201 }
202
203 fn close(&mut self) -> Result<(), LinkError> {
204 self.is_open = false;
205 Ok(())
206 }
207
208 fn alloc_tx_buffer(&mut self) -> Result<Vec<TxMessage>, LinkError> {
209 Ok(self.buffer_pool.borrow())
210 }
211
212 fn send(&mut self, tx: Vec<TxMessage>) -> Result<(), LinkError> {
213 if !self.down {
214 self.send_cnt += 1;
215 }
216 self.buffer_pool.return_buffer(tx);
217 Ok(())
218 }
219
220 fn receive(&mut self, rx: &mut [RxMessage]) -> Result<(), LinkError> {
221 if self.recv_cnt > 10 {
222 return Err(LinkError::new("too many"));
223 }
224
225 if !self.down {
226 self.recv_cnt += 1;
227 }
228 rx.iter_mut().for_each(|r| {
229 *r = RxMessage::new(r.data(), Ack::new().with_msg_id(self.recv_cnt as u8))
230 });
231
232 Ok(())
233 }
234
235 fn is_open(&self) -> bool {
236 self.is_open
237 }
238 }
239
240 #[test]
241 fn test_close() -> anyhow::Result<()> {
242 let mut link = MockLink::default();
243 link.open(&Geometry::new(Vec::new()))?;
244
245 assert!(link.is_open());
246
247 link.close()?;
248
249 assert!(!link.is_open());
250
251 Ok(())
252 }
253
254 #[rstest::rstest]
255 #[case(StdSleeper)]
256 #[case(SpinSleeper::default())]
257 #[case(SpinWaitSleeper)]
258 #[test]
259 fn test_send_receive(#[case] sleeper: impl Sleep) {
260 let mut link = MockLink::default();
261 let mut geometry = crate::autd3_device::tests::create_geometry(1);
262 let mut sent_flags = vec![false; 1];
263 let mut rx = Vec::new();
264 let mut msg_id = MsgId::new(0);
265
266 assert!(link.open(&geometry).is_ok());
267 let mut sender = Sender {
268 msg_id: &mut msg_id,
269 link: &mut link,
270 geometry: &mut geometry,
271 sent_flags: &mut sent_flags,
272 rx: &mut rx,
273 env: &Environment::new(),
274 option: SenderOption {
275 send_interval: Duration::from_millis(1),
276 receive_interval: Duration::from_millis(1),
277 timeout: None,
278 parallel: ParallelMode::Auto,
279 strict: true,
280 },
281 timer_strategy: FixedSchedule(sleeper),
282 _phantom: std::marker::PhantomData,
283 };
284
285 let tx = sender.link.alloc_tx_buffer().unwrap();
286 assert_eq!(Ok(()), sender.send_receive(tx, Duration::ZERO, true));
287
288 let tx = sender.link.alloc_tx_buffer().unwrap();
289 assert_eq!(
290 Ok(()),
291 sender.send_receive(tx, Duration::from_millis(1), true)
292 );
293
294 sender.link.is_open = false;
295 let tx = sender.link.alloc_tx_buffer().unwrap();
296 assert_eq!(
297 Err(AUTDDriverError::Link(LinkError::closed())),
298 sender.send_receive(tx, Duration::ZERO, true),
299 );
300 }
301
302 #[rstest::rstest]
303 #[case(StdSleeper)]
304 #[case(SpinSleeper::default())]
305 #[case(SpinWaitSleeper)]
306 #[test]
307 fn test_wait_msg_processed<S: Sleep>(#[case] sleeper: S) {
308 let mut link = MockLink::default();
309 let mut geometry = crate::autd3_device::tests::create_geometry(1);
310 let mut sent_flags = vec![true; 1];
311 let mut rx = vec![RxMessage::new(0, Ack::new())];
312 let mut msg_id = MsgId::new(1);
313
314 assert!(link.open(&geometry).is_ok());
315 let mut sender = Sender {
316 msg_id: &mut msg_id,
317 link: &mut link,
318 geometry: &mut geometry,
319 sent_flags: &mut sent_flags,
320 rx: &mut rx,
321 env: &Environment::new(),
322 option: SenderOption {
323 send_interval: Duration::from_millis(1),
324 receive_interval: Duration::from_millis(1),
325 timeout: None,
326 parallel: ParallelMode::Auto,
327 strict: true,
328 },
329 timer_strategy: FixedSchedule(sleeper),
330 _phantom: std::marker::PhantomData::<S>,
331 };
332
333 assert_eq!(
334 Ok(()),
335 sender.wait_msg_processed(Duration::from_millis(10), true)
336 );
337
338 sender.link.recv_cnt = 0;
339 sender.link.is_open = false;
340 assert_eq!(
341 Err(AUTDDriverError::Link(LinkError::closed())),
342 sender.wait_msg_processed(Duration::from_millis(10), true)
343 );
344
345 sender.link.recv_cnt = 0;
346 sender.link.is_open = true;
347 sender.link.down = true;
348 assert_eq!(
349 Err(AUTDDriverError::ConfirmResponseFailed),
350 sender.wait_msg_processed(Duration::ZERO, true)
351 );
352
353 sender.link.recv_cnt = 0;
354 sender.link.is_open = true;
355 sender.link.down = true;
356 assert_eq!(
357 Err(AUTDDriverError::ConfirmResponseFailed),
358 sender.wait_msg_processed(Duration::from_secs(1), true)
359 );
360
361 sender.link.recv_cnt = 0;
362 sender.link.is_open = true;
363 sender.link.down = true;
364 assert_eq!(Ok(()), sender.wait_msg_processed(Duration::ZERO, false));
365
366 sender.link.down = false;
367 sender.link.recv_cnt = 0;
368 *sender.msg_id = MsgId::new(20);
369 assert_eq!(
370 Err(AUTDDriverError::Link(LinkError::new("too many"))),
371 sender.wait_msg_processed(Duration::from_secs(1), true)
372 );
373 }
374}