1use std::future::Future;
2
3use autd3_core::{
4 geometry::{Geometry, Point3, Quaternion, UnitQuaternion},
5 link::{Ack, AsyncLink, LinkError, RxMessage, TxMessage},
6};
7use tokio::{
8 io::{AsyncReadExt, AsyncWriteExt},
9 net::{TcpListener, TcpStream},
10 select,
11};
12
13use crate::{
14 MSG_CLOSE, MSG_CONFIG_GEOMETRY, MSG_ERROR, MSG_HELLO, MSG_OK, MSG_READ_DATA, MSG_SEND_DATA,
15 MSG_UPDATE_GEOMETRY, REMOTE_PROTOCOL_MAGIC, REMOTE_PROTOCOL_VERSION,
16};
17
18pub struct RemoteServer<L: AsyncLink, F: Fn() -> L> {
20 link_factory: F,
21 link: Option<L>,
22 port: u16,
23 rx_buf: Option<Vec<RxMessage>>,
24 num_devices: usize,
25 shutdown: Option<Box<dyn Future<Output = ()> + Send + Unpin>>,
26 read_buffer: Vec<u8>,
27}
28
29impl<L: AsyncLink, F: Fn() -> L> RemoteServer<L, F> {
30 pub const fn new(port: u16, link_factory: F) -> Self {
37 Self {
38 link_factory,
39 link: None,
40 port,
41 num_devices: 0,
42 rx_buf: None,
43 shutdown: None,
44 read_buffer: Vec::new(),
45 }
46 }
47
48 pub fn with_graceful_shutdown<S>(self, signal: S) -> Self
54 where
55 S: Future<Output = ()> + Send + 'static,
56 {
57 Self {
58 shutdown: Some(Box::new(Box::pin(signal))),
59 ..self
60 }
61 }
62
63 pub async fn run(&mut self) -> Result<(), LinkError> {
78 let listener = TcpListener::bind(("0.0.0.0", self.port)).await?;
79 tracing::info!("Remote server listening on port {}", self.port);
80
81 if let Some(shutdown) = self.shutdown.take() {
82 select! {
83 result = self.accept_loop(&listener) => result,
84 _ = shutdown => {
85 tracing::info!("Shutdown signal received, stopping server");
86 Ok(())
87 },
88 }
89 } else {
90 self.accept_loop(&listener).await
91 }
92 }
93
94 async fn accept_loop(&mut self, listener: &TcpListener) -> Result<(), LinkError> {
95 loop {
96 let (stream, _) = listener.accept().await?;
97 tracing::info!("Client connected: {:?}", stream.peer_addr()?);
98 self.handle_client(stream).await;
99 tracing::info!("Client disconnected");
100 if let Some(mut link) = self.link.take()
101 && let Err(e) = AsyncLink::close(&mut link).await
102 {
103 tracing::error!("Error closing link: {}", e);
104 }
105 }
106 }
107
108 async fn handle_client(&mut self, mut stream: TcpStream) {
109 let mut handshake_completed = false;
110
111 loop {
112 let mut msg_type = [0u8; size_of::<u8>()];
113 if stream.read_exact(&mut msg_type).await.is_err() {
114 break;
115 }
116
117 let msg = msg_type[0];
118 let result = if msg == MSG_HELLO {
119 tracing::info!("Handling handshake...");
120 if handshake_completed {
121 tracing::error!("Handshake already completed");
122 Err(LinkError::new("Handshake already completed"))
123 } else {
124 match Self::handle_handshake(&mut stream).await {
125 Ok(()) => {
126 tracing::info!("Handshake completed");
127 handshake_completed = true;
128 Ok(())
129 }
130 Err(e) => {
131 tracing::error!("Handshake failed: {}", e);
132 Err(e)
133 }
134 }
135 }
136 } else if !handshake_completed {
137 Err(LinkError::new(
138 "Handshake is required before sending commands",
139 ))
140 } else {
141 match msg {
142 MSG_CONFIG_GEOMETRY => self.handle_config_geometry(&mut stream).await,
143 MSG_UPDATE_GEOMETRY => self.handle_update_geometry(&mut stream).await,
144 MSG_SEND_DATA => self.handle_send_data(&mut stream).await,
145 MSG_READ_DATA => self.handle_read_data(&mut stream).await,
146 MSG_CLOSE => self.handle_close(&mut stream).await,
147 other => Err(LinkError::new(format!("Unknown message type: {}", other))),
148 }
149 };
150
151 match result {
152 Ok(()) => {
153 if msg == MSG_CLOSE {
154 break;
155 }
156 }
157 Err(e) => {
158 tracing::error!("Error handling client request: {}", e);
159 let _ = self.send_error(&mut stream, &e).await;
160 if !handshake_completed || msg == MSG_CLOSE {
161 break;
162 }
163 }
164 }
165 }
166 }
167
168 async fn handle_handshake(stream: &mut TcpStream) -> Result<(), LinkError> {
169 let mut version_buf = [0u8; size_of::<u16>()];
170 stream.read_exact(&mut version_buf).await?;
171 let version = u16::from_le_bytes(version_buf);
172 if version != REMOTE_PROTOCOL_VERSION {
173 return Err(LinkError::new(format!(
174 "Unsupported protocol version: {}",
175 version
176 )));
177 }
178 tracing::info!("Client protocol version: {}", version);
179
180 let mut magic_buf = [0u8; REMOTE_PROTOCOL_MAGIC.len()];
181 stream.read_exact(&mut magic_buf).await?;
182 if &magic_buf != REMOTE_PROTOCOL_MAGIC {
183 tracing::error!("Invalid client magic: {:?}", magic_buf);
184 return Err(LinkError::new("Invalid client magic"));
185 }
186
187 stream.write_all(&[MSG_OK]).await?;
188 Ok(())
189 }
190
191 async fn handle_config_geometry(&mut self, stream: &mut TcpStream) -> Result<(), LinkError> {
192 if self.link.is_some() {
193 tracing::error!("Link is already open");
194 Err(LinkError::new("Link is already opened"))
195 } else {
196 let geometry = Self::read_geometry(stream).await?;
197 tracing::info!("Opening link...");
198
199 let mut link = (self.link_factory)();
200 AsyncLink::open(&mut link, &geometry).await?;
201 self.num_devices = geometry.num_devices();
202 tracing::info!(
203 "Link opened with {} device{}",
204 self.num_devices,
205 if self.num_devices == 1 { "" } else { "s" }
206 );
207
208 stream.write_all(&[MSG_OK]).await?;
209
210 self.link = Some(link);
211
212 Ok(())
213 }
214 }
215
216 async fn handle_update_geometry(&mut self, stream: &mut TcpStream) -> Result<(), LinkError> {
217 if let Some(link) = self.link.as_mut() {
218 let geometry = Self::read_geometry(stream).await?;
219 AsyncLink::update(link, &geometry).await?;
220 stream.write_all(&[MSG_OK]).await?;
221 Ok(())
222 } else {
223 Err(LinkError::closed())
224 }
225 }
226
227 async fn read_geometry(stream: &mut TcpStream) -> std::io::Result<Geometry> {
228 let mut num_devices_buf = [0u8; size_of::<u32>()];
229 stream.read_exact(&mut num_devices_buf).await?;
230 let num_devices = u32::from_le_bytes(num_devices_buf);
231
232 let mut devices = Vec::new();
233 for _ in 0..num_devices {
234 let mut pos_buf = [0u8; size_of::<f32>() * 3];
235 stream.read_exact(&mut pos_buf).await?;
236 let x = f32::from_le_bytes([pos_buf[0], pos_buf[1], pos_buf[2], pos_buf[3]]);
237 let y = f32::from_le_bytes([pos_buf[4], pos_buf[5], pos_buf[6], pos_buf[7]]);
238 let z = f32::from_le_bytes([pos_buf[8], pos_buf[9], pos_buf[10], pos_buf[11]]);
239
240 let mut rot_buf = [0u8; size_of::<f32>() * 4];
241 stream.read_exact(&mut rot_buf).await?;
242 let w = f32::from_le_bytes([rot_buf[0], rot_buf[1], rot_buf[2], rot_buf[3]]);
243 let i = f32::from_le_bytes([rot_buf[4], rot_buf[5], rot_buf[6], rot_buf[7]]);
244 let j = f32::from_le_bytes([rot_buf[8], rot_buf[9], rot_buf[10], rot_buf[11]]);
245 let k = f32::from_le_bytes([rot_buf[12], rot_buf[13], rot_buf[14], rot_buf[15]]);
246
247 devices.push(
248 autd3_core::devices::AUTD3 {
249 pos: Point3::new(x, y, z),
250 rot: UnitQuaternion::new_unchecked(Quaternion::new(w, i, j, k)),
251 }
252 .into(),
253 );
254 }
255
256 Ok(Geometry::new(devices))
257 }
258
259 async fn handle_send_data(&mut self, stream: &mut TcpStream) -> Result<(), LinkError> {
260 if let Some(link) = self.link.as_mut() {
261 let mut tx = AsyncLink::alloc_tx_buffer(link).await?;
262
263 for tx_msg in tx.iter_mut() {
264 let bytes = unsafe {
265 std::slice::from_raw_parts_mut(
266 tx_msg as *mut TxMessage as *mut u8,
267 size_of::<TxMessage>(),
268 )
269 };
270 stream.read_exact(bytes).await?;
271 }
272
273 AsyncLink::send(link, tx).await?;
274
275 stream.write_all(&[MSG_OK]).await?;
276
277 Ok(())
278 } else {
279 Err(LinkError::closed())
280 }
281 }
282
283 async fn handle_read_data(&mut self, stream: &mut TcpStream) -> Result<(), LinkError> {
284 let num_devices = self.num_devices;
285 let mut rx = match self.rx_buf.take() {
286 Some(buf) if buf.len() == num_devices => buf,
287 _ => vec![RxMessage::new(0, Ack::new(0, 0)); num_devices],
288 };
289 if let Some(link) = self.link.as_mut() {
290 AsyncLink::receive(link, &mut rx).await?;
291
292 let buffer_size = size_of::<u8>() + size_of::<RxMessage>() * rx.len();
293 if self.read_buffer.len() < buffer_size {
294 self.read_buffer.resize(buffer_size, 0);
295 }
296
297 self.read_buffer[0] = MSG_OK;
298 unsafe {
299 std::ptr::copy_nonoverlapping(
300 rx.as_ptr() as *const u8,
301 self.read_buffer.as_mut_ptr().add(1),
302 size_of::<RxMessage>() * rx.len(),
303 );
304 }
305 self.rx_buf = Some(rx);
306 stream.write_all(&self.read_buffer).await?;
307
308 Ok(())
309 } else {
310 Err(LinkError::closed())
311 }
312 }
313
314 async fn handle_close(&mut self, stream: &mut TcpStream) -> Result<(), LinkError> {
315 if let Some(link) = self.link.as_mut() {
316 AsyncLink::close(link).await?;
317 stream.write_all(&[MSG_OK]).await?;
318 Ok(())
319 } else {
320 Err(LinkError::closed())
321 }
322 }
323
324 async fn send_error(&self, stream: &mut TcpStream, error: &LinkError) -> std::io::Result<()> {
325 let error_msg = error.to_string();
326 let error_bytes = error_msg.as_bytes();
327 let error_len = error_bytes.len() as u32;
328
329 let mut buffer = Vec::with_capacity(size_of::<u8>() + size_of::<u32>() + error_bytes.len());
330 buffer.push(MSG_ERROR);
331 buffer.extend_from_slice(&error_len.to_le_bytes());
332 buffer.extend_from_slice(error_bytes);
333
334 stream.write_all(&buffer).await
335 }
336}