1use crate::config::{TlsCipherSuite, TlsConfig, TlsVerifier};
2use crate::handshake::{ClientHandshake, ServerHandshake};
3use crate::key_schedule::{KeySchedule, ReadKeySchedule, WriteKeySchedule};
4use crate::record::{ClientRecord, ServerRecord};
5use crate::record_reader::RecordReader;
6use crate::write_buffer::WriteBuffer;
7use crate::TlsError;
8use crate::{
9 alert::*,
10 handshake::{certificate::CertificateRef, certificate_request::CertificateRequest},
11};
12use core::fmt::Debug;
13use embedded_io::Error as _;
14use embedded_io::{Read as BlockingRead, Write as BlockingWrite};
15use embedded_io_async::{Read as AsyncRead, Write as AsyncWrite};
16use rand_core::{CryptoRng, RngCore};
17
18use crate::application_data::ApplicationData;
19use crate::buffer::CryptoBuffer;
26use digest::generic_array::typenum::Unsigned;
27use p256::ecdh::EphemeralSecret;
28
29use crate::content_types::ContentType;
30use crate::parse_buffer::ParseBuffer;
37use aes_gcm::aead::{AeadCore, AeadInPlace, KeyInit};
38
39pub(crate) fn decrypt_record<CipherSuite>(
40 key_schedule: &mut ReadKeySchedule<CipherSuite>,
41 record: ServerRecord<'_, CipherSuite>,
42 mut cb: impl FnMut(
43 &mut ReadKeySchedule<CipherSuite>,
44 ServerRecord<'_, CipherSuite>,
45 ) -> Result<(), TlsError>,
46) -> Result<(), TlsError>
47where
48 CipherSuite: TlsCipherSuite,
49{
50 if let ServerRecord::ApplicationData(ApplicationData {
51 header,
52 data: mut app_data,
53 }) = record
54 {
55 let server_key = key_schedule.get_key()?;
56 let nonce = key_schedule.get_nonce()?;
57
58 let crypto = <CipherSuite::Cipher as KeyInit>::new(&server_key);
59 crypto
60 .decrypt_in_place(&nonce, header.data(), &mut app_data)
61 .map_err(|_| TlsError::CryptoError)?;
62
63 let padding = app_data
64 .as_slice()
65 .iter()
66 .enumerate()
67 .rfind(|(_, b)| **b != 0);
68 if let Some((index, _)) = padding {
69 app_data.truncate(index + 1);
70 };
71
72 let content_type =
73 ContentType::of(*app_data.as_slice().last().unwrap()).ok_or(TlsError::InvalidRecord)?;
74
75 trace!("Decrypting: content type = {:?}", content_type);
76
77 app_data.truncate(app_data.len() - 1);
79
80 let mut buf = ParseBuffer::new(app_data.as_slice());
81 match content_type {
82 ContentType::Handshake => {
83 while buf.remaining() > 0 {
85 let inner = ServerHandshake::read(&mut buf, key_schedule.transcript_hash())?;
86 cb(key_schedule, ServerRecord::Handshake(inner))?;
87 }
88 }
89 ContentType::ApplicationData => {
90 let inner = ApplicationData::new(app_data, header);
91 cb(key_schedule, ServerRecord::ApplicationData(inner))?;
92 }
93 ContentType::Alert => {
94 let alert = Alert::parse(&mut buf)?;
95 cb(key_schedule, ServerRecord::Alert(alert))?;
96 }
97 _ => return Err(TlsError::Unimplemented),
98 }
99 key_schedule.increment_counter();
100 } else {
101 trace!("Not decrypting: content_type = {:?}", record.content_type());
102 cb(key_schedule, record)?;
103 }
104 Ok(())
105}
106
107pub(crate) fn encrypt<CipherSuite>(
108 key_schedule: &WriteKeySchedule<CipherSuite>,
109 buf: &mut CryptoBuffer<'_>,
110) -> Result<(), TlsError>
111where
112 CipherSuite: TlsCipherSuite,
113{
114 let client_key = key_schedule.get_key()?;
115 let nonce = key_schedule.get_nonce()?;
116 let crypto = <CipherSuite::Cipher as KeyInit>::new(&client_key);
121 let len = buf.len() + <CipherSuite::Cipher as AeadCore>::TagSize::to_usize();
122
123 if len > buf.capacity() {
124 return Err(TlsError::InsufficientSpace);
125 }
126
127 trace!("output size {}", len);
128 let len_bytes = (len as u16).to_be_bytes();
129 let additional_data = [
130 ContentType::ApplicationData as u8,
131 0x03,
132 0x03,
133 len_bytes[0],
134 len_bytes[1],
135 ];
136
137 crypto
138 .encrypt_in_place(&nonce, &additional_data, buf)
139 .map_err(|_| TlsError::InvalidApplicationData)
140}
141
142pub struct Handshake<CipherSuite, Verifier>
143where
144 CipherSuite: TlsCipherSuite,
145{
146 traffic_hash: Option<CipherSuite::Hash>,
147 secret: Option<EphemeralSecret>,
148 certificate_request: Option<CertificateRequest>,
149 verifier: Verifier,
150}
151
152impl<'v, CipherSuite, Verifier> Handshake<CipherSuite, Verifier>
153where
154 CipherSuite: TlsCipherSuite,
155 Verifier: TlsVerifier<'v, CipherSuite>,
156{
157 pub fn new(verifier: Verifier) -> Handshake<CipherSuite, Verifier> {
158 Handshake {
159 traffic_hash: None,
160 secret: None,
161 certificate_request: None,
162 verifier,
163 }
164 }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq)]
168#[cfg_attr(feature = "defmt", derive(defmt::Format))]
169pub enum State {
170 ClientHello,
171 ServerHello,
172 ServerVerify,
173 ClientCert,
174 ClientFinished,
175 ApplicationData,
176}
177
178impl<'a> State {
179 #[allow(clippy::too_many_arguments)]
180 pub async fn process<'v, Transport, CipherSuite, RNG, Verifier>(
181 self,
182 transport: &mut Transport,
183 handshake: &mut Handshake<CipherSuite, Verifier>,
184 record_reader: &mut RecordReader<'_, CipherSuite>,
185 tx_buf: &mut WriteBuffer<'_>,
186 key_schedule: &mut KeySchedule<CipherSuite>,
187 config: &TlsConfig<'a, CipherSuite>,
188 rng: &mut RNG,
189 ) -> Result<State, TlsError>
190 where
191 Transport: AsyncRead + AsyncWrite + 'a,
192 RNG: CryptoRng + RngCore + 'a,
193 CipherSuite: TlsCipherSuite,
194 Verifier: TlsVerifier<'v, CipherSuite>,
195 {
196 match self {
197 State::ClientHello => {
198 let (state, tx) = client_hello(key_schedule, config, rng, tx_buf, handshake)?;
199
200 respond(tx, transport, key_schedule).await?;
201
202 Ok(state)
203 }
204 State::ServerHello => {
205 let record = record_reader
206 .read(transport, key_schedule.read_state())
207 .await?;
208
209 let result = process_server_hello(handshake, key_schedule, record);
210
211 handle_processing_error(result, transport, key_schedule, tx_buf).await
212 }
213 State::ServerVerify => {
214 let record = record_reader
215 .read(transport, key_schedule.read_state())
216 .await?;
217
218 let result = process_server_verify(handshake, key_schedule, config, record);
219
220 handle_processing_error(result, transport, key_schedule, tx_buf).await
221 }
222 State::ClientCert => {
223 let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?;
224
225 respond(tx, transport, key_schedule).await?;
226
227 Ok(state)
228 }
229 State::ClientFinished => {
230 let tx = client_finished(key_schedule, tx_buf)?;
231
232 respond(tx, transport, key_schedule).await?;
233
234 client_finished_finalize(key_schedule, handshake)
235 }
236 State::ApplicationData => Ok(State::ApplicationData),
237 }
238 }
239
240 #[allow(clippy::too_many_arguments)]
241 pub fn process_blocking<'v, Transport, CipherSuite, RNG, Verifier>(
242 self,
243 transport: &mut Transport,
244 handshake: &mut Handshake<CipherSuite, Verifier>,
245 record_reader: &mut RecordReader<'_, CipherSuite>,
246 tx_buf: &mut WriteBuffer,
247 key_schedule: &mut KeySchedule<CipherSuite>,
248 config: &TlsConfig<'a, CipherSuite>,
249 rng: &mut RNG,
250 ) -> Result<State, TlsError>
251 where
252 Transport: BlockingRead + BlockingWrite + 'a,
253 RNG: CryptoRng + RngCore,
254 CipherSuite: TlsCipherSuite + 'static,
255 Verifier: TlsVerifier<'v, CipherSuite>,
256 {
257 match self {
258 State::ClientHello => {
259 let (state, tx) = client_hello(key_schedule, config, rng, tx_buf, handshake)?;
260
261 respond_blocking(tx, transport, key_schedule)?;
262
263 Ok(state)
264 }
265 State::ServerHello => {
266 let record = record_reader.read_blocking(transport, key_schedule.read_state())?;
267
268 let result = process_server_hello(handshake, key_schedule, record);
269
270 handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
271 }
272 State::ServerVerify => {
273 let record = record_reader.read_blocking(transport, key_schedule.read_state())?;
274
275 let result = process_server_verify(handshake, key_schedule, config, record);
276
277 handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
278 }
279 State::ClientCert => {
280 let (state, tx) = client_cert(handshake, key_schedule, config, tx_buf)?;
281
282 respond_blocking(tx, transport, key_schedule)?;
283
284 Ok(state)
285 }
286 State::ClientFinished => {
287 let tx = client_finished(key_schedule, tx_buf)?;
288
289 respond_blocking(tx, transport, key_schedule)?;
290
291 client_finished_finalize(key_schedule, handshake)
292 }
293 State::ApplicationData => Ok(State::ApplicationData),
294 }
295 }
296}
297
298fn handle_processing_error_blocking<CipherSuite>(
299 result: Result<State, TlsError>,
300 transport: &mut impl BlockingWrite,
301 key_schedule: &mut KeySchedule<CipherSuite>,
302 tx_buf: &mut WriteBuffer,
303) -> Result<State, TlsError>
304where
305 CipherSuite: TlsCipherSuite,
306{
307 if let Err(TlsError::AbortHandshake(level, description)) = result {
308 let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
309 let tx = tx_buf.write_record(
310 &ClientRecord::Alert(Alert { level, description }, false),
311 write_key_schedule,
312 Some(read_key_schedule),
313 )?;
314
315 respond_blocking(tx, transport, key_schedule)?;
316 }
317
318 result
319}
320
321fn respond_blocking<CipherSuite>(
322 tx: &[u8],
323 transport: &mut impl BlockingWrite,
324 key_schedule: &mut KeySchedule<CipherSuite>,
325) -> Result<(), TlsError>
326where
327 CipherSuite: TlsCipherSuite,
328{
329 transport
330 .write_all(tx)
331 .map_err(|e| TlsError::Io(e.kind()))?;
332
333 key_schedule.write_state().increment_counter();
334
335 transport.flush().map_err(|e| TlsError::Io(e.kind()))?;
336
337 Ok(())
338}
339
340async fn handle_processing_error<'a, CipherSuite>(
341 result: Result<State, TlsError>,
342 transport: &mut impl AsyncWrite,
343 key_schedule: &mut KeySchedule<CipherSuite>,
344 tx_buf: &mut WriteBuffer<'a>,
345) -> Result<State, TlsError>
346where
347 CipherSuite: TlsCipherSuite,
348{
349 if let Err(TlsError::AbortHandshake(level, description)) = result {
350 let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
351 let tx = tx_buf.write_record(
352 &ClientRecord::Alert(Alert { level, description }, false),
353 write_key_schedule,
354 Some(read_key_schedule),
355 )?;
356
357 respond(tx, transport, key_schedule).await?;
358 }
359
360 result
361}
362
363async fn respond<CipherSuite>(
364 tx: &[u8],
365 transport: &mut impl AsyncWrite,
366 key_schedule: &mut KeySchedule<CipherSuite>,
367) -> Result<(), TlsError>
368where
369 CipherSuite: TlsCipherSuite,
370{
371 transport
372 .write_all(tx)
373 .await
374 .map_err(|e| TlsError::Io(e.kind()))?;
375
376 key_schedule.write_state().increment_counter();
377
378 transport
379 .flush()
380 .await
381 .map_err(|e| TlsError::Io(e.kind()))?;
382
383 Ok(())
384}
385
386fn client_hello<'r, CipherSuite, RNG, Verifier>(
387 key_schedule: &mut KeySchedule<CipherSuite>,
388 config: &TlsConfig<CipherSuite>,
389 rng: &mut RNG,
390 tx_buf: &'r mut WriteBuffer,
391 handshake: &mut Handshake<CipherSuite, Verifier>,
392) -> Result<(State, &'r [u8]), TlsError>
393where
394 RNG: CryptoRng + RngCore,
395 CipherSuite: TlsCipherSuite,
396{
397 key_schedule.initialize_early_secret(config.psk.as_ref().map(|p| p.0))?;
398 let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
399 let client_hello = ClientRecord::client_hello(config, rng);
400 let slice = tx_buf.write_record(&client_hello, write_key_schedule, Some(read_key_schedule))?;
401
402 if let ClientRecord::Handshake(ClientHandshake::ClientHello(client_hello), _) = client_hello {
403 handshake.secret.replace(client_hello.secret);
404 Ok((State::ServerHello, slice))
405 } else {
406 Err(TlsError::EncodeError)
407 }
408}
409
410fn process_server_hello<CipherSuite, Verifier>(
411 handshake: &mut Handshake<CipherSuite, Verifier>,
412 key_schedule: &mut KeySchedule<CipherSuite>,
413 record: ServerRecord<'_, CipherSuite>,
414) -> Result<State, TlsError>
415where
416 CipherSuite: TlsCipherSuite,
417{
418 match record {
419 ServerRecord::Handshake(server_handshake) => match server_handshake {
420 ServerHandshake::ServerHello(server_hello) => {
421 trace!("********* ServerHello");
422 let secret = handshake.secret.take().ok_or(TlsError::InvalidHandshake)?;
423 let shared = server_hello
424 .calculate_shared_secret(&secret)
425 .ok_or(TlsError::InvalidKeyShare)?;
426 key_schedule.initialize_handshake_secret(shared.raw_secret_bytes())?;
427 Ok(State::ServerVerify)
428 }
429 _ => Err(TlsError::InvalidHandshake),
430 },
431 ServerRecord::Alert(alert) => {
432 Err(TlsError::HandshakeAborted(alert.level, alert.description))
433 }
434 _ => Err(TlsError::InvalidRecord),
435 }
436}
437
438fn process_server_verify<'a, 'v, CipherSuite, Verifier>(
439 handshake: &mut Handshake<CipherSuite, Verifier>,
440 key_schedule: &mut KeySchedule<CipherSuite>,
441 config: &TlsConfig<'a, CipherSuite>,
442 record: ServerRecord<'_, CipherSuite>,
443) -> Result<State, TlsError>
444where
445 CipherSuite: TlsCipherSuite,
446 Verifier: TlsVerifier<'v, CipherSuite>,
447{
448 let mut state = State::ServerVerify;
449 decrypt_record(key_schedule.read_state(), record, |key_schedule, record| {
450 match record {
451 ServerRecord::Handshake(server_handshake) => {
452 match server_handshake {
453 ServerHandshake::EncryptedExtensions(_) => {}
454 ServerHandshake::Certificate(certificate) => {
455 let transcript = key_schedule.transcript_hash();
456 handshake.verifier.verify_certificate(
457 transcript,
458 &config.ca,
459 certificate,
460 )?;
461 debug!("Certificate verified!");
462 }
463 ServerHandshake::CertificateVerify(verify) => {
464 handshake.verifier.verify_signature(verify)?;
465 debug!("Signature verified!");
466 }
467 ServerHandshake::CertificateRequest(request) => {
468 handshake.certificate_request.replace(request.try_into()?);
469 }
470 ServerHandshake::Finished(finished) => {
471 if !key_schedule.verify_server_finished(&finished)? {
472 warn!("Server signature verification failed");
473 return Err(TlsError::InvalidSignature);
474 }
475
476 state = if handshake.certificate_request.is_some() {
478 State::ClientCert
479 } else {
480 handshake
481 .traffic_hash
482 .replace(key_schedule.transcript_hash().clone());
483 State::ClientFinished
484 };
485 }
486 _ => return Err(TlsError::InvalidHandshake),
487 }
488 }
489 ServerRecord::ChangeCipherSpec(_) => {}
490 _ => return Err(TlsError::InvalidRecord),
491 }
492
493 Ok(())
494 })?;
495 Ok(state)
496}
497
498fn client_cert<'r, CipherSuite, Verifier>(
499 handshake: &mut Handshake<CipherSuite, Verifier>,
500 key_schedule: &mut KeySchedule<CipherSuite>,
501 config: &TlsConfig<CipherSuite>,
502 buffer: &'r mut WriteBuffer,
503) -> Result<(State, &'r [u8]), TlsError>
504where
505 CipherSuite: TlsCipherSuite,
506{
507 handshake
508 .traffic_hash
509 .replace(key_schedule.transcript_hash().clone());
510
511 let request_context = &handshake
512 .certificate_request
513 .as_ref()
514 .ok_or(TlsError::InvalidHandshake)?
515 .request_context;
516
517 let mut certificate = CertificateRef::with_context(request_context);
518 if let Some(cert) = &config.cert {
519 certificate.add(cert.into())?;
520 }
521 let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
522
523 buffer
524 .write_record(
525 &ClientRecord::Handshake(ClientHandshake::ClientCert(certificate), true),
526 write_key_schedule,
527 Some(read_key_schedule),
528 )
529 .map(|slice| (State::ClientFinished, slice))
530}
531
532fn client_finished<'r, CipherSuite>(
533 key_schedule: &mut KeySchedule<CipherSuite>,
534 buffer: &'r mut WriteBuffer,
535) -> Result<&'r [u8], TlsError>
536where
537 CipherSuite: TlsCipherSuite,
538{
539 let client_finished = key_schedule
540 .create_client_finished()
541 .map_err(|_| TlsError::InvalidHandshake)?;
542
543 let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
544
545 buffer.write_record(
546 &ClientRecord::Handshake(ClientHandshake::Finished(client_finished), true),
547 write_key_schedule,
548 Some(read_key_schedule),
549 )
550}
551
552fn client_finished_finalize<CipherSuite, Verifier>(
553 key_schedule: &mut KeySchedule<CipherSuite>,
554 handshake: &mut Handshake<CipherSuite, Verifier>,
555) -> Result<State, TlsError>
556where
557 CipherSuite: TlsCipherSuite,
558{
559 key_schedule.replace_transcript_hash(
560 handshake
561 .traffic_hash
562 .take()
563 .ok_or(TlsError::InvalidHandshake)?,
564 );
565 key_schedule.initialize_master_secret()?;
566
567 Ok(State::ApplicationData)
568}