torrust-actix 4.1.2

A rich, fast and efficient Bittorrent Tracker.
use crate::config::enums::cluster_encoding::ClusterEncoding;
use serde::{de::DeserializeOwned, Serialize};

#[derive(Debug)]
pub enum EncodingError {
    SerializationError(String),
    DeserializationError(String),
}

impl std::fmt::Display for EncodingError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            EncodingError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
            EncodingError::DeserializationError(msg) => write!(f, "Deserialization error: {}", msg),
        }
    }
}

impl std::error::Error for EncodingError {}

pub fn encode<T: Serialize>(encoding: &ClusterEncoding, value: &T) -> Result<Vec<u8>, EncodingError> {
    match encoding {
        ClusterEncoding::binary => encode_binary(value),
        ClusterEncoding::json => encode_json(value),
        ClusterEncoding::msgpack => encode_msgpack(value),
    }
}

pub fn decode<T: DeserializeOwned>(encoding: &ClusterEncoding, data: &[u8]) -> Result<T, EncodingError> {
    match encoding {
        ClusterEncoding::binary => decode_binary(data),
        ClusterEncoding::json => decode_json(data),
        ClusterEncoding::msgpack => decode_msgpack(data),
    }
}

fn encode_binary<T: Serialize>(value: &T) -> Result<Vec<u8>, EncodingError> {
    rmp_serde::to_vec(value)
        .map_err(|e| EncodingError::SerializationError(e.to_string()))
}

fn decode_binary<T: DeserializeOwned>(data: &[u8]) -> Result<T, EncodingError> {
    rmp_serde::from_slice(data)
        .map_err(|e| EncodingError::DeserializationError(e.to_string()))
}

fn encode_json<T: Serialize>(value: &T) -> Result<Vec<u8>, EncodingError> {
    serde_json::to_vec(value)
        .map_err(|e| EncodingError::SerializationError(e.to_string()))
}

fn decode_json<T: DeserializeOwned>(data: &[u8]) -> Result<T, EncodingError> {
    serde_json::from_slice(data)
        .map_err(|e| EncodingError::DeserializationError(e.to_string()))
}

fn encode_msgpack<T: Serialize>(value: &T) -> Result<Vec<u8>, EncodingError> {
    rmp_serde::to_vec(value)
        .map_err(|e| EncodingError::SerializationError(e.to_string()))
}

