Documentation
use super::PlainClient;
use byteorder::{BigEndian, ReadBytesExt as _, WriteBytesExt as _};
use std::io::{Read as _, Write as _};
use thrift::transport::{
    ReadHalf, TFramedReadTransport, TFramedWriteTransport, TIoChannel as _, TTcpChannel, WriteHalf,
};

pub type TSaslClientReadTransport = TFramedReadTransport<ReadHalf<TTcpChannel>>;
pub type TSaslClientWriteTransport = TFramedWriteTransport<WriteHalf<TTcpChannel>>;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum NegotiationStatus {
    Start = 1,
    Ok,
    Bad,
    Error,
    Complete,
}

impl NegotiationStatus {
    fn name(&self) -> &str {
        match self {
            NegotiationStatus::Start => "START",
            NegotiationStatus::Ok => "OK",
            NegotiationStatus::Bad => "BAD",
            NegotiationStatus::Error => "ERROR",
            NegotiationStatus::Complete => "COMPLETE",
        }
    }
}

impl TryFrom<u8> for NegotiationStatus {
    type Error = anyhow::Error;

    fn try_from(value: u8) -> Result<Self, anyhow::Error> {
        match value {
            1 => Ok(Self::Start),
            2 => Ok(Self::Ok),
            3 => Ok(Self::Bad),
            4 => Ok(Self::Error),
            5 => Ok(Self::Complete),
            _ => Err(anyhow::anyhow!("Invalid status {}", value)),
        }
    }
}

#[derive(Debug)]
struct SaslResponse {
    status: NegotiationStatus,
    payload: Vec<u8>,
}

#[derive(Debug)]
pub struct TSaslClientTransport {
    channel: TTcpChannel,
    mechanism: String,
    sasl_client: PlainClient,
}

impl TSaslClientTransport {
    pub fn new(channel: TTcpChannel, username: String, password: Vec<u8>) -> Self {
        let sasl_client = PlainClient::new(None, username, password);
        Self {
            channel,
            mechanism: sasl_client.mechanism_name().to_owned(),
            sasl_client,
        }
    }

    pub fn open(mut self) -> anyhow::Result<Self> {
        if self.sasl_client.is_complete() {
            Err(anyhow::anyhow!("SASL transport already open"))?
        }
        self.handle_sasl_start_message()?;
        let mut message = None;
        while !self.sasl_client.is_complete() {
            let _message = self.receive_sasl_message()?;
            match _message.status {
                NegotiationStatus::Ok => {
                    let response = self.sasl_client.step(&_message.payload)?;
                    self.send_sasl_message(
                        if self.sasl_client.is_complete() {
                            NegotiationStatus::Complete
                        } else {
                            NegotiationStatus::Ok
                        },
                        &response,
                    )?;
                }
                NegotiationStatus::Complete => (),
                _ => Err(anyhow::anyhow!(
                    "Expected COMPLETE or OK, got {}",
                    _message.status.name()
                ))?,
            }
            message = Some(_message);
        }
        if message.is_none() || message.is_some_and(|m| m.status == NegotiationStatus::Ok) {
            let _message = self.receive_sasl_message()?;
            if _message.status != NegotiationStatus::Complete {
                Err(anyhow::anyhow!(
                    "Expected SASL COMPLETE, but got {}",
                    _message.status.name()
                ))?
            }
        }
        Ok(self)
    }

    fn handle_sasl_start_message(&mut self) -> anyhow::Result<()> {
        let mut initial_response = vec![];
        if self.sasl_client.has_initial_response() {
            initial_response = self.sasl_client.step(&initial_response)?;
        }
        let mechanism = self.mechanism.to_owned();
        self.send_sasl_message(NegotiationStatus::Start, mechanism.as_bytes())?;
        self.send_sasl_message(
            if self.sasl_client.is_complete() {
                NegotiationStatus::Complete
            } else {
                NegotiationStatus::Ok
            },
            &initial_response,
        )?;
        self.channel.flush()?;
        Ok(())
    }

    fn send_sasl_message(
        &mut self,
        status: NegotiationStatus,
        payload: &[u8],
    ) -> anyhow::Result<()> {
        self.channel.write_u8(status as u8)?;
        self.channel.write_u32::<BigEndian>(payload.len() as u32)?;
        self.channel.write_all(payload)?;
        self.channel.flush()?;
        Ok(())
    }

    fn receive_sasl_message(&mut self) -> anyhow::Result<SaslResponse> {
        let status = NegotiationStatus::try_from(self.channel.read_u8()?)?;
        let payload_bytes = self.channel.read_i32::<BigEndian>()?;
        if !(0..=104857600).contains(&payload_bytes) {
            Err(anyhow::anyhow!(
                "Invalid payload header length: {}",
                payload_bytes
            ))?
        }
        let mut payload = vec![0; payload_bytes as usize];
        self.channel.read_exact(&mut payload)?;
        if ![NegotiationStatus::Bad, NegotiationStatus::Error].contains(&status) {
            Ok(SaslResponse { status, payload })
        } else {
            Err(anyhow::anyhow!(
                "Peer indicated failure: {}",
                String::from_utf8(payload)?
            ))?
        }
    }

    pub fn split(self) -> anyhow::Result<(TSaslClientReadTransport, TSaslClientWriteTransport)> {
        let (in_channel, out_channel) = self.channel.split()?;
        let read_transport = TFramedReadTransport::new(in_channel);
        let write_transport = TFramedWriteTransport::new(out_channel);
        Ok((read_transport, write_transport))
    }
}