embedded_tls/
connection.rs

1use crate::config::{TlsCipherSuite, TlsConfig, TlsVerifier};
2use crate::handshake::{ClientHandshake, ServerHandshake};
3use crate::key_schedule::{KeySchedule, ReadKeySchedule, WriteKeySchedule};
4use crate::record::{ClientRecord, ServerRecord};
5use crate::record_reader::RecordReader;
6use crate::write_buffer::WriteBuffer;
7use crate::TlsError;
8use crate::{
9    alert::*,
10    handshake::{certificate::CertificateRef, certificate_request::CertificateRequest},
11};
12use core::fmt::Debug;
13use embedded_io::Error as _;
14use embedded_io::{Read as BlockingRead, Write as BlockingWrite};
15use embedded_io_async::{Read as AsyncRead, Write as AsyncWrite};
16use rand_core::{CryptoRng, RngCore};
17
18use crate::application_data::ApplicationData;
19// use crate::handshake::certificate_request::CertificateRequest;
20// use crate::handshake::certificate_verify::CertificateVerify;
21// use crate::handshake::encrypted_extensions::EncryptedExtensions;
22// use crate::handshake::finished::Finished;
23// use crate::handshake::new_session_ticket::NewSessionTicket;
24// use crate::handshake::server_hello::ServerHello;
25use crate::buffer::CryptoBuffer;
26use digest::generic_array::typenum::Unsigned;
27use p256::ecdh::EphemeralSecret;
28
29use crate::content_types::ContentType;
30// use crate::handshake::certificate_request::CertificateRequest;
31// use crate::handshake::certificate_verify::CertificateVerify;
32// use crate::handshake::encrypted_extensions::EncryptedExtensions;
33// use crate::handshake::finished::Finished;
34// use crate::handshake::new_session_ticket::NewSessionTicket;
35// use crate::handshake::server_hello::ServerHello;
36use crate::parse_buffer::ParseBuffer;
37use aes_gcm::aead::{AeadCore, AeadInPlace, KeyInit};
38
39pub(crate) fn decrypt_record<CipherSuite>(
40    key_schedule: &mut ReadKeySchedule<CipherSuite>,
41    record: ServerRecord<'_, CipherSuite>,
42    mut cb: impl FnMut(
43        &mut ReadKeySchedule<CipherSuite>,
44        ServerRecord<'_, CipherSuite>,
45    ) -> Result<(), TlsError>,
46) -> Result<(), TlsError>
47where
48    CipherSuite: TlsCipherSuite,
49{
50    if let ServerRecord::ApplicationData(ApplicationData {
51        header,
52        data: mut app_data,
53    }) = record
54    {
55        let server_key = key_schedule.get_key()?;
56        let nonce = key_schedule.get_nonce()?;
57
58        let crypto = <CipherSuite::Cipher as KeyInit>::new(&server_key);
59        crypto
60            .decrypt_in_place(&nonce, header.data(), &mut app_data)
61            .map_err(|_| TlsError::CryptoError)?;
62
63        let padding = app_data
64            .as_slice()
65            .iter()
66            .enumerate()
67            .rfind(|(_, b)| **b != 0);
68        if let Some((index, _)) = padding {
69            app_data.truncate(index + 1);
70        };
71
72        let content_type =
73            ContentType::of(*app_data.as_slice().last().unwrap()).ok_or(TlsError::InvalidRecord)?;
74
75        trace!("Decrypting: content type = {:?}", content_type);
76
77        // Remove the content type
78        app_data.truncate(app_data.len() - 1);
79
80        let mut buf = ParseBuffer::new(app_data.as_slice());
81        match content_type {
82            ContentType::Handshake => {
83                // Decode potentially coalesced handshake messages
84                while buf.remaining() > 0 {
85                    let inner = ServerHandshake::read(&mut buf, key_schedule.transcript_hash())?;
86                    cb(key_schedule, ServerRecord::Handshake(inner))?;
87                }
88            }
89            ContentType::ApplicationData => {
90                let inner = ApplicationData::new(app_data, header);
91                cb(key_schedule, ServerRecord::ApplicationData(inner))?;
92            }
93            ContentType::Alert => {
94                let alert = Alert::parse(&mut buf)?;
95                cb(key_schedule, ServerRecord::Alert(alert))?;
96            }
97            _ => return Err(TlsError::Unimplemented),
98        }
99        key_schedule.increment_counter();
100    } else {
101        trace!("Not decrypting: content_type = {:?}", record.content_type());
102        cb(key_schedule, record)?;
103    }
104    Ok(())
105}
106
107pub(crate) fn encrypt<CipherSuite>(
108    key_schedule: &WriteKeySchedule<CipherSuite>,
109    buf: &mut CryptoBuffer<'_>,
110) -> Result<(), TlsError>
111where
112    CipherSuite: TlsCipherSuite,
113{
114    let client_key = key_schedule.get_key()?;
115    let nonce = key_schedule.get_nonce()?;
116    // trace!("encrypt key {:02x?}", client_key);
117    // trace!("encrypt nonce {:02x?}", nonce);
118    // trace!("plaintext {} {:02x?}", buf.len(), buf.as_slice(),);
119    //let crypto = Aes128Gcm::new_varkey(&self.key_schedule.get_client_key()).unwrap();
120    let crypto = <CipherSuite::Cipher as KeyInit>::new(&client_key);
121    let len = buf.len() + <CipherSuite::Cipher as AeadCore>::TagSize::to_usize();
122
123    if len > buf.capacity() {
124        return Err(TlsError::InsufficientSpace);
125    }
126
127    trace!("output size {}", len);
128    let len_bytes = (len as u16).to_be_bytes();
129    let additional_data = [
130        ContentType::ApplicationData as u8,
131        0x03,
132        0x03,
133        len_bytes[0],
134        len_bytes[1],
135    ];
136
137    crypto
138        .encrypt_in_place(&nonce, &additional_data, buf)
139        .map_err(|_| TlsError::InvalidApplicationData)
140}
141
142pub struct Handshake<CipherSuite, Verifier>
143where
144    CipherSuite: TlsCipherSuite,
145{
146    traffic_hash: Option<CipherSuite::Hash>,
147    secret: Option<EphemeralSecret>,
148    certificate_request: Option<CertificateRequest>,
149    verifier: Verifier,
150}
151
152impl<'v, CipherSuite, Verifier> Handshake<CipherSuite, Verifier>
153where
154    CipherSuite: TlsCipherSuite,
155    Verifier: TlsVerifier<'v, CipherSuite>,
156{
157    pub fn new(verifier: Verifier) -> Handshake<CipherSuite, Verifier> {
158        Handshake {
159            traffic_hash: None,
160            secret: None,
161            certificate_request: None,
162            verifier,
163        }
164    }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq)]
168#[cfg_attr(feature = "defmt", derive(defmt::Format))]
169pub enum State {
170    ClientHello,
171    ServerHello,
172    ServerVerify,
173    ClientCert,
174    ClientFinished,
175    ApplicationData,
176}
177
178impl<'a> State {
179    #[allow(clippy::too_many_arguments)]
180    pub async fn process<'v, Transport, CipherSuite, RNG, Verifier>(
181        self,
182        transport: &mut Transport,
183        handshake: &mut Handshake<CipherSuite, Verifier>,
184        record_reader: &mut RecordReader<'_, CipherSuite>,
185        tx_buf: &mut WriteBuffer<'_>,
186        key_schedule: &mut KeySchedule<CipherSuite>,
187        config: &TlsConfig<'a, CipherSuite>,
188        rng: &mut RNG,
189    ) -> Result<State, TlsError>
190    where
191        Transport: AsyncRead + AsyncWrite + 'a,
192        RNG: CryptoRng + RngCore + 'a,
193        CipherSuite: TlsCipherSuite,
194        Verifier: TlsVerifier<'v, CipherSuite>,
195    {
196        match self {
197            State::ClientHello => {
198                let (state, tx) = client_hello(key_schedule, config, rng, tx_buf, handshake)?;
199
200                respond(tx, transport, key_schedule).await?;
201
202                Ok(state)
203            }
204            State::ServerHello => {
205                let record = record_reader
206                    .read(transport, key_schedule.read_state())
207                    .await?;
208
209                let result = process_server_hello(handshake, key_schedule, record);
210
211                handle_processing_error(result, transport, key_schedule, tx_buf).await
212            }
213            State::ServerVerify => {
214                let record = record_reader
215                    .read(transport, key_schedule.read_state())
216                    .await?;
217
218                let result = process_server_verify(handshake, key_schedule, config, record);
219
220                handle_processing_error(result, transport, key_schedule, tx_buf).await
221            }
222            State::ClientCert => {
223                let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?;
224
225                respond(tx, transport, key_schedule).await?;
226
227                Ok(state)
228            }
229            State::ClientFinished => {
230                let tx = client_finished(key_schedule, tx_buf)?;
231
232                respond(tx, transport, key_schedule).await?;
233
234                client_finished_finalize(key_schedule, handshake)
235            }
236            State::ApplicationData => Ok(State::ApplicationData),
237        }
238    }
239
240    #[allow(clippy::too_many_arguments)]
241    pub fn process_blocking<'v, Transport, CipherSuite, RNG, Verifier>(
242        self,
243        transport: &mut Transport,
244        handshake: &mut Handshake<CipherSuite, Verifier>,
245        record_reader: &mut RecordReader<'_, CipherSuite>,
246        tx_buf: &mut WriteBuffer,
247        key_schedule: &mut KeySchedule<CipherSuite>,
248        config: &TlsConfig<'a, CipherSuite>,
249        rng: &mut RNG,
250    ) -> Result<State, TlsError>
251    where
252        Transport: BlockingRead + BlockingWrite + 'a,
253        RNG: CryptoRng + RngCore,
254        CipherSuite: TlsCipherSuite + 'static,
255        Verifier: TlsVerifier<'v, CipherSuite>,
256    {
257        match self {
258            State::ClientHello => {
259                let (state, tx) = client_hello(key_schedule, config, rng, tx_buf, handshake)?;
260
261                respond_blocking(tx, transport, key_schedule)?;
262
263                Ok(state)
264            }
265            State::ServerHello => {
266                let record = record_reader.read_blocking(transport, key_schedule.read_state())?;
267
268                let result = process_server_hello(handshake, key_schedule, record);
269
270                handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
271            }
272            State::ServerVerify => {
273                let record = record_reader.read_blocking(transport, key_schedule.read_state())?;
274
275                let result = process_server_verify(handshake, key_schedule, config, record);
276
277                handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
278            }
279            State::ClientCert => {
280                let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?;
281
282                respond_blocking(tx, transport, key_schedule)?;
283
284                Ok(state)
285            }
286            State::ClientFinished => {
287                let tx = client_finished(key_schedule, tx_buf)?;
288
289                respond_blocking(tx, transport, key_schedule)?;
290
291                client_finished_finalize(key_schedule, handshake)
292            }
293            State::ApplicationData => Ok(State::ApplicationData),
294        }
295    }
296}
297
298fn handle_processing_error_blocking<CipherSuite>(
299    result: Result<State, TlsError>,
300    transport: &mut impl BlockingWrite,
301    key_schedule: &mut KeySchedule<CipherSuite>,
302    tx_buf: &mut WriteBuffer,
303) -> Result<State, TlsError>
304where
305    CipherSuite: TlsCipherSuite,
306{
307    if let Err(TlsError::AbortHandshake(level, description)) = result {
308        let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
309        let tx = tx_buf.write_record(
310            &ClientRecord::Alert(Alert { level, description }, false),
311            write_key_schedule,
312            Some(read_key_schedule),
313        )?;
314
315        respond_blocking(tx, transport, key_schedule)?;
316    }
317
318    result
319}
320
321fn respond_blocking<CipherSuite>(
322    tx: &[u8],
323    transport: &mut impl BlockingWrite,
324    key_schedule: &mut KeySchedule<CipherSuite>,
325) -> Result<(), TlsError>
326where
327    CipherSuite: TlsCipherSuite,
328{
329    transport
330        .write_all(tx)
331        .map_err(|e| TlsError::Io(e.kind()))?;
332
333    key_schedule.write_state().increment_counter();
334
335    transport.flush().map_err(|e| TlsError::Io(e.kind()))?;
336
337    Ok(())
338}
339
340async fn handle_processing_error<'a, CipherSuite>(
341    result: Result<State, TlsError>,
342    transport: &mut impl AsyncWrite,
343    key_schedule: &mut KeySchedule<CipherSuite>,
344    tx_buf: &mut WriteBuffer<'a>,
345) -> Result<State, TlsError>
346where
347    CipherSuite: TlsCipherSuite,
348{
349    if let Err(TlsError::AbortHandshake(level, description)) = result {
350        let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
351        let tx = tx_buf.write_record(
352            &ClientRecord::Alert(Alert { level, description }, false),
353            write_key_schedule,
354            Some(read_key_schedule),
355        )?;
356
357        respond(tx, transport, key_schedule).await?;
358    }
359
360    result
361}
362
363async fn respond<CipherSuite>(
364    tx: &[u8],
365    transport: &mut impl AsyncWrite,
366    key_schedule: &mut KeySchedule<CipherSuite>,
367) -> Result<(), TlsError>
368where
369    CipherSuite: TlsCipherSuite,
370{
371    transport
372        .write_all(tx)
373        .await
374        .map_err(|e| TlsError::Io(e.kind()))?;
375
376    key_schedule.write_state().increment_counter();
377
378    transport
379        .flush()
380        .await
381        .map_err(|e| TlsError::Io(e.kind()))?;
382
383    Ok(())
384}
385
386fn client_hello<'r, CipherSuite, RNG, Verifier>(
387    key_schedule: &mut KeySchedule<CipherSuite>,
388    config: &TlsConfig<CipherSuite>,
389    rng: &mut RNG,
390    tx_buf: &'r mut WriteBuffer,
391    handshake: &mut Handshake<CipherSuite, Verifier>,
392) -> Result<(State, &'r [u8]), TlsError>
393where
394    RNG: CryptoRng + RngCore,
395    CipherSuite: TlsCipherSuite,
396{
397    key_schedule.initialize_early_secret(config.psk.as_ref().map(|p| p.0))?;
398    let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
399    let client_hello = ClientRecord::client_hello(config, rng);
400    let slice = tx_buf.write_record(&client_hello, write_key_schedule, Some(read_key_schedule))?;
401
402    if let ClientRecord::Handshake(ClientHandshake::ClientHello(client_hello), _) = client_hello {
403        handshake.secret.replace(client_hello.secret);
404        Ok((State::ServerHello, slice))
405    } else {
406        Err(TlsError::EncodeError)
407    }
408}
409
410fn process_server_hello<CipherSuite, Verifier>(
411    handshake: &mut Handshake<CipherSuite, Verifier>,
412    key_schedule: &mut KeySchedule<CipherSuite>,
413    record: ServerRecord<'_, CipherSuite>,
414) -> Result<State, TlsError>
415where
416    CipherSuite: TlsCipherSuite,
417{
418    match record {
419        ServerRecord::Handshake(server_handshake) => match server_handshake {
420            ServerHandshake::ServerHello(server_hello) => {
421                trace!("********* ServerHello");
422                let secret = handshake.secret.take().ok_or(TlsError::InvalidHandshake)?;
423                let shared = server_hello
424                    .calculate_shared_secret(&secret)
425                    .ok_or(TlsError::InvalidKeyShare)?;
426                key_schedule.initialize_handshake_secret(shared.raw_secret_bytes())?;
427                Ok(State::ServerVerify)
428            }
429            _ => Err(TlsError::InvalidHandshake),
430        },
431        ServerRecord::Alert(alert) => {
432            Err(TlsError::HandshakeAborted(alert.level, alert.description))
433        }
434        _ => Err(TlsError::InvalidRecord),
435    }
436}
437
438fn process_server_verify<'a, 'v, CipherSuite, Verifier>(
439    handshake: &mut Handshake<CipherSuite, Verifier>,
440    key_schedule: &mut KeySchedule<CipherSuite>,
441    config: &TlsConfig<'a, CipherSuite>,
442    record: ServerRecord<'_, CipherSuite>,
443) -> Result<State, TlsError>
444where
445    CipherSuite: TlsCipherSuite,
446    Verifier: TlsVerifier<'v, CipherSuite>,
447{
448    let mut state = State::ServerVerify;
449    decrypt_record(key_schedule.read_state(), record, |key_schedule, record| {
450        match record {
451            ServerRecord::Handshake(server_handshake) => {
452                match server_handshake {
453                    ServerHandshake::EncryptedExtensions(_) => {}
454                    ServerHandshake::Certificate(certificate) => {
455                        let transcript = key_schedule.transcript_hash();
456                        handshake.verifier.verify_certificate(
457                            transcript,
458                            &config.ca,
459                            certificate,
460                        )?;
461                        debug!("Certificate verified!");
462                    }
463                    ServerHandshake::CertificateVerify(verify) => {
464                        handshake.verifier.verify_signature(verify)?;
465                        debug!("Signature verified!");
466                    }
467                    ServerHandshake::CertificateRequest(request) => {
468                        handshake.certificate_request.replace(request.try_into()?);
469                    }
470                    ServerHandshake::Finished(finished) => {
471                        if !key_schedule.verify_server_finished(&finished)? {
472                            warn!("Server signature verification failed");
473                            return Err(TlsError::InvalidSignature);
474                        }
475
476                        // trace!("server verified {}", verified);
477                        state = if handshake.certificate_request.is_some() {
478                            State::ClientCert
479                        } else {
480                            handshake
481                                .traffic_hash
482                                .replace(key_schedule.transcript_hash().clone());
483                            State::ClientFinished
484                        };
485                    }
486                    _ => return Err(TlsError::InvalidHandshake),
487                }
488            }
489            ServerRecord::ChangeCipherSpec(_) => {}
490            _ => return Err(TlsError::InvalidRecord),
491        }
492
493        Ok(())
494    })?;
495    Ok(state)
496}
497
498fn client_cert<'r, CipherSuite, Verifier>(
499    handshake: &mut Handshake<CipherSuite, Verifier>,
500    key_schedule: &mut KeySchedule<CipherSuite>,
501    config: &TlsConfig<CipherSuite>,
502    buffer: &'r mut WriteBuffer,
503) -> Result<(State, &'r [u8]), TlsError>
504where
505    CipherSuite: TlsCipherSuite,
506{
507    handshake
508        .traffic_hash
509        .replace(key_schedule.transcript_hash().clone());
510
511    let request_context = &handshake
512        .certificate_request
513        .as_ref()
514        .ok_or(TlsError::InvalidHandshake)?
515        .request_context;
516
517    let mut certificate = CertificateRef::with_context(request_context);
518    if let Some(cert) = &config.cert {
519        certificate.add(cert.into())?;
520    }
521    let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
522
523    buffer
524        .write_record(
525            &ClientRecord::Handshake(ClientHandshake::ClientCert(certificate), true),
526            write_key_schedule,
527            Some(read_key_schedule),
528        )
529        .map(|slice| (State::ClientFinished, slice))
530}
531
532fn client_finished<'r, CipherSuite>(
533    key_schedule: &mut KeySchedule<CipherSuite>,
534    buffer: &'r mut WriteBuffer,
535) -> Result<&'r [u8], TlsError>
536where
537    CipherSuite: TlsCipherSuite,
538{
539    let client_finished = key_schedule
540        .create_client_finished()
541        .map_err(|_| TlsError::InvalidHandshake)?;
542
543    let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
544
545    buffer.write_record(
546        &ClientRecord::Handshake(ClientHandshake::Finished(client_finished), true),
547        write_key_schedule,
548        Some(read_key_schedule),
549    )
550}
551
552fn client_finished_finalize<CipherSuite, Verifier>(
553    key_schedule: &mut KeySchedule<CipherSuite>,
554    handshake: &mut Handshake<CipherSuite, Verifier>,
555) -> Result<State, TlsError>
556where
557    CipherSuite: TlsCipherSuite,
558{
559    key_schedule.replace_transcript_hash(
560        handshake
561            .traffic_hash
562            .take()
563            .ok_or(TlsError::InvalidHandshake)?,
564    );
565    key_schedule.initialize_master_secret()?;
566
567    Ok(State::ApplicationData)
568}