use std::sync::{atomic::AtomicBool, Arc};
use futures_util::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::{
net::TcpStream,
sync::mpsc,
time::{sleep, timeout},
};
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
#[cfg(feature = "bebop")]
use crate::generated::schema::Category;
#[cfg(feature = "bebop")]
use crate::helpers::common::get_data_schema;
use crate::{
helpers::{
common::{make_disconnect_message, make_ping_message},
server_sender::{SenderStatus, ServerSenderTrait},
traits::atomic::FlagAtomic,
},
log_debug, log_error, Settings,
};
use super::{
internal_client::ClientOptions,
types::{save_key, RwServerSender, DB},
};
struct TryConnectGuard {
server_sender: RwServerSender,
disarmed: bool,
}
impl TryConnectGuard {
fn new(server_sender: RwServerSender) -> Self {
Self {
server_sender,
disarmed: false,
}
}
fn disarm(&mut self) {
self.disarmed = true;
}
}
impl Drop for TryConnectGuard {
fn drop(&mut self) {
if !self.disarmed {
let server_sender = self.server_sender.clone();
tokio::spawn(async move {
server_sender.write().await.is_try_connect = false;
});
}
}
}
pub async fn wrap_get_internal_websocket(
db: DB,
server_sender: RwServerSender,
server_ip: String,
options: ClientOptions,
) -> bool {
match get_internal_websocket(db, server_sender, server_ip, options).await {
Ok(_) => true,
Err(e) => {
log_error!("Error getting websocket: {:?}", e);
false
}
}
}
pub async fn get_internal_websocket(
db: DB,
server_sender: RwServerSender,
server_ip: String,
options: ClientOptions,
) -> tokio_tungstenite::tungstenite::Result<()> {
log_debug!("Connecting to {}", server_ip);
match timeout(
Duration::from_secs(options.connect_timeout_seconds),
connect_async(&server_ip),
)
.await
{
Ok(Ok((ws_stream, _))) => {
if let Err(err) = handle_websocket(
db,
server_sender.clone(),
options,
server_ip.clone(),
ws_stream,
)
.await
{
server_sender.write().await.is_try_connect = false;
log_error!("Error handling websocket: {:?}", err);
}
}
Err(e) => {
server_sender.remove_ip_if_valid_server_ip(&server_ip).await;
log_error!("Error connecting to {}: {:?}", server_ip, e);
}
Ok(Err(e)) => {
server_sender.remove_ip_if_valid_server_ip(&server_ip).await;
log_error!("Error connecting to {}: {:?}", server_ip, e);
}
}
log_debug!("Connection session ended for {}", server_ip);
Ok(())
}
#[cfg(feature = "bebop")]
pub async fn handle_websocket(
db: DB,
server_sender: RwServerSender,
options: ClientOptions,
server_ip: String,
ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
) -> tokio_tungstenite::tungstenite::Result<()> {
{
let mut guard = server_sender.write().await;
if guard.is_try_connect {
return Ok(());
}
guard.is_try_connect = true;
}
let mut connect_guard = TryConnectGuard::new(server_sender.clone());
let (mut ostream, mut istream) = ws_stream.split();
log_debug!("Connected to {} for web socket", server_ip);
let (sx, mut rx) = mpsc::channel(options.per_connection_buffer_size);
let id = get_id(db.clone()).await;
server_sender.add(sx.clone(), &server_ip).await;
let mut is_first = true;
let use_ping = options.use_ping;
if use_ping {
log_debug!("Client send message: {:?}", make_ping_message(&id));
server_sender.send(make_ping_message(&id)).await;
} else {
is_first = false;
server_sender.write_received_times().await;
server_sender.send_status(SenderStatus::Connected).await;
}
let retry_seconds = options.retry_seconds;
let server_sender_clone = server_sender.clone();
let server_ip_clone = server_ip.clone();
let (stream_end_tx, mut stream_end_rx) = tokio::sync::oneshot::channel::<()>();
tokio::spawn(async move {
let server_ip = server_ip_clone;
let server_sender = server_sender_clone;
let is_wait_ping = Arc::new(AtomicBool::new(false));
while let Some(Ok(message)) = istream.next().await {
server_sender.write_received_times().await;
let value = message.into_data();
let data = match get_data_schema(&value) {
Ok(data) => data,
Err(e) => {
log_error!("Error getting data schema: {:?}", e);
continue;
}
};
if is_first {
is_first = false;
server_sender.send_status(SenderStatus::Connected).await;
}
let id = id.clone();
log_debug!("Client receive message: {:?}", data);
if data.category == Category::Pong as u16 {
if !is_wait_ping.is_true() {
is_wait_ping.set_bool(true);
let server_sender_clone = server_sender.clone();
let is_wait_ping_clone = is_wait_ping.clone();
tokio::spawn(async move {
sleep(Duration::from_secs(retry_seconds)).await;
is_wait_ping_clone.set_bool(false);
server_sender_clone.send(make_ping_message(&id)).await;
});
}
continue;
} else if data.category == Category::Disconnect as u16 {
let peer = server_ip
.split("://")
.nth(1)
.and_then(|s| s.split(':').next())
.unwrap_or(&server_ip);
let _ = sx.send(make_disconnect_message(peer)).await;
break;
}
server_sender.send_handle_message(data).await;
}
let _ = stream_end_tx.send(());
});
loop {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(message) => {
match ostream.send(message.clone()).await {
Ok(_) => {
let data = message.into_data();
let data = match get_data_schema(&data) {
Ok(data) => data,
Err(e) => {
log_error!("Error getting data schema: {:?}", e);
rx.close();
break;
}
};
log_debug!("Send message: {:?}", data);
if data.category == Category::Disconnect as u16 {
rx.close();
break;
}
}
Err(e) => {
log_error!("Error sending message: {:?}", e);
break;
}
}
}
None => break,
}
}
_ = &mut stream_end_rx => {
break;
}
}
}
log_debug!("WebSocket closed");
let _ = timeout(Duration::from_secs(1), ostream.flush()).await;
server_sender.write().await.is_try_connect = false;
connect_guard.disarm();
Ok(())
}
#[cfg(not(feature = "bebop"))]
pub async fn handle_websocket(
db: DB,
server_sender: RwServerSender,
options: ClientOptions,
server_ip: String,
ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
) -> tokio_tungstenite::tungstenite::Result<()> {
{
let mut guard = server_sender.write().await;
if guard.is_try_connect {
return Ok(());
}
guard.is_try_connect = true;
}
let mut connect_guard = TryConnectGuard::new(server_sender.clone());
let (mut ostream, mut istream) = ws_stream.split();
log_debug!("Connected to {} for web socket", server_ip);
let (sx, mut rx) = mpsc::channel(options.per_connection_buffer_size);
server_sender.add(sx.clone(), &server_ip).await;
server_sender.write_received_times().await;
server_sender.send_status(SenderStatus::Connected).await;
let server_sender_clone = server_sender.clone();
let (stream_end_tx, mut stream_end_rx) = tokio::sync::oneshot::channel::<()>();
tokio::spawn(async move {
let server_sender = server_sender_clone;
while let Some(Ok(message)) = istream.next().await {
server_sender.write_received_times().await;
let value = message.into_data();
server_sender.send_handle_message(value.to_vec()).await;
}
let _ = stream_end_tx.send(());
});
loop {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(message) => {
if let Err(e) = ostream.send(message).await {
log_error!("Error sending message: {:?}", e);
break;
}
}
None => break,
}
}
_ = &mut stream_end_rx => {
break;
}
}
}
log_debug!("WebSocket closed");
let _ = timeout(Duration::from_secs(1), ostream.flush()).await;
server_sender.write().await.is_try_connect = false;
connect_guard.disarm();
Ok(())
}
#[cfg(feature = "native-db")]
pub async fn get_id(db: DB) -> String {
let db = db.lock().await;
let Ok(reader) = db.r_transaction() else {
return String::new();
};
let Ok(Some(data)) = reader.get().primary::<Settings>(save_key::CLIENT_ID) else {
return String::new();
};
String::from_utf8(data.value).unwrap_or_default()
}
#[cfg(not(feature = "native-db"))]
pub async fn get_id(db: DB) -> String {
let db = db.lock().await;
db.get(save_key::CLIENT_ID)
.map(|v| String::from_utf8(v.clone()).unwrap_or_default())
.unwrap_or_default()
}