fn decode_msgpack<T: DeserializeOwned>(data: &[u8]) -> Result<T, EncodingError> {
    rmp_serde::from_slice(data)
        .map_err(|e| EncodingError::DeserializationError(e.to_string()))
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::websocket::enums::protocol_type::ProtocolType;
    use crate::websocket::enums::request_type::RequestType;
    use crate::websocket::structs::cluster_request::ClusterRequest;
    use crate::websocket::structs::cluster_response::ClusterResponse;
    use crate::websocket::structs::handshake::{HandshakeRequest, HandshakeResponse};
    use std::net::IpAddr;

    #[test]
    fn test_encode_decode_json() {
        let request = ClusterRequest::new(
            1,
            ProtocolType::Http,
            RequestType::Announce,
            "127.0.0.1".parse::<IpAddr>().unwrap(),
            6969,
            vec![1, 2, 3, 4],
        );
        let encoded = encode(&ClusterEncoding::json, &request).unwrap();
        let decoded: ClusterRequest = decode(&ClusterEncoding::json, &encoded).unwrap();
        assert_eq!(request.request_id, decoded.request_id);
        assert_eq!(request.protocol, decoded.protocol);
        assert_eq!(request.payload, decoded.payload);
    }

    #[test]
    fn test_encode_decode_msgpack() {
        let request = ClusterRequest::new(
            1,
            ProtocolType::Udp,
            RequestType::Scrape,
            "::1".parse::<IpAddr>().unwrap(),
            6969,
            vec![5, 6, 7, 8],
        );
        let encoded = encode(&ClusterEncoding::msgpack, &request).unwrap();
        let decoded: ClusterRequest = decode(&ClusterEncoding::msgpack, &encoded).unwrap();
        assert_eq!(request.request_id, decoded.request_id);
        assert_eq!(request.protocol, decoded.protocol);
        assert_eq!(request.payload, decoded.payload);
    }

    #[test]
    fn test_encode_decode_binary() {
        let request = ClusterRequest::new(
            1,
            ProtocolType::Api,
            RequestType::ApiCall {
                endpoint: "/api/stats".to_string(),
                method: "GET".to_string(),
            },
            "192.168.1.1".parse::<IpAddr>().unwrap(),
            8080,
            vec![],
        );
        let encoded = encode(&ClusterEncoding::binary, &request).unwrap();
        let decoded: ClusterRequest = decode(&ClusterEncoding::binary, &encoded).unwrap();
        assert_eq!(request.request_id, decoded.request_id);
        assert_eq!(request.protocol, decoded.protocol);
    }

    #[test]
    fn test_encoding_error_display_serialization() {
        let error = EncodingError::SerializationError("test error".to_string());
        assert_eq!(format!("{}", error), "Serialization error: test error");
    }

    #[test]
    fn test_encoding_error_display_deserialization() {
        let error = EncodingError::DeserializationError("invalid data".to_string());
        assert_eq!(format!("{}", error), "Deserialization error: invalid data");
    }

    #[test]
    fn test_encoding_error_debug() {
        let error = EncodingError::SerializationError("test".to_string());
        let debug_str = format!("{:?}", error);
        assert!(debug_str.contains("SerializationError"));
    }

    #[test]
    fn test_decode_invalid_json() {
        let invalid_json = b"not valid json at all";
        let result: Result<ClusterRequest, _> = decode(&ClusterEncoding::json, invalid_json);
        assert!(result.is_err());
        if let Err(EncodingError::DeserializationError(msg)) = result {
            assert!(!msg.is_empty());
        } else {
            panic!("Expected DeserializationError");
        }
    }

    #[test]
    fn test_decode_invalid_msgpack() {
        let invalid_data = b"\xff\xff\xff";
        let result: Result<ClusterRequest, _> = decode(&ClusterEncoding::msgpack, invalid_data);
        assert!(result.is_err());
    }

    #[test]
    fn test_decode_invalid_binary() {
        let invalid_data = b"\x00\x01\x02";
        let result: Result<ClusterRequest, _> = decode(&ClusterEncoding::binary, invalid_data);
        assert!(result.is_err());
    }

    #[test]
    fn test_encode_decode_cluster_response_json() {
        let response = ClusterResponse::success(42, vec![0xde, 0xad, 0xbe, 0xef]);
        let encoded = encode(&ClusterEncoding::json, &response).unwrap();
        let decoded: ClusterResponse = decode(&ClusterEncoding::json, &encoded).unwrap();
        assert_eq!(response.request_id, decoded.request_id);
        assert_eq!(response.success, decoded.success);
        assert_eq!(response.payload, decoded.payload);
    }

    #[test]
    fn test_encode_decode_cluster_response_msgpack() {
        let response = ClusterResponse::error(99, "Something failed".to_string());
        let encoded = encode(&ClusterEncoding::msgpack, &response).unwrap();
        let decoded: ClusterResponse = decode(&ClusterEncoding::msgpack, &encoded).unwrap();
        assert_eq!(response.request_id, decoded.request_id);
        assert!(!decoded.success);
        assert_eq!(response.error_message, decoded.error_message);
    }

    #[test]
    fn test_encode_decode_handshake_request_json() {
        let request = HandshakeRequest::new("secret_token".to_string(), "slave-001".to_string());
        let encoded = encode(&ClusterEncoding::json, &request).unwrap();
        let decoded: HandshakeRequest = decode(&ClusterEncoding::json, &encoded).unwrap();
        assert_eq!(request.token, decoded.token);
        assert_eq!(request.slave_id, decoded.slave_id);
        assert_eq!(request.version, decoded.version);
    }

    #[test]
    fn test_encode_decode_handshake_response_binary() {
        let response = HandshakeResponse::success(ClusterEncoding::binary, "master-001".to_string());
        let encoded = encode(&ClusterEncoding::binary, &response).unwrap();
        let decoded: HandshakeResponse = decode(&ClusterEncoding::binary, &encoded).unwrap();
        assert!(decoded.success);
        assert_eq!(response.encoding, decoded.encoding);
        assert_eq!(response.master_id, decoded.master_id);
    }

    #[test]
    fn test_encode_decode_handshake_failure_msgpack() {
        let response = HandshakeResponse::failure("Invalid token".to_string());
        let encoded = encode(&ClusterEncoding::msgpack, &response).unwrap();
        let decoded: HandshakeResponse = decode(&ClusterEncoding::msgpack, &encoded).unwrap();
        assert!(!decoded.success);
        assert_eq!(response.error, decoded.error);
    }

    #[test]
    fn test_encode_empty_payload() {
        let request = ClusterRequest::new(
            0,
            ProtocolType::Http,
            RequestType::Announce,
            "0.0.0.0".parse::<IpAddr>().unwrap(),
            0,
            vec![],
        );
        for encoding in &[ClusterEncoding::json, ClusterEncoding::msgpack, ClusterEncoding::binary] {
            let encoded = encode(encoding, &request).unwrap();
            let decoded: ClusterRequest = decode(encoding, &encoded).unwrap();
            assert!(decoded.payload.is_empty());
        }
    }

    #[test]
    fn test_encode_large_payload() {
        let large_payload: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
        let request = ClusterRequest::new(
            u64::MAX,
            ProtocolType::Udp,
            RequestType::UdpPacket,
            "255.255.255.255".parse::<IpAddr>().unwrap(),
            65535,
            large_payload.clone(),
        );
        let encoded = encode(&ClusterEncoding::binary, &request).unwrap();
        let decoded: ClusterRequest = decode(&ClusterEncoding::binary, &encoded).unwrap();
        assert_eq!(decoded.payload.len(), 10000);
        assert_eq!(decoded.payload, large_payload);
    }

    #[test]
    fn test_binary_and_msgpack_produce_same_output() {
        let request = ClusterRequest::new(
            123,
            ProtocolType::Http,
            RequestType::Announce,
            "10.0.0.1".parse::<IpAddr>().unwrap(),
            6881,
            vec![1, 2, 3],
        );
        let binary_encoded = encode(&ClusterEncoding::binary, &request).unwrap();
        let msgpack_encoded = encode(&ClusterEncoding::msgpack, &request).unwrap();
        assert_eq!(binary_encoded, msgpack_encoded);
    }

    #[test]
    fn test_json_encoding_is_human_readable() {
        let response = ClusterResponse::success(42, vec![1, 2, 3]);
        let encoded = encode(&ClusterEncoding::json, &response).unwrap();
        let json_str = String::from_utf8(encoded).unwrap();
        assert!(json_str.contains("request_id"));
        assert!(json_str.contains("42"));
        assert!(json_str.contains("success"));
        assert!(json_str.contains("true"));
    }

    #[test]
    fn test_msgpack_is_more_compact_than_json() {
        let request = ClusterRequest::new(
            999999,
            ProtocolType::Api,
            RequestType::ApiCall {
                endpoint: "/api/v1/very/long/endpoint/path".to_string(),
                method: "POST".to_string(),
            },
            "192.168.100.200".parse::<IpAddr>().unwrap(),
            12345,
            vec![0; 100],
        );
        let json_encoded = encode(&ClusterEncoding::json, &request).unwrap();
        let msgpack_encoded = encode(&ClusterEncoding::msgpack, &request).unwrap();
        assert!(msgpack_encoded.len() < json_encoded.len());
    }
}