oracledb_protocol/packet/
mod.rs1#![forbid(unsafe_code)]
2
3use crate::{ProtocolError, Result};
4
5pub const TNS_HEADER_LEN: usize = 8;
6
7#[derive(Clone, Debug, Eq, PartialEq)]
8pub struct TnsPacket {
9 pub packet_type: u8,
10 pub flags: u8,
11 pub payload: Vec<u8>,
12}
13
14impl TnsPacket {
15 pub fn encode(&self) -> Result<Vec<u8>> {
16 let length = TNS_HEADER_LEN + self.payload.len();
17 let wire_length =
18 u16::try_from(length).map_err(|_| ProtocolError::PacketTooLarge { length })?;
19 let mut out = Vec::with_capacity(length);
20 out.extend_from_slice(&wire_length.to_be_bytes());
21 out.extend_from_slice(&0u16.to_be_bytes());
22 out.push(self.packet_type);
23 out.push(self.flags);
24 out.extend_from_slice(&0u16.to_be_bytes());
25 out.extend_from_slice(&self.payload);
26 Ok(out)
27 }
28
29 pub fn parse(input: &[u8]) -> Result<Self> {
30 let header = input
31 .get(..TNS_HEADER_LEN)
32 .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?;
33 let length_bytes = input
34 .get(..2)
35 .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?;
36 let declared = usize::from(u16::from_be_bytes(
37 length_bytes
38 .try_into()
39 .map_err(|_| ProtocolError::TruncatedHeader { got: input.len() })?,
40 ));
41 if declared < TNS_HEADER_LEN {
42 return Err(ProtocolError::InvalidPacketLength {
43 length: declared,
44 minimum: TNS_HEADER_LEN,
45 });
46 }
47 if declared > input.len() {
48 return Err(ProtocolError::IncompletePacket {
49 declared,
50 available: input.len(),
51 });
52 }
53
54 Ok(Self {
55 packet_type: *header
56 .get(4)
57 .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?,
58 flags: *header
59 .get(5)
60 .ok_or(ProtocolError::TruncatedHeader { got: input.len() })?,
61 payload: input
62 .get(TNS_HEADER_LEN..declared)
63 .ok_or(ProtocolError::IncompletePacket {
64 declared,
65 available: input.len(),
66 })?
67 .to_vec(),
68 })
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn packet_round_trips() {
78 let packet = TnsPacket {
79 packet_type: 1,
80 flags: 0,
81 payload: b"hello".to_vec(),
82 };
83
84 let encoded = packet.encode().expect("small packet should encode");
85 assert_eq!(
86 TnsPacket::parse(&encoded).expect("encoded packet should parse"),
87 packet
88 );
89 }
90
91 #[test]
92 fn packet_decoder_fails_closed_on_short_header() {
93 assert!(matches!(
94 TnsPacket::parse(&[0, 1, 2]),
95 Err(ProtocolError::TruncatedHeader { got: 3 })
96 ));
97 }
98
99 #[test]
100 fn packet_decoder_fails_closed_on_incomplete_body() {
101 let mut bytes = TnsPacket {
102 packet_type: 1,
103 flags: 0,
104 payload: b"hello".to_vec(),
105 }
106 .encode()
107 .expect("small packet should encode");
108 *bytes
109 .get_mut(1)
110 .expect("encoded packet header should contain length byte") = 128;
111
112 assert!(matches!(
113 TnsPacket::parse(&bytes),
114 Err(ProtocolError::IncompletePacket { .. })
115 ));
116 }
117
118 #[test]
119 fn packet_encoder_fails_closed_on_oversize_payload() {
120 let packet = TnsPacket {
121 packet_type: 1,
122 flags: 0,
123 payload: vec![0; usize::from(u16::MAX) + 1],
124 };
125
126 assert!(matches!(
127 packet.encode(),
128 Err(ProtocolError::PacketTooLarge { .. })
129 ));
130 }
131}