Skip to main content

oracledb_protocol/packet/
mod.rs

1#![forbid(unsafe_code)]
2
3use crate::wire::ProtocolLimits;
4use crate::{ProtocolError, Result};
5
6pub const TNS_HEADER_LEN: usize = 8;
7
8#[derive(Clone, Debug, Eq, PartialEq)]
9pub struct TnsPacket {
10    pub packet_type: u8,
11    pub flags: u8,
12    pub payload: Vec<u8>,
13}
14
15impl TnsPacket {
16    pub fn encode(&self) -> Result<Vec<u8>> {
17        let length = TNS_HEADER_LEN + self.payload.len();
18        let wire_length =
19            u16::try_from(length).map_err(|_| ProtocolError::PacketTooLarge { length })?;
20        let mut out = Vec::with_capacity(length);
21        out.extend_from_slice(&wire_length.to_be_bytes());
22        out.extend_from_slice(&0u16.to_be_bytes());
23        out.push(self.packet_type);
24        out.push(self.flags);
25        out.extend_from_slice(&0u16.to_be_bytes());
26        out.extend_from_slice(&self.payload);
27        Ok(out)
28    }
29
30    pub fn parse(input: &[u8]) -> Result<Self> {
31        Self::parse_with_limits(input, ProtocolLimits::DEFAULT)
32    }
33
34    pub fn parse_with_limits(input: &[u8], limits: ProtocolLimits) -> Result<Self> {
35        let limits = limits.validate()?;
36        let header = input
37            .get(..TNS_HEADER_LEN)
38            .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?;
39        let length_bytes = input
40            .get(..2)
41            .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?;
42        let declared = usize::from(u16::from_be_bytes(
43            length_bytes
44                .try_into()
45                .map_err(|_| ProtocolError::TruncatedHeader { got: input.len() })?,
46        ));
47        if declared < TNS_HEADER_LEN {
48            return Err(ProtocolError::InvalidPacketLength {
49                length: declared,
50                minimum: TNS_HEADER_LEN,
51            });
52        }
53        limits.check_packet_bytes(declared)?;
54        if declared > input.len() {
55            return Err(ProtocolError::IncompletePacket {
56                declared,
57                available: input.len(),
58            });
59        }
60
61        Ok(Self {
62            packet_type: *header
63                .get(4)
64                .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?,
65            flags: *header
66                .get(5)
67                .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?,
68            payload: input
69                .get(TNS_HEADER_LEN..declared)
70                .ok_or(ProtocolError::IncompletePacket {
71                    declared,
72                    available: input.len(),
73                })?
74                .to_vec(),
75        })
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn packet_round_trips() {
85        let packet = TnsPacket {
86            packet_type: 1,
87            flags: 0,
88            payload: b"hello".to_vec(),
89        };
90
91        let encoded = packet.encode().expect("small packet should encode");
92        assert_eq!(
93            TnsPacket::parse(&encoded).expect("encoded packet should parse"),
94            packet
95        );
96    }
97
98    #[test]
99    fn packet_decoder_fails_closed_on_short_header() {
100        assert!(matches!(
101            TnsPacket::parse(&[0, 1, 2]),
102            Err(ProtocolError::TruncatedHeader { got: 3 })
103        ));
104    }
105
106    #[test]
107    fn packet_decoder_fails_closed_on_incomplete_body() {
108        let mut bytes = TnsPacket {
109            packet_type: 1,
110            flags: 0,
111            payload: b"hello".to_vec(),
112        }
113        .encode()
114        .expect("small packet should encode");
115        *bytes
116            .get_mut(1)
117            .expect("encoded packet header should contain length byte") = 128;
118
119        assert!(matches!(
120            TnsPacket::parse(&bytes),
121            Err(ProtocolError::IncompletePacket { .. })
122        ));
123    }
124
125    #[test]
126    fn packet_decoder_uses_protocol_limits_before_copying_payload() {
127        let bytes = TnsPacket {
128            packet_type: 1,
129            flags: 0,
130            payload: b"hello".to_vec(),
131        }
132        .encode()
133        .expect("small packet should encode");
134        let limits = ProtocolLimits {
135            max_packet_bytes: bytes.len() - 1,
136            max_frame_bytes: bytes.len() - 1,
137            max_response_bytes: bytes.len() - 1,
138            ..ProtocolLimits::DEFAULT
139        };
140
141        assert!(matches!(
142            TnsPacket::parse_with_limits(&bytes, limits),
143            Err(ProtocolError::ResourceLimit {
144                limit: "packet_bytes",
145                observed,
146                maximum,
147            }) if observed == bytes.len() && maximum == bytes.len() - 1
148        ));
149    }
150
151    #[test]
152    fn packet_encoder_fails_closed_on_oversize_payload() {
153        let packet = TnsPacket {
154            packet_type: 1,
155            flags: 0,
156            payload: vec![0; usize::from(u16::MAX) + 1],
157        };
158
159        assert!(matches!(
160            packet.encode(),
161            Err(ProtocolError::PacketTooLarge { .. })
162        ));
163    }
164}