Documentation
//! Internal protocol implementation for Nuvo.
//!
//! The protocol follows a simple binary format:
//! - Handshake: HELLO (client) -> ACCEPT/REJECT (server)
//! - Data exchange: DATA packets containing an authentication token and payload.

use std::io;
use std::net::SocketAddr;

use rand::RngCore;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};

/// Magic bytes to identify Nuvo protocol ("NV").
const MAGIC: [u8; 2] = *b"NV";
/// Current protocol version.
const VERSION: u8 = 1;

/// Message types.
const TYPE_HELLO: u8 = 1;
const TYPE_ACCEPT: u8 = 2;
const TYPE_REJECT: u8 = 3;
const TYPE_DATA: u8 = 4;

const MAX_PASSWORD_LEN: usize = 1024;
const MAX_TOKEN_LEN: usize = 64;
const TOKEN_LEN: usize = 32;
const MAX_PAYLOAD_LEN: usize = 8 * 1024 * 1024;

/// Internal receiver implementation that wraps a [`TcpListener`].
pub struct Receiver {
    listener: TcpListener,
    expected_password: Option<String>,
}

/// Internal session implementation that wraps a [`TcpStream`] and maintains a session token.
pub struct Session {
    stream: TcpStream,
    token: Vec<u8>,
    peer: SocketAddr,
}

/// Creates a new [`Receiver`] bound to the specified port.
pub async fn rx(port: u16, expected_password: Option<&str>) -> io::Result<Receiver> {
    let addr = format!("0.0.0.0:{port}");
    let listener = TcpListener::bind(addr).await?;
    Ok(Receiver {
        listener,
        expected_password: expected_password.map(|s| s.to_string()),
    })
}

/// Creates a new [`Session`] (client-side) by connecting to the specified address.
///
/// This performs the HELLO handshake.
pub async fn tx(ip: &str, port: u16, password: &str) -> io::Result<Session> {
    let addr = (ip, port);
    let mut stream = TcpStream::connect(addr).await?;
    send_hello(&mut stream, password).await?;
    let (token, peer) = read_accept(&mut stream).await?;
    Ok(Session { stream, token, peer })
}

impl Receiver {
    /// Accepts a new connection and performs the server-side handshake.
    pub async fn accept(&self) -> io::Result<Session> {
        let (mut stream, peer) = self.listener.accept().await?;
        let password = read_hello(&mut stream).await?;

        if let Some(expected) = &self.expected_password {
            let expected_bytes = expected.as_bytes();
            if password != expected_bytes {
                send_reject(&mut stream, "invalid password").await?;
                return Err(io::Error::new(
                    io::ErrorKind::PermissionDenied,
                    "invalid password",
                ));
            }
        }

        let token = generate_token();
        send_accept(&mut stream, &token).await?;
        Ok(Session { stream, token, peer })
    }
}

impl Session {
    /// Returns the session token.
    #[allow(dead_code)]
    pub fn token(&self) -> &[u8] {
        &self.token
    }

    /// Returns the peer address.
    pub fn peer_addr(&self) -> SocketAddr {
        self.peer
    }

    /// Sends a DATA packet with the session token.
    pub async fn send(&mut self, payload: &[u8]) -> io::Result<()> {
        send_data(&mut self.stream, &self.token, payload).await
    }

    /// Reads a DATA packet and validates the session token.
    pub async fn recv(&mut self) -> io::Result<Vec<u8>> {
        read_data(&mut self.stream, &self.token).await
    }
}

// --- Protocol Helper Functions ---

async fn send_hello(stream: &mut TcpStream, password: &str) -> io::Result<()> {
    let bytes = password.as_bytes();
    if bytes.len() > MAX_PASSWORD_LEN {
        return Err(io::Error::new(
            io::ErrorKind::InvalidInput,
            "password too long",
        ));
    }

    write_header(stream, TYPE_HELLO).await?;
    write_u16(stream, bytes.len() as u16).await?;
    stream.write_all(bytes).await?;
    stream.flush().await?;
    Ok(())
}

async fn read_hello(stream: &mut TcpStream) -> io::Result<Vec<u8>> {
    let msg_type = read_header(stream).await?;
    if msg_type != TYPE_HELLO {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "expected HELLO",
        ));
    }

    let len = read_u16(stream).await? as usize;
    if len > MAX_PASSWORD_LEN {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "password too long",
        ));
    }
    let mut buf = vec![0u8; len];
    stream.read_exact(&mut buf).await?;
    Ok(buf)
}

async fn send_accept(stream: &mut TcpStream, token: &[u8]) -> io::Result<()> {
    if token.len() > MAX_TOKEN_LEN {
        return Err(io::Error::new(
            io::ErrorKind::InvalidInput,
            "token too long",
        ));
    }

    write_header(stream, TYPE_ACCEPT).await?;
    write_u16(stream, token.len() as u16).await?;
    stream.write_all(token).await?;
    stream.flush().await?;
    Ok(())
}

