1use std::io::{Cursor, Read, Write};
9
10use crate::types::{status_code::StatusCode, *};
11
12use super::{
13 message_chunk_info::ChunkInfo,
14 secure_channel::SecureChannel,
15 security_header::{
16 AsymmetricSecurityHeader, SecurityHeader, SequenceHeader, SymmetricSecurityHeader,
17 },
18 tcp_types::{
19 CHUNK_FINAL, CHUNK_FINAL_ERROR, CHUNK_INTERMEDIATE, CHUNK_MESSAGE,
20 CLOSE_SECURE_CHANNEL_MESSAGE, MIN_CHUNK_SIZE, OPEN_SECURE_CHANNEL_MESSAGE,
21 },
22};
23
24pub const MESSAGE_CHUNK_HEADER_SIZE: usize = 12;
26
27#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum MessageChunkType {
29 Message,
30 OpenSecureChannel,
31 CloseSecureChannel,
32}
33
34impl MessageChunkType {
35 pub fn is_open_secure_channel(&self) -> bool {
36 *self == MessageChunkType::OpenSecureChannel
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq)]
41pub enum MessageIsFinalType {
42 Intermediate,
44 Final,
46 FinalError,
48}
49
50#[derive(Debug, Clone, PartialEq)]
51pub struct MessageChunkHeader {
52 pub message_type: MessageChunkType,
54 pub is_final: MessageIsFinalType,
56 pub message_size: u32,
58 pub secure_channel_id: u32,
60}
61
62impl BinaryEncoder<MessageChunkHeader> for MessageChunkHeader {
63 fn byte_len(&self) -> usize {
64 MESSAGE_CHUNK_HEADER_SIZE
65 }
66
67 fn encode<S: Write>(&self, stream: &mut S) -> EncodingResult<usize> {
68 let message_type = match self.message_type {
69 MessageChunkType::Message => CHUNK_MESSAGE,
70 MessageChunkType::OpenSecureChannel => OPEN_SECURE_CHANNEL_MESSAGE,
71 MessageChunkType::CloseSecureChannel => CLOSE_SECURE_CHANNEL_MESSAGE,
72 };
73
74 let is_final = match self.is_final {
75 MessageIsFinalType::Intermediate => CHUNK_INTERMEDIATE,
76 MessageIsFinalType::Final => CHUNK_FINAL,
77 MessageIsFinalType::FinalError => CHUNK_FINAL_ERROR,
78 };
79
80 let mut size = 0;
81 size += process_encode_io_result(stream.write(message_type))?;
82 size += write_u8(stream, is_final)?;
83 size += write_u32(stream, self.message_size)?;
84 size += write_u32(stream, self.secure_channel_id)?;
85 assert_eq!(size, self.byte_len());
86 Ok(size)
87 }
88
89 fn decode<S: Read>(stream: &mut S, _: &DecodingOptions) -> EncodingResult<Self> {
90 let mut message_type_code = [0u8; 3];
91 process_decode_io_result(stream.read_exact(&mut message_type_code))?;
92 let message_type = if message_type_code == CHUNK_MESSAGE {
93 MessageChunkType::Message
94 } else if message_type_code == OPEN_SECURE_CHANNEL_MESSAGE {
95 MessageChunkType::OpenSecureChannel
96 } else if message_type_code == CLOSE_SECURE_CHANNEL_MESSAGE {
97 MessageChunkType::CloseSecureChannel
98 } else {
99 error!("Invalid message code");
100 return Err(StatusCode::BadDecodingError);
101 };
102
103 let chunk_type_code = read_u8(stream)?;
104 let is_final = match chunk_type_code {
105 CHUNK_FINAL => MessageIsFinalType::Final,
106 CHUNK_INTERMEDIATE => MessageIsFinalType::Intermediate,
107 CHUNK_FINAL_ERROR => MessageIsFinalType::FinalError,
108 _ => {
109 error!("Invalid chunk type");
110 return Err(StatusCode::BadDecodingError);
111 }
112 };
113
114 let message_size = read_u32(stream)?;
115 let secure_channel_id = read_u32(stream)?;
116
117 Ok(MessageChunkHeader {
118 message_type,
119 is_final,
120 message_size,
121 secure_channel_id,
122 })
123 }
124}
125
126impl MessageChunkHeader {}
127
128#[derive(Debug)]
132pub struct MessageChunk {
133 pub data: Vec<u8>,
135}
136
137impl BinaryEncoder<MessageChunk> for MessageChunk {
138 fn byte_len(&self) -> usize {
139 self.data.len()
140 }
141
142 fn encode<S: Write>(&self, stream: &mut S) -> EncodingResult<usize> {
143 stream.write(&self.data).map_err(|_| {
144 error!("Encoding error while writing to stream");
145 StatusCode::BadEncodingError
146 })
147 }
148
149 fn decode<S: Read>(
150 in_stream: &mut S,
151 decoding_options: &DecodingOptions,
152 ) -> EncodingResult<Self> {
153 let chunk_header =
155 MessageChunkHeader::decode(in_stream, decoding_options).map_err(|err| {
156 error!("Cannot decode chunk header {:?}", err);
157 StatusCode::BadCommunicationError
158 })?;
159
160 let message_size = chunk_header.message_size as usize;
161 if decoding_options.max_message_size > 0 && message_size > decoding_options.max_message_size
162 {
163 Err(StatusCode::BadTcpMessageTooLarge)
165 } else {
166 let data = vec![0u8; message_size];
168 let mut stream = Cursor::new(data);
169
170 let chunk_header_size = chunk_header.encode(&mut stream)?;
172 assert_eq!(chunk_header_size, MESSAGE_CHUNK_HEADER_SIZE);
173
174 let mut data = stream.into_inner();
176
177 let _ = in_stream.read_exact(&mut data[chunk_header_size..]);
179
180 Ok(MessageChunk { data })
181 }
182 }
183}
184
185impl MessageChunk {
186 pub fn new(
187 sequence_number: u32,
188 request_id: u32,
189 message_type: MessageChunkType,
190 is_final: MessageIsFinalType,
191 secure_channel: &SecureChannel,
192 data: &[u8],
193 ) -> Result<MessageChunk, StatusCode> {
194 let security_header = secure_channel.make_security_header(message_type);
196 let sequence_header = SequenceHeader {
197 sequence_number,
198 request_id,
199 };
200
201 let mut message_size = MESSAGE_CHUNK_HEADER_SIZE;
203 message_size += security_header.byte_len();
204 message_size += sequence_header.byte_len();
205 message_size += data.len();
206
207 trace!(
208 "Creating a chunk with a size of {}, data excluding padding & signature",
209 message_size
210 );
211 let secure_channel_id = secure_channel.secure_channel_id();
212 let chunk_header = MessageChunkHeader {
213 message_type,
214 is_final,
215 message_size: message_size as u32,
216 secure_channel_id,
217 };
218
219 let mut stream = Cursor::new(vec![0u8; message_size]);
220 let _ = chunk_header.encode(&mut stream);
222 let _ = security_header.encode(&mut stream);
224 let _ = sequence_header.encode(&mut stream);
226 let _ = stream.write(data);
228
229 Ok(MessageChunk {
230 data: stream.into_inner(),
231 })
232 }
233
234 pub fn body_size_from_message_size(
238 message_type: MessageChunkType,
239 secure_channel: &SecureChannel,
240 message_size: usize,
241 ) -> Result<usize, ()> {
242 if message_size < MIN_CHUNK_SIZE {
243 error!(
244 "message size {} is less than minimum allowed by the spec",
245 message_size
246 );
247 Err(())
248 } else {
249 let security_header = secure_channel.make_security_header(message_type);
250
251 let mut data_size = MESSAGE_CHUNK_HEADER_SIZE;
252 data_size += security_header.byte_len();
253 data_size += (SequenceHeader {
254 sequence_number: 0,
255 request_id: 0,
256 })
257 .byte_len();
258
259 let signature_size = secure_channel.signature_size(&security_header);
261 data_size += secure_channel
262 .padding_size(&security_header, 1, signature_size)
263 .0;
264
265 data_size += signature_size;
267
268 Ok(message_size - data_size)
270 }
271 }
272
273 pub fn message_header(
274 &self,
275 decoding_options: &DecodingOptions,
276 ) -> Result<MessageChunkHeader, StatusCode> {
277 let mut stream = Cursor::new(&self.data);
279 MessageChunkHeader::decode(&mut stream, decoding_options)
280 }
281
282 pub fn security_header(
283 &self,
284 decoding_options: &DecodingOptions,
285 ) -> Result<SecurityHeader, StatusCode> {
286 let mut stream = Cursor::new(&self.data);
288 let message_header = MessageChunkHeader::decode(&mut stream, decoding_options)?;
289 let security_header = if message_header.message_type == MessageChunkType::OpenSecureChannel
290 {
291 SecurityHeader::Asymmetric(AsymmetricSecurityHeader::decode(
292 &mut stream,
293 decoding_options,
294 )?)
295 } else {
296 SecurityHeader::Symmetric(SymmetricSecurityHeader::decode(
297 &mut stream,
298 decoding_options,
299 )?)
300 };
301 Ok(security_header)
302 }
303
304 pub fn is_open_secure_channel(&self, decoding_options: &DecodingOptions) -> bool {
305 if let Ok(message_header) = self.message_header(decoding_options) {
306 message_header.message_type.is_open_secure_channel()
307 } else {
308 false
309 }
310 }
311
312 pub fn chunk_info(
313 &self,
314 secure_channel: &SecureChannel,
315 ) -> std::result::Result<ChunkInfo, StatusCode> {
316 ChunkInfo::new(self, secure_channel)
317 }
318}