oracledb-protocol 0.5.1

Sans-I/O Oracle TNS/TTC protocol core for the oracledb crate.
Documentation
#![forbid(unsafe_code)]

use crate::wire::ProtocolLimits;
use crate::{ProtocolError, Result};

pub const TNS_HEADER_LEN: usize = 8;

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TnsPacket {
    pub packet_type: u8,
    pub flags: u8,
    pub payload: Vec<u8>,
}

impl TnsPacket {
    pub fn encode(&self) -> Result<Vec<u8>> {
        let length = TNS_HEADER_LEN + self.payload.len();
        let wire_length =
            u16::try_from(length).map_err(|_| ProtocolError::PacketTooLarge { length })?;
        let mut out = Vec::with_capacity(length);
        out.extend_from_slice(&wire_length.to_be_bytes());
        out.extend_from_slice(&0u16.to_be_bytes());
        out.push(self.packet_type);
        out.push(self.flags);
        out.extend_from_slice(&0u16.to_be_bytes());
        out.extend_from_slice(&self.payload);
        Ok(out)
    }

    pub fn parse(input: &[u8]) -> Result<Self> {
        Self::parse_with_limits(input, ProtocolLimits::DEFAULT)
    }

    pub fn parse_with_limits(input: &[u8], limits: ProtocolLimits) -> Result<Self> {
        let limits = limits.validate()?;
        let header = input
            .get(..TNS_HEADER_LEN)
            .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?;
        let length_bytes = input
            .get(..2)
            .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?;
        let declared = usize::from(u16::from_be_bytes(
            length_bytes
                .try_into()
                .map_err(|_| ProtocolError::TruncatedHeader { got: input.len() })?,
        ));
        if declared < TNS_HEADER_LEN {
            return Err(ProtocolError::InvalidPacketLength {
                length: declared,
                minimum: TNS_HEADER_LEN,
            });
        }
        limits.check_packet_bytes(declared)?;
        if declared > input.len() {
            return Err(ProtocolError::IncompletePacket {
                declared,
                available: input.len(),
            });
        }

        Ok(Self {
            packet_type: *header
                .get(4)
                .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?,
            flags: *header
                .get(5)
                .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?,
            payload: input
                .get(TNS_HEADER_LEN..declared)
                .ok_or(ProtocolError::IncompletePacket {
                    declared,
                    available: input.len(),
                })?
                .to_vec(),
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn packet_round_trips() {
        let packet = TnsPacket {
            packet_type: 1,
            flags: 0,
            payload: b"hello".to_vec(),
        };

        let encoded = packet.encode().expect("small packet should encode");
        assert_eq!(
            TnsPacket::parse(&encoded).expect("encoded packet should parse"),
            packet
        );
    }

    #[test]
    fn packet_decoder_fails_closed_on_short_header() {
        assert!(matches!(
            TnsPacket::parse(&[0, 1, 2]),
            Err(ProtocolError::TruncatedHeader { got: 3 })
        ));
    }

    #[test]
    fn packet_decoder_fails_closed_on_incomplete_body() {
        let mut bytes = TnsPacket {
            packet_type: 1,
            flags: 0,
            payload: b"hello".to_vec(),
        }
        .encode()
        .expect("small packet should encode");
        *bytes
            .get_mut(1)
            .expect("encoded packet header should contain length byte") = 128;

        assert!(matches!(
            TnsPacket::parse(&bytes),
            Err(ProtocolError::IncompletePacket { .. })
        ));
    }

    #[test]
    fn packet_decoder_uses_protocol_limits_before_copying_payload() {
        let bytes = TnsPacket {
            packet_type: 1,
            flags: 0,
            payload: b"hello".to_vec(),
        }
        .encode()
        .expect("small packet should encode");
        let limits = ProtocolLimits {
            max_packet_bytes: bytes.len() - 1,
            max_frame_bytes: bytes.len() - 1,
            max_response_bytes: bytes.len() - 1,
            ..ProtocolLimits::DEFAULT
        };

        assert!(matches!(
            TnsPacket::parse_with_limits(&bytes, limits),
            Err(ProtocolError::ResourceLimit {
                limit: "packet_bytes",
                observed,
                maximum,
            }) if observed == bytes.len() && maximum == bytes.len() - 1
        ));
    }

    #[test]
    fn packet_encoder_fails_closed_on_oversize_payload() {
        let packet = TnsPacket {
            packet_type: 1,
            flags: 0,
            payload: vec![0; usize::from(u16::MAX) + 1],
        };

        assert!(matches!(
            packet.encode(),
            Err(ProtocolError::PacketTooLarge { .. })
        ));
    }
}