realflight_bridge/bridge/remote/
async_impl.rs

1//! Async implementation of the remote bridge for RealFlight simulator.
2
3use std::net::ToSocketAddrs;
4use std::time::Duration;
5
6use log::error;
7use postcard::{from_bytes, to_stdvec};
8use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
9use tokio::net::TcpStream;
10use tokio::sync::Mutex;
11use tokio::time::timeout;
12
13use crate::bridge::AsyncBridge;
14use crate::{BridgeError, ControlInputs, SimulatorState};
15
16use super::{Request, RequestType, Response};
17
18const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
19
20/// Builder for AsyncRemoteBridge.
21///
22/// Configure options synchronously, then call `build()` to connect.
23#[derive(Debug, Clone)]
24pub struct AsyncRemoteBridgeBuilder {
25    address: String,
26    connect_timeout: Duration,
27}
28
29impl AsyncRemoteBridgeBuilder {
30    /// Creates a new builder with the specified address.
31    pub fn new(address: &str) -> Self {
32        Self {
33            address: address.to_string(),
34            connect_timeout: DEFAULT_TIMEOUT,
35        }
36    }
37
38    /// Sets the connection timeout.
39    #[must_use]
40    pub fn timeout(mut self, timeout: Duration) -> Self {
41        self.connect_timeout = timeout;
42        self
43    }
44
45    /// Builds the AsyncRemoteBridge, connecting to the server.
46    pub async fn build(self) -> Result<AsyncRemoteBridge, BridgeError> {
47        let addr = self
48            .address
49            .to_socket_addrs()
50            .map_err(|e| BridgeError::Initialization(format!("Invalid address: {}", e)))?
51            .next()
52            .ok_or_else(|| BridgeError::Initialization("Invalid address".into()))?;
53
54        let stream = timeout(self.connect_timeout, TcpStream::connect(addr))
55            .await
56            .map_err(|_| {
57                BridgeError::Initialization(format!(
58                    "Connection timeout after {:?}",
59                    self.connect_timeout
60                ))
61            })?
62            .map_err(|e| BridgeError::Initialization(format!("Connection failed: {}", e)))?;
63
64        stream
65            .set_nodelay(true)
66            .map_err(|e| BridgeError::Initialization(format!("Failed to set nodelay: {}", e)))?;
67
68        let (read_half, write_half) = stream.into_split();
69
70        Ok(AsyncRemoteBridge {
71            reader: Mutex::new(BufReader::new(read_half)),
72            writer: Mutex::new(BufWriter::new(write_half)),
73            response_buffer: Mutex::new(Vec::with_capacity(4096)),
74        })
75    }
76}
77
78/// Async client for interacting with a remote RealFlight simulator via a proxy server.
79///
80/// # Examples
81///
82/// ```no_run
83/// use realflight_bridge::{AsyncBridge, AsyncRemoteBridge, ControlInputs};
84/// use std::time::Duration;
85///
86/// #[tokio::main]
87/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
88///     // Connect to a remote proxy server
89///     let bridge = AsyncRemoteBridge::new("192.168.1.100:18083").await?;
90///
91///     // Or with custom timeout
92///     let bridge = AsyncRemoteBridge::builder("192.168.1.100:18083")
93///         .timeout(Duration::from_secs(10))
94///         .build()
95///         .await?;
96///
97///     // Create sample control inputs
98///     let inputs = ControlInputs::default();
99///
100///     // Exchange data with the simulator
101///     let state = bridge.exchange_data(&inputs).await?;
102///     println!("Current airspeed: {:?}", state.airspeed);
103///
104///     Ok(())
105/// }
106/// ```
107pub struct AsyncRemoteBridge {
108    reader: Mutex<BufReader<tokio::net::tcp::OwnedReadHalf>>,
109    writer: Mutex<BufWriter<tokio::net::tcp::OwnedWriteHalf>>,
110    response_buffer: Mutex<Vec<u8>>,
111}
112
113impl AsyncBridge for AsyncRemoteBridge {
114    async fn exchange_data(&self, control: &ControlInputs) -> Result<SimulatorState, BridgeError> {
115        let response = self
116            .send_request(RequestType::ExchangeData, Some(control))
117            .await?;
118        if let Some(state) = response.payload {
119            Ok(state)
120        } else {
121            error!("No payload in response: {:?}", response.status);
122            Err(BridgeError::SoapFault("No payload in response".to_string()))
123        }
124    }
125
126    async fn enable_rc(&self) -> Result<(), BridgeError> {
127        self.send_request(RequestType::EnableRC, None).await?;
128        Ok(())
129    }
130
131    async fn disable_rc(&self) -> Result<(), BridgeError> {
132        self.send_request(RequestType::DisableRC, None).await?;
133        Ok(())
134    }
135
136    async fn reset_aircraft(&self) -> Result<(), BridgeError> {
137        self.send_request(RequestType::ResetAircraft, None).await?;
138        Ok(())
139    }
140}
141
142impl AsyncRemoteBridge {
143    /// Creates a new AsyncRemoteBridge connected to the specified address.
144    pub async fn new(address: &str) -> Result<Self, BridgeError> {
145        AsyncRemoteBridgeBuilder::new(address).build().await
146    }
147
148    /// Returns a builder for custom configuration.
149    pub fn builder(address: &str) -> AsyncRemoteBridgeBuilder {
150        AsyncRemoteBridgeBuilder::new(address)
151    }
152
153    /// Sends a request to the server and receives a response.
154    async fn send_request(
155        &self,
156        request_type: RequestType,
157        payload: Option<&ControlInputs>,
158    ) -> Result<Response, BridgeError> {
159        let request = Request {
160            request_type,
161            payload: payload.cloned(),
162        };
163
164        // Serialize the request to a byte vector
165        let request_bytes = to_stdvec(&request)
166            .map_err(|e| BridgeError::SoapFault(format!("Serialization error: {}", e)))?;
167
168        let mut writer = self.writer.lock().await;
169
170        // Send the length of the request (4 bytes)
171        let length_bytes = (request_bytes.len() as u32).to_be_bytes();
172        writer.write_all(&length_bytes).await?;
173
174        // Send the serialized request data
175        writer.write_all(&request_bytes).await?;
176        writer.flush().await?;
177
178        drop(writer); // Release lock before reading
179
180        let mut reader = self.reader.lock().await;
181
182        // Read the response length (4 bytes)
183        let mut length_buffer = [0u8; 4];
184        reader.read_exact(&mut length_buffer).await?;
185        let response_length = u32::from_be_bytes(length_buffer) as usize;
186
187        // Read the response data into reusable buffer
188        let mut response_buffer = self.response_buffer.lock().await;
189        response_buffer.clear();
190        response_buffer.resize(response_length, 0);
191        reader.read_exact(&mut response_buffer).await?;
192
193        // Deserialize the response
194        let response: Response = from_bytes(&response_buffer)
195            .map_err(|e| BridgeError::SoapFault(format!("Deserialization error: {}", e)))?;
196
197        Ok(response)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::bridge::AsyncBridge;
205    use crate::bridge::remote::Response;
206    use std::io::{Read, Write};
207    use std::net::TcpListener;
208
209    // ========================================================================
210    // Connection Tests
211    // ========================================================================
212
213    #[tokio::test]
214    async fn connects_to_server() {
215        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
216        let addr = listener.local_addr().unwrap().to_string();
217
218        // Accept one connection in background
219        let handle = std::thread::spawn(move || {
220            let _ = listener.accept();
221        });
222
223        let result = AsyncRemoteBridge::builder(&addr)
224            .timeout(Duration::from_secs(1))
225            .build()
226            .await;
227
228        assert!(result.is_ok());
229        let _ = handle.join();
230    }
231
232    #[tokio::test]
233    async fn builder_sets_timeout() {
234        let builder =
235            AsyncRemoteBridgeBuilder::new("127.0.0.1:12345").timeout(Duration::from_millis(100));
236
237        assert_eq!(builder.connect_timeout, Duration::from_millis(100));
238    }
239
240    #[tokio::test]
241    async fn connection_timeout_returns_error() {
242        // Use a non-routable address to trigger timeout
243        let result = AsyncRemoteBridge::builder("10.255.255.1:12345")
244            .timeout(Duration::from_millis(100))
245            .build()
246            .await;
247
248        assert!(result.is_err());
249    }
250
251    #[tokio::test]
252    async fn invalid_address_returns_error() {
253        let result = AsyncRemoteBridge::new("not-a-valid-address").await;
254        assert!(result.is_err());
255    }
256
257    // ========================================================================
258    // Helper Functions
259    // ========================================================================
260
261    fn mock_server_send_response(mut stream: std::net::TcpStream, response: Response) {
262        // Read the request (length + data)
263        let mut length_buffer = [0u8; 4];
264        stream.read_exact(&mut length_buffer).unwrap();
265        let msg_length = u32::from_be_bytes(length_buffer) as usize;
266        let mut buffer = vec![0u8; msg_length];
267        stream.read_exact(&mut buffer).unwrap();
268
269        // Send response
270        let response_bytes = to_stdvec(&response).unwrap();
271        let length_bytes = (response_bytes.len() as u32).to_be_bytes();
272        stream.write_all(&length_bytes).unwrap();
273        stream.write_all(&response_bytes).unwrap();
274        stream.flush().unwrap();
275    }
276
277    // ========================================================================
278    // Operation Tests
279    // ========================================================================
280
281    #[tokio::test]
282    async fn enable_rc_succeeds() {
283        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
284        let addr = listener.local_addr().unwrap().to_string();
285
286        let handle = std::thread::spawn(move || {
287            let (stream, _) = listener.accept().unwrap();
288            mock_server_send_response(stream, Response::success());
289        });
290
291        let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
292        let result = bridge.enable_rc().await;
293
294        assert!(result.is_ok());
295        let _ = handle.join();
296    }
297
298    #[tokio::test]
299    async fn disable_rc_succeeds() {
300        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
301        let addr = listener.local_addr().unwrap().to_string();
302
303        let handle = std::thread::spawn(move || {
304            let (stream, _) = listener.accept().unwrap();
305            mock_server_send_response(stream, Response::success());
306        });
307
308        let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
309        let result = bridge.disable_rc().await;
310
311        assert!(result.is_ok());
312        let _ = handle.join();
313    }
314
315    #[tokio::test]
316    async fn reset_aircraft_succeeds() {
317        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
318        let addr = listener.local_addr().unwrap().to_string();
319
320        let handle = std::thread::spawn(move || {
321            let (stream, _) = listener.accept().unwrap();
322            mock_server_send_response(stream, Response::success());
323        });
324
325        let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
326        let result = bridge.reset_aircraft().await;
327
328        assert!(result.is_ok());
329        let _ = handle.join();
330    }
331
332    #[tokio::test]
333    async fn exchange_data_succeeds() {
334        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
335        let addr = listener.local_addr().unwrap().to_string();
336
337        let handle = std::thread::spawn(move || {
338            let (stream, _) = listener.accept().unwrap();
339            let state = SimulatorState::default();
340            mock_server_send_response(stream, Response::success_with(state));
341        });
342
343        let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
344        let control = ControlInputs::default();
345        let result = bridge.exchange_data(&control).await;
346
347        assert!(result.is_ok());
348        let _ = handle.join();
349    }
350
351    // ========================================================================
352    // Error Handling Tests
353    // ========================================================================
354
355    #[tokio::test]
356    async fn exchange_data_no_payload_returns_error() {
357        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
358        let addr = listener.local_addr().unwrap().to_string();
359
360        let handle = std::thread::spawn(move || {
361            let (stream, _) = listener.accept().unwrap();
362            // Send success but with no payload
363            mock_server_send_response(stream, Response::success());
364        });
365
366        let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
367        let control = ControlInputs::default();
368        let result = bridge.exchange_data(&control).await;
369
370        match result {
371            Err(BridgeError::SoapFault(msg)) => {
372                assert!(msg.contains("No payload"));
373            }
374            other => panic!("expected SoapFault, got {:?}", other),
375        }
376        let _ = handle.join();
377    }
378
379    #[tokio::test]
380    async fn malformed_response_returns_error() {
381        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
382        let addr = listener.local_addr().unwrap().to_string();
383
384        let handle = std::thread::spawn(move || {
385            let (mut stream, _) = listener.accept().unwrap();
386            // Read the request
387            let mut length_buffer = [0u8; 4];
388            stream.read_exact(&mut length_buffer).unwrap();
389            let msg_length = u32::from_be_bytes(length_buffer) as usize;
390            let mut buffer = vec![0u8; msg_length];
391            stream.read_exact(&mut buffer).unwrap();
392
393            // Send malformed response (invalid postcard data)
394            let garbage = vec![0xFF, 0xFF, 0xFF, 0xFF];
395            let length_bytes = (garbage.len() as u32).to_be_bytes();
396            stream.write_all(&length_bytes).unwrap();
397            stream.write_all(&garbage).unwrap();
398            stream.flush().unwrap();
399        });
400
401        let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
402        let result = bridge.enable_rc().await;
403
404        match result {
405            Err(BridgeError::SoapFault(msg)) => {
406                assert!(msg.contains("Deserialization"));
407            }
408            other => panic!("expected SoapFault, got {:?}", other),
409        }
410        let _ = handle.join();
411    }
412
413    #[tokio::test]
414    async fn server_disconnect_returns_error() {
415        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
416        let addr = listener.local_addr().unwrap().to_string();
417
418        let handle = std::thread::spawn(move || {
419            let (stream, _) = listener.accept().unwrap();
420            // Close connection immediately without responding
421            drop(stream);
422        });
423
424        let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
425        let result = bridge.enable_rc().await;
426
427        assert!(result.is_err());
428        let _ = handle.join();
429    }
430}