datex_native/com_interfaces/websocket/
websocket_server.rs1use datex_core::{derive_setup_data};
2use core::{
3 result::Result, str::FromStr,
4};
5use std::net::SocketAddr;
6use std::sync::Arc;
7use futures_util::{SinkExt, StreamExt};
8use futures_util::stream::{SplitSink, SplitStream};
9use log::{error, info};
10use tokio::net::{TcpListener, TcpStream};
11use tungstenite::Message;
12use tokio_tungstenite::{accept_async, WebSocketStream};
13use futures::lock::Mutex;
14use datex_core::{
15 network::{
16 com_hub::errors::ComInterfaceCreateError,
17 com_interfaces::com_interface::{
18 factory::{
19 ComInterfaceAsyncFactory, ComInterfaceAsyncFactoryResult,
20 },
21 properties::{InterfaceDirection, ComInterfaceProperties},
22 },
23 },
24};
25use datex_core::global::dxb_block::DXBBlock;
26use datex_core::network::com_interfaces::com_interface::factory::{ComInterfaceConfiguration, SendCallback, SendFailure, SocketProperties, SocketConfiguration};
27use datex_core::network::com_interfaces::default_setup_data::websocket::websocket_server::WebSocketServerInterfaceSetupData;
28
29derive_setup_data!(WebSocketServerInterfaceSetupDataNative, WebSocketServerInterfaceSetupData);
30
31
32impl WebSocketServerInterfaceSetupDataNative {
33 async fn create_interface(self) -> Result<ComInterfaceConfiguration, ComInterfaceCreateError> {
34 let addr = SocketAddr::from_str(&self.bind_address)
35 .map_err(ComInterfaceCreateError::invalid_setup_data)?;
36
37 let listener = TcpListener::bind(&addr).await.map_err(|err| {
38 ComInterfaceCreateError::connection_error_with_details(err)
39 })?;
40
41 info!("WebSocket Server listening on {addr}");
42
43 Ok(ComInterfaceConfiguration::new_multi_socket(
44 ComInterfaceProperties {
45 name: Some(addr.to_string()),
46 connectable_interfaces: WebSocketServerInterfaceSetupData::get_clients_setup_data(self.0.accept_addresses)?,
47 ..Self::get_default_properties()
48 },
49 async gen move {
50 loop {
51 match Self::get_next_websocket_connection(&listener).await {
53 Ok((mut read, write)) => {
54 info!("Accepted new WebSocket connection");
55 yield Ok(SocketConfiguration::new_in_out(
57 SocketProperties::new(InterfaceDirection::InOut, 1),
58 async gen move {
60 loop {
62 match read.next().await {
63 Some(Ok(Message::Binary(bin))) => {
64 yield Ok(bin);
65 }
66 Some(Ok(_)) => {
67 error!("Invalid message type received");
68 return yield Err(());
69 }
70 Some(Err(e)) => {
71 error!("WebSocket error from {addr}: {e}");
72 return yield Err(())
73 }
74 None => {
75 return;
77 }
78 }
79 }
80 },
81 SendCallback::new_async(move |block: DXBBlock| {
83 let write = write.clone();
84 async move {
85 write
86 .lock()
87 .await
88 .send(Message::Binary(block.to_bytes())).await
89 .map_err(|e| {
90 error!("WebSocket write error: {e}");
91 SendFailure(Box::new(block))
92 })
93 }
94 })
95 ));
96 }
97 Err(_) => {
98 continue;
100 }
101 }
102 }
103 }
104 ))
105 }
106
107 async fn get_next_websocket_connection(listener: &TcpListener) -> Result<
108 (SplitStream<WebSocketStream<TcpStream>>, Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>),
109 ()
110 > {
111 let next_socket = listener.accept().await;
113 match next_socket {
114 Ok((stream, addr)) => {
115 match accept_async(stream).await {
116 Ok(ws_stream) => {
117 let (write, read) = ws_stream.split();
118 let write = Arc::new(Mutex::new(write));
119 Ok((read, write))
120 }
121 Err(e) => {
122 error!("WebSocket handshake failed with {addr}: {e}");
123 Err(())
124 }
125 }
126 }
127 Err(e) => {
128 error!("Failed to accept connection: {e}");
129 Err(())
130 }
131 }
132 }
133}
134
135impl ComInterfaceAsyncFactory for WebSocketServerInterfaceSetupDataNative {
136 fn create_interface(self) -> ComInterfaceAsyncFactoryResult {
137 Box::pin(self.create_interface())
138 }
139
140 fn get_default_properties() -> ComInterfaceProperties {
141 WebSocketServerInterfaceSetupData::get_default_properties()
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use std::assert_matches;
148 use datex_core::{
149 network::{
150 com_hub::errors::ComInterfaceCreateError,
151 },
152 };
153 use super::*;
154
155 #[tokio::test]
156 async fn test_construct() {
157
158 let address = "0.0.0.0:1234".to_string();
159
160 let interface_configuration =
161 WebSocketServerInterfaceSetupDataNative(WebSocketServerInterfaceSetupData {
162 bind_address: address.clone(),
163 accept_addresses: None,
164 })
165 .create_interface()
166 .await
167 .unwrap();
168
169 assert_eq!(
170 interface_configuration.properties.name,
171 Some(address)
172 );
173 }
174
175 #[tokio::test]
176 async fn test_construct_invalid_address() {
177
178 assert_matches!(
179 WebSocketServerInterfaceSetupDataNative(WebSocketServerInterfaceSetupData {
180 bind_address: "1.2.3".to_string(),
181 accept_addresses: None,
182 })
183 .create_interface()
184 .await,
185 Err(ComInterfaceCreateError::InvalidSetupData(_))
186 );
187 }
188}