1use std::{
2 io::{Read, Write},
3 net::{SocketAddr, TcpStream},
4 time::Duration,
5};
6
7use crate::{
8 MSG_CLOSE, MSG_CONFIG_GEOMETRY, MSG_ERROR, MSG_HELLO, MSG_OK, MSG_READ_DATA, MSG_SEND_DATA,
9 MSG_UPDATE_GEOMETRY, REMOTE_PROTOCOL_MAGIC, REMOTE_PROTOCOL_VERSION,
10};
11
12use autd3_core::{
13 geometry::Geometry,
14 link::{Link, LinkError, RxMessage, TxBufferPoolSync, TxMessage},
15};
16
17const REMOTE_HANDSHAKE_LEN: usize =
18 size_of::<u8>() + size_of::<u16>() + REMOTE_PROTOCOL_MAGIC.len();
19const fn handshake_payload() -> [u8; REMOTE_HANDSHAKE_LEN] {
20 let mut payload = [0u8; REMOTE_HANDSHAKE_LEN];
21 payload[0] = MSG_HELLO;
22
23 let version = REMOTE_PROTOCOL_VERSION.to_le_bytes();
24 let version_end = 1 + version.len();
25 let mut i = 1;
26 while i < version_end {
27 payload[i] = version[i - 1];
28 i += 1;
29 }
30 while i < REMOTE_HANDSHAKE_LEN {
31 payload[i] = REMOTE_PROTOCOL_MAGIC[i - version_end];
32 i += 1;
33 }
34 payload
35}
36
37struct RemoteInner {
38 stream: TcpStream,
39 last_geometry_version: usize,
40 tx_buffer_pool: TxBufferPoolSync,
41 buffer: Vec<u8>,
42}
43
44impl RemoteInner {
45 fn open(
46 addr: &SocketAddr,
47 timeout: Option<Duration>,
48 geometry: &Geometry,
49 ) -> Result<RemoteInner, LinkError> {
50 let mut stream = if let Some(timeout) = timeout {
51 TcpStream::connect_timeout(addr, timeout)
52 } else {
53 TcpStream::connect(addr)
54 }?;
55
56 stream.set_write_timeout(timeout)?;
57 stream.set_read_timeout(timeout)?;
58
59 Self::perform_handshake(&mut stream)?;
60 Self::send_geometry(&mut stream, MSG_CONFIG_GEOMETRY, geometry)?;
61 Self::wait_response(&mut stream)?;
62
63 let mut tx_buffer_pool = TxBufferPoolSync::default();
64 tx_buffer_pool.init(geometry);
65
66 Ok(Self {
67 stream,
68 last_geometry_version: geometry.version(),
69 tx_buffer_pool,
70 buffer: Vec::new(),
71 })
72 }
73
74 fn send_geometry(
75 stream: &mut TcpStream,
76 msg_type: u8,
77 geometry: &autd3_core::geometry::Geometry,
78 ) -> Result<(), LinkError> {
79 let num_devices = geometry.len() as u32;
80
81 let mut buffer = Vec::with_capacity(
82 size_of::<u8>()
83 + size_of::<u32>()
84 + (size_of::<f32>() * 3 + size_of::<f32>() * 4) * geometry.len(),
85 );
86 buffer.push(msg_type);
87 buffer.extend_from_slice(&num_devices.to_le_bytes());
88 geometry.iter().for_each(|dev| {
89 let pos = dev[0].position();
90 buffer.extend_from_slice(&pos.x.to_le_bytes());
91 buffer.extend_from_slice(&pos.y.to_le_bytes());
92 buffer.extend_from_slice(&pos.z.to_le_bytes());
93
94 let rot = dev.rotation();
95 buffer.extend_from_slice(&rot.w.to_le_bytes());
96 buffer.extend_from_slice(&rot.i.to_le_bytes());
97 buffer.extend_from_slice(&rot.j.to_le_bytes());
98 buffer.extend_from_slice(&rot.k.to_le_bytes());
99 });
100
101 stream.write_all(&buffer)?;
102
103 Ok(())
104 }
105
106 fn perform_handshake(stream: &mut TcpStream) -> Result<(), LinkError> {
107 const PAYLOAD: [u8; REMOTE_HANDSHAKE_LEN] = handshake_payload();
108 stream.write_all(&PAYLOAD)?;
109 Self::wait_response(stream)
110 }
111
112 fn wait_response(stream: &mut TcpStream) -> Result<(), LinkError> {
113 let mut status = [0u8; size_of::<u8>()];
114 stream.read_exact(&mut status)?;
115
116 match status[0] {
117 MSG_OK => Ok(()),
118 MSG_ERROR => {
119 let mut error_len_buf = [0u8; size_of::<u32>()];
120 stream.read_exact(&mut error_len_buf)?;
121 let error_len = u32::from_le_bytes(error_len_buf) as usize;
122
123 let mut error_msg = vec![0u8; error_len];
124 stream.read_exact(&mut error_msg)?;
125
126 let error_str = String::from_utf8_lossy(&error_msg);
127 Err(LinkError::new(format!("Server error: {}", error_str)))
128 }
129 msg => Err(LinkError::new(format!("Unknown response status: {}", msg))),
130 }
131 }
132
133 fn close(&mut self) -> Result<(), LinkError> {
134 self.stream.write_all(&[MSG_CLOSE])?;
135 Self::wait_response(&mut self.stream)?;
136 Ok(())
137 }
138
139 fn update(&mut self, geometry: &autd3_core::geometry::Geometry) -> Result<(), LinkError> {
140 if self.last_geometry_version == geometry.version() {
141 return Ok(());
142 }
143 self.last_geometry_version = geometry.version();
144 Self::send_geometry(&mut self.stream, MSG_UPDATE_GEOMETRY, geometry)?;
145 Self::wait_response(&mut self.stream)?;
146 Ok(())
147 }
148
149 fn alloc_tx_buffer(&mut self) -> Vec<TxMessage> {
150 self.tx_buffer_pool.borrow()
151 }
152
153 fn send(&mut self, tx: Vec<TxMessage>) -> Result<(), LinkError> {
154 let buffer_size = size_of::<u8>() + size_of::<TxMessage>() * tx.len();
155 if self.buffer.len() < buffer_size {
156 self.buffer.resize(buffer_size, 0);
157 }
158
159 self.buffer[0] = MSG_SEND_DATA;
160 unsafe {
161 std::ptr::copy_nonoverlapping(
162 tx.as_ptr() as *const u8,
163 self.buffer.as_mut_ptr().add(1),
164 size_of::<TxMessage>() * tx.len(),
165 );
166 }
167 self.tx_buffer_pool.return_buffer(tx);
168
169 self.stream.write_all(&self.buffer)?;
170 Self::wait_response(&mut self.stream)?;
171
172 Ok(())
173 }
174
175 fn receive(&mut self, rx: &mut [RxMessage]) -> Result<(), LinkError> {
176 self.stream.write_all(&[MSG_READ_DATA])?;
177 Self::wait_response(&mut self.stream)?;
178 rx.iter_mut()
179 .map(|msg| unsafe {
180 std::slice::from_raw_parts_mut(
181 msg as *mut RxMessage as *mut u8,
182 size_of::<RxMessage>(),
183 )
184 })
185 .try_for_each(|bytes| self.stream.read_exact(bytes))?;
186 Ok(())
187 }
188}
189
190#[derive(Clone, Debug, Default)]
191pub struct RemoteOption {
193 pub timeout: Option<Duration>,
195}
196
197pub struct Remote {
201 addr: SocketAddr,
202 inner: Option<RemoteInner>,
203 option: RemoteOption,
204}
205
206impl Remote {
207 #[must_use]
209 pub const fn new(addr: SocketAddr, option: RemoteOption) -> Remote {
210 Remote {
211 addr,
212 inner: None,
213 option,
214 }
215 }
216}
217
218impl Link for Remote {
219 fn open(&mut self, geometry: &autd3_core::geometry::Geometry) -> Result<(), LinkError> {
220 self.inner = Some(RemoteInner::open(
221 &self.addr,
222 self.option.timeout,
223 geometry,
224 )?);
225 Ok(())
226 }
227
228 fn close(&mut self) -> Result<(), LinkError> {
229 if let Some(mut inner) = self.inner.take() {
230 inner.close()?;
231 }
232 Ok(())
233 }
234
235 fn update(&mut self, geometry: &autd3_core::geometry::Geometry) -> Result<(), LinkError> {
236 if let Some(inner) = self.inner.as_mut() {
237 inner.update(geometry)
238 } else {
239 Err(LinkError::closed())
240 }
241 }
242
243 fn alloc_tx_buffer(&mut self) -> Result<Vec<TxMessage>, LinkError> {
244 if let Some(inner) = self.inner.as_mut() {
245 Ok(inner.alloc_tx_buffer())
246 } else {
247 Err(LinkError::closed())
248 }
249 }
250
251 fn send(&mut self, tx: Vec<TxMessage>) -> Result<(), LinkError> {
252 if let Some(inner) = self.inner.as_mut() {
253 inner.send(tx)
254 } else {
255 Err(LinkError::closed())
256 }
257 }
258
259 fn receive(&mut self, rx: &mut [RxMessage]) -> Result<(), LinkError> {
260 if let Some(inner) = self.inner.as_mut() {
261 inner.receive(rx)
262 } else {
263 Err(LinkError::closed())
264 }
265 }
266
267 fn is_open(&self) -> bool {
268 self.inner.is_some()
269 }
270}
271
272impl autd3_core::link::AsyncLink for Remote {}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn handshake_payload_format() {
280 let payload = handshake_payload();
281 assert_eq!(payload.len(), REMOTE_HANDSHAKE_LEN);
282 assert_eq!(payload[0], MSG_HELLO);
283 let version_bytes = REMOTE_PROTOCOL_VERSION.to_le_bytes();
284 assert_eq!(payload[1..1 + version_bytes.len()], version_bytes);
285 assert_eq!(
286 &payload[1 + version_bytes.len()..],
287 REMOTE_PROTOCOL_MAGIC.as_slice()
288 );
289 }
290}