autd3_link_remote/
server.rs

1use std::future::Future;
2
3use autd3_core::{
4    geometry::{Geometry, Point3, Quaternion, UnitQuaternion},
5    link::{Ack, AsyncLink, LinkError, RxMessage, TxMessage},
6};
7use tokio::{
8    io::{AsyncReadExt, AsyncWriteExt},
9    net::{TcpListener, TcpStream},
10    select,
11};
12
13use crate::{
14    MSG_CLOSE, MSG_CONFIG_GEOMETRY, MSG_ERROR, MSG_HELLO, MSG_OK, MSG_READ_DATA, MSG_SEND_DATA,
15    MSG_UPDATE_GEOMETRY, REMOTE_PROTOCOL_MAGIC, REMOTE_PROTOCOL_VERSION,
16};
17
18/// A server that accepts connections from [`Remote`](crate::Remote) clients and forwards requests to a link.
19pub struct RemoteServer<L: AsyncLink, F: Fn() -> L> {
20    link_factory: F,
21    link: Option<L>,
22    port: u16,
23    rx_buf: Option<Vec<RxMessage>>,
24    num_devices: usize,
25    shutdown: Option<Box<dyn Future<Output = ()> + Send + Unpin>>,
26    read_buffer: Vec<u8>,
27}
28
29impl<L: AsyncLink, F: Fn() -> L> RemoteServer<L, F> {
30    /// Create a new [`RemoteServer`].
31    ///
32    /// # Arguments
33    ///
34    /// * `port` - The port to listen on
35    /// * `link` - A factory function that creates a new link instance
36    pub const fn new(port: u16, link_factory: F) -> Self {
37        Self {
38            link_factory,
39            link: None,
40            port,
41            num_devices: 0,
42            rx_buf: None,
43            shutdown: None,
44            read_buffer: Vec::new(),
45        }
46    }
47
48    /// Configure graceful shutdown with a custom shutdown signal.
49    ///
50    /// # Arguments
51    ///
52    /// * `signal` - A future that completes when the server should shut down
53    pub fn with_graceful_shutdown<S>(self, signal: S) -> Self
54    where
55        S: Future<Output = ()> + Send + 'static,
56    {
57        Self {
58            shutdown: Some(Box::new(Box::pin(signal))),
59            ..self
60        }
61    }
62
63    /// Run the server.
64    ///
65    /// This method listens for incoming connections asynchronously.
66    /// For each connection, it processes requests and forwards them to the link.
67    ///
68    /// If a shutdown signal is configured via [`with_graceful_shutdown`](Self::with_graceful_shutdown),
69    /// the server will gracefully shut down when the signal completes.
70    ///
71    /// # Errors
72    ///
73    /// Returns an error if:
74    /// - Failed to bind to the specified port
75    /// - Failed to accept a connection
76    /// - Failed to process a request
77    pub async fn run(&mut self) -> Result<(), LinkError> {
78        let listener = TcpListener::bind(("0.0.0.0", self.port)).await?;
79        tracing::info!("Remote server listening on port {}", self.port);
80
81        if let Some(shutdown) = self.shutdown.take() {
82            select! {
83                result = self.accept_loop(&listener) => result,
84                _ = shutdown => {
85                    tracing::info!("Shutdown signal received, stopping server");
86                    Ok(())
87                },
88            }
89        } else {
90            self.accept_loop(&listener).await
91        }
92    }
93
94    async fn accept_loop(&mut self, listener: &TcpListener) -> Result<(), LinkError> {
95        loop {
96            let (stream, _) = listener.accept().await?;
97            tracing::info!("Client connected: {:?}", stream.peer_addr()?);
98            self.handle_client(stream).await;
99            tracing::info!("Client disconnected");
100            if let Some(mut link) = self.link.take()
101                && let Err(e) = AsyncLink::close(&mut link).await
102            {
103                tracing::error!("Error closing link: {}", e);
104            }
105        }
106    }
107
108    async fn handle_client(&mut self, mut stream: TcpStream) {
109        let mut handshake_completed = false;
110
111        loop {
112            let mut msg_type = [0u8; size_of::<u8>()];
113            if stream.read_exact(&mut msg_type).await.is_err() {
114                break;
115            }
116
117            let msg = msg_type[0];
118            let result = if msg == MSG_HELLO {
119                tracing::info!("Handling handshake...");
120                if handshake_completed {
121                    tracing::error!("Handshake already completed");
122                    Err(LinkError::new("Handshake already completed"))
123                } else {
124                    match Self::handle_handshake(&mut stream).await {
125                        Ok(()) => {
126                            tracing::info!("Handshake completed");
127                            handshake_completed = true;
128                            Ok(())
129                        }
130                        Err(e) => {
131                            tracing::error!("Handshake failed: {}", e);
132                            Err(e)
133                        }
134                    }
135                }
136            } else if !handshake_completed {
137                Err(LinkError::new(
138                    "Handshake is required before sending commands",
139                ))
140            } else {
141                match msg {
142                    MSG_CONFIG_GEOMETRY => self.handle_config_geometry(&mut stream).await,
143                    MSG_UPDATE_GEOMETRY => self.handle_update_geometry(&mut stream).await,
144                    MSG_SEND_DATA => self.handle_send_data(&mut stream).await,
145                    MSG_READ_DATA => self.handle_read_data(&mut stream).await,
146                    MSG_CLOSE => self.handle_close(&mut stream).await,
147                    other => Err(LinkError::new(format!("Unknown message type: {}", other))),
148                }
149            };
150
151            match result {
152                Ok(()) => {
153                    if msg == MSG_CLOSE {
154                        break;
155                    }
156                }
157                Err(e) => {
158                    tracing::error!("Error handling client request: {}", e);
159                    let _ = self.send_error(&mut stream, &e).await;
160                    if !handshake_completed || msg == MSG_CLOSE {
161                        break;
162                    }
163                }
164            }
165        }
166    }
167
168    async fn handle_handshake(stream: &mut TcpStream) -> Result<(), LinkError> {
169        let mut version_buf = [0u8; size_of::<u16>()];
170        stream.read_exact(&mut version_buf).await?;
171        let version = u16::from_le_bytes(version_buf);
172        if version != REMOTE_PROTOCOL_VERSION {
173            return Err(LinkError::new(format!(
174                "Unsupported protocol version: {}",
175                version
176            )));
177        }
178        tracing::info!("Client protocol version: {}", version);
179
180        let mut magic_buf = [0u8; REMOTE_PROTOCOL_MAGIC.len()];
181        stream.read_exact(&mut magic_buf).await?;
182        if &magic_buf != REMOTE_PROTOCOL_MAGIC {
183            tracing::error!("Invalid client magic: {:?}", magic_buf);
184            return Err(LinkError::new("Invalid client magic"));
185        }
186
187        stream.write_all(&[MSG_OK]).await?;
188        Ok(())
189    }
190
191    async fn handle_config_geometry(&mut self, stream: &mut TcpStream) -> Result<(), LinkError> {
192        if self.link.is_some() {
193            tracing::error!("Link is already open");
194            Err(LinkError::new("Link is already opened"))
195        } else {
196            let geometry = Self::read_geometry(stream).await?;
197            tracing::info!("Opening link...");
198
199            let mut link = (self.link_factory)();
200            AsyncLink::open(&mut link, &geometry).await?;
201            self.num_devices = geometry.num_devices();
202            tracing::info!(
203                "Link opened with {} device{}",
204                self.num_devices,
205                if self.num_devices == 1 { "" } else { "s" }
206            );
207
208            stream.write_all(&[MSG_OK]).await?;
209
210            self.link = Some(link);
211
212            Ok(())
213        }
214    }
215
216    async fn handle_update_geometry(&mut self, stream: &mut TcpStream) -> Result<(), LinkError> {
217        if let Some(link) = self.link.as_mut() {
218            let geometry = Self::read_geometry(stream).await?;
219            AsyncLink::update(link, &geometry).await?;
220            stream.write_all(&[MSG_OK]).await?;
221            Ok(())
222        } else {
223            Err(LinkError::closed())
224        }
225    }
226
227    async fn read_geometry(stream: &mut TcpStream) -> std::io::Result<Geometry> {
228        let mut num_devices_buf = [0u8; size_of::<u32>()];
229        stream.read_exact(&mut num_devices_buf).await?;
230        let num_devices = u32::from_le_bytes(num_devices_buf);
231
232        let mut devices = Vec::new();
233        for _ in 0..num_devices {
234            let mut pos_buf = [0u8; size_of::<f32>() * 3];
235            stream.read_exact(&mut pos_buf).await?;
236            let x = f32::from_le_bytes([pos_buf[0], pos_buf[1], pos_buf[2], pos_buf[3]]);
237            let y = f32::from_le_bytes([pos_buf[4], pos_buf[5], pos_buf[6], pos_buf[7]]);
238            let z = f32::from_le_bytes([pos_buf[8], pos_buf[9], pos_buf[10], pos_buf[11]]);
239
240            let mut rot_buf = [0u8; size_of::<f32>() * 4];
241            stream.read_exact(&mut rot_buf).await?;
242            let w = f32::from_le_bytes([rot_buf[0], rot_buf[1], rot_buf[2], rot_buf[3]]);
243            let i = f32::from_le_bytes([rot_buf[4], rot_buf[5], rot_buf[6], rot_buf[7]]);
244            let j = f32::from_le_bytes([rot_buf[8], rot_buf[9], rot_buf[10], rot_buf[11]]);
245            let k = f32::from_le_bytes([rot_buf[12], rot_buf[13], rot_buf[14], rot_buf[15]]);
246
247            devices.push(
248                autd3_core::devices::AUTD3 {
249                    pos: Point3::new(x, y, z),
250                    rot: UnitQuaternion::new_unchecked(Quaternion::new(w, i, j, k)),
251                }
252                .into(),
253            );
254        }
255
256        Ok(Geometry::new(devices))
257    }
258
259    async fn handle_send_data(&mut self, stream: &mut TcpStream) -> Result<(), LinkError> {
260        if let Some(link) = self.link.as_mut() {
261            let mut tx = AsyncLink::alloc_tx_buffer(link).await?;
262
263            for tx_msg in tx.iter_mut() {
264                let bytes = unsafe {
265                    std::slice::from_raw_parts_mut(
266                        tx_msg as *mut TxMessage as *mut u8,
267                        size_of::<TxMessage>(),
268                    )
269                };
270                stream.read_exact(bytes).await?;
271            }
272
273            AsyncLink::send(link, tx).await?;
274
275            stream.write_all(&[MSG_OK]).await?;
276
277            Ok(())
278        } else {
279            Err(LinkError::closed())
280        }
281    }
282
283    async fn handle_read_data(&mut self, stream: &mut TcpStream) -> Result<(), LinkError> {
284        let num_devices = self.num_devices;
285        let mut rx = match self.rx_buf.take() {
286            Some(buf) if buf.len() == num_devices => buf,
287            _ => vec![RxMessage::new(0, Ack::new(0, 0)); num_devices],
288        };
289        if let Some(link) = self.link.as_mut() {
290            AsyncLink::receive(link, &mut rx).await?;
291
292            let buffer_size = size_of::<u8>() + size_of::<RxMessage>() * rx.len();
293            if self.read_buffer.len() < buffer_size {
294                self.read_buffer.resize(buffer_size, 0);
295            }
296
297            self.read_buffer[0] = MSG_OK;
298            unsafe {
299                std::ptr::copy_nonoverlapping(
300                    rx.as_ptr() as *const u8,
301                    self.read_buffer.as_mut_ptr().add(1),
302                    size_of::<RxMessage>() * rx.len(),
303                );
304            }
305            self.rx_buf = Some(rx);
306            stream.write_all(&self.read_buffer).await?;
307
308            Ok(())
309        } else {
310            Err(LinkError::closed())
311        }
312    }
313
314    async fn handle_close(&mut self, stream: &mut TcpStream) -> Result<(), LinkError> {
315        if let Some(link) = self.link.as_mut() {
316            AsyncLink::close(link).await?;
317            stream.write_all(&[MSG_OK]).await?;
318            Ok(())
319        } else {
320            Err(LinkError::closed())
321        }
322    }
323
324    async fn send_error(&self, stream: &mut TcpStream, error: &LinkError) -> std::io::Result<()> {
325        let error_msg = error.to_string();
326        let error_bytes = error_msg.as_bytes();
327        let error_len = error_bytes.len() as u32;
328
329        let mut buffer = Vec::with_capacity(size_of::<u8>() + size_of::<u32>() + error_bytes.len());
330        buffer.push(MSG_ERROR);
331        buffer.extend_from_slice(&error_len.to_le_bytes());
332        buffer.extend_from_slice(error_bytes);
333
334        stream.write_all(&buffer).await
335    }
336}