use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use tokio::net::TcpStream;
use tokio::sync::Semaphore;
use tokio::time::{timeout, MissedTickBehavior};
use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream};
use crate::helpers::traits::connection_state::ConnectionManager;
use crate::log_debug;
use crate::server_sender::get_ip_address;
const DEFAULT_MAX_CONCURRENT_CONNECTIONS: usize = 50;
pub struct ConnectionState {
pub status: WebSocketStatus,
pub is_connecting: bool,
pub ws_stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
}
pub struct ScanManager {
scan_ips: Vec<String>,
connection_states: Arc<DashMap<String, ConnectionState>>,
semaphore: Arc<Semaphore>,
stop_flag: Arc<AtomicBool>,
}
impl ScanManager {
pub fn new(port: &str) -> Self {
Self::with_concurrency(port, DEFAULT_MAX_CONCURRENT_CONNECTIONS)
}
pub fn with_concurrency(port: &str, max_concurrent: usize) -> Self {
let mut scan_ips = Vec::new();
let ip = get_ip_address();
let ips = ip.split('.').collect::<Vec<&str>>();
for sub_ip in 1..255 {
let ip = format!("ws://{}.{}.{}.{}:{}", ips[0], ips[1], ips[2], sub_ip, port);
scan_ips.push(ip);
}
Self {
scan_ips,
connection_states: Arc::new(DashMap::new()),
semaphore: Arc::new(Semaphore::new(max_concurrent)),
stop_flag: Arc::new(AtomicBool::new(false)),
}
}
fn is_connecting_allowed(&self, server_ip: &str) -> bool {
if self.stop_flag.load(Ordering::Acquire) {
return false;
}
if let Some(state) = self.connection_states.get(server_ip) {
if state.is_connecting {
return false;
}
}
if self.connection_states.is_connected() {
return false;
}
true
}
fn get_scannable_ips(&self) -> Vec<String> {
self.scan_ips
.iter()
.filter(|server_ip| {
if let Some(state) = self.connection_states.get(*server_ip) {
if state.is_connecting {
return false;
}
}
true
})
.cloned()
.collect()
}
async fn scan_network(&mut self) {
if self.stop_flag.load(Ordering::Acquire) {
return;
}
let scan_list: Vec<String> = self.get_scannable_ips();
for server_ip in scan_list {
if !self.is_connecting_allowed(&server_ip) {
continue;
}
let connection_states = self.connection_states.clone();
let semaphore = self.semaphore.clone();
let stop_flag = self.stop_flag.clone();
tokio::spawn(async move {
if stop_flag.load(Ordering::Acquire) {
return;
}
let _permit = semaphore.acquire().await;
if stop_flag.load(Ordering::Acquire) {
return;
}
connection_states.start_connection(&server_ip);
let status = check_connection(server_ip.clone()).await;
log_debug!("server_ip: {}, {:?}", server_ip, status);
if status.0 == WebSocketStatus::Connected {
stop_flag.store(true, Ordering::Release);
}
connection_states.end_connection(&server_ip, status);
});
}
}
pub async fn run(&mut self) -> (String, WebSocketStream<MaybeTlsStream<TcpStream>>) {
let mut interval =
tokio::time::interval_at(tokio::time::Instant::now(), Duration::from_secs(2));
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
interval.tick().await;
if let Some(state) = self.connection_states.get_connected_ip() {
return state;
}
self.scan_network().await;
}
}
}
async fn check_connection(
server_ip: String,
) -> (
WebSocketStatus,
Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
match timeout(Duration::from_secs(10), connect_async(&server_ip)).await {
Ok(result) => match result {
Ok((ws_stream, _)) => {
(WebSocketStatus::Connected, Some(ws_stream))
}
Err(e) => {
match e {
tungstenite::Error::Io(e) => match e.kind() {
std::io::ErrorKind::ConnectionRefused => {
(WebSocketStatus::ConnectionRefused, None)
}
_ => {
(WebSocketStatus::Timeout, None)
}
},
_ => (WebSocketStatus::Timeout, None),
}
}
},
Err(_) => {
(WebSocketStatus::Timeout, None)
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum WebSocketStatus {
Connecting,
Connected,
ConnectionRefused,
Timeout,
}