autd/controller/
autd_logic.rs

1/*
2 * File: autd_logic.rs
3 * Project: controller
4 * Created Date: 30/12/2020
5 * Author: Shun Suzuki
6 * -----
7 * Last Modified: 08/03/2021
8 * Modified By: Shun Suzuki (suzuki@hapis.k.u-tokyo.ac.jp)
9 * -----
10 * Copyright (c) 2020 Hapis Lab. All rights reserved.
11 *
12 */
13
14use std::{error::Error, mem::size_of, ptr::copy_nonoverlapping};
15
16use crate::{
17    core::{configuration::Configuration, consts::*, *},
18    gain::Gain,
19    geometry::Geometry,
20    link::Link,
21    modulation::Modulation,
22    sequence::PointSequence,
23    Float,
24};
25
26use super::{GainPtr, ModPtr};
27
28pub struct AUTDLogic<L: Link> {
29    geometry: Geometry,
30    link: L,
31    rx_data: Vec<u8>,
32    seq_mode: bool,
33    pub(crate) silent_mode: bool,
34    config: Configuration,
35}
36
37impl<L: Link> AUTDLogic<L> {
38    pub fn geometry(&self) -> &Geometry {
39        &self.geometry
40    }
41}
42
43impl<L: Link> AUTDLogic<L> {
44    pub(crate) fn new(geometry: Geometry, link: L) -> Self {
45        Self {
46            geometry,
47            link,
48            rx_data: vec![],
49            seq_mode: false,
50            silent_mode: true,
51            config: Configuration::default(),
52        }
53    }
54
55    pub(crate) fn is_open(&self) -> bool {
56        self.link.is_open()
57    }
58
59    pub(crate) fn build_gain<G: Gain>(&self, g: &mut G) {
60        g.build(&self.geometry)
61    }
62
63    pub(crate) fn build_modulation<M: Modulation>(&self, m: &mut M) {
64        m.build(self.config)
65    }
66
67    pub(crate) fn build_gain_ptr(&self, g: &mut GainPtr) {
68        g.build(&self.geometry)
69    }
70
71    pub(crate) fn build_modulation_ptr(&self, m: &mut ModPtr) {
72        m.build(self.config)
73    }
74
75    pub(crate) fn send_gain_mod<G: Gain, M: Modulation>(
76        &mut self,
77        g: Option<&G>,
78        m: Option<&mut M>,
79    ) -> Result<u8, Box<dyn Error>> {
80        if g.is_some() {
81            self.seq_mode = false;
82        }
83
84        let dev_num = self.geometry.num_devices();
85        let (msg_id, body) = Self::make_body(g, m, dev_num, self.silent_mode, self.seq_mode);
86        self.link.send(&body)?;
87        Ok(msg_id)
88    }
89
90    pub(crate) fn send_gain_mod_ptr(
91        &mut self,
92        g: Option<GainPtr>,
93        m: Option<&mut ModPtr>,
94    ) -> Result<u8, Box<dyn Error>> {
95        if g.is_some() {
96            self.seq_mode = false;
97        }
98
99        let dev_num = self.geometry.num_devices();
100        let (msg_id, body) = Self::make_body_ptr(g, m, dev_num, self.silent_mode, self.seq_mode);
101        self.link.send(&body)?;
102        Ok(msg_id)
103    }
104
105    pub(crate) fn send_gain_mod_blocking<G: Gain, M: Modulation>(
106        &mut self,
107        g: Option<&G>,
108        m: Option<&mut M>,
109    ) -> Result<bool, Box<dyn Error>> {
110        let msg_id = self.send_gain_mod(g, m)?;
111        let dev_num = self.geometry.num_devices();
112        self.wait_msg_processed(dev_num, msg_id, 0xFF, 200)
113    }
114
115    pub(crate) fn send_seq_blocking(
116        &mut self,
117        seq: &mut PointSequence,
118    ) -> Result<bool, Box<dyn Error>> {
119        self.seq_mode = true;
120
121        let (msg_id, body) = Self::make_seq_body(seq, &self.geometry, self.silent_mode);
122        self.link.send(&body)?;
123
124        let dev_num = self.geometry.num_devices();
125        if *seq.sent() == seq.control_points().len() {
126            self.wait_msg_processed(dev_num, 0xC0, 0xE0, 2000)
127        } else {
128            self.wait_msg_processed(dev_num, msg_id, 0xFF, 200)
129        }
130    }
131
132    pub(crate) fn send_data(&mut self, data: &[u8]) -> Result<u8, Box<dyn Error>> {
133        let msg_id = data[0];
134        self.link.send(&data)?;
135        Ok(msg_id)
136    }
137
138    fn send_header_blocking(
139        &mut self,
140        command: CommandType,
141        max_trial: usize,
142    ) -> Result<bool, Box<dyn Error>> {
143        let header = RxGlobalHeader::new_with_cmd(command);
144        let dev_num = self.geometry.num_devices();
145        unsafe {
146            self.link.send(Self::convert_to_u8_slice(&header))?;
147        }
148        self.wait_msg_processed(dev_num, command as u8, 0xFF, max_trial)
149    }
150
151    pub(crate) fn calibrate(&mut self, config: Configuration) -> Result<bool, Box<dyn Error>> {
152        self.config = config;
153        self.send_header_blocking(CommandType::CmdInitRefClock, 5000)
154    }
155
156    pub(crate) fn calibrate_seq(&mut self) -> Result<(), Box<dyn Error>> {
157        let rx_data = &self.rx_data;
158        let mut laps = Vec::with_capacity(rx_data.len() / 2);
159        for j in 0..laps.capacity() {
160            let lap_raw = ((rx_data[2 * j + 1] as u16) << 8) | rx_data[2 * j] as u16;
161            laps.push(lap_raw & 0x03FF);
162        }
163        let minimum = laps.iter().min().unwrap();
164        let diffs = laps.iter().map(|&d| d - minimum).collect::<Vec<_>>();
165        let diff_max = *diffs.iter().max().unwrap();
166        let diffs: Vec<u16> = if diff_max == 0 {
167            return Ok(());
168        } else if diff_max > 500 {
169            let laps = laps
170                .iter()
171                .map(|&d| if d < 500 { d + 1000 } else { d })
172                .collect::<Vec<_>>();
173            let minimum = laps.iter().min().unwrap();
174            laps.iter().map(|d| d - minimum).collect()
175        } else {
176            diffs
177        };
178
179        let dev_num = diffs.len();
180        let calib_body = Self::make_calib_body(diffs);
181        self.link.send(&calib_body)?;
182        self.wait_msg_processed(dev_num, 0xE0, 0xE0, 200)?;
183
184        Ok(())
185    }
186
187    pub(crate) fn clear(&mut self) -> Result<bool, Box<dyn Error>> {
188        self.send_header_blocking(CommandType::CmdClear, 5000)
189    }
190
191    pub(crate) fn close(&mut self) -> Result<bool, Box<dyn Error>> {
192        self.clear()?;
193        self.link.close()?;
194        Ok(true)
195    }
196
197    #[allow(clippy::needless_range_loop)]
198    pub(crate) fn firmware_info_list(&mut self) -> Result<Vec<FirmwareInfo>, Box<dyn Error>> {
199        let size = self.geometry.num_devices();
200
201        let mut cpu_versions: Vec<u16> = vec![0x0000; size];
202        let mut fpga_versions: Vec<u16> = vec![0x0000; size];
203
204        self.send_header_blocking(CommandType::CmdReadCpuVerLsb, 50)?;
205        for i in 0..size {
206            cpu_versions[i] = self.rx_data[2 * i] as u16;
207        }
208
209        self.send_header_blocking(CommandType::CmdReadCpuVerMsb, 50)?;
210        for i in 0..size {
211            cpu_versions[i] |= (self.rx_data[2 * i] as u16) << 8;
212        }
213
214        self.send_header_blocking(CommandType::CmdReadFpgaVerLsb, 50)?;
215        for i in 0..size {
216            fpga_versions[i] = self.rx_data[2 * i] as u16;
217        }
218
219        self.send_header_blocking(CommandType::CmdReadFpgaVerMsb, 50)?;
220        for i in 0..size {
221            fpga_versions[i] |= (self.rx_data[2 * i] as u16) << 8;
222        }
223
224        let mut res = Vec::with_capacity(size);
225        for i in 0..size {
226            let firm_info = FirmwareInfo::new(i as u16, cpu_versions[i], fpga_versions[i]);
227            res.push(firm_info);
228        }
229
230        Ok(res)
231    }
232
233    fn make_calib_body(diffs: Vec<u16>) -> Vec<u8> {
234        let header = RxGlobalHeader::new_with_cmd(CommandType::CmdCalibSeqClock);
235        let mut body =
236            vec![0x00; size_of::<RxGlobalHeader>() + NUM_TRANS_IN_UNIT * 2 * diffs.len()];
237        unsafe {
238            copy_nonoverlapping(
239                &header as *const RxGlobalHeader as *const u8,
240                body.as_mut_ptr(),
241                size_of::<RxGlobalHeader>(),
242            );
243            let mut cursor = size_of::<RxGlobalHeader>();
244            for diff in diffs {
245                body[cursor] = (diff & 0x00FF) as u8;
246                body[cursor + 1] = ((diff & 0xFF00) >> 8) as u8;
247                cursor += NUM_TRANS_IN_UNIT * 2;
248            }
249        }
250        body
251    }
252
253    fn wait_msg_processed(
254        &mut self,
255        dev_num: usize,
256        msg_id: u8,
257        mask: u8,
258        max_trial: usize,
259    ) -> Result<bool, Box<dyn Error>> {
260        let buffer_len = dev_num * INPUT_FRAME_SIZE;
261
262        self.rx_data.resize(buffer_len, 0x00);
263        for _ in 0..max_trial {
264            self.link.read(&mut self.rx_data, buffer_len)?;
265
266            let processed = (0..dev_num)
267                .map(|dev| self.rx_data[dev as usize * INPUT_FRAME_SIZE + 1])
268                .filter(|&proc_id| (proc_id & mask) == msg_id)
269                .count();
270
271            if processed == dev_num {
272                return Ok(true);
273            }
274
275            let wait_t = (EC_TRAFFIC_DELAY * 1000.0 / EC_DEVICE_PER_FRAME as Float
276                * dev_num as Float)
277                .ceil() as u64;
278            let wait_t = 1.max(wait_t);
279            std::thread::sleep(std::time::Duration::from_millis(wait_t));
280        }
281        Ok(false)
282    }
283
284    fn make_body<G: Gain, M: Modulation>(
285        gain: Option<&G>,
286        m: Option<&mut M>,
287        num_devices: usize,
288        is_silent: bool,
289        is_seq_mode: bool,
290    ) -> (u8, Vec<u8>) {
291        let num_bodies = if gain.is_some() { num_devices } else { 0 };
292        let size = size_of::<RxGlobalHeader>() + NUM_TRANS_IN_UNIT * 2 * num_bodies;
293
294        let mut body = vec![0x00; size];
295        let mut ctrl_flags = RxGlobalControlFlags::NONE;
296        if is_silent {
297            ctrl_flags |= RxGlobalControlFlags::SILENT;
298        }
299        if is_seq_mode {
300            ctrl_flags |= RxGlobalControlFlags::SEQ_MODE;
301        }
302
303        let mod_data = if let Some(modulation) = m {
304            let sent = *modulation.sent();
305            let mod_size = num::clamp(modulation.buffer().len() - sent, 0, MOD_FRAME_SIZE);
306            if sent == 0 {
307                ctrl_flags |= RxGlobalControlFlags::LOOP_BEGIN;
308            }
309            if sent + mod_size >= modulation.buffer().len() {
310                ctrl_flags |= RxGlobalControlFlags::LOOP_END;
311            }
312            *modulation.sent() += mod_size;
313            &modulation.buffer()[sent..(sent + mod_size)]
314        } else {
315            &[]
316        };
317        let msg_id = unsafe {
318            let header = RxGlobalHeader::new_op(ctrl_flags, mod_data);
319            let src_ptr = &header as *const RxGlobalHeader as *const u8;
320            let dst_ptr = body.as_mut_ptr();
321            copy_nonoverlapping(src_ptr, dst_ptr, size_of::<RxGlobalHeader>());
322            header.msg_id
323        };
324
325        if let Some(gain) = gain {
326            let mut cursor = size_of::<RxGlobalHeader>();
327            let byte_size = NUM_TRANS_IN_UNIT * 2;
328            let gain_ptr = gain.get_data().as_ptr() as *const u8;
329            unsafe {
330                for i in 0..num_devices {
331                    let src_ptr = gain_ptr.add(i * byte_size);
332                    let dst_ptr = body.as_mut_ptr().add(cursor);
333
334                    copy_nonoverlapping(src_ptr, dst_ptr, byte_size);
335                    cursor += byte_size;
336                }
337            }
338        }
339
340        (msg_id, body)
341    }
342
343    pub(crate) fn make_body_ptr(
344        gain: Option<GainPtr>,
345        m: Option<&mut ModPtr>,
346        num_devices: usize,
347        is_silent: bool,
348        is_seq_mode: bool,
349    ) -> (u8, Vec<u8>) {
350        let num_bodies = if gain.is_some() { num_devices } else { 0 };
351        let size = size_of::<RxGlobalHeader>() + NUM_TRANS_IN_UNIT * 2 * num_bodies;
352
353        let mut body = vec![0x00; size];
354        let mut ctrl_flags = RxGlobalControlFlags::NONE;
355        if is_silent {
356            ctrl_flags |= RxGlobalControlFlags::SILENT;
357        }
358        if is_seq_mode {
359            ctrl_flags |= RxGlobalControlFlags::SEQ_MODE;
360        }
361
362        let mod_data = if let Some(modulation) = m {
363            let sent = *modulation.sent();
364            let mod_size = num::clamp(modulation.buffer().len() - sent, 0, MOD_FRAME_SIZE);
365            if sent == 0 {
366                ctrl_flags |= RxGlobalControlFlags::LOOP_BEGIN;
367            }
368            if sent + mod_size >= modulation.buffer().len() {
369                ctrl_flags |= RxGlobalControlFlags::LOOP_END;
370            }
371            *modulation.sent() += mod_size;
372            &modulation.buffer()[sent..(sent + mod_size)]
373        } else {
374            &[]
375        };
376        let msg_id = unsafe {
377            let header = RxGlobalHeader::new_op(ctrl_flags, mod_data);
378            let src_ptr = &header as *const RxGlobalHeader as *const u8;
379            let dst_ptr = body.as_mut_ptr();
380            copy_nonoverlapping(src_ptr, dst_ptr, size_of::<RxGlobalHeader>());
381            header.msg_id
382        };
383
384        if let Some(gain) = gain {
385            let mut cursor = size_of::<RxGlobalHeader>();
386            let byte_size = NUM_TRANS_IN_UNIT * 2;
387            let gain_ptr = gain.get_data().as_ptr() as *const u8;
388            unsafe {
389                for i in 0..num_devices {
390                    let src_ptr = gain_ptr.add(i * byte_size);
391                    let dst_ptr = body.as_mut_ptr().add(cursor);
392
393                    copy_nonoverlapping(src_ptr, dst_ptr, byte_size);
394                    cursor += byte_size;
395                }
396            }
397        }
398        (msg_id, body)
399    }
400
401    fn make_seq_body(
402        seq: &mut PointSequence,
403        geometry: &Geometry,
404        is_silent: bool,
405    ) -> (u8, Vec<u8>) {
406        let num_devices = geometry.num_devices();
407        let size = size_of::<RxGlobalHeader>() + NUM_TRANS_IN_UNIT * 2 * num_devices;
408
409        let sent = *seq.sent();
410
411        let mut body = vec![0x00; size];
412        let send_size = num::clamp(seq.control_points().len() - sent, 0, 40);
413
414        let mut ctrl_flags = RxGlobalControlFlags::SEQ_MODE;
415        if is_silent {
416            ctrl_flags |= RxGlobalControlFlags::SILENT;
417        }
418
419        if sent == 0 {
420            ctrl_flags |= RxGlobalControlFlags::SEQ_BEGIN;
421        }
422        if sent + send_size >= seq.control_points().len() {
423            ctrl_flags |= RxGlobalControlFlags::SEQ_END;
424        }
425        let msg_id = unsafe {
426            let header =
427                RxGlobalHeader::new_seq(ctrl_flags, send_size as u16, seq.sampling_freq_div());
428            let src_ptr = &header as *const RxGlobalHeader as *const u8;
429            let dst_ptr = body.as_mut_ptr();
430            copy_nonoverlapping(src_ptr, dst_ptr, size_of::<RxGlobalHeader>());
431            header.msg_id
432        };
433
434        let mut cursor = size_of::<RxGlobalHeader>();
435        let fixed_num_unit: Float = geometry.wavelength() / 256.0;
436        unsafe {
437            for device in 0..num_devices {
438                let mut foci = Vec::with_capacity(send_size as usize * 10);
439                for i in 0..(send_size as usize) {
440                    let v64 = geometry.local_position(device, seq.control_points()[sent + i]);
441                    let x = (v64.x / fixed_num_unit) as i32 as u32;
442                    let y = (v64.y / fixed_num_unit) as i32 as u32;
443                    let z = (v64.z / fixed_num_unit) as i32 as u32;
444                    foci.push((x & 0x000000FF) as u8);
445                    foci.push(((x & 0x0000FF00) >> 8) as u8);
446                    foci.push((((x & 0x80000000) >> 24) | ((x & 0x007F0000) >> 16)) as u8);
447                    foci.push((y & 0x000000FF) as u8);
448                    foci.push(((y & 0x0000FF00) >> 8) as u8);
449                    foci.push((((y & 0x80000000) >> 24) | ((y & 0x007F0000) >> 16)) as u8);
450                    foci.push((z & 0x000000FF) as u8);
451                    foci.push(((z & 0x0000FF00) >> 8) as u8);
452                    foci.push((((z & 0x80000000) >> 24) | ((z & 0x007F0000) >> 16)) as u8);
453                    foci.push(0xFF); // amp
454                }
455                let src_ptr = foci.as_ptr() as *const u8;
456                let dst_ptr = body.as_mut_ptr().add(cursor);
457
458                copy_nonoverlapping(src_ptr, dst_ptr, foci.len());
459                cursor += NUM_TRANS_IN_UNIT * 2;
460            }
461        }
462        *seq.sent() += send_size;
463        (msg_id, body)
464    }
465
466    unsafe fn convert_to_u8_slice<T: Sized>(p: &T) -> &[u8] {
467        ::std::slice::from_raw_parts((p as *const T) as *const u8, ::std::mem::size_of::<T>())
468    }
469}