realflight_bridge/bridge/remote/
async_impl.rs1use std::net::ToSocketAddrs;
4use std::time::Duration;
5
6use log::error;
7use postcard::{from_bytes, to_stdvec};
8use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
9use tokio::net::TcpStream;
10use tokio::sync::Mutex;
11use tokio::time::timeout;
12
13use crate::bridge::AsyncBridge;
14use crate::{BridgeError, ControlInputs, SimulatorState};
15
16use super::{Request, RequestType, Response};
17
18const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
19
20#[derive(Debug, Clone)]
24pub struct AsyncRemoteBridgeBuilder {
25 address: String,
26 connect_timeout: Duration,
27}
28
29impl AsyncRemoteBridgeBuilder {
30 pub fn new(address: &str) -> Self {
32 Self {
33 address: address.to_string(),
34 connect_timeout: DEFAULT_TIMEOUT,
35 }
36 }
37
38 #[must_use]
40 pub fn timeout(mut self, timeout: Duration) -> Self {
41 self.connect_timeout = timeout;
42 self
43 }
44
45 pub async fn build(self) -> Result<AsyncRemoteBridge, BridgeError> {
47 let addr = self
48 .address
49 .to_socket_addrs()
50 .map_err(|e| BridgeError::Initialization(format!("Invalid address: {}", e)))?
51 .next()
52 .ok_or_else(|| BridgeError::Initialization("Invalid address".into()))?;
53
54 let stream = timeout(self.connect_timeout, TcpStream::connect(addr))
55 .await
56 .map_err(|_| {
57 BridgeError::Initialization(format!(
58 "Connection timeout after {:?}",
59 self.connect_timeout
60 ))
61 })?
62 .map_err(|e| BridgeError::Initialization(format!("Connection failed: {}", e)))?;
63
64 stream
65 .set_nodelay(true)
66 .map_err(|e| BridgeError::Initialization(format!("Failed to set nodelay: {}", e)))?;
67
68 let (read_half, write_half) = stream.into_split();
69
70 Ok(AsyncRemoteBridge {
71 reader: Mutex::new(BufReader::new(read_half)),
72 writer: Mutex::new(BufWriter::new(write_half)),
73 response_buffer: Mutex::new(Vec::with_capacity(4096)),
74 })
75 }
76}
77
78pub struct AsyncRemoteBridge {
108 reader: Mutex<BufReader<tokio::net::tcp::OwnedReadHalf>>,
109 writer: Mutex<BufWriter<tokio::net::tcp::OwnedWriteHalf>>,
110 response_buffer: Mutex<Vec<u8>>,
111}
112
113impl AsyncBridge for AsyncRemoteBridge {
114 async fn exchange_data(&self, control: &ControlInputs) -> Result<SimulatorState, BridgeError> {
115 let response = self
116 .send_request(RequestType::ExchangeData, Some(control))
117 .await?;
118 if let Some(state) = response.payload {
119 Ok(state)
120 } else {
121 error!("No payload in response: {:?}", response.status);
122 Err(BridgeError::SoapFault("No payload in response".to_string()))
123 }
124 }
125
126 async fn enable_rc(&self) -> Result<(), BridgeError> {
127 self.send_request(RequestType::EnableRC, None).await?;
128 Ok(())
129 }
130
131 async fn disable_rc(&self) -> Result<(), BridgeError> {
132 self.send_request(RequestType::DisableRC, None).await?;
133 Ok(())
134 }
135
136 async fn reset_aircraft(&self) -> Result<(), BridgeError> {
137 self.send_request(RequestType::ResetAircraft, None).await?;
138 Ok(())
139 }
140}
141
142impl AsyncRemoteBridge {
143 pub async fn new(address: &str) -> Result<Self, BridgeError> {
145 AsyncRemoteBridgeBuilder::new(address).build().await
146 }
147
148 pub fn builder(address: &str) -> AsyncRemoteBridgeBuilder {
150 AsyncRemoteBridgeBuilder::new(address)
151 }
152
153 async fn send_request(
155 &self,
156 request_type: RequestType,
157 payload: Option<&ControlInputs>,
158 ) -> Result<Response, BridgeError> {
159 let request = Request {
160 request_type,
161 payload: payload.cloned(),
162 };
163
164 let request_bytes = to_stdvec(&request)
166 .map_err(|e| BridgeError::SoapFault(format!("Serialization error: {}", e)))?;
167
168 let mut writer = self.writer.lock().await;
169
170 let length_bytes = (request_bytes.len() as u32).to_be_bytes();
172 writer.write_all(&length_bytes).await?;
173
174 writer.write_all(&request_bytes).await?;
176 writer.flush().await?;
177
178 drop(writer); let mut reader = self.reader.lock().await;
181
182 let mut length_buffer = [0u8; 4];
184 reader.read_exact(&mut length_buffer).await?;
185 let response_length = u32::from_be_bytes(length_buffer) as usize;
186
187 let mut response_buffer = self.response_buffer.lock().await;
189 response_buffer.clear();
190 response_buffer.resize(response_length, 0);
191 reader.read_exact(&mut response_buffer).await?;
192
193 let response: Response = from_bytes(&response_buffer)
195 .map_err(|e| BridgeError::SoapFault(format!("Deserialization error: {}", e)))?;
196
197 Ok(response)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::bridge::AsyncBridge;
205 use crate::bridge::remote::Response;
206 use std::io::{Read, Write};
207 use std::net::TcpListener;
208
209 #[tokio::test]
214 async fn connects_to_server() {
215 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
216 let addr = listener.local_addr().unwrap().to_string();
217
218 let handle = std::thread::spawn(move || {
220 let _ = listener.accept();
221 });
222
223 let result = AsyncRemoteBridge::builder(&addr)
224 .timeout(Duration::from_secs(1))
225 .build()
226 .await;
227
228 assert!(result.is_ok());
229 let _ = handle.join();
230 }
231
232 #[tokio::test]
233 async fn builder_sets_timeout() {
234 let builder =
235 AsyncRemoteBridgeBuilder::new("127.0.0.1:12345").timeout(Duration::from_millis(100));
236
237 assert_eq!(builder.connect_timeout, Duration::from_millis(100));
238 }
239
240 #[tokio::test]
241 async fn connection_timeout_returns_error() {
242 let result = AsyncRemoteBridge::builder("10.255.255.1:12345")
244 .timeout(Duration::from_millis(100))
245 .build()
246 .await;
247
248 assert!(result.is_err());
249 }
250
251 #[tokio::test]
252 async fn invalid_address_returns_error() {
253 let result = AsyncRemoteBridge::new("not-a-valid-address").await;
254 assert!(result.is_err());
255 }
256
257 fn mock_server_send_response(mut stream: std::net::TcpStream, response: Response) {
262 let mut length_buffer = [0u8; 4];
264 stream.read_exact(&mut length_buffer).unwrap();
265 let msg_length = u32::from_be_bytes(length_buffer) as usize;
266 let mut buffer = vec![0u8; msg_length];
267 stream.read_exact(&mut buffer).unwrap();
268
269 let response_bytes = to_stdvec(&response).unwrap();
271 let length_bytes = (response_bytes.len() as u32).to_be_bytes();
272 stream.write_all(&length_bytes).unwrap();
273 stream.write_all(&response_bytes).unwrap();
274 stream.flush().unwrap();
275 }
276
277 #[tokio::test]
282 async fn enable_rc_succeeds() {
283 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
284 let addr = listener.local_addr().unwrap().to_string();
285
286 let handle = std::thread::spawn(move || {
287 let (stream, _) = listener.accept().unwrap();
288 mock_server_send_response(stream, Response::success());
289 });
290
291 let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
292 let result = bridge.enable_rc().await;
293
294 assert!(result.is_ok());
295 let _ = handle.join();
296 }
297
298 #[tokio::test]
299 async fn disable_rc_succeeds() {
300 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
301 let addr = listener.local_addr().unwrap().to_string();
302
303 let handle = std::thread::spawn(move || {
304 let (stream, _) = listener.accept().unwrap();
305 mock_server_send_response(stream, Response::success());
306 });
307
308 let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
309 let result = bridge.disable_rc().await;
310
311 assert!(result.is_ok());
312 let _ = handle.join();
313 }
314
315 #[tokio::test]
316 async fn reset_aircraft_succeeds() {
317 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
318 let addr = listener.local_addr().unwrap().to_string();
319
320 let handle = std::thread::spawn(move || {
321 let (stream, _) = listener.accept().unwrap();
322 mock_server_send_response(stream, Response::success());
323 });
324
325 let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
326 let result = bridge.reset_aircraft().await;
327
328 assert!(result.is_ok());
329 let _ = handle.join();
330 }
331
332 #[tokio::test]
333 async fn exchange_data_succeeds() {
334 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
335 let addr = listener.local_addr().unwrap().to_string();
336
337 let handle = std::thread::spawn(move || {
338 let (stream, _) = listener.accept().unwrap();
339 let state = SimulatorState::default();
340 mock_server_send_response(stream, Response::success_with(state));
341 });
342
343 let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
344 let control = ControlInputs::default();
345 let result = bridge.exchange_data(&control).await;
346
347 assert!(result.is_ok());
348 let _ = handle.join();
349 }
350
351 #[tokio::test]
356 async fn exchange_data_no_payload_returns_error() {
357 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
358 let addr = listener.local_addr().unwrap().to_string();
359
360 let handle = std::thread::spawn(move || {
361 let (stream, _) = listener.accept().unwrap();
362 mock_server_send_response(stream, Response::success());
364 });
365
366 let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
367 let control = ControlInputs::default();
368 let result = bridge.exchange_data(&control).await;
369
370 match result {
371 Err(BridgeError::SoapFault(msg)) => {
372 assert!(msg.contains("No payload"));
373 }
374 other => panic!("expected SoapFault, got {:?}", other),
375 }
376 let _ = handle.join();
377 }
378
379 #[tokio::test]
380 async fn malformed_response_returns_error() {
381 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
382 let addr = listener.local_addr().unwrap().to_string();
383
384 let handle = std::thread::spawn(move || {
385 let (mut stream, _) = listener.accept().unwrap();
386 let mut length_buffer = [0u8; 4];
388 stream.read_exact(&mut length_buffer).unwrap();
389 let msg_length = u32::from_be_bytes(length_buffer) as usize;
390 let mut buffer = vec![0u8; msg_length];
391 stream.read_exact(&mut buffer).unwrap();
392
393 let garbage = vec![0xFF, 0xFF, 0xFF, 0xFF];
395 let length_bytes = (garbage.len() as u32).to_be_bytes();
396 stream.write_all(&length_bytes).unwrap();
397 stream.write_all(&garbage).unwrap();
398 stream.flush().unwrap();
399 });
400
401 let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
402 let result = bridge.enable_rc().await;
403
404 match result {
405 Err(BridgeError::SoapFault(msg)) => {
406 assert!(msg.contains("Deserialization"));
407 }
408 other => panic!("expected SoapFault, got {:?}", other),
409 }
410 let _ = handle.join();
411 }
412
413 #[tokio::test]
414 async fn server_disconnect_returns_error() {
415 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
416 let addr = listener.local_addr().unwrap().to_string();
417
418 let handle = std::thread::spawn(move || {
419 let (stream, _) = listener.accept().unwrap();
420 drop(stream);
422 });
423
424 let bridge = AsyncRemoteBridge::new(&addr).await.unwrap();
425 let result = bridge.enable_rc().await;
426
427 assert!(result.is_err());
428 let _ = handle.join();
429 }
430}