torrust-actix 4.1.2

A rich, fast and efficient Bittorrent Tracker.
use crate::config::enums::cluster_encoding::ClusterEncoding;
use crate::stats::enums::stats_event::StatsEvent;
use crate::tracker::structs::torrent_tracker::TorrentTracker;
use crate::websocket::encoding::encoder::{decode, encode};
use crate::websocket::structs::cluster_request::ClusterRequest;
use crate::websocket::structs::cluster_response::ClusterResponse;
use crate::websocket::structs::handshake::{HandshakeRequest, HandshakeResponse, CLUSTER_PROTOCOL_VERSION};
use futures_util::{SinkExt, StreamExt};
use log::{debug, error, info, warn};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tokio_tungstenite::{connect_async, tungstenite::Message};

type PendingRequestSender = oneshot::Sender<ClusterResponse>;
type SlaveSenderChannel = Arc<RwLock<Option<tokio::sync::mpsc::UnboundedSender<Vec<u8>>>>>;

pub struct SlaveClientState {
    pub encoding: Option<ClusterEncoding>,
    pub connected: bool,
    pub pending_requests: HashMap<u64, PendingRequestSender>,
    pub request_counter: u64,
}

impl SlaveClientState {
    pub fn new() -> Self {
        Self {
            encoding: None,
            connected: false,
            pending_requests: HashMap::new(),
            request_counter: 0,
        }
    }

    pub fn next_request_id(&mut self) -> u64 {
        self.request_counter = self.request_counter.wrapping_add(1);
        self.request_counter
    }
}

impl Default for SlaveClientState {
    fn default() -> Self {
        Self::new()
    }
}

pub static SLAVE_CLIENT: once_cell::sync::Lazy<Arc<RwLock<SlaveClientState>>> =
    once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(SlaveClientState::new())));

pub static SLAVE_SENDER: once_cell::sync::Lazy<SlaveSenderChannel> =
    once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(None)));

pub fn is_connected() -> bool {
    SLAVE_CLIENT.read().connected
}

pub fn get_encoding() -> Option<ClusterEncoding> {
    SLAVE_CLIENT.read().encoding.clone()
}

pub async fn send_request(
    tracker: &Arc<TorrentTracker>,
    request: ClusterRequest,
) -> Result<ClusterResponse, super::forwarder::ForwardError> {
    let (connected, encoding) = {
        let state = SLAVE_CLIENT.read();
        (state.connected, state.encoding.clone())
    };
    if !connected {
        return Err(super::forwarder::ForwardError::NotConnected);
    }
    let encoding = match encoding {
        Some(e) => e,
        None => return Err(super::forwarder::ForwardError::NotConnected),
    };
    let encoded = match encode(&encoding, &request) {
        Ok(data) => data,
        Err(e) => return Err(super::forwarder::ForwardError::EncodingError(e.to_string())),
    };
    let (tx, rx) = oneshot::channel();
    let request_id = request.request_id;
    {
        let mut state = SLAVE_CLIENT.write();
        state.pending_requests.insert(request_id, tx);
    }
    let send_result = {
        let sender_guard = SLAVE_SENDER.read();
        if let Some(sender) = sender_guard.as_ref() {
            sender.send(encoded).map_err(|_| ())
        } else {
            Err(())
        }
    };
    match send_result {
        Ok(_) => {
            tracker.update_stats(StatsEvent::WsRequestsSent, 1);
        }
        Err(_) => {
            let mut state = SLAVE_CLIENT.write();
            state.pending_requests.remove(&request_id);
            let sender_guard = SLAVE_SENDER.read();
            if sender_guard.is_none() {
                return Err(super::forwarder::ForwardError::NotConnected);
            }
            return Err(super::forwarder::ForwardError::ConnectionLost);
        }
    }
    let timeout_duration = Duration::from_secs(tracker.config.tracker_config.cluster_request_timeout);
    match timeout(timeout_duration, rx).await {
        Ok(Ok(response)) => {
            tracker.update_stats(StatsEvent::WsResponsesReceived, 1);
            Ok(response)
        }
        Ok(Err(_)) => {
            tracker.update_stats(StatsEvent::WsTimeouts, 1);
            Err(super::forwarder::ForwardError::ConnectionLost)
        }
        Err(_) => {
            {
                let mut state = SLAVE_CLIENT.write();
                state.pending_requests.remove(&request_id);
            }
            tracker.update_stats(StatsEvent::WsTimeouts, 1);
            Err(super::forwarder::ForwardError::Timeout)
        }
    }
}

