autd3_link_remote/
link.rs

1use std::{
2    io::{Read, Write},
3    net::{SocketAddr, TcpStream},
4    time::Duration,
5};
6
7use crate::{
8    MSG_CLOSE, MSG_CONFIG_GEOMETRY, MSG_ERROR, MSG_HELLO, MSG_OK, MSG_READ_DATA, MSG_SEND_DATA,
9    MSG_UPDATE_GEOMETRY, REMOTE_PROTOCOL_MAGIC, REMOTE_PROTOCOL_VERSION,
10};
11
12use autd3_core::{
13    geometry::Geometry,
14    link::{Link, LinkError, RxMessage, TxBufferPoolSync, TxMessage},
15};
16
17const REMOTE_HANDSHAKE_LEN: usize =
18    size_of::<u8>() + size_of::<u16>() + REMOTE_PROTOCOL_MAGIC.len();
19const fn handshake_payload() -> [u8; REMOTE_HANDSHAKE_LEN] {
20    let mut payload = [0u8; REMOTE_HANDSHAKE_LEN];
21    payload[0] = MSG_HELLO;
22
23    let version = REMOTE_PROTOCOL_VERSION.to_le_bytes();
24    let version_end = 1 + version.len();
25    let mut i = 1;
26    while i < version_end {
27        payload[i] = version[i - 1];
28        i += 1;
29    }
30    while i < REMOTE_HANDSHAKE_LEN {
31        payload[i] = REMOTE_PROTOCOL_MAGIC[i - version_end];
32        i += 1;
33    }
34    payload
35}
36
37struct RemoteInner {
38    stream: TcpStream,
39    last_geometry_version: usize,
40    tx_buffer_pool: TxBufferPoolSync,
41    buffer: Vec<u8>,
42}
43
44impl RemoteInner {
45    fn open(
46        addr: &SocketAddr,
47        timeout: Option<Duration>,
48        geometry: &Geometry,
49    ) -> Result<RemoteInner, LinkError> {
50        let mut stream = if let Some(timeout) = timeout {
51            TcpStream::connect_timeout(addr, timeout)
52        } else {
53            TcpStream::connect(addr)
54        }?;
55
56        stream.set_write_timeout(timeout)?;
57        stream.set_read_timeout(timeout)?;
58
59        Self::perform_handshake(&mut stream)?;
60        Self::send_geometry(&mut stream, MSG_CONFIG_GEOMETRY, geometry)?;
61        Self::wait_response(&mut stream)?;
62
63        let mut tx_buffer_pool = TxBufferPoolSync::default();
64        tx_buffer_pool.init(geometry);
65
66        Ok(Self {
67            stream,
68            last_geometry_version: geometry.version(),
69            tx_buffer_pool,
70            buffer: Vec::new(),
71        })
72    }
73
74    fn send_geometry(
75        stream: &mut TcpStream,
76        msg_type: u8,
77        geometry: &autd3_core::geometry::Geometry,
78    ) -> Result<(), LinkError> {
79        let num_devices = geometry.len() as u32;
80
81        let mut buffer = Vec::with_capacity(
82            size_of::<u8>()
83                + size_of::<u32>()
84                + (size_of::<f32>() * 3 + size_of::<f32>() * 4) * geometry.len(),
85        );
86        buffer.push(msg_type);
87        buffer.extend_from_slice(&num_devices.to_le_bytes());
88        geometry.iter().for_each(|dev| {
89            let pos = dev[0].position();
90            buffer.extend_from_slice(&pos.x.to_le_bytes());
91            buffer.extend_from_slice(&pos.y.to_le_bytes());
92            buffer.extend_from_slice(&pos.z.to_le_bytes());
93
94            let rot = dev.rotation();
95            buffer.extend_from_slice(&rot.w.to_le_bytes());
96            buffer.extend_from_slice(&rot.i.to_le_bytes());
97            buffer.extend_from_slice(&rot.j.to_le_bytes());
98            buffer.extend_from_slice(&rot.k.to_le_bytes());
99        });
100
101        stream.write_all(&buffer)?;
102
103        Ok(())
104    }
105
106    fn perform_handshake(stream: &mut TcpStream) -> Result<(), LinkError> {
107        const PAYLOAD: [u8; REMOTE_HANDSHAKE_LEN] = handshake_payload();
108        stream.write_all(&PAYLOAD)?;
109        Self::wait_response(stream)
110    }
111
112    fn wait_response(stream: &mut TcpStream) -> Result<(), LinkError> {
113        let mut status = [0u8; size_of::<u8>()];
114        stream.read_exact(&mut status)?;
115
116        match status[0] {
117            MSG_OK => Ok(()),
118            MSG_ERROR => {
119                let mut error_len_buf = [0u8; size_of::<u32>()];
120                stream.read_exact(&mut error_len_buf)?;
121                let error_len = u32::from_le_bytes(error_len_buf) as usize;
122
123                let mut error_msg = vec![0u8; error_len];
124                stream.read_exact(&mut error_msg)?;
125
126                let error_str = String::from_utf8_lossy(&error_msg);
127                Err(LinkError::new(format!("Server error: {}", error_str)))
128            }
129            msg => Err(LinkError::new(format!("Unknown response status: {}", msg))),
130        }
131    }
132
133    fn close(&mut self) -> Result<(), LinkError> {
134        self.stream.write_all(&[MSG_CLOSE])?;
135        Self::wait_response(&mut self.stream)?;
136        Ok(())
137    }
138
139    fn update(&mut self, geometry: &autd3_core::geometry::Geometry) -> Result<(), LinkError> {
140        if self.last_geometry_version == geometry.version() {
141            return Ok(());
142        }
143        self.last_geometry_version = geometry.version();
144        Self::send_geometry(&mut self.stream, MSG_UPDATE_GEOMETRY, geometry)?;
145        Self::wait_response(&mut self.stream)?;
146        Ok(())
147    }
148
149    fn alloc_tx_buffer(&mut self) -> Vec<TxMessage> {
150        self.tx_buffer_pool.borrow()
151    }
152
153    fn send(&mut self, tx: Vec<TxMessage>) -> Result<(), LinkError> {
154        let buffer_size = size_of::<u8>() + size_of::<TxMessage>() * tx.len();
155        if self.buffer.len() < buffer_size {
156            self.buffer.resize(buffer_size, 0);
157        }
158
159        self.buffer[0] = MSG_SEND_DATA;
160        unsafe {
161            std::ptr::copy_nonoverlapping(
162                tx.as_ptr() as *const u8,
163                self.buffer.as_mut_ptr().add(1),
164                size_of::<TxMessage>() * tx.len(),
165            );
166        }
167        self.tx_buffer_pool.return_buffer(tx);
168
169        self.stream.write_all(&self.buffer)?;
170        Self::wait_response(&mut self.stream)?;
171
172        Ok(())
173    }
174
175    fn receive(&mut self, rx: &mut [RxMessage]) -> Result<(), LinkError> {
176        self.stream.write_all(&[MSG_READ_DATA])?;
177        Self::wait_response(&mut self.stream)?;
178        rx.iter_mut()
179            .map(|msg| unsafe {
180                std::slice::from_raw_parts_mut(
181                    msg as *mut RxMessage as *mut u8,
182                    size_of::<RxMessage>(),
183                )
184            })
185            .try_for_each(|bytes| self.stream.read_exact(bytes))?;
186        Ok(())
187    }
188}
189
190#[derive(Clone, Debug, Default)]
191/// Options for [`Remote`].
192pub struct RemoteOption {
193    /// Timeout duration for connecting and read/write operations. The default is `None`, which means no timeout.
194    pub timeout: Option<Duration>,
195}
196
197/// A [`Link`] for a remote server or [`AUTD3 Simulator`].
198///
199/// [`AUTD3 Simulator`]: https://github.com/shinolab/autd3-server
200pub struct Remote {
201    addr: SocketAddr,
202    inner: Option<RemoteInner>,
203    option: RemoteOption,
204}
205
206impl Remote {
207    /// Creates a new [`Remote`].
208    #[must_use]
209    pub const fn new(addr: SocketAddr, option: RemoteOption) -> Remote {
210        Remote {
211            addr,
212            inner: None,
213            option,
214        }
215    }
216}
217
218impl Link for Remote {
219    fn open(&mut self, geometry: &autd3_core::geometry::Geometry) -> Result<(), LinkError> {
220        self.inner = Some(RemoteInner::open(
221            &self.addr,
222            self.option.timeout,
223            geometry,
224        )?);
225        Ok(())
226    }
227
228    fn close(&mut self) -> Result<(), LinkError> {
229        if let Some(mut inner) = self.inner.take() {
230            inner.close()?;
231        }
232        Ok(())
233    }
234
235    fn update(&mut self, geometry: &autd3_core::geometry::Geometry) -> Result<(), LinkError> {
236        if let Some(inner) = self.inner.as_mut() {
237            inner.update(geometry)
238        } else {
239            Err(LinkError::closed())
240        }
241    }
242
243    fn alloc_tx_buffer(&mut self) -> Result<Vec<TxMessage>, LinkError> {
244        if let Some(inner) = self.inner.as_mut() {
245            Ok(inner.alloc_tx_buffer())
246        } else {
247            Err(LinkError::closed())
248        }
249    }
250
251    fn send(&mut self, tx: Vec<TxMessage>) -> Result<(), LinkError> {
252        if let Some(inner) = self.inner.as_mut() {
253            inner.send(tx)
254        } else {
255            Err(LinkError::closed())
256        }
257    }
258
259    fn receive(&mut self, rx: &mut [RxMessage]) -> Result<(), LinkError> {
260        if let Some(inner) = self.inner.as_mut() {
261            inner.receive(rx)
262        } else {
263            Err(LinkError::closed())
264        }
265    }
266
267    fn is_open(&self) -> bool {
268        self.inner.is_some()
269    }
270}
271
272impl autd3_core::link::AsyncLink for Remote {}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn handshake_payload_format() {
280        let payload = handshake_payload();
281        assert_eq!(payload.len(), REMOTE_HANDSHAKE_LEN);
282        assert_eq!(payload[0], MSG_HELLO);
283        let version_bytes = REMOTE_PROTOCOL_VERSION.to_le_bytes();
284        assert_eq!(payload[1..1 + version_bytes.len()], version_bytes);
285        assert_eq!(
286            &payload[1 + version_bytes.len()..],
287            REMOTE_PROTOCOL_MAGIC.as_slice()
288        );
289    }
290}