1use alloc::boxed::Box;
2use alloc::vec::Vec;
3
4use pki_types::CertificateDer;
5
6use crate::enums::{AlertDescription, ContentType, HandshakeType, ProtocolVersion};
7use crate::error::{Error, InvalidMessage, PeerMisbehaved};
8#[cfg(feature = "logging")]
9use crate::log::{debug, warn};
10use crate::msgs::alert::AlertMessagePayload;
11use crate::msgs::base::Payload;
12use crate::msgs::enums::{AlertLevel, KeyUpdateRequest};
13use crate::msgs::fragmenter::MessageFragmenter;
14use crate::msgs::handshake::CertificateChain;
15use crate::msgs::message::{
16 Message, MessagePayload, OutboundChunks, OutboundOpaqueMessage, OutboundPlainMessage,
17 PlainMessage,
18};
19use crate::suites::{PartiallyExtractedSecrets, SupportedCipherSuite};
20#[cfg(feature = "tls12")]
21use crate::tls12::ConnectionSecrets;
22use crate::unbuffered::{EncryptError, InsufficientSizeError};
23use crate::vecbuf::ChunkVecBuffer;
24use crate::{quic, record_layer};
25
26pub struct CommonState {
28 pub(crate) negotiated_version: Option<ProtocolVersion>,
29 pub(crate) side: Side,
30 pub(crate) record_layer: record_layer::RecordLayer,
31 pub(crate) suite: Option<SupportedCipherSuite>,
32 pub(crate) alpn_protocol: Option<Vec<u8>>,
33 pub(crate) aligned_handshake: bool,
34 pub(crate) may_send_application_data: bool,
35 pub(crate) may_receive_application_data: bool,
36 pub(crate) early_traffic: bool,
37 sent_fatal_alert: bool,
38 pub(crate) has_received_close_notify: bool,
40 #[cfg(feature = "std")]
41 pub(crate) has_seen_eof: bool,
42 pub(crate) received_middlebox_ccs: u8,
43 pub(crate) peer_certificates: Option<CertificateChain<'static>>,
44 message_fragmenter: MessageFragmenter,
45 pub(crate) received_plaintext: ChunkVecBuffer,
46 pub(crate) sendable_tls: ChunkVecBuffer,
47 queued_key_update_message: Option<Vec<u8>>,
48
49 pub(crate) protocol: Protocol,
51 pub(crate) quic: quic::Quic,
52 pub(crate) enable_secret_extraction: bool,
53}
54
55impl CommonState {
56 pub(crate) fn new(side: Side) -> Self {
57 Self {
58 negotiated_version: None,
59 side,
60 record_layer: record_layer::RecordLayer::new(),
61 suite: None,
62 alpn_protocol: None,
63 aligned_handshake: true,
64 may_send_application_data: false,
65 may_receive_application_data: false,
66 early_traffic: false,
67 sent_fatal_alert: false,
68 has_received_close_notify: false,
69 #[cfg(feature = "std")]
70 has_seen_eof: false,
71 received_middlebox_ccs: 0,
72 peer_certificates: None,
73 message_fragmenter: MessageFragmenter::default(),
74 received_plaintext: ChunkVecBuffer::new(Some(DEFAULT_RECEIVED_PLAINTEXT_LIMIT)),
75 sendable_tls: ChunkVecBuffer::new(Some(DEFAULT_BUFFER_LIMIT)),
76 queued_key_update_message: None,
77 protocol: Protocol::Tcp,
78 quic: quic::Quic::default(),
79 enable_secret_extraction: false,
80 }
81 }
82
83 pub fn wants_write(&self) -> bool {
87 !self.sendable_tls.is_empty()
88 }
89
90 pub fn is_handshaking(&self) -> bool {
98 !(self.may_send_application_data && self.may_receive_application_data)
99 }
100
101 pub fn peer_certificates(&self) -> Option<&[CertificateDer<'static>]> {
117 self.peer_certificates.as_deref()
118 }
119
120 pub fn alpn_protocol(&self) -> Option<&[u8]> {
126 self.get_alpn_protocol()
127 }
128
129 pub fn negotiated_cipher_suite(&self) -> Option<SupportedCipherSuite> {
133 self.suite
134 }
135
136 pub fn protocol_version(&self) -> Option<ProtocolVersion> {
140 self.negotiated_version
141 }
142
143 pub(crate) fn is_tls13(&self) -> bool {
144 matches!(self.negotiated_version, Some(ProtocolVersion::TLSv1_3))
145 }
146
147 pub(crate) fn process_main_protocol<Data>(
148 &mut self,
149 msg: Message,
150 mut state: Box<dyn State<Data>>,
151 data: &mut Data,
152 sendable_plaintext: Option<&mut ChunkVecBuffer>,
153 ) -> Result<Box<dyn State<Data>>, Error> {
154 if self.may_receive_application_data && !self.is_tls13() {
157 let reject_ty = match self.side {
158 Side::Client => HandshakeType::HelloRequest,
159 Side::Server => HandshakeType::ClientHello,
160 };
161 if msg.is_handshake_type(reject_ty) {
162 self.send_warning_alert(AlertDescription::NoRenegotiation);
163 return Ok(state);
164 }
165 }
166
167 let mut cx = Context {
168 common: self,
169 data,
170 sendable_plaintext,
171 };
172 match state.handle(&mut cx, msg) {
173 Ok(next) => {
174 state = next.into_owned();
175 Ok(state)
176 }
177 Err(e @ Error::InappropriateMessage { .. })
178 | Err(e @ Error::InappropriateHandshakeMessage { .. }) => {
179 Err(self.send_fatal_alert(AlertDescription::UnexpectedMessage, e))
180 }
181 Err(e) => Err(e),
182 }
183 }
184
185 pub(crate) fn write_plaintext(
186 &mut self,
187 payload: OutboundChunks<'_>,
188 outgoing_tls: &mut [u8],
189 ) -> Result<usize, EncryptError> {
190 if payload.is_empty() {
191 return Ok(0);
192 }
193
194 let fragments = self
195 .message_fragmenter
196 .fragment_payload(
197 ContentType::ApplicationData,
198 ProtocolVersion::TLSv1_2,
199 payload.clone(),
200 );
201
202 let remaining_encryptions = self
203 .record_layer
204 .remaining_write_seq()
205 .ok_or(EncryptError::EncryptExhausted)?;
206
207 if fragments.len() as u64 > remaining_encryptions.get() {
208 return Err(EncryptError::EncryptExhausted);
209 }
210
211 self.check_required_size(
212 outgoing_tls,
213 self.queued_key_update_message
214 .as_deref(),
215 fragments,
216 )?;
217
218 let fragments = self
219 .message_fragmenter
220 .fragment_payload(
221 ContentType::ApplicationData,
222 ProtocolVersion::TLSv1_2,
223 payload,
224 );
225
226 let opt_msg = self.queued_key_update_message.take();
227 let written = self.write_fragments(outgoing_tls, opt_msg, fragments);
228
229 Ok(written)
230 }
231
232 pub(crate) fn check_aligned_handshake(&mut self) -> Result<(), Error> {
237 if !self.aligned_handshake {
238 Err(self.send_fatal_alert(
239 AlertDescription::UnexpectedMessage,
240 PeerMisbehaved::KeyEpochWithPendingFragment,
241 ))
242 } else {
243 Ok(())
244 }
245 }
246
247 pub(crate) fn send_msg_encrypt(&mut self, m: PlainMessage) {
250 let iter = self
251 .message_fragmenter
252 .fragment_message(&m);
253 for m in iter {
254 self.send_single_fragment(m);
255 }
256 }
257
258 fn send_appdata_encrypt(&mut self, payload: OutboundChunks<'_>, limit: Limit) -> usize {
260 let len = match limit {
265 #[cfg(feature = "std")]
266 Limit::Yes => self
267 .sendable_tls
268 .apply_limit(payload.len()),
269 Limit::No => payload.len(),
270 };
271
272 let iter = self
273 .message_fragmenter
274 .fragment_payload(
275 ContentType::ApplicationData,
276 ProtocolVersion::TLSv1_2,
277 payload.split_at(len).0,
278 );
279 for m in iter {
280 self.send_single_fragment(m);
281 }
282
283 len
284 }
285
286 fn send_single_fragment(&mut self, m: OutboundPlainMessage) {
287 if self
290 .record_layer
291 .wants_close_before_encrypt()
292 {
293 self.send_close_notify();
294 }
295
296 if self.record_layer.encrypt_exhausted() {
299 return;
300 }
301
302 let em = self.record_layer.encrypt_outgoing(m);
303 self.queue_tls_message(em);
304 }
305
306 fn send_plain_non_buffering(&mut self, payload: OutboundChunks<'_>, limit: Limit) -> usize {
307 debug_assert!(self.may_send_application_data);
308 debug_assert!(self.record_layer.is_encrypting());
309
310 if payload.is_empty() {
311 return 0;
313 }
314
315 self.send_appdata_encrypt(payload, limit)
316 }
317
318 pub(crate) fn start_outgoing_traffic(
322 &mut self,
323 sendable_plaintext: &mut Option<&mut ChunkVecBuffer>,
324 ) {
325 self.may_send_application_data = true;
326 if let Some(sendable_plaintext) = sendable_plaintext {
327 self.flush_plaintext(sendable_plaintext);
328 }
329 }
330
331 pub(crate) fn start_traffic(&mut self, sendable_plaintext: &mut Option<&mut ChunkVecBuffer>) {
335 self.may_receive_application_data = true;
336 self.start_outgoing_traffic(sendable_plaintext);
337 }
338
339 fn flush_plaintext(&mut self, sendable_plaintext: &mut ChunkVecBuffer) {
342 if !self.may_send_application_data {
343 return;
344 }
345
346 while let Some(buf) = sendable_plaintext.pop() {
347 self.send_plain_non_buffering(buf.as_slice().into(), Limit::No);
348 }
349 }
350
351 fn queue_tls_message(&mut self, m: OutboundOpaqueMessage) {
353 self.sendable_tls.append(m.encode());
354 }
355
356 pub(crate) fn send_msg(&mut self, m: Message, must_encrypt: bool) {
358 {
359 if let Protocol::Quic = self.protocol {
360 if let MessagePayload::Alert(alert) = m.payload {
361 self.quic.alert = Some(alert.description);
362 } else {
363 debug_assert!(
364 matches!(m.payload, MessagePayload::Handshake { .. }),
365 "QUIC uses TLS for the cryptographic handshake only"
366 );
367 let mut bytes = Vec::new();
368 m.payload.encode(&mut bytes);
369 self.quic
370 .hs_queue
371 .push_back((must_encrypt, bytes));
372 }
373 return;
374 }
375 }
376 if !must_encrypt {
377 let msg = &m.into();
378 let iter = self
379 .message_fragmenter
380 .fragment_message(msg);
381 for m in iter {
382 self.queue_tls_message(m.to_unencrypted_opaque());
383 }
384 } else {
385 self.send_msg_encrypt(m.into());
386 }
387 }
388
389 pub(crate) fn take_received_plaintext(&mut self, bytes: Payload) {
390 self.received_plaintext
391 .append(bytes.into_vec());
392 }
393
394 #[cfg(feature = "tls12")]
395 pub(crate) fn start_encryption_tls12(&mut self, secrets: &ConnectionSecrets, side: Side) {
396 let (dec, enc) = secrets.make_cipher_pair(side);
397 self.record_layer
398 .prepare_message_encrypter(enc);
399 self.record_layer
400 .prepare_message_decrypter(dec);
401 }
402
403 pub(crate) fn missing_extension(&mut self, why: PeerMisbehaved) -> Error {
404 self.send_fatal_alert(AlertDescription::MissingExtension, why)
405 }
406
407 fn send_warning_alert(&mut self, desc: AlertDescription) {
408 warn!("Sending warning alert {:?}", desc);
409 self.send_warning_alert_no_log(desc);
410 }
411
412 pub(crate) fn process_alert(&mut self, alert: &AlertMessagePayload) -> Result<(), Error> {
413 if let AlertLevel::Unknown(_) = alert.level {
415 return Err(self.send_fatal_alert(
416 AlertDescription::IllegalParameter,
417 Error::AlertReceived(alert.description),
418 ));
419 }
420
421 if alert.description == AlertDescription::CloseNotify {
424 self.has_received_close_notify = true;
425 return Ok(());
426 }
427
428 let err = Error::AlertReceived(alert.description);
431 if alert.level == AlertLevel::Warning {
432 if self.is_tls13() && alert.description != AlertDescription::UserCanceled {
433 return Err(self.send_fatal_alert(AlertDescription::DecodeError, err));
434 } else {
435 warn!("TLS alert warning received: {:?}", alert);
436 return Ok(());
437 }
438 }
439
440 Err(err)
441 }
442
443 pub(crate) fn send_cert_verify_error_alert(&mut self, err: Error) -> Error {
444 self.send_fatal_alert(
445 match &err {
446 Error::InvalidCertificate(e) => e.clone().into(),
447 Error::PeerMisbehaved(_) => AlertDescription::IllegalParameter,
448 _ => AlertDescription::HandshakeFailure,
449 },
450 err,
451 )
452 }
453
454 pub(crate) fn send_fatal_alert(
455 &mut self,
456 desc: AlertDescription,
457 err: impl Into<Error>,
458 ) -> Error {
459 debug_assert!(!self.sent_fatal_alert);
460 let m = Message::build_alert(AlertLevel::Fatal, desc);
461 self.send_msg(m, self.record_layer.is_encrypting());
462 self.sent_fatal_alert = true;
463 err.into()
464 }
465
466 pub fn send_close_notify(&mut self) {
472 debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
473 self.send_warning_alert_no_log(AlertDescription::CloseNotify);
474 }
475
476 pub(crate) fn eager_send_close_notify(
477 &mut self,
478 outgoing_tls: &mut [u8],
479 ) -> Result<usize, EncryptError> {
480 debug_assert!(self.record_layer.is_encrypting());
481
482 let m = Message::build_alert(AlertLevel::Warning, AlertDescription::CloseNotify).into();
483
484 let iter = self
485 .message_fragmenter
486 .fragment_message(&m);
487
488 self.check_required_size(outgoing_tls, None, iter)?;
489
490 debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
491
492 let iter = self
493 .message_fragmenter
494 .fragment_message(&m);
495
496 let written = self.write_fragments(outgoing_tls, None, iter);
497
498 Ok(written)
499 }
500
501 fn send_warning_alert_no_log(&mut self, desc: AlertDescription) {
502 let m = Message::build_alert(AlertLevel::Warning, desc);
503 self.send_msg(m, self.record_layer.is_encrypting());
504 }
505
506 fn check_required_size<'a>(
507 &self,
508 outgoing_tls: &mut [u8],
509 opt_msg: Option<&[u8]>,
510 fragments: impl Iterator<Item = OutboundPlainMessage<'a>>,
511 ) -> Result<(), EncryptError> {
512 let mut required_size = 0;
513 if let Some(message) = opt_msg {
514 required_size += message.len();
515 }
516
517 for m in fragments {
518 required_size += m.encoded_len(&self.record_layer);
519 }
520
521 if required_size > outgoing_tls.len() {
522 return Err(EncryptError::InsufficientSize(InsufficientSizeError {
523 required_size,
524 }));
525 }
526
527 Ok(())
528 }
529
530 fn write_fragments<'a>(
531 &mut self,
532 outgoing_tls: &mut [u8],
533 opt_msg: Option<Vec<u8>>,
534 fragments: impl Iterator<Item = OutboundPlainMessage<'a>>,
535 ) -> usize {
536 let mut written = 0;
537
538 if let Some(message) = opt_msg {
539 let len = message.len();
540 outgoing_tls[written..written + len].copy_from_slice(&message);
541 written += len;
542 }
543
544 for m in fragments {
545 let em = self
546 .record_layer
547 .encrypt_outgoing(m)
548 .encode();
549
550 let len = em.len();
551 outgoing_tls[written..written + len].copy_from_slice(&em);
552 written += len;
553 }
554
555 written
556 }
557
558 pub(crate) fn set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error> {
559 self.message_fragmenter
560 .set_max_fragment_size(new)
561 }
562
563 pub(crate) fn get_alpn_protocol(&self) -> Option<&[u8]> {
564 self.alpn_protocol
565 .as_ref()
566 .map(AsRef::as_ref)
567 }
568
569 pub fn wants_read(&self) -> bool {
579 self.received_plaintext.is_empty()
586 && !self.has_received_close_notify
587 && (self.may_send_application_data || self.sendable_tls.is_empty())
588 }
589
590 pub(crate) fn current_io_state(&self) -> IoState {
591 IoState {
592 tls_bytes_to_write: self.sendable_tls.len(),
593 plaintext_bytes_to_read: self.received_plaintext.len(),
594 peer_has_closed: self.has_received_close_notify,
595 }
596 }
597
598 pub(crate) fn is_quic(&self) -> bool {
599 self.protocol == Protocol::Quic
600 }
601
602 pub(crate) fn should_update_key(
603 &mut self,
604 key_update_request: &KeyUpdateRequest,
605 ) -> Result<bool, Error> {
606 match key_update_request {
607 KeyUpdateRequest::UpdateNotRequested => Ok(false),
608 KeyUpdateRequest::UpdateRequested => Ok(self.queued_key_update_message.is_none()),
609 _ => Err(self.send_fatal_alert(
610 AlertDescription::IllegalParameter,
611 InvalidMessage::InvalidKeyUpdate,
612 )),
613 }
614 }
615
616 pub(crate) fn enqueue_key_update_notification(&mut self) {
617 let message = PlainMessage::from(Message::build_key_update_notify());
618 self.queued_key_update_message = Some(
619 self.record_layer
620 .encrypt_outgoing(message.borrow_outbound())
621 .encode(),
622 );
623 }
624}
625
626#[cfg(feature = "std")]
627impl CommonState {
628 pub(crate) fn buffer_plaintext(
634 &mut self,
635 payload: OutboundChunks<'_>,
636 sendable_plaintext: &mut ChunkVecBuffer,
637 ) -> usize {
638 self.perhaps_write_key_update();
639 self.send_plain(payload, Limit::Yes, sendable_plaintext)
640 }
641
642 pub(crate) fn send_early_plaintext(&mut self, data: &[u8]) -> usize {
643 debug_assert!(self.early_traffic);
644 debug_assert!(self.record_layer.is_encrypting());
645
646 if data.is_empty() {
647 return 0;
649 }
650
651 self.send_appdata_encrypt(data.into(), Limit::Yes)
652 }
653
654 fn send_plain(
660 &mut self,
661 payload: OutboundChunks<'_>,
662 limit: Limit,
663 sendable_plaintext: &mut ChunkVecBuffer,
664 ) -> usize {
665 if !self.may_send_application_data {
666 let len = match limit {
669 Limit::Yes => sendable_plaintext.append_limited_copy(payload),
670 Limit::No => sendable_plaintext.append(payload.to_vec()),
671 };
672 return len;
673 }
674
675 self.send_plain_non_buffering(payload, limit)
676 }
677
678 pub(crate) fn perhaps_write_key_update(&mut self) {
679 if let Some(message) = self.queued_key_update_message.take() {
680 self.sendable_tls.append(message);
681 }
682 }
683}
684
685#[derive(Debug, Eq, PartialEq)]
690pub struct IoState {
691 tls_bytes_to_write: usize,
692 plaintext_bytes_to_read: usize,
693 peer_has_closed: bool,
694}
695
696impl IoState {
697 pub fn tls_bytes_to_write(&self) -> usize {
702 self.tls_bytes_to_write
703 }
704
705 pub fn plaintext_bytes_to_read(&self) -> usize {
708 self.plaintext_bytes_to_read
709 }
710
711 pub fn peer_has_closed(&self) -> bool {
720 self.peer_has_closed
721 }
722}
723
724pub(crate) trait State<Data>: Send + Sync {
725 fn handle<'m>(
726 self: Box<Self>,
727 cx: &mut Context<'_, Data>,
728 message: Message<'m>,
729 ) -> Result<Box<dyn State<Data> + 'm>, Error>
730 where
731 Self: 'm;
732
733 fn export_keying_material(
734 &self,
735 _output: &mut [u8],
736 _label: &[u8],
737 _context: Option<&[u8]>,
738 ) -> Result<(), Error> {
739 Err(Error::HandshakeNotComplete)
740 }
741
742 fn extract_secrets(&self) -> Result<PartiallyExtractedSecrets, Error> {
743 Err(Error::HandshakeNotComplete)
744 }
745
746 fn handle_decrypt_error(&self) {}
747
748 fn into_owned(self: Box<Self>) -> Box<dyn State<Data> + 'static>;
749}
750
751pub(crate) struct Context<'a, Data> {
752 pub(crate) common: &'a mut CommonState,
753 pub(crate) data: &'a mut Data,
754 pub(crate) sendable_plaintext: Option<&'a mut ChunkVecBuffer>,
757}
758
759#[derive(Clone, Copy, Debug, PartialEq)]
761pub enum Side {
762 Client,
764 Server,
766}
767
768impl Side {
769 pub(crate) fn peer(&self) -> Self {
770 match self {
771 Self::Client => Self::Server,
772 Self::Server => Self::Client,
773 }
774 }
775}
776
777#[derive(Copy, Clone, Eq, PartialEq, Debug)]
778pub(crate) enum Protocol {
779 Tcp,
780 Quic,
781}
782
783enum Limit {
784 #[cfg(feature = "std")]
785 Yes,
786 No,
787}
788
789const DEFAULT_RECEIVED_PLAINTEXT_LIMIT: usize = 16 * 1024;
790pub(crate) const DEFAULT_BUFFER_LIMIT: usize = 64 * 1024;