pub async fn start_slave_client(tracker: Arc<TorrentTracker>) {
    let config = tracker.config.clone();
    let master_address = &config.tracker_config.cluster_master_address;
    let token = &config.tracker_config.cluster_token;
    let use_ssl = config.tracker_config.cluster_ssl;
    let reconnect_interval = config.tracker_config.cluster_reconnect_interval;
    let protocol = if use_ssl { "wss" } else { "ws" };
    let websocket_url = format!("{}://{}/cluster", protocol, master_address);
    let slave_id = uuid::Uuid::new_v4().to_string();
    info!("[WEBSOCKET SLAVE] Starting slave client, connecting to {}", websocket_url);
    info!("[WEBSOCKET SLAVE] Slave UUID: {}", slave_id);
    loop {
        match connect_to_master(
            &tracker,
            &websocket_url,
            token,
            &slave_id,
        ).await {
            Ok(()) => {
                info!("[WEBSOCKET SLAVE] Disconnected from master");
            }
            Err(e) => {
                error!("[WEBSOCKET SLAVE] Connection error: {}", e);
            }
        }
        {
            let mut state = SLAVE_CLIENT.write();
            state.connected = false;
            state.encoding = None;
            for (_, sender) in state.pending_requests.drain() {
                let _ = sender.send(ClusterResponse::error(0, "Connection lost".to_string()));
            }
        }
        {
            let mut sender_guard = SLAVE_SENDER.write();
            *sender_guard = None;
        }
        tracker.update_stats(StatsEvent::WsConnectionsActive, -1);
        tracker.update_stats(StatsEvent::WsReconnects, 1);
        info!(
            "[WEBSOCKET SLAVE] Reconnecting in {} seconds...",
            reconnect_interval
        );
        tokio::time::sleep(Duration::from_secs(reconnect_interval)).await;
    }
}

async fn connect_to_master(
    tracker: &Arc<TorrentTracker>,
    websocket_url: &str,
    token: &str,
    slave_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    debug!("[WEBSOCKET SLAVE] Connecting to master: {}", websocket_url);
    let (ws_stream, _) = connect_async(websocket_url).await?;
    let (mut write, mut read) = ws_stream.split();
    info!("[WEBSOCKET SLAVE] Connected, sending handshake...");
    let handshake = HandshakeRequest::new(token.to_string(), slave_id.to_string());
    let handshake_data = serde_json::to_vec(&handshake)?;
    write.send(Message::Binary(handshake_data.into())).await?;
    let handshake_response: HandshakeResponse = match read.next().await {
        Some(Ok(Message::Binary(data))) => serde_json::from_slice(&data)?,
        Some(Ok(Message::Text(text))) => serde_json::from_str(&text)?,
        Some(Err(e)) => return Err(format!("WebSocket error during handshake: {}", e).into()),
        None => return Err("Connection closed during handshake".into()),
        _ => return Err("Unexpected message type during handshake".into()),
    };
    if !handshake_response.success {
        let error_msg = handshake_response.error.unwrap_or_else(|| "Unknown error".to_string());
        error!("[WEBSOCKET SLAVE] Handshake failed: {}", error_msg);
        tracker.update_stats(StatsEvent::WsAuthFailed, 1);
        return Err(format!("Handshake failed: {}", error_msg).into());
    }
    if handshake_response.version != CLUSTER_PROTOCOL_VERSION {
        warn!(
            "[WEBSOCKET SLAVE] Protocol version mismatch: master={}, slave={}",
            handshake_response.version, CLUSTER_PROTOCOL_VERSION
        );
    }
    let encoding = handshake_response.encoding.unwrap_or(ClusterEncoding::binary);
    let master_id = handshake_response.master_id.unwrap_or_else(|| "unknown".to_string());
    info!(
        "[WEBSOCKET SLAVE] Handshake successful, connected to master UUID: {}, using encoding: {:?}",
        master_id, encoding
    );
    tracker.update_stats(StatsEvent::WsAuthSuccess, 1);
    tracker.update_stats(StatsEvent::WsConnectionsActive, 1);
    {
        let mut state = SLAVE_CLIENT.write();
        state.connected = true;
        state.encoding = Some(encoding.clone());
    }
    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
    {
        let mut sender_guard = SLAVE_SENDER.write();
        *sender_guard = Some(tx);
    }
    let write_handle = tokio::spawn(async move {
        while let Some(data) = rx.recv().await {
            if write.send(Message::Binary(data.into())).await.is_err() {
                break;
            }
        }
    });
    let encoding_for_read = encoding.clone();
    while let Some(msg) = read.next().await {
        match msg {
            Ok(Message::Binary(data)) => {
                handle_response(&encoding_for_read, &data);
            }
            Ok(Message::Ping(data)) => {
                debug!("[WEBSOCKET SLAVE] Received ping");
                
                let _ = data;
            }
            Ok(Message::Pong(_)) => {
                debug!("[WEBSOCKET SLAVE] Received pong");
            }
            Ok(Message::Close(_)) => {
                info!("[WEBSOCKET SLAVE] Received close from master");
                break;
            }
            Err(e) => {
                error!("[WEBSOCKET SLAVE] WebSocket error: {}", e);
                break;
            }
            _ => {}
        }
    }
    write_handle.abort();
    Ok(())
}

fn handle_response(encoding: &ClusterEncoding, data: &[u8]) {
    let response: ClusterResponse = match decode(encoding, data) {
        Ok(r) => r,
        Err(e) => {
            error!("[WEBSOCKET SLAVE] Failed to decode response: {}", e);
            return;
        }
    };
    let mut state = SLAVE_CLIENT.write();
    if let Some(sender) = state.pending_requests.remove(&response.request_id) {
        let _ = sender.send(response);
    } else {
        warn!(
            "[WEBSOCKET SLAVE] Received response for unknown request: {}",
            response.request_id
        );
    }
}