1use std::io;
4use std::os::fd::{AsFd, AsRawFd};
5
6use bitfield_struct::bitfield;
7
8use crate::error::{Error, InvalidMessage, PeerMisbehaved, Result};
9use crate::ffi::{recv_tls_record, send_tls_control_message};
10use crate::tls::{
11 AlertDescription, AlertLevel, ContentType, HandshakeType, KeyUpdateRequest, Peer,
12 ProtocolVersion, TlsSession,
13};
14use crate::utils::Buffer;
15
16#[derive(Debug)]
17pub struct Context<C: TlsSession> {
19 state: State,
21
22 buffer: Buffer,
24
25 session: C,
27}
28
29impl<C: TlsSession> Context<C> {
30 pub fn new(session: C, buffer: Option<Buffer>) -> Self {
34 Self {
35 state: State::new(),
36 buffer: buffer.unwrap_or_default(),
37 session,
38 }
39 }
40
41 pub const fn state(&self) -> &State {
43 &self.state
44 }
45
46 pub const fn buffer(&self) -> &Buffer {
48 &self.buffer
49 }
50
51 pub const fn buffer_mut(&mut self) -> &mut Buffer {
53 &mut self.buffer
54 }
55
56 pub fn handle_io_error<S: AsFd>(&mut self, socket: &S, err: io::Error) -> io::Result<()> {
79 if err.raw_os_error() == Some(libc::EIO) {
80 crate::trace!("Received EIO, handling TLS control message");
81
82 self.handle_tls_control_message(socket)?;
83
84 return Ok(());
85 }
86
87 if err.kind() == io::ErrorKind::BrokenPipe {
88 crate::trace!("The underlying stream is closed (BrokenPipe)");
89
90 } else {
93 self.send_tls_alert(socket, AlertLevel::Fatal, AlertDescription::InternalError);
94 }
95
96 self.state.set_is_read_closed(true);
97 self.state.set_is_write_closed(true);
98
99 Err(err)
100 }
101
102 #[allow(clippy::too_many_lines)]
103 fn handle_tls_control_message<S: AsFd>(&mut self, socket: &S) -> Result<()> {
111 match recv_tls_record(socket.as_fd().as_raw_fd(), &mut self.buffer) {
112 Ok(ContentType::Handshake) => {
113 return self.handle_tls_control_message_handshake(socket);
114 }
115 Ok(ContentType::Alert) => {
116 if let &[level, desc] = self.buffer.unfilled_initialized() {
117 return self.handle_tls_control_message_alert(
118 socket,
119 AlertLevel::from_int(level),
120 AlertDescription::from_int(desc),
121 );
122 }
123
124 crate::error!(
128 "Invalid alert message received: {:?}, {:?}",
129 self.buffer.unfilled_initialized(),
130 self.buffer
131 );
132
133 return self.abort(
134 socket,
135 InvalidMessage::MessageTooLarge,
136 InvalidMessage::MessageTooLarge.description(),
137 );
138 }
139 Ok(ContentType::ChangeCipherSpec) => {
140 crate::warn!("Received unexpected ChangeCipherSpec message");
150
151 return self.abort(
152 socket,
153 PeerMisbehaved::IllegalMiddleboxChangeCipherSpec,
154 PeerMisbehaved::IllegalMiddleboxChangeCipherSpec.description(),
155 );
156 }
157 Ok(ContentType::ApplicationData) => {
158 crate::warn!(
161 "Received {} bytes of application data, unexpected usage",
162 self.buffer.unfilled_initialized().len()
163 );
164
165 self.buffer.set_filled_all();
166 }
167 Ok(_content_type) => {
168 crate::error!(
169 "Received unexpected TLS control message: content_type={_content_type:?}",
170 );
171
172 return self.abort(
173 socket,
174 InvalidMessage::InvalidContentType,
175 InvalidMessage::InvalidContentType.description(),
176 );
177 }
178 Err(error) => {
179 crate::error!("Failed to receive TLS control message: {error}");
180
181 return self.abort(
182 socket,
183 Error::General(error),
184 AlertDescription::InternalError,
185 );
186 }
187 }
188
189 Ok(())
190 }
191
192 #[allow(clippy::too_many_lines)]
193 fn handle_tls_control_message_handshake<S: AsFd>(&mut self, socket: &S) -> Result<()> {
195 let mut messages =
196 HandshakeMessagesIter::new(self.buffer.unfilled_initialized()).enumerate();
197
198 while let Some((idx, payload)) = messages.next() {
199 let Ok((handshake_type, payload)) = payload else {
200 return self.abort(
201 socket,
202 InvalidMessage::MessageTooShort,
203 InvalidMessage::MessageTooShort.description(),
204 );
205 };
206
207 match handshake_type {
208 HandshakeType::KeyUpdate
209 if self.session.protocol_version() == ProtocolVersion::TLSv1_3 =>
210 {
211 if idx != 0 || messages.next().is_some() {
212 crate::error!(
213 "RFC 8446, section 5.1: Handshake messages MUST NOT span key changes."
214 );
215
216 return self.abort(
217 socket,
218 PeerMisbehaved::KeyEpochWithPendingFragment,
219 PeerMisbehaved::KeyEpochWithPendingFragment.description(),
220 );
221 }
222
223 let &[payload] = payload else {
224 crate::error!(
225 "Received invalid KeyUpdate message, expected 1 byte payload, got: \
226 {:?}",
227 payload
228 );
229
230 return self.abort(
231 socket,
232 InvalidMessage::InvalidKeyUpdate,
233 InvalidMessage::InvalidKeyUpdate.description(),
234 );
235 };
236
237 let key_update_request = KeyUpdateRequest::from_int(payload);
238
239 if let Err(error) = self
240 .session
241 .update_rx_secret()
242 .and_then(|secret| secret.set(socket))
243 {
244 return self.abort(socket, error, AlertDescription::InternalError);
245 }
246
247 match key_update_request {
248 KeyUpdateRequest::UpdateNotRequested => {}
249 KeyUpdateRequest::UpdateRequested => {
250 if let Err(error) = send_tls_control_message(
252 socket.as_fd().as_raw_fd(),
253 ContentType::Handshake,
254 &mut [
255 HandshakeType::KeyUpdate.to_int(), 0,
257 0,
258 1, KeyUpdateRequest::UpdateNotRequested.to_int(),
260 ],
261 )
262 .map_err(Error::KeyUpdateFailed)
263 {
264 crate::error!("Failed to send KeyUpdate message: {error}");
266
267 return self.abort(socket, error, AlertDescription::InternalError);
268 }
269
270 if let Err(error) = self
271 .session
272 .update_tx_secret()
273 .and_then(|secret| secret.set(socket))
274 {
275 crate::error!("Failed to update TX secret: {error}");
276
277 return self.abort(socket, error, AlertDescription::InternalError);
278 }
279 }
280 KeyUpdateRequest::Unknown(_payload) => {
281 crate::warn!(
282 "Received KeyUpdate message with unknown request value: {_payload}"
283 );
284
285 return self.abort(
286 socket,
287 InvalidMessage::InvalidKeyUpdate,
288 InvalidMessage::InvalidKeyUpdate.description(),
289 );
290 }
291 }
292 }
293 HandshakeType::NewSessionTicket
294 if self.session.protocol_version() == ProtocolVersion::TLSv1_3 =>
295 {
296 if self.session.peer() != Peer::Client {
297 crate::warn!("TLS 1.2 peer sent a TLS 1.3 NewSessionTicket message");
298
299 return self.abort(
300 socket,
301 InvalidMessage::UnexpectedMessage(
302 "TLS 1.2 peer sent a TLS 1.3 NewSessionTicket message",
303 ),
304 AlertDescription::UnexpectedMessage,
305 );
306 }
307
308 if let Err(error) = self
309 .session
310 .handle_new_session_ticket(payload)
311 {
312 return self.abort(socket, error, AlertDescription::InternalError);
313 }
314 }
315 _ if self.session.protocol_version() == ProtocolVersion::TLSv1_3 => {
316 crate::error!(
317 "Unexpected handshake message for a TLS 1.3 connection: \
318 typ={handshake_type:?}",
319 );
320
321 return self.abort(
322 socket,
323 InvalidMessage::UnexpectedMessage(
324 "expected KeyUpdate or NewSessionTicket message",
325 ),
326 AlertDescription::UnexpectedMessage,
327 );
328 }
329 _ => {
330 crate::error!(
331 "Unexpected handshake message: ver={:?}, typ={handshake_type:?}",
332 self.session.protocol_version()
333 );
334
335 return self.abort(
336 socket,
337 InvalidMessage::UnexpectedMessage(
338 "handshake messages are not expected on TLS 1.2 connections",
339 ),
340 AlertDescription::UnexpectedMessage,
341 );
342 }
343 }
344 }
345
346 Ok(())
347 }
348
349 fn handle_tls_control_message_alert<S: AsFd>(
351 &mut self,
352 socket: &S,
353 level: AlertLevel,
354 desc: AlertDescription,
355 ) -> Result<()> {
356 match desc {
357 AlertDescription::CloseNotify
358 if self.session.protocol_version() == ProtocolVersion::TLSv1_2 =>
359 {
360 crate::trace!("Received `close_notify` alert, should shutdown the TLS stream");
366
367 self.shutdown(socket);
368 }
369 AlertDescription::CloseNotify => {
370 crate::trace!(
382 "Received `close_notify` alert, should shutdown the read side of TLS stream"
383 );
384
385 self.state.set_is_read_closed(true);
386 }
387 _ if self.session.protocol_version() == ProtocolVersion::TLSv1_2
388 && level == AlertLevel::Warning =>
389 {
390 crate::warn!("Received non fatal alert, level={level:?}, desc: {desc:?}");
395 }
396 _ => {
397 crate::error!("Received fatal alert, desc: {desc:?}");
401
402 self.state.set_is_read_closed(true);
403 self.state.set_is_write_closed(true);
404
405 return Err(Error::AlertReceived(desc));
406 }
407 }
408
409 Ok(())
410 }
411
412 pub fn shutdown<S: AsFd>(&mut self, socket: &S) {
415 crate::trace!("Shutting down the TLS stream with `close_notify` alert...");
416
417 self.send_tls_alert(socket, AlertLevel::Warning, AlertDescription::CloseNotify);
418
419 if self.session.protocol_version() == ProtocolVersion::TLSv1_2 {
420 self.state.set_is_read_closed(true);
422 }
423
424 self.state.set_is_write_closed(true);
425 }
426
427 fn abort<T, S, E, D>(&mut self, socket: &S, error: E, description: D) -> Result<T>
429 where
430 S: AsFd,
431 E: Into<Error>,
432 D: Into<AlertDescription>,
433 {
434 crate::trace!("Aborting the TLS stream with fatal alert...");
435
436 self.send_tls_alert(socket, AlertLevel::Fatal, description.into());
437
438 self.state.set_is_read_closed(true);
439 self.state.set_is_write_closed(true);
440
441 Err(error.into())
442 }
443
444 fn send_tls_alert<S: AsFd>(
446 &mut self,
447 socket: &S,
448 level: AlertLevel,
449 description: AlertDescription,
450 ) {
451 if !self.state.is_write_closed() {
452 let _ = send_tls_control_message(
453 socket.as_fd().as_raw_fd(),
454 ContentType::Alert,
455 &mut [level.to_int(), description.to_int()],
456 )
457 .inspect_err(|_e| {
458 crate::trace!("Failed to send alert: {_e}");
459 });
460 }
461 }
462}
463
464#[bitfield(u8)]
465pub struct State {
467 pub is_read_closed: bool,
469
470 pub is_write_closed: bool,
472
473 #[bits(6)]
474 _reserved: u8,
475}
476
477impl State {
478 #[must_use]
481 pub const fn is_closed(&self) -> bool {
482 self.is_read_closed() && self.is_write_closed()
483 }
484}
485
486struct HandshakeMessagesIter<'a> {
487 inner: Result<Option<&'a [u8]>, ()>,
488}
489
490impl<'a> HandshakeMessagesIter<'a> {
491 const fn new(payloads: &'a [u8]) -> Self {
492 Self {
493 inner: Ok(Some(payloads)),
494 }
495 }
496}
497
498impl<'a> Iterator for HandshakeMessagesIter<'a> {
499 type Item = Result<(HandshakeType, &'a [u8]), ()>;
500
501 fn next(&mut self) -> Option<Self::Item> {
502 match self.inner {
503 Ok(None) => None,
504 Ok(Some(&[typ, a, b, c, ref rest @ ..])) => {
505 let handshake_type = HandshakeType::from_int(typ);
506 let payload_length = u32::from_be_bytes([0, a, b, c]) as usize;
507
508 let Some((payload, rest)) = rest.split_at_checked(payload_length) else {
509 crate::error!(
510 "Received truncated handshake message payload, expected: \
511 {payload_length}, actual: {}",
512 rest.len()
513 );
514
515 self.inner = Err(());
516
517 return Some(Err(()));
518 };
519
520 if rest.is_empty() {
521 self.inner = Ok(None);
522 } else {
523 self.inner = Ok(Some(rest));
524 }
525
526 Some(Ok((handshake_type, payload)))
527 }
528 Ok(Some(_truncated)) => {
529 crate::error!("Received truncated handshake message payload: {_truncated:?}");
530
531 self.inner = Err(());
532
533 Some(Err(()))
534 }
535 Err(()) => Some(Err(())),
536 }
537 }
538}