async fn send_reject(stream: &mut TcpStream, reason: &str) -> io::Result<()> {
    let bytes = reason.as_bytes();
    write_header(stream, TYPE_REJECT).await?;
    write_u16(stream, bytes.len() as u16).await?;
    stream.write_all(bytes).await?;
    stream.flush().await?;
    Ok(())
}

async fn read_accept(stream: &mut TcpStream) -> io::Result<(Vec<u8>, SocketAddr)> {
    let msg_type = read_header(stream).await?;
    if msg_type == TYPE_REJECT {
        let reason = read_string(stream, MAX_PAYLOAD_LEN).await?;
        return Err(io::Error::new(io::ErrorKind::PermissionDenied, reason));
    }
    if msg_type != TYPE_ACCEPT {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "expected ACCEPT",
        ));
    }

    let len = read_u16(stream).await? as usize;
    if len > MAX_TOKEN_LEN {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "token too long",
        ));
    }
    let mut token = vec![0u8; len];
    stream.read_exact(&mut token).await?;
    let peer = stream.peer_addr()?;
    Ok((token, peer))
}

async fn send_data(stream: &mut TcpStream, token: &[u8], payload: &[u8]) -> io::Result<()> {
    if token.len() > MAX_TOKEN_LEN {
        return Err(io::Error::new(
            io::ErrorKind::InvalidInput,
            "token too long",
        ));
    }
    if payload.len() > MAX_PAYLOAD_LEN {
        return Err(io::Error::new(
            io::ErrorKind::InvalidInput,
            "payload too large",
        ));
    }

    write_header(stream, TYPE_DATA).await?;
    write_u16(stream, token.len() as u16).await?;
    stream.write_all(token).await?;
    write_u32(stream, payload.len() as u32).await?;
    stream.write_all(payload).await?;
    stream.flush().await?;
    Ok(())
}

async fn read_data(stream: &mut TcpStream, expected_token: &[u8]) -> io::Result<Vec<u8>> {
    let msg_type = read_header(stream).await?;
    if msg_type != TYPE_DATA {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "expected DATA",
        ));
    }

    let token_len = read_u16(stream).await? as usize;
    if token_len > MAX_TOKEN_LEN {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "token too long",
        ));
    }

    let mut token = vec![0u8; token_len];
    stream.read_exact(&mut token).await?;
    if token != expected_token {
        return Err(io::Error::new(
            io::ErrorKind::PermissionDenied,
            "invalid token",
        ));
    }

    let len = read_u32(stream).await? as usize;
    if len > MAX_PAYLOAD_LEN {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "payload too large",
        ));
    }
    let mut payload = vec![0u8; len];
    stream.read_exact(&mut payload).await?;
    Ok(payload)
}

async fn write_header(stream: &mut TcpStream, msg_type: u8) -> io::Result<()> {
    stream.write_all(&MAGIC).await?;
    stream.write_all(&[VERSION, msg_type]).await?;
    Ok(())
}

async fn read_header(stream: &mut TcpStream) -> io::Result<u8> {
    let mut header = [0u8; 4];
    stream.read_exact(&mut header).await?;
    if header[0..2] != MAGIC {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "invalid magic",
        ));
    }
    if header[2] != VERSION {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "unsupported version",
        ));
    }
    Ok(header[3])
}

async fn write_u16(stream: &mut TcpStream, value: u16) -> io::Result<()> {
    stream.write_all(&value.to_be_bytes()).await
}

async fn write_u32(stream: &mut TcpStream, value: u32) -> io::Result<()> {
    stream.write_all(&value.to_be_bytes()).await
}

async fn read_u16(stream: &mut TcpStream) -> io::Result<u16> {
    let mut buf = [0u8; 2];
    stream.read_exact(&mut buf).await?;
    Ok(u16::from_be_bytes(buf))
}

async fn read_u32(stream: &mut TcpStream) -> io::Result<u32> {
    let mut buf = [0u8; 4];
    stream.read_exact(&mut buf).await?;
    Ok(u32::from_be_bytes(buf))
}

async fn read_string(stream: &mut TcpStream, max_len: usize) -> io::Result<String> {
    let len = read_u16(stream).await? as usize;
    if len > max_len {
        return Err(io::Error::new(
            io::ErrorKind::InvalidData,
            "string too long",
        ));
    }
    let mut buf = vec![0u8; len];
    stream.read_exact(&mut buf).await?;
    Ok(String::from_utf8_lossy(&buf).to_string())
}

/// Generates a random session token.
fn generate_token() -> Vec<u8> {
    let mut token = vec![0u8; TOKEN_LEN];
    rand::rngs::OsRng.fill_bytes(&mut token);
    token
}