use alloc::{
collections::vec_deque::VecDeque, rc::Rc, string::String, vec::Vec,
};
use core::{
net::{IpAddr, SocketAddr},
prelude::rust_2024::*,
result::Result,
time::Duration,
};
use datex_core::{
std_sync::Mutex,
stdlib::{future::Future, pin::Pin},
};
use edge_http::ws::{
MAX_BASE64_KEY_LEN, MAX_BASE64_KEY_RESPONSE_LEN, NONCE_LEN,
};
use edge_nal_embassy::{
Dns as EmbassyDns, Tcp as EmbassyTcp, TcpBuffers, TcpSocket, TcpSocketRead,
TcpSocketWrite,
};
use edge_net::{
http::io::client::Connection,
nal::{AddrType, Dns, TcpSplit},
ws::{FrameHeader, FrameType},
};
use embassy_executor::Spawner;
use embassy_futures::select::{Either, select};
use embassy_net::Stack;
use embassy_sync::once_lock::OnceLock;
use datex_core::{
network::com_interfaces::com_interface::ComInterface, stdlib::sync::Arc,
};
use crate::hal::rng::RngHal;
use alloc::{boxed::Box, string::ToString, vec};
use core::ops::Deref;
use datex_core::{
channel::mpsc::{UnboundedReceiver, UnboundedSender},
network::{
com_hub::{
errors::InterfaceCreateError,
managers::interfaces_manager::ComInterfaceAsyncFactoryResult,
},
com_interfaces::{
com_interface::{
ComInterfaceEvent, ComInterfaceProxy,
error::ComInterfaceError,
factory::{ComInterfaceAsyncFactory, ComInterfaceSyncFactory},
properties::{InterfaceDirection, InterfaceProperties},
},
default_com_interfaces::websocket::websocket_common::{
WebSocketClientInterfaceSetupData, WebSocketError, parse_url,
},
},
},
task::spawn_with_panic_notify,
};
use log::{error, info};
use serde::Deserialize;
use url::Url;
static mut GLOBAL_STATE: Option<WebSocketClientInterfaceEmbeddedGlobalState> =
None;
pub struct WebSocketClientInterfaceEmbeddedGlobalState {
pub stack: Stack<'static>,
pub rng: Rc<dyn RngHal>,
}
impl WebSocketClientInterfaceEmbeddedGlobalState {
pub fn set_global_state(
global_state: WebSocketClientInterfaceEmbeddedGlobalState,
) {
unsafe { GLOBAL_STATE = Some(global_state) }
}
}
#[derive(Deserialize)]
pub struct WebSocketClientInterfaceSetupDataEmbedded(
pub WebSocketClientInterfaceSetupData,
);
impl Deref for WebSocketClientInterfaceSetupDataEmbedded {
type Target = WebSocketClientInterfaceSetupData;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl WebSocketClientInterfaceSetupDataEmbedded {
async fn create_interface(
self,
com_interface_proxy: ComInterfaceProxy,
) -> Result<InterfaceProperties, InterfaceCreateError> {
let global_state = unsafe {GLOBAL_STATE.as_ref()}.ok_or_else(|| {
InterfaceCreateError::invalid_setup_data("websocket-client cannot be created via factory, missing global state")
})?;
let url = parse_url(&self.url).map_err(|_| {
InterfaceCreateError::invalid_setup_data("Invalid WebSocket URL")
})?;
let connection_data = ConnectionData {
host: url.host_str().unwrap().to_string(),
port: url.port().unwrap_or(80),
ip: {
let dns = EmbassyDns::new(global_state.stack.clone());
dns.get_host_by_name(url.host_str().unwrap(), AddrType::IPv4)
.await
.unwrap()
},
};
info!(
"Connecting to WebSocket server at {}:{} (IP: {})",
connection_data.host, connection_data.port, connection_data.ip
);
let (uuid, sender) = com_interface_proxy
.create_and_init_socket(InterfaceDirection::InOut, 1);
info!("Connection upgraded to WS, starting traffic now");
info!(
"Opening TCP connection to {}:{}",
connection_data.ip, connection_data.port
);
let connect_result =
Arc::new(OnceLock::<Result<(), WebSocketError>>::new());
spawn_with_panic_notify(
&com_interface_proxy.async_context,
listen(
global_state.stack.clone(),
connection_data,
com_interface_proxy.event_receiver,
sender,
connect_result.clone(),
global_state.rng.clone(),
),
);
connect_result.get().await.clone().map_err(|e| {
InterfaceCreateError::InterfaceError(
ComInterfaceError::connection_error_with_details(e),
)
})?;
Ok(InterfaceProperties {
name: Some(url.to_string()),
created_sockets: Some(vec![uuid]),
..Self::get_default_properties()
})
}
}
struct ConnectionData {
host: String,
port: u16,
ip: IpAddr,
}
async fn connect<'a>(
connection_data: ConnectionData,
buf: &'a mut [u8],
tcp: &'a EmbassyTcp<'a, 10>,
rng: Rc<dyn RngHal>,
) -> Result<TcpSocket<'a, 10, 1024, 1024>, ()> {
let mut conn: Connection<_> = Connection::new(
buf,
tcp,
SocketAddr::new(connection_data.ip, connection_data.port),
);
let mut nonce = [0_u8; NONCE_LEN];
rng.fill(&mut nonce);
let mut buf = [0_u8; MAX_BASE64_KEY_LEN];
info!("Initiating WS upgrade request");
conn.initiate_ws_upgrade_request(
Some(&connection_data.host),
Some(&connection_data.host),
"/",
None,
&nonce,
&mut buf,
)
.await
.map_err(|_| ())?;
info!("Waiting for WS upgrade response");
conn.initiate_response().await.unwrap();
info!("Waiting for WS upgrade confirmation");
let mut buf = [0_u8; MAX_BASE64_KEY_RESPONSE_LEN];
if !conn.is_ws_upgrade_accepted(&nonce, &mut buf).unwrap() {
error!("WS upgrade failed");
}
conn.complete().await.unwrap();
let (socket, _buf) = conn.release();
info!("TCP connection established and upgraded to WebSocket");
Ok(socket)
}
#[embassy_executor::task]
async fn listen(
stack: Stack<'static>,
connection_data: ConnectionData,
receiver: UnboundedReceiver<ComInterfaceEvent>,
sender: UnboundedSender<Vec<u8>>,
connect_result: Arc<OnceLock<Result<(), WebSocketError>>>,
rng: Rc<dyn RngHal>,
) {
let buffers = TcpBuffers::<10, 1024, 1024>::default();
let tcp = EmbassyTcp::new(stack, &buffers);
let mut buf = Box::new([0_u8; 8192]);
let result =
connect(connection_data, buf.as_mut(), &tcp, rng.clone()).await;
if let Ok(mut socket) = result {
connect_result.get_or_init(|| Ok(()));
let (read, write) = socket.split();
match select(
receive_loop(read, sender),
send_loop(write, receiver, rng),
)
.await
{
Either::First(_) => {
info!("receive_loop stopped");
}
Either::Second(_) => {
info!("send_loop stopped");
}
}
info!("Websocket loop stopped");
} else {
connect_result.get_or_init(|| Err(WebSocketError::ConnectionError));
}
}
async fn send_loop<'a>(
mut socket_write: TcpSocketWrite<'a>,
mut receiver: UnboundedReceiver<ComInterfaceEvent>,
rng: Rc<dyn RngHal>,
) -> Result<(), ()> {
while let Some(event) = receiver.next().await {
match event {
ComInterfaceEvent::SendBlock(block, _) => {
let header = FrameHeader {
frame_type: FrameType::Binary(false),
payload_len: block.len() as _,
mask_key: rng.random().into(),
};
header.send(&mut socket_write).await.map_err(|_| ())?;
header
.send_payload(&mut socket_write, &block.to_bytes())
.await
.map_err(|_| ())?;
}
ComInterfaceEvent::Destroy => {
info!("send_loop received Destroy event, stopping");
break;
}
_ => {
error!("Unhandled event in send_loop: {:?}", event);
}
}
}
Ok(())
}
async fn receive_loop<'a>(
mut socket_rc: TcpSocketRead<'a>,
mut sender: UnboundedSender<Vec<u8>>,
) -> Result<!, ()> {
let mut buf = [0_u8; 256];
loop {
let header = FrameHeader::recv(&mut socket_rc).await.map_err(|_| ())?;
let payload = header
.recv_payload(&mut socket_rc, buf.as_mut())
.await
.map_err(|_| ())?;
match header.frame_type {
FrameType::Text(_) => {
info!(
"Got {header}, with payload \"{}\"",
core::str::from_utf8(payload).unwrap()
);
}
FrameType::Binary(_) => {
info!("Got {header}, with payload {payload:?}");
sender.start_send(payload.to_vec())?;
}
_ => {
error!("Unexpected {}", header);
}
}
if !header.frame_type.is_final() {
error!("Unexpected fragmented frame");
}
}
}
impl ComInterfaceAsyncFactory for WebSocketClientInterfaceSetupDataEmbedded {
fn create_interface(
self,
com_interface_proxy: ComInterfaceProxy,
) -> ComInterfaceAsyncFactoryResult {
Box::pin(
async move { self.create_interface(com_interface_proxy).await },
)
}
fn get_default_properties() -> InterfaceProperties {
InterfaceProperties {
interface_type: "websocket-client".to_string(),
channel: "websocket".to_string(),
round_trip_time: Duration::from_millis(40),
max_bandwidth: 1000,
..InterfaceProperties::default()
}
}
}