Skip to main content

oracledb_protocol/packet/
mod.rs

1#![forbid(unsafe_code)]
2
3use crate::{ProtocolError, Result};
4
5pub const TNS_HEADER_LEN: usize = 8;
6
7#[derive(Clone, Debug, Eq, PartialEq)]
8pub struct TnsPacket {
9    pub packet_type: u8,
10    pub flags: u8,
11    pub payload: Vec<u8>,
12}
13
14impl TnsPacket {
15    pub fn encode(&self) -> Result<Vec<u8>> {
16        let length = TNS_HEADER_LEN + self.payload.len();
17        let wire_length =
18            u16::try_from(length).map_err(|_| ProtocolError::PacketTooLarge { length })?;
19        let mut out = Vec::with_capacity(length);
20        out.extend_from_slice(&wire_length.to_be_bytes());
21        out.extend_from_slice(&0u16.to_be_bytes());
22        out.push(self.packet_type);
23        out.push(self.flags);
24        out.extend_from_slice(&0u16.to_be_bytes());
25        out.extend_from_slice(&self.payload);
26        Ok(out)
27    }
28
29    pub fn parse(input: &[u8]) -> Result<Self> {
30        let header = input
31            .get(..TNS_HEADER_LEN)
32            .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?;
33        let length_bytes = input
34            .get(..2)
35            .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?;
36        let declared = usize::from(u16::from_be_bytes(
37            length_bytes
38                .try_into()
39                .map_err(|_| ProtocolError::TruncatedHeader { got: input.len() })?,
40        ));
41        if declared < TNS_HEADER_LEN {
42            return Err(ProtocolError::InvalidPacketLength {
43                length: declared,
44                minimum: TNS_HEADER_LEN,
45            });
46        }
47        if declared > input.len() {
48            return Err(ProtocolError::IncompletePacket {
49                declared,
50                available: input.len(),
51            });
52        }
53
54        Ok(Self {
55            packet_type: *header
56                .get(4)
57                .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?,
58            flags: *header
59                .get(5)
60                .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?,
61            payload: input
62                .get(TNS_HEADER_LEN..declared)
63                .ok_or(ProtocolError::IncompletePacket {
64                    declared,
65                    available: input.len(),
66                })?
67                .to_vec(),
68        })
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn packet_round_trips() {
78        let packet = TnsPacket {
79            packet_type: 1,
80            flags: 0,
81            payload: b"hello".to_vec(),
82        };
83
84        let encoded = packet.encode().expect("small packet should encode");
85        assert_eq!(
86            TnsPacket::parse(&encoded).expect("encoded packet should parse"),
87            packet
88        );
89    }
90
91    #[test]
92    fn packet_decoder_fails_closed_on_short_header() {
93        assert!(matches!(
94            TnsPacket::parse(&[0, 1, 2]),
95            Err(ProtocolError::TruncatedHeader { got: 3 })
96        ));
97    }
98
99    #[test]
100    fn packet_decoder_fails_closed_on_incomplete_body() {
101        let mut bytes = TnsPacket {
102            packet_type: 1,
103            flags: 0,
104            payload: b"hello".to_vec(),
105        }
106        .encode()
107        .expect("small packet should encode");
108        *bytes
109            .get_mut(1)
110            .expect("encoded packet header should contain length byte") = 128;
111
112        assert!(matches!(
113            TnsPacket::parse(&bytes),
114            Err(ProtocolError::IncompletePacket { .. })
115        ));
116    }
117
118    #[test]
119    fn packet_encoder_fails_closed_on_oversize_payload() {
120        let packet = TnsPacket {
121            packet_type: 1,
122            flags: 0,
123            payload: vec![0; usize::from(u16::MAX) + 1],
124        };
125
126        assert!(matches!(
127            packet.encode(),
128            Err(ProtocolError::PacketTooLarge { .. })
129        ));
130    }
131}