infrarust 1.1.2

A Rust adaptation of the Infrared Minecraft proxy
Documentation
use std::{
    collections::HashMap,
    time::{Duration, SystemTime},
};

use tracing::{debug, debug_span, instrument, Instrument};

use crate::{
    network::{
        packet::Packet,
        proxy_protocol::{errors::ProxyProtocolError, ProtocolResult},
    },
    server::motd::{generate_motd, MotdConfig},
    version::Version,
    ServerConnection, CONFIG,
};

use crate::telemetry::TELEMETRY;

use super::{backend::Server, ServerRequest};

pub struct StatusCache {
    ttl: Duration,
    entries: HashMap<u64, CacheEntry>,
}

struct CacheEntry {
    expires_at: SystemTime,
    response: Packet,
}

impl StatusCache {
    pub fn new(ttl: Duration) -> Self {
        Self {
            ttl,
            entries: HashMap::new(),
        }
    }

    #[instrument(name = "get_status_response", skip(self, server), fields(
        server_addr = %server.config.addresses.first().unwrap_or(&String::new()),
        protocol_version = ?req.protocol_version
    ))]
    pub async fn get_status_response(
        &mut self,
        server: &Server,
        req: &ServerRequest,
    ) -> ProtocolResult<Packet> {
        match self.try_get_status_response(server, req).await {
            Ok(response) => Ok(response),
            Err(e) => {
                TELEMETRY.record_protocol_error(
                    "status_fetch_failed",
                    &e.to_string(),
                    req.session_id,
                );
                Err(e)
            }
        }
    }

    #[instrument(name = "try_get_status_response", skip(self, server), fields(
        server_addr = %server.config.addresses.first().unwrap_or(&String::new()),
        protocol_version = ?req.protocol_version
    ))]
    pub async fn try_get_status_response(
        &mut self,
        server: &Server,
        req: &ServerRequest,
    ) -> ProtocolResult<Packet> {
        let key = self.cache_key(server, req.protocol_version);

        if let Some(entry) = self.entries.get(&key) {
            if entry.expires_at > SystemTime::now() {
                debug!("Cache hit, returning cached status response");
                return Ok(entry.response.clone());
            }
            debug!("Cache expired");
        } else {
            debug!("Cache miss");
        }

        let use_proxy_protocol = server.config.send_proxy_protocol.unwrap_or(false);

        let response = match if use_proxy_protocol {
            debug!("Using proxy protocol for status connection");
            server
                .dial_with_proxy_protocol(req.session_id, req.client_addr)
                .instrument(debug_span!("backend_server_connect_with_proxy"))
                .await
        } else {
            debug!("Using standard connection for status");
            server
                .dial(req.session_id)
                .instrument(debug_span!("backend_server_connect"))
                .await
        } {
            Ok(mut conn) => {
                self.fetch_status(&mut conn, req)
                    .instrument(debug_span!("fetch_server_status"))
                    .await?
            }
            Err(e) => {
                let guard = CONFIG.read();

                if guard.motds.unreachable.is_some() {
                    let motd = guard.motds.unreachable.clone().unwrap();

                    if motd.enabled && !motd.is_empty() {
                        return generate_motd(&motd, true);
                    } else if motd.enabled {
                        return generate_motd(&MotdConfig::default_unreachable(), true);
                    }
                }

                debug!("Failed to connect to server: {}", e);

                return Err(ProxyProtocolError::Other(format!(
                    "Failed to connect to server: {}",
                    e
                )));
            }
        };

        if let Some(motd) = &server.config.motd {
            let response_packet = generate_motd(motd, false)?;

            self.entries.insert(
                key,
                CacheEntry {
                    expires_at: SystemTime::now() + self.ttl,
                    response: response_packet.clone(),
                },
            );

            return Ok(response_packet);
        }

        debug!("Caching new status response");
        self.entries.insert(
            key,
            CacheEntry {
                expires_at: SystemTime::now() + self.ttl,
                response: response.clone(),
            },
        );

        Ok(response)
    }

    #[instrument(skip(self, conn), fields(
        packets_count = %req.read_packets.len()
    ))]
    pub async fn fetch_status(
        &self,
        conn: &mut ServerConnection,
        req: &ServerRequest,
    ) -> ProtocolResult<Packet> {
        debug!("ReadPacket: {:?}", req.read_packets[0]);
        debug!("ReadPacket: {:?}", req.read_packets[1]);
        conn.write_packet(&req.read_packets[0].clone()).await?;
        conn.write_packet(&req.read_packets[1].clone()).await?;
        conn.read_packet().await
    }

    fn cache_key(&self, server: &Server, version: Version) -> u64 {
        use std::collections::hash_map::DefaultHasher;
        use std::hash::{Hash, Hasher};

        let mut hasher = DefaultHasher::new();
        server.config.addresses[0].hash(&mut hasher);
        version.hash(&mut hasher);
        hasher.finish()
    }
}