aranya_internal_rustls/
common_state.rs

1use alloc::boxed::Box;
2use alloc::vec::Vec;
3
4use pki_types::CertificateDer;
5
6use crate::crypto::SupportedKxGroup;
7use crate::enums::{AlertDescription, ContentType, HandshakeType, ProtocolVersion};
8use crate::error::{Error, InvalidMessage, PeerMisbehaved};
9use crate::hash_hs::HandshakeHash;
10use crate::log::{debug, error, warn};
11use crate::msgs::alert::AlertMessagePayload;
12use crate::msgs::base::Payload;
13use crate::msgs::codec::Codec;
14use crate::msgs::enums::{AlertLevel, KeyUpdateRequest};
15use crate::msgs::fragmenter::MessageFragmenter;
16use crate::msgs::handshake::{CertificateChain, HandshakeMessagePayload};
17use crate::msgs::message::{
18    Message, MessagePayload, OutboundChunks, OutboundOpaqueMessage, OutboundPlainMessage,
19    PlainMessage,
20};
21use crate::record_layer::PreEncryptAction;
22use crate::suites::{PartiallyExtractedSecrets, SupportedCipherSuite};
23#[cfg(feature = "tls12")]
24use crate::tls12::ConnectionSecrets;
25use crate::unbuffered::{EncryptError, InsufficientSizeError};
26use crate::vecbuf::ChunkVecBuffer;
27use crate::{quic, record_layer};
28
29/// Connection state common to both client and server connections.
30pub struct CommonState {
31    pub(crate) negotiated_version: Option<ProtocolVersion>,
32    pub(crate) chosen_psk_identity: Option<Vec<u8>>,
33    pub(crate) handshake_kind: Option<HandshakeKind>,
34    pub(crate) side: Side,
35    pub(crate) record_layer: record_layer::RecordLayer,
36    pub(crate) suite: Option<SupportedCipherSuite>,
37    pub(crate) kx_state: KxState,
38    pub(crate) alpn_protocol: Option<Vec<u8>>,
39    pub(crate) aligned_handshake: bool,
40    pub(crate) may_send_application_data: bool,
41    pub(crate) may_receive_application_data: bool,
42    pub(crate) early_traffic: bool,
43    sent_fatal_alert: bool,
44    /// If we signaled end of stream.
45    pub(crate) has_sent_close_notify: bool,
46    /// If the peer has signaled end of stream.
47    pub(crate) has_received_close_notify: bool,
48    #[cfg(feature = "std")]
49    pub(crate) has_seen_eof: bool,
50    pub(crate) peer_certificates: Option<CertificateChain<'static>>,
51    message_fragmenter: MessageFragmenter,
52    pub(crate) received_plaintext: ChunkVecBuffer,
53    pub(crate) sendable_tls: ChunkVecBuffer,
54    queued_key_update_message: Option<Vec<u8>>,
55
56    /// Protocol whose key schedule should be used. Unused for TLS < 1.3.
57    pub(crate) protocol: Protocol,
58    pub(crate) quic: quic::Quic,
59    pub(crate) enable_secret_extraction: bool,
60    temper_counters: TemperCounters,
61    pub(crate) refresh_traffic_keys_pending: bool,
62    pub(crate) fips: bool,
63}
64
65impl CommonState {
66    pub(crate) fn new(side: Side) -> Self {
67        Self {
68            negotiated_version: None,
69            chosen_psk_identity: None,
70            handshake_kind: None,
71            side,
72            record_layer: record_layer::RecordLayer::new(),
73            suite: None,
74            kx_state: KxState::default(),
75            alpn_protocol: None,
76            aligned_handshake: true,
77            may_send_application_data: false,
78            may_receive_application_data: false,
79            early_traffic: false,
80            sent_fatal_alert: false,
81            has_sent_close_notify: false,
82            has_received_close_notify: false,
83            #[cfg(feature = "std")]
84            has_seen_eof: false,
85            peer_certificates: None,
86            message_fragmenter: MessageFragmenter::default(),
87            received_plaintext: ChunkVecBuffer::new(Some(DEFAULT_RECEIVED_PLAINTEXT_LIMIT)),
88            sendable_tls: ChunkVecBuffer::new(Some(DEFAULT_BUFFER_LIMIT)),
89            queued_key_update_message: None,
90            protocol: Protocol::Tcp,
91            quic: quic::Quic::default(),
92            enable_secret_extraction: false,
93            temper_counters: TemperCounters::default(),
94            refresh_traffic_keys_pending: false,
95            fips: false,
96        }
97    }
98
99    /// Returns true if the caller should call [`Connection::write_tls`] as soon as possible.
100    ///
101    /// [`Connection::write_tls`]: crate::Connection::write_tls
102    pub fn wants_write(&self) -> bool {
103        !self.sendable_tls.is_empty()
104    }
105
106    /// Returns true if the connection is currently performing the TLS handshake.
107    ///
108    /// During this time plaintext written to the connection is buffered in memory. After
109    /// [`Connection::process_new_packets()`] has been called, this might start to return `false`
110    /// while the final handshake packets still need to be extracted from the connection's buffers.
111    ///
112    /// [`Connection::process_new_packets()`]: crate::Connection::process_new_packets
113    pub fn is_handshaking(&self) -> bool {
114        !(self.may_send_application_data && self.may_receive_application_data)
115    }
116
117    /// Retrieves the certificate chain or the raw public key used by the peer to authenticate.
118    ///
119    /// The order of the certificate chain is as it appears in the TLS
120    /// protocol: the first certificate relates to the peer, the
121    /// second certifies the first, the third certifies the second, and
122    /// so on.
123    ///
124    /// When using raw public keys, the first and only element is the raw public key.
125    ///
126    /// This is made available for both full and resumed handshakes.
127    ///
128    /// For clients, this is the certificate chain or the raw public key of the server.
129    ///
130    /// For servers, this is the certificate chain or the raw public key of the client,
131    /// if client authentication was completed.
132    ///
133    /// The return value is None until this value is available.
134    ///
135    /// Note: the return type of the 'certificate', when using raw public keys is `CertificateDer<'static>`
136    /// even though this should technically be a `SubjectPublicKeyInfoDer<'static>`.
137    /// This choice simplifies the API and ensures backwards compatibility.
138    pub fn peer_certificates(&self) -> Option<&[CertificateDer<'static>]> {
139        self.peer_certificates.as_deref()
140    }
141
142    /// Retrieves the protocol agreed with the peer via ALPN.
143    ///
144    /// A return value of `None` after handshake completion
145    /// means no protocol was agreed (because no protocols
146    /// were offered or accepted by the peer).
147    pub fn alpn_protocol(&self) -> Option<&[u8]> {
148        self.get_alpn_protocol()
149    }
150
151    /// Retrieves the ciphersuite agreed with the peer.
152    ///
153    /// This returns None until the ciphersuite is agreed.
154    pub fn negotiated_cipher_suite(&self) -> Option<SupportedCipherSuite> {
155        self.suite
156    }
157
158    /// Retrieves the key exchange group agreed with the peer.
159    ///
160    /// This function may return `None` depending on the state of the connection,
161    /// the type of handshake, and the protocol version.
162    ///
163    /// If [`CommonState::is_handshaking()`] is true this function will return `None`.
164    /// Similarly, if the [`CommonState::handshake_kind()`] is [`HandshakeKind::Resumed`]
165    /// and the [`CommonState::protocol_version()`] is TLS 1.2, then no key exchange will have
166    /// occurred and this function will return `None`.
167    pub fn negotiated_key_exchange_group(&self) -> Option<&'static dyn SupportedKxGroup> {
168        match self.kx_state {
169            KxState::Complete(group) => Some(group),
170            _ => None,
171        }
172    }
173
174    /// Retrieves the protocol version agreed with the peer.
175    ///
176    /// This returns `None` until the version is agreed.
177    pub fn protocol_version(&self) -> Option<ProtocolVersion> {
178        self.negotiated_version
179    }
180
181    /// Which kind of handshake was performed.
182    ///
183    /// This tells you whether the handshake was a resumption or not.
184    ///
185    /// This will return `None` before it is known which sort of
186    /// handshake occurred.
187    pub fn handshake_kind(&self) -> Option<HandshakeKind> {
188        self.handshake_kind
189    }
190
191    /// Returns the identity of the PSK agreed with the peer, if
192    /// any.
193    ///
194    /// It returns `None` before a PSK has been negotiated, or if
195    /// a PSK wasn't used.
196    pub fn chosen_psk_identity(&self) -> Option<&[u8]> {
197        self.chosen_psk_identity.as_deref()
198    }
199
200    pub(crate) fn is_tls13(&self) -> bool {
201        matches!(self.negotiated_version, Some(ProtocolVersion::TLSv1_3))
202    }
203
204    pub(crate) fn process_main_protocol<Data>(
205        &mut self,
206        msg: Message<'_>,
207        mut state: Box<dyn State<Data>>,
208        data: &mut Data,
209        sendable_plaintext: Option<&mut ChunkVecBuffer>,
210    ) -> Result<Box<dyn State<Data>>, Error> {
211        // For TLS1.2, outside of the handshake, send rejection alerts for
212        // renegotiation requests.  These can occur any time.
213        if self.may_receive_application_data && !self.is_tls13() {
214            let reject_ty = match self.side {
215                Side::Client => HandshakeType::HelloRequest,
216                Side::Server => HandshakeType::ClientHello,
217            };
218            if msg.is_handshake_type(reject_ty) {
219                self.temper_counters
220                    .received_renegotiation_request()?;
221                self.send_warning_alert(AlertDescription::NoRenegotiation);
222                return Ok(state);
223            }
224        }
225
226        let mut cx = Context {
227            common: self,
228            data,
229            sendable_plaintext,
230        };
231        match state.handle(&mut cx, msg) {
232            Ok(next) => {
233                state = next.into_owned();
234                Ok(state)
235            }
236            Err(e @ Error::InappropriateMessage { .. })
237            | Err(e @ Error::InappropriateHandshakeMessage { .. }) => {
238                Err(self.send_fatal_alert(AlertDescription::UnexpectedMessage, e))
239            }
240            Err(e) => Err(e),
241        }
242    }
243
244    pub(crate) fn write_plaintext(
245        &mut self,
246        payload: OutboundChunks<'_>,
247        outgoing_tls: &mut [u8],
248    ) -> Result<usize, EncryptError> {
249        if payload.is_empty() {
250            return Ok(0);
251        }
252
253        let fragments = self
254            .message_fragmenter
255            .fragment_payload(
256                ContentType::ApplicationData,
257                ProtocolVersion::TLSv1_2,
258                payload.clone(),
259            );
260
261        for f in 0..fragments.len() {
262            match self
263                .record_layer
264                .pre_encrypt_action(f as u64)
265            {
266                PreEncryptAction::Nothing => {}
267                PreEncryptAction::RefreshOrClose => match self.negotiated_version {
268                    Some(ProtocolVersion::TLSv1_3) => {
269                        // driven by caller, as we don't have the `State` here
270                        self.refresh_traffic_keys_pending = true;
271                    }
272                    _ => {
273                        error!(
274                            "traffic keys exhausted, closing connection to prevent security failure"
275                        );
276                        self.send_close_notify();
277                        return Err(EncryptError::EncryptExhausted);
278                    }
279                },
280                PreEncryptAction::Refuse => {
281                    return Err(EncryptError::EncryptExhausted);
282                }
283            }
284        }
285
286        self.perhaps_write_key_update();
287
288        self.check_required_size(outgoing_tls, fragments)?;
289
290        let fragments = self
291            .message_fragmenter
292            .fragment_payload(
293                ContentType::ApplicationData,
294                ProtocolVersion::TLSv1_2,
295                payload,
296            );
297
298        Ok(self.write_fragments(outgoing_tls, fragments))
299    }
300
301    // Changing the keys must not span any fragmented handshake
302    // messages.  Otherwise the defragmented messages will have
303    // been protected with two different record layer protections,
304    // which is illegal.  Not mentioned in RFC.
305    pub(crate) fn check_aligned_handshake(&mut self) -> Result<(), Error> {
306        if !self.aligned_handshake {
307            Err(self.send_fatal_alert(
308                AlertDescription::UnexpectedMessage,
309                PeerMisbehaved::KeyEpochWithPendingFragment,
310            ))
311        } else {
312            Ok(())
313        }
314    }
315
316    /// Fragment `m`, encrypt the fragments, and then queue
317    /// the encrypted fragments for sending.
318    pub(crate) fn send_msg_encrypt(&mut self, m: PlainMessage) {
319        let iter = self
320            .message_fragmenter
321            .fragment_message(&m);
322        for m in iter {
323            self.send_single_fragment(m);
324        }
325    }
326
327    /// Like send_msg_encrypt, but operate on an appdata directly.
328    fn send_appdata_encrypt(&mut self, payload: OutboundChunks<'_>, limit: Limit) -> usize {
329        // Here, the limit on sendable_tls applies to encrypted data,
330        // but we're respecting it for plaintext data -- so we'll
331        // be out by whatever the cipher+record overhead is.  That's a
332        // constant and predictable amount, so it's not a terrible issue.
333        let len = match limit {
334            #[cfg(feature = "std")]
335            Limit::Yes => self
336                .sendable_tls
337                .apply_limit(payload.len()),
338            Limit::No => payload.len(),
339        };
340
341        let iter = self
342            .message_fragmenter
343            .fragment_payload(
344                ContentType::ApplicationData,
345                ProtocolVersion::TLSv1_2,
346                payload.split_at(len).0,
347            );
348        for m in iter {
349            self.send_single_fragment(m);
350        }
351
352        len
353    }
354
355    fn send_single_fragment(&mut self, m: OutboundPlainMessage<'_>) {
356        if m.typ == ContentType::Alert {
357            // Alerts are always sendable -- never quashed by a PreEncryptAction.
358            let em = self.record_layer.encrypt_outgoing(m);
359            self.queue_tls_message(em);
360            return;
361        }
362
363        match self
364            .record_layer
365            .next_pre_encrypt_action()
366        {
367            PreEncryptAction::Nothing => {}
368
369            // Close connection once we start to run out of
370            // sequence space.
371            PreEncryptAction::RefreshOrClose => {
372                match self.negotiated_version {
373                    Some(ProtocolVersion::TLSv1_3) => {
374                        // driven by caller, as we don't have the `State` here
375                        self.refresh_traffic_keys_pending = true;
376                    }
377                    _ => {
378                        error!(
379                            "traffic keys exhausted, closing connection to prevent security failure"
380                        );
381                        self.send_close_notify();
382                        return;
383                    }
384                }
385            }
386
387            // Refuse to wrap counter at all costs.  This
388            // is basically untestable unfortunately.
389            PreEncryptAction::Refuse => {
390                return;
391            }
392        };
393
394        let em = self.record_layer.encrypt_outgoing(m);
395        self.queue_tls_message(em);
396    }
397
398    fn send_plain_non_buffering(&mut self, payload: OutboundChunks<'_>, limit: Limit) -> usize {
399        debug_assert!(self.may_send_application_data);
400        debug_assert!(self.record_layer.is_encrypting());
401
402        if payload.is_empty() {
403            // Don't send empty fragments.
404            return 0;
405        }
406
407        self.send_appdata_encrypt(payload, limit)
408    }
409
410    /// Mark the connection as ready to send application data.
411    ///
412    /// Also flush `sendable_plaintext` if it is `Some`.
413    pub(crate) fn start_outgoing_traffic(
414        &mut self,
415        sendable_plaintext: &mut Option<&mut ChunkVecBuffer>,
416    ) {
417        self.may_send_application_data = true;
418        if let Some(sendable_plaintext) = sendable_plaintext {
419            self.flush_plaintext(sendable_plaintext);
420        }
421    }
422
423    /// Mark the connection as ready to send and receive application data.
424    ///
425    /// Also flush `sendable_plaintext` if it is `Some`.
426    pub(crate) fn start_traffic(&mut self, sendable_plaintext: &mut Option<&mut ChunkVecBuffer>) {
427        self.may_receive_application_data = true;
428        self.start_outgoing_traffic(sendable_plaintext);
429    }
430
431    /// Send any buffered plaintext.  Plaintext is buffered if
432    /// written during handshake.
433    fn flush_plaintext(&mut self, sendable_plaintext: &mut ChunkVecBuffer) {
434        if !self.may_send_application_data {
435            return;
436        }
437
438        while let Some(buf) = sendable_plaintext.pop() {
439            self.send_plain_non_buffering(buf.as_slice().into(), Limit::No);
440        }
441    }
442
443    // Put m into sendable_tls for writing.
444    fn queue_tls_message(&mut self, m: OutboundOpaqueMessage) {
445        self.perhaps_write_key_update();
446        self.sendable_tls.append(m.encode());
447    }
448
449    pub(crate) fn perhaps_write_key_update(&mut self) {
450        if let Some(message) = self.queued_key_update_message.take() {
451            self.sendable_tls.append(message);
452        }
453    }
454
455    /// Send a raw TLS message, fragmenting it if needed.
456    pub(crate) fn send_msg(&mut self, m: Message<'_>, must_encrypt: bool) {
457        {
458            if let Protocol::Quic = self.protocol {
459                if let MessagePayload::Alert(alert) = m.payload {
460                    self.quic.alert = Some(alert.description);
461                } else {
462                    debug_assert!(
463                        matches!(
464                            m.payload,
465                            MessagePayload::Handshake { .. } | MessagePayload::HandshakeFlight(_)
466                        ),
467                        "QUIC uses TLS for the cryptographic handshake only"
468                    );
469                    let mut bytes = Vec::new();
470                    m.payload.encode(&mut bytes);
471                    self.quic
472                        .hs_queue
473                        .push_back((must_encrypt, bytes));
474                }
475                return;
476            }
477        }
478        if !must_encrypt {
479            let msg = &m.into();
480            let iter = self
481                .message_fragmenter
482                .fragment_message(msg);
483            for m in iter {
484                self.queue_tls_message(m.to_unencrypted_opaque());
485            }
486        } else {
487            self.send_msg_encrypt(m.into());
488        }
489    }
490
491    pub(crate) fn take_received_plaintext(&mut self, bytes: Payload<'_>) {
492        self.received_plaintext
493            .append(bytes.into_vec());
494    }
495
496    #[cfg(feature = "tls12")]
497    pub(crate) fn start_encryption_tls12(&mut self, secrets: &ConnectionSecrets, side: Side) {
498        let (dec, enc) = secrets.make_cipher_pair(side);
499        self.record_layer
500            .prepare_message_encrypter(
501                enc,
502                secrets
503                    .suite()
504                    .common
505                    .confidentiality_limit,
506            );
507        self.record_layer
508            .prepare_message_decrypter(dec);
509    }
510
511    pub(crate) fn missing_extension(&mut self, why: PeerMisbehaved) -> Error {
512        self.send_fatal_alert(AlertDescription::MissingExtension, why)
513    }
514
515    fn send_warning_alert(&mut self, desc: AlertDescription) {
516        warn!("Sending warning alert {:?}", desc);
517        self.send_warning_alert_no_log(desc);
518    }
519
520    pub(crate) fn process_alert(&mut self, alert: &AlertMessagePayload) -> Result<(), Error> {
521        // Reject unknown AlertLevels.
522        if let AlertLevel::Unknown(_) = alert.level {
523            return Err(self.send_fatal_alert(
524                AlertDescription::IllegalParameter,
525                Error::AlertReceived(alert.description),
526            ));
527        }
528
529        // If we get a CloseNotify, make a note to declare EOF to our
530        // caller.  But do not treat unauthenticated alerts like this.
531        if self.may_receive_application_data && alert.description == AlertDescription::CloseNotify {
532            self.has_received_close_notify = true;
533            return Ok(());
534        }
535
536        // Warnings are nonfatal for TLS1.2, but outlawed in TLS1.3
537        // (except, for no good reason, user_cancelled).
538        let err = Error::AlertReceived(alert.description);
539        if alert.level == AlertLevel::Warning {
540            self.temper_counters
541                .received_warning_alert()?;
542            if self.is_tls13() && alert.description != AlertDescription::UserCanceled {
543                return Err(self.send_fatal_alert(AlertDescription::DecodeError, err));
544            }
545
546            // Some implementations send pointless `user_canceled` alerts, don't log them
547            // in release mode (https://bugs.openjdk.org/browse/JDK-8323517).
548            if alert.description != AlertDescription::UserCanceled || cfg!(debug_assertions) {
549                warn!("TLS alert warning received: {alert:?}");
550            }
551
552            return Ok(());
553        }
554
555        Err(err)
556    }
557
558    pub(crate) fn send_cert_verify_error_alert(&mut self, err: Error) -> Error {
559        self.send_fatal_alert(
560            match &err {
561                Error::InvalidCertificate(e) => e.clone().into(),
562                Error::PeerMisbehaved(_) => AlertDescription::IllegalParameter,
563                _ => AlertDescription::HandshakeFailure,
564            },
565            err,
566        )
567    }
568
569    pub(crate) fn send_fatal_alert(
570        &mut self,
571        desc: AlertDescription,
572        err: impl Into<Error>,
573    ) -> Error {
574        debug_assert!(!self.sent_fatal_alert);
575        let m = Message::build_alert(AlertLevel::Fatal, desc);
576        self.send_msg(m, self.record_layer.is_encrypting());
577        self.sent_fatal_alert = true;
578        err.into()
579    }
580
581    /// Queues a `close_notify` warning alert to be sent in the next
582    /// [`Connection::write_tls`] call.  This informs the peer that the
583    /// connection is being closed.
584    ///
585    /// Does nothing if any `close_notify` or fatal alert was already sent.
586    ///
587    /// [`Connection::write_tls`]: crate::Connection::write_tls
588    pub fn send_close_notify(&mut self) {
589        if self.sent_fatal_alert {
590            return;
591        }
592        debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
593        self.sent_fatal_alert = true;
594        self.has_sent_close_notify = true;
595        self.send_warning_alert_no_log(AlertDescription::CloseNotify);
596    }
597
598    pub(crate) fn eager_send_close_notify(
599        &mut self,
600        outgoing_tls: &mut [u8],
601    ) -> Result<usize, EncryptError> {
602        self.send_close_notify();
603        self.check_required_size(outgoing_tls, [].into_iter())?;
604        Ok(self.write_fragments(outgoing_tls, [].into_iter()))
605    }
606
607    fn send_warning_alert_no_log(&mut self, desc: AlertDescription) {
608        let m = Message::build_alert(AlertLevel::Warning, desc);
609        self.send_msg(m, self.record_layer.is_encrypting());
610    }
611
612    fn check_required_size<'a>(
613        &self,
614        outgoing_tls: &mut [u8],
615        fragments: impl Iterator<Item = OutboundPlainMessage<'a>>,
616    ) -> Result<(), EncryptError> {
617        let mut required_size = self.sendable_tls.len();
618
619        for m in fragments {
620            required_size += m.encoded_len(&self.record_layer);
621        }
622
623        if required_size > outgoing_tls.len() {
624            return Err(EncryptError::InsufficientSize(InsufficientSizeError {
625                required_size,
626            }));
627        }
628
629        Ok(())
630    }
631
632    fn write_fragments<'a>(
633        &mut self,
634        outgoing_tls: &mut [u8],
635        fragments: impl Iterator<Item = OutboundPlainMessage<'a>>,
636    ) -> usize {
637        let mut written = 0;
638
639        // Any pre-existing encrypted messages in `sendable_tls` must
640        // be output before encrypting any of the `fragments`.
641        while let Some(message) = self.sendable_tls.pop() {
642            let len = message.len();
643            outgoing_tls[written..written + len].copy_from_slice(&message);
644            written += len;
645        }
646
647        for m in fragments {
648            let em = self
649                .record_layer
650                .encrypt_outgoing(m)
651                .encode();
652
653            let len = em.len();
654            outgoing_tls[written..written + len].copy_from_slice(&em);
655            written += len;
656        }
657
658        written
659    }
660
661    pub(crate) fn set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error> {
662        self.message_fragmenter
663            .set_max_fragment_size(new)
664    }
665
666    pub(crate) fn get_alpn_protocol(&self) -> Option<&[u8]> {
667        self.alpn_protocol
668            .as_ref()
669            .map(AsRef::as_ref)
670    }
671
672    /// Returns true if the caller should call [`Connection::read_tls`] as soon
673    /// as possible.
674    ///
675    /// If there is pending plaintext data to read with [`Connection::reader`],
676    /// this returns false.  If your application respects this mechanism,
677    /// only one full TLS message will be buffered by rustls.
678    ///
679    /// [`Connection::reader`]: crate::Connection::reader
680    /// [`Connection::read_tls`]: crate::Connection::read_tls
681    pub fn wants_read(&self) -> bool {
682        // We want to read more data all the time, except when we have unprocessed plaintext.
683        // This provides back-pressure to the TCP buffers. We also don't want to read more after
684        // the peer has sent us a close notification.
685        //
686        // In the handshake case we don't have readable plaintext before the handshake has
687        // completed, but also don't want to read if we still have sendable tls.
688        self.received_plaintext.is_empty()
689            && !self.has_received_close_notify
690            && (self.may_send_application_data || self.sendable_tls.is_empty())
691    }
692
693    pub(crate) fn current_io_state(&self) -> IoState {
694        IoState {
695            tls_bytes_to_write: self.sendable_tls.len(),
696            plaintext_bytes_to_read: self.received_plaintext.len(),
697            peer_has_closed: self.has_received_close_notify,
698        }
699    }
700
701    pub(crate) fn is_quic(&self) -> bool {
702        self.protocol == Protocol::Quic
703    }
704
705    pub(crate) fn should_update_key(
706        &mut self,
707        key_update_request: &KeyUpdateRequest,
708    ) -> Result<bool, Error> {
709        self.temper_counters
710            .received_key_update_request()?;
711
712        match key_update_request {
713            KeyUpdateRequest::UpdateNotRequested => Ok(false),
714            KeyUpdateRequest::UpdateRequested => Ok(self.queued_key_update_message.is_none()),
715            _ => Err(self.send_fatal_alert(
716                AlertDescription::IllegalParameter,
717                InvalidMessage::InvalidKeyUpdate,
718            )),
719        }
720    }
721
722    pub(crate) fn enqueue_key_update_notification(&mut self) {
723        let message = PlainMessage::from(Message::build_key_update_notify());
724        self.queued_key_update_message = Some(
725            self.record_layer
726                .encrypt_outgoing(message.borrow_outbound())
727                .encode(),
728        );
729    }
730
731    pub(crate) fn received_tls13_change_cipher_spec(&mut self) -> Result<(), Error> {
732        self.temper_counters
733            .received_tls13_change_cipher_spec()
734    }
735}
736
737#[cfg(feature = "std")]
738impl CommonState {
739    /// Send plaintext application data, fragmenting and
740    /// encrypting it as it goes out.
741    ///
742    /// If internal buffers are too small, this function will not accept
743    /// all the data.
744    pub(crate) fn buffer_plaintext(
745        &mut self,
746        payload: OutboundChunks<'_>,
747        sendable_plaintext: &mut ChunkVecBuffer,
748    ) -> usize {
749        self.perhaps_write_key_update();
750        self.send_plain(payload, Limit::Yes, sendable_plaintext)
751    }
752
753    pub(crate) fn send_early_plaintext(&mut self, data: &[u8]) -> usize {
754        debug_assert!(self.early_traffic);
755        debug_assert!(self.record_layer.is_encrypting());
756
757        if data.is_empty() {
758            // Don't send empty fragments.
759            return 0;
760        }
761
762        self.send_appdata_encrypt(data.into(), Limit::Yes)
763    }
764
765    /// Encrypt and send some plaintext `data`.  `limit` controls
766    /// whether the per-connection buffer limits apply.
767    ///
768    /// Returns the number of bytes written from `data`: this might
769    /// be less than `data.len()` if buffer limits were exceeded.
770    fn send_plain(
771        &mut self,
772        payload: OutboundChunks<'_>,
773        limit: Limit,
774        sendable_plaintext: &mut ChunkVecBuffer,
775    ) -> usize {
776        if !self.may_send_application_data {
777            // If we haven't completed handshaking, buffer
778            // plaintext to send once we do.
779            let len = match limit {
780                Limit::Yes => sendable_plaintext.append_limited_copy(payload),
781                Limit::No => sendable_plaintext.append(payload.to_vec()),
782            };
783            return len;
784        }
785
786        self.send_plain_non_buffering(payload, limit)
787    }
788}
789
790/// Describes which sort of handshake happened.
791#[derive(Debug, PartialEq, Clone, Copy)]
792pub enum HandshakeKind {
793    /// A full handshake.
794    ///
795    /// This is the typical TLS connection initiation process when resumption is
796    /// not yet unavailable, and the initial `ClientHello` was accepted by the server.
797    Full,
798
799    /// A full TLS1.3 handshake, with an extra round-trip for a `HelloRetryRequest`.
800    ///
801    /// The server can respond with a `HelloRetryRequest` if the initial `ClientHello`
802    /// is unacceptable for several reasons, the most likely if no supported key
803    /// shares were offered by the client.
804    FullWithHelloRetryRequest,
805
806    /// A resumed handshake.
807    ///
808    /// Resumed handshakes involve fewer round trips and less cryptography than
809    /// full ones, but can only happen when the peers have previously done a full
810    /// handshake together, and then remember data about it.
811    Resumed,
812
813    /// A PSK-based handshake.
814    Psk,
815}
816
817/// Values of this structure are returned from [`Connection::process_new_packets`]
818/// and tell the caller the current I/O state of the TLS connection.
819///
820/// [`Connection::process_new_packets`]: crate::Connection::process_new_packets
821#[derive(Debug, Eq, PartialEq)]
822pub struct IoState {
823    tls_bytes_to_write: usize,
824    plaintext_bytes_to_read: usize,
825    peer_has_closed: bool,
826}
827
828impl IoState {
829    /// How many bytes could be written by [`Connection::write_tls`] if called
830    /// right now.  A non-zero value implies [`CommonState::wants_write`].
831    ///
832    /// [`Connection::write_tls`]: crate::Connection::write_tls
833    pub fn tls_bytes_to_write(&self) -> usize {
834        self.tls_bytes_to_write
835    }
836
837    /// How many plaintext bytes could be obtained via [`std::io::Read`]
838    /// without further I/O.
839    pub fn plaintext_bytes_to_read(&self) -> usize {
840        self.plaintext_bytes_to_read
841    }
842
843    /// True if the peer has sent us a close_notify alert.  This is
844    /// the TLS mechanism to securely half-close a TLS connection,
845    /// and signifies that the peer will not send any further data
846    /// on this connection.
847    ///
848    /// This is also signalled via returning `Ok(0)` from
849    /// [`std::io::Read`], after all the received bytes have been
850    /// retrieved.
851    pub fn peer_has_closed(&self) -> bool {
852        self.peer_has_closed
853    }
854}
855
856pub(crate) trait State<Data>: Send + Sync {
857    fn handle<'m>(
858        self: Box<Self>,
859        cx: &mut Context<'_, Data>,
860        message: Message<'m>,
861    ) -> Result<Box<dyn State<Data> + 'm>, Error>
862    where
863        Self: 'm;
864
865    fn export_keying_material(
866        &self,
867        _output: &mut [u8],
868        _label: &[u8],
869        _context: Option<&[u8]>,
870    ) -> Result<(), Error> {
871        Err(Error::HandshakeNotComplete)
872    }
873
874    fn extract_secrets(&self) -> Result<PartiallyExtractedSecrets, Error> {
875        Err(Error::HandshakeNotComplete)
876    }
877
878    fn send_key_update_request(&mut self, _common: &mut CommonState) -> Result<(), Error> {
879        Err(Error::HandshakeNotComplete)
880    }
881
882    fn handle_decrypt_error(&self) {}
883
884    fn into_owned(self: Box<Self>) -> Box<dyn State<Data> + 'static>;
885}
886
887pub(crate) struct Context<'a, Data> {
888    pub(crate) common: &'a mut CommonState,
889    pub(crate) data: &'a mut Data,
890    /// Buffered plaintext. This is `Some` if any plaintext was written during handshake and `None`
891    /// otherwise.
892    pub(crate) sendable_plaintext: Option<&'a mut ChunkVecBuffer>,
893}
894
895/// Side of the connection.
896#[derive(Clone, Copy, Debug, PartialEq)]
897pub enum Side {
898    /// A client initiates the connection.
899    Client,
900    /// A server waits for a client to connect.
901    Server,
902}
903
904impl Side {
905    pub(crate) fn peer(&self) -> Self {
906        match self {
907            Self::Client => Self::Server,
908            Self::Server => Self::Client,
909        }
910    }
911}
912
913#[derive(Copy, Clone, Eq, PartialEq, Debug)]
914pub(crate) enum Protocol {
915    Tcp,
916    Quic,
917}
918
919enum Limit {
920    #[cfg(feature = "std")]
921    Yes,
922    No,
923}
924
925/// Tracking technically-allowed protocol actions
926/// that we limit to avoid denial-of-service vectors.
927struct TemperCounters {
928    allowed_warning_alerts: u8,
929    allowed_renegotiation_requests: u8,
930    allowed_key_update_requests: u8,
931    allowed_middlebox_ccs: u8,
932}
933
934impl TemperCounters {
935    fn received_warning_alert(&mut self) -> Result<(), Error> {
936        match self.allowed_warning_alerts {
937            0 => Err(PeerMisbehaved::TooManyWarningAlertsReceived.into()),
938            _ => {
939                self.allowed_warning_alerts -= 1;
940                Ok(())
941            }
942        }
943    }
944
945    fn received_renegotiation_request(&mut self) -> Result<(), Error> {
946        match self.allowed_renegotiation_requests {
947            0 => Err(PeerMisbehaved::TooManyRenegotiationRequests.into()),
948            _ => {
949                self.allowed_renegotiation_requests -= 1;
950                Ok(())
951            }
952        }
953    }
954
955    fn received_key_update_request(&mut self) -> Result<(), Error> {
956        match self.allowed_key_update_requests {
957            0 => Err(PeerMisbehaved::TooManyKeyUpdateRequests.into()),
958            _ => {
959                self.allowed_key_update_requests -= 1;
960                Ok(())
961            }
962        }
963    }
964
965    fn received_tls13_change_cipher_spec(&mut self) -> Result<(), Error> {
966        match self.allowed_middlebox_ccs {
967            0 => Err(PeerMisbehaved::IllegalMiddleboxChangeCipherSpec.into()),
968            _ => {
969                self.allowed_middlebox_ccs -= 1;
970                Ok(())
971            }
972        }
973    }
974}
975
976impl Default for TemperCounters {
977    fn default() -> Self {
978        Self {
979            // cf. BoringSSL `kMaxWarningAlerts`
980            // <https://github.com/google/boringssl/blob/dec5989b793c56ad4dd32173bd2d8595ca78b398/ssl/tls_record.cc#L137-L139>
981            allowed_warning_alerts: 4,
982
983            // we rebuff renegotiation requests with a `NoRenegotiation` warning alerts.
984            // a second request after this is fatal.
985            allowed_renegotiation_requests: 1,
986
987            // cf. BoringSSL `kMaxKeyUpdates`
988            // <https://github.com/google/boringssl/blob/dec5989b793c56ad4dd32173bd2d8595ca78b398/ssl/tls13_both.cc#L35-L38>
989            allowed_key_update_requests: 32,
990
991            // At most two CCS are allowed: one after each ClientHello (recall a second
992            // ClientHello happens after a HelloRetryRequest).
993            //
994            // note BoringSSL allows up to 32.
995            allowed_middlebox_ccs: 2,
996        }
997    }
998}
999
1000#[derive(Debug, Default)]
1001pub(crate) enum KxState {
1002    #[default]
1003    None,
1004    Start(&'static dyn SupportedKxGroup),
1005    Complete(&'static dyn SupportedKxGroup),
1006}
1007
1008impl KxState {
1009    pub(crate) fn complete(&mut self) {
1010        debug_assert!(matches!(self, Self::Start(_)));
1011        if let Self::Start(group) = self {
1012            *self = Self::Complete(*group);
1013        }
1014    }
1015}
1016
1017pub(crate) struct HandshakeFlight<'a, const TLS13: bool> {
1018    pub(crate) transcript: &'a mut HandshakeHash,
1019    body: Vec<u8>,
1020}
1021
1022impl<'a, const TLS13: bool> HandshakeFlight<'a, TLS13> {
1023    pub(crate) fn new(transcript: &'a mut HandshakeHash) -> Self {
1024        Self {
1025            transcript,
1026            body: Vec::new(),
1027        }
1028    }
1029
1030    pub(crate) fn add(&mut self, hs: HandshakeMessagePayload<'_>) {
1031        let start_len = self.body.len();
1032        hs.encode(&mut self.body);
1033        self.transcript
1034            .add(&self.body[start_len..]);
1035    }
1036
1037    pub(crate) fn finish(self, common: &mut CommonState) {
1038        common.send_msg(
1039            Message {
1040                version: match TLS13 {
1041                    true => ProtocolVersion::TLSv1_3,
1042                    false => ProtocolVersion::TLSv1_2,
1043                },
1044                payload: MessagePayload::HandshakeFlight(Payload::new(self.body)),
1045            },
1046            TLS13,
1047        );
1048    }
1049}
1050
1051#[cfg(feature = "tls12")]
1052pub(crate) type HandshakeFlightTls12<'a> = HandshakeFlight<'a, false>;
1053pub(crate) type HandshakeFlightTls13<'a> = HandshakeFlight<'a, true>;
1054
1055const DEFAULT_RECEIVED_PLAINTEXT_LIMIT: usize = 16 * 1024;
1056pub(crate) const DEFAULT_BUFFER_LIMIT: usize = 64 * 1024;