use std::error::Error;
use std::net::UdpSocket;
use std::time::Duration;
#[cfg(feature = "bebop")]
use crate::generated::schema::ServerConnectInfo;
use crate::helpers::get_internal_websocket::handle_websocket;
use crate::helpers::get_outer_websocket::wrap_get_outer_websocket;
use crate::helpers::scan_manager::ScanManager;
use crate::helpers::{
common::{get_setting_by_key, make_ping_message},
get_internal_websocket::{get_id, wrap_get_internal_websocket},
server_sender::{SenderStatus, ServerSenderTrait},
traits::date_time::now,
};
use crate::{helpers::metrics::Metrics, log_debug, log_error, AtomicWebsocketType, Settings};
#[cfg(feature = "bebop")]
use bebop::Record;
use tokio::sync::mpsc::Receiver;
use tokio_util::sync::CancellationToken;
use super::types::{save_key, RwServerSender, DB};
#[derive(Clone)]
pub struct ClientOptions {
pub use_ping: bool,
pub url: String,
pub retry_seconds: u64,
pub use_keep_ip: bool,
pub connect_timeout_seconds: u64,
pub atomic_websocket_type: AtomicWebsocketType,
#[cfg(feature = "rustls")]
pub use_tls: bool,
pub handler_buffer_size: usize,
pub status_buffer_size: usize,
pub per_connection_buffer_size: usize,
pub spillover_buffer_size: usize,
}
impl Default for ClientOptions {
fn default() -> Self {
Self {
use_ping: true,
url: "".into(),
retry_seconds: 30,
use_keep_ip: false,
connect_timeout_seconds: 3,
atomic_websocket_type: AtomicWebsocketType::Internal,
#[cfg(feature = "rustls")]
use_tls: true,
handler_buffer_size: 256,
status_buffer_size: 8,
per_connection_buffer_size: 8,
spillover_buffer_size: 1024,
}
}
}
pub struct AtomicClient {
pub server_sender: RwServerSender,
pub options: ClientOptions,
pub(crate) cancel_token: CancellationToken,
}
impl AtomicClient {
pub async fn internal_initialize(&self, db: DB) {
self.regist_id(db.clone()).await;
tokio::spawn(internal_ping_loop_cheker(
self.server_sender.clone(),
self.options.clone(),
self.cancel_token.clone(),
));
}
pub async fn outer_initialize(&self, db: DB) {
#[cfg(feature = "rustls")]
self.initial_rustls();
self.regist_id(db.clone()).await;
tokio::spawn(outer_ping_loop_cheker(
self.server_sender.clone(),
self.options.clone(),
self.cancel_token.clone(),
));
}
pub async fn disconnect(&self) {
self.cancel_token.cancel();
self.server_sender.remove_ip().await;
}
pub async fn get_outer_connect(&self, db: DB) -> Result<(), Box<dyn Error>> {
get_outer_connect(db, self.server_sender.clone(), self.options.clone()).await
}
#[cfg(all(feature = "native-db", feature = "bebop"))]
pub async fn get_internal_connect(
&self,
input: Option<ServerConnectInfo<'_>>,
db: DB,
) -> Result<(), Box<dyn Error>> {
get_internal_connect(input, db, self.server_sender.clone(), self.options.clone()).await
}
#[cfg(all(not(feature = "native-db"), feature = "bebop"))]
pub async fn get_internal_connect(
&self,
_input: Option<ServerConnectInfo<'_>>,
db: DB,
) -> Result<(), Box<dyn Error>> {
get_internal_connect(None, db, self.server_sender.clone(), self.options.clone()).await
}
#[cfg(not(feature = "bebop"))]
pub async fn get_internal_connect(
&self,
_input: Option<()>,
db: DB,
) -> Result<(), Box<dyn Error>> {
get_internal_connect(None, db, self.server_sender.clone(), self.options.clone()).await
}
#[cfg(feature = "rustls")]
pub fn initial_rustls(&self) {
use rustls::crypto::{ring, CryptoProvider};
if CryptoProvider::get_default().is_none() {
let provider = ring::default_provider();
if let Err(e) = provider.install_default() {
log_error!("Failed to install rustls crypto provider: {:?}", e);
}
}
}
#[cfg(feature = "native-db")]
pub async fn regist_id(&self, db: DB) {
let db = db.lock().await;
let Ok(reader) = db.r_transaction() else {
log_error!("Failed to create r_transaction for regist_id");
return;
};
let data = match reader.get().primary::<Settings>(save_key::CLIENT_ID) {
Ok(data) => data,
Err(e) => {
log_error!("Failed to get ClientId: {:?}", e);
return;
}
};
drop(reader);
if data.is_none() {
use nanoid::nanoid;
let Ok(writer) = db.rw_transaction() else {
log_error!("Failed to create rw_transaction for regist_id");
return;
};
if let Err(e) = writer.insert::<Settings>(Settings {
key: save_key::CLIENT_ID.to_owned(),
value: nanoid!().as_bytes().to_vec(),
}) {
log_error!("Failed to insert ClientId: {:?}", e);
return;
}
if let Err(e) = writer.commit() {
log_error!("Failed to commit ClientId: {:?}", e);
}
}
}
#[cfg(not(feature = "native-db"))]
pub async fn regist_id(&self, db: DB) {
let mut db = db.lock().await;
if db.get(save_key::CLIENT_ID).is_none() {
use nanoid::nanoid;
db.insert(
save_key::CLIENT_ID.to_owned(),
nanoid!().as_bytes().to_vec(),
);
}
}
pub async fn get_status_receiver(&self) -> Option<Receiver<SenderStatus>> {
self.server_sender.get_status_receiver().await
}
pub async fn get_handle_message_receiver(&self) -> Option<Receiver<Vec<u8>>> {
self.server_sender.get_handle_message_receiver().await
}
pub async fn metrics(&self) -> std::sync::Arc<Metrics> {
self.server_sender.read().await.metrics.clone()
}
}
async fn internal_ping_loop_cheker(
server_sender: RwServerSender,
options: ClientOptions,
cancel_token: CancellationToken,
) {
let retry_seconds = options.retry_seconds.max(1);
let use_keep_ip = options.use_keep_ip;
let max_retry_seconds = retry_seconds * 8;
let mut current_retry_seconds = retry_seconds;
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
log_debug!("internal_ping_loop_cheker cancelled");
break;
}
_ = tokio::time::sleep(Duration::from_secs(current_retry_seconds)) => {}
}
let server_sender_read = server_sender.read().await;
if server_sender_read.server_received_times > 0
&& server_sender_read.server_received_times + (retry_seconds as i64 * 4)
< now().timestamp()
{
drop(server_sender_read);
server_sender.send_status(SenderStatus::Disconnected).await;
if !use_keep_ip {
server_sender.remove_ip_if_valid_server_ip("").await;
}
server_sender.send_status(SenderStatus::Reconnecting).await;
let (metrics, db) = {
let guard = server_sender.read().await;
(guard.metrics.clone(), guard.db.clone())
};
metrics.inc_reconnections();
let server_sender = server_sender.clone();
let options = options.clone();
tokio::spawn(async move {
if let Err(e) = get_internal_connect(None, db, server_sender, options).await {
log_error!("Internal reconnection failed: {:?}", e);
}
});
current_retry_seconds = (current_retry_seconds * 2).min(max_retry_seconds);
}
else if server_sender_read.server_received_times > 0
&& server_sender_read.server_received_times + (retry_seconds as i64 * 2)
< now().timestamp()
{
if options.use_ping {
log_debug!("Try ping from loop checker");
let db = server_sender_read.db.clone();
drop(server_sender_read);
let id: String = get_id(db).await;
server_sender.send(make_ping_message(&id)).await;
}
} else {
current_retry_seconds = retry_seconds;
}
log_debug!("loop server checker finish");
}
}
async fn outer_ping_loop_cheker(
server_sender: RwServerSender,
options: ClientOptions,
cancel_token: CancellationToken,
) {
let retry_seconds = options.retry_seconds.max(1);
let use_keep_ip = options.use_keep_ip;
let max_retry_seconds = retry_seconds * 8;
let mut current_retry_seconds = retry_seconds;
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
log_debug!("outer_ping_loop_cheker cancelled");
break;
}
_ = tokio::time::sleep(Duration::from_secs(current_retry_seconds)) => {}
}
let server_sender_read = server_sender.read().await;
if server_sender_read.server_received_times > 0
&& server_sender_read.server_received_times + (retry_seconds as i64 * 4)
< now().timestamp()
{
drop(server_sender_read);
server_sender.send_status(SenderStatus::Disconnected).await;
if !use_keep_ip {
server_sender.remove_ip().await;
}
server_sender.send_status(SenderStatus::Reconnecting).await;
let (metrics, db) = {
let guard = server_sender.read().await;
(guard.metrics.clone(), guard.db.clone())
};
metrics.inc_reconnections();
let server_sender = server_sender.clone();
let options = options.clone();
tokio::spawn(async move {
if let Err(e) = get_outer_connect(db, server_sender, options).await {
log_error!("External reconnection failed: {:?}", e);
}
});
current_retry_seconds = (current_retry_seconds * 2).min(max_retry_seconds);
}
else if server_sender_read.server_received_times > 0
&& server_sender_read.server_received_times + (retry_seconds as i64 * 2)
< now().timestamp()
{
log_debug!(
"send: {:?}, current: {:?}",
server_sender_read.server_received_times,
now().timestamp()
);
if options.use_ping {
log_debug!("Try ping from loop checker");
let db = server_sender_read.db.clone();
drop(server_sender_read);
let id: String = get_id(db).await;
server_sender.send(make_ping_message(&id)).await;
}
} else {
current_retry_seconds = retry_seconds;
}
log_debug!("loop server checker finish");
}
}
pub async fn get_outer_connect(
db: DB,
server_sender: RwServerSender,
options: ClientOptions,
) -> Result<(), Box<dyn Error>> {
if server_sender.read().await.is_try_connect {
return Ok(());
}
if server_sender.is_valid_server_ip().await {
server_sender.send_status(SenderStatus::Connected).await;
return Ok(());
}
let server_connect_info =
get_setting_by_key(db.clone(), save_key::SERVER_CONNECT_INFO.to_owned()).await?;
log_debug!("server_connect_info: {:?}", server_connect_info);
if options.url.is_empty() && !server_sender.is_valid_server_ip().await {
server_sender.send_status(SenderStatus::Disconnected).await;
return Ok(());
}
server_sender.send_status(SenderStatus::Connecting).await;
tokio::spawn(wrap_get_outer_websocket(db, server_sender, options));
Ok(())
}
#[cfg(all(feature = "native-db", feature = "bebop"))]
pub async fn get_internal_connect(
input: Option<ServerConnectInfo<'_>>,
db: DB,
server_sender: RwServerSender,
options: ClientOptions,
) -> Result<(), Box<dyn Error>> {
if server_sender.read().await.is_try_connect {
return Ok(());
}
if server_sender.is_valid_server_ip().await {
server_sender.send_status(SenderStatus::Connected).await;
return Ok(());
}
let server_connect_info =
get_setting_by_key(db.clone(), save_key::SERVER_CONNECT_INFO.to_owned()).await?;
log_debug!("server_connect_info: {:?}", server_connect_info);
if let (Some(input_ref), None) = (input.as_ref(), server_connect_info.as_ref()) {
let db_clone = db.lock().await;
let writer = db_clone.rw_transaction()?;
let mut value = Vec::new();
ServerConnectInfo {
server_ip: "",
port: input_ref.port,
}
.serialize(&mut value)?;
writer.insert::<Settings>(Settings {
key: save_key::SERVER_CONNECT_INFO.to_owned(),
value,
})?;
writer.commit()?;
drop(db_clone);
}
if input.is_none() && server_connect_info.is_none() {
server_sender.send_status(SenderStatus::Disconnected).await;
return Ok(());
}
let connect_info_data = match input.as_ref() {
Some(info) => ServerConnectInfo {
server_ip: match server_connect_info.as_ref() {
Some(server_connect_info) => {
match ServerConnectInfo::deserialize(&server_connect_info.value) {
Ok(info) => info.server_ip,
Err(_) => "",
}
}
None => "",
},
port: info.port,
},
None => {
let Some(ref stored_info) = server_connect_info else {
server_sender.send_status(SenderStatus::Disconnected).await;
return Ok(());
};
match ServerConnectInfo::deserialize(&stored_info.value) {
Ok(info) => info,
Err(e) => {
log_error!("Failed to deserialize ServerConnectInfo: {:?}", e);
server_sender.send_status(SenderStatus::Disconnected).await;
return Ok(());
}
}
}
};
let ip = get_ip_address();
if ip.is_empty() {
server_sender.send_status(SenderStatus::Disconnected).await;
return Ok(());
}
server_sender.send_status(SenderStatus::Connecting).await;
match connect_info_data.server_ip {
"" => {
let (server_ip, ws_stream) = ScanManager::new(connect_info_data.port).run().await;
tokio::spawn(async move {
if let Err(error) =
handle_websocket(db, server_sender.clone(), options, server_ip, ws_stream).await
{
log_error!("Error handling websocket: {:?}", error);
server_sender.write().await.is_try_connect = false;
}
});
}
_server_ip => {
tokio::spawn(wrap_get_internal_websocket(
db.clone(),
server_sender.clone(),
_server_ip.into(),
options.clone(),
));
}
};
Ok(())
}
#[cfg(not(all(feature = "native-db", feature = "bebop")))]
pub async fn get_internal_connect(
_input: Option<()>,
db: DB,
server_sender: RwServerSender,
options: ClientOptions,
) -> Result<(), Box<dyn Error>> {
if server_sender.read().await.is_try_connect {
return Ok(());
}
if server_sender.is_valid_server_ip().await {
server_sender.send_status(SenderStatus::Connected).await;
return Ok(());
}
let ip = get_ip_address();
if ip.is_empty() {
server_sender.send_status(SenderStatus::Disconnected).await;
return Ok(());
}
server_sender.send_status(SenderStatus::Connecting).await;
let (server_ip, ws_stream) = ScanManager::new("9000").run().await;
tokio::spawn(async move {
if let Err(error) =
handle_websocket(db, server_sender.clone(), options, server_ip, ws_stream).await
{
log_error!("Error handling websocket: {:?}", error);
server_sender.write().await.is_try_connect = false;
}
});
Ok(())
}
pub fn get_ip_address() -> String {
let socket = UdpSocket::bind("0.0.0.0:0");
let socket = match socket {
Ok(socket) => socket,
Err(_) => return "".into(),
};
match socket.connect("8.8.8.8:80") {
Ok(_) => {}
Err(_) => return "".into(),
};
let addr = match socket.local_addr() {
Ok(addr) => addr,
Err(_) => return "".into(),
};
addr.ip().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_options_default() {
let options = ClientOptions::default();
assert!(options.use_ping);
assert_eq!(options.url, "");
assert_eq!(options.retry_seconds, 30);
assert!(!options.use_keep_ip);
assert_eq!(options.connect_timeout_seconds, 3);
assert!(matches!(
options.atomic_websocket_type,
AtomicWebsocketType::Internal
));
}
#[cfg(feature = "rustls")]
#[test]
fn test_client_options_default_with_tls() {
let options = ClientOptions::default();
assert!(options.use_tls);
}
#[test]
fn test_client_options_clone() {
let options = ClientOptions {
use_ping: false,
url: "ws://example.com:9000".to_string(),
retry_seconds: 60,
use_keep_ip: true,
connect_timeout_seconds: 10,
atomic_websocket_type: AtomicWebsocketType::External,
#[cfg(feature = "rustls")]
use_tls: false,
..Default::default()
};
let cloned = options.clone();
assert!(!cloned.use_ping);
assert_eq!(cloned.url, "ws://example.com:9000");
assert_eq!(cloned.retry_seconds, 60);
assert!(cloned.use_keep_ip);
assert_eq!(cloned.connect_timeout_seconds, 10);
assert!(matches!(
cloned.atomic_websocket_type,
AtomicWebsocketType::External
));
}
#[test]
fn test_client_options_custom_values() {
let options = ClientOptions {
use_ping: false,
url: "192.168.1.100:9000".to_string(),
retry_seconds: 5,
use_keep_ip: true,
connect_timeout_seconds: 1,
atomic_websocket_type: AtomicWebsocketType::Internal,
#[cfg(feature = "rustls")]
use_tls: true,
..Default::default()
};
assert!(!options.use_ping);
assert_eq!(options.url, "192.168.1.100:9000");
assert_eq!(options.retry_seconds, 5);
assert!(options.use_keep_ip);
assert_eq!(options.connect_timeout_seconds, 1);
}
#[test]
fn test_get_ip_address_format() {
let ip = get_ip_address();
if !ip.is_empty() {
let parts: Vec<&str> = ip.split('.').collect();
assert_eq!(parts.len(), 4, "IP should have 4 octets");
for part in parts {
let num: Result<u8, _> = part.parse();
assert!(num.is_ok(), "Each octet should be a valid u8");
}
}
}
}