1use core::ffi::c_void;
10use core::{cmp, ffi::CStr};
11
12extern crate alloc;
13use alloc::ffi::CString;
14use alloc::format;
15use alloc::string::ToString;
16use alloc::vec;
17use alloc::vec::Vec;
18
19use crate::{
20 ErrorKind, OrtResult, Read, Write, common::utils::to_ascii, libc, ort_error, ort_from_err,
21};
22
23mod aead;
24mod ecdh;
25mod hkdf;
26mod hmac;
27mod sha2;
28
29const REC_TYPE_CHANGE_CIPHER_SPEC: u8 = 20; const REC_TYPE_ALERT: u8 = 21; const REC_TYPE_HANDSHAKE: u8 = 22; const REC_TYPE_APPDATA: u8 = 23; const LEGACY_REC_VER: u16 = 0x0303;
36
37const HS_CLIENT_HELLO: u8 = 1;
38const HS_SERVER_HELLO: u8 = 2;
39const HS_FINISHED: u8 = 20; const CIPHER_TLS_AES_128_GCM_SHA256: u16 = 0x1301;
47const TLS13: u16 = 0x0304;
49const GROUP_X25519: u16 = 0x001d;
51
52const EXT_SERVER_NAME: u16 = 0x0000;
54const EXT_SUPPORTED_GROUPS: u16 = 0x000a;
55const EXT_SIGNATURE_ALGS: u16 = 0x000d;
56const EXT_SUPPORTED_VERSIONS: u16 = 0x002b;
58const EXT_KEY_SHARE: u16 = 0x0033;
60
61const AEAD_TAG_LEN: usize = 16;
63
64fn put_u16(buf: &mut Vec<u8>, v: u16) {
66 buf.extend_from_slice(&v.to_be_bytes());
67}
68fn put_u24(buf: &mut Vec<u8>, v: usize) {
69 let v = v as u32;
70 buf.extend_from_slice(&[(v >> 16) as u8, (v >> 8) as u8, v as u8]);
71}
72
73fn hkdf_expand_label<const N: usize>(prk: &[u8], label: &str, data: &[u8]) -> [u8; N] {
74 let mut info = Vec::with_capacity(2 + 1 + 6 + label.len() + 1 + data.len());
75 put_u16(&mut info, N as u16);
76 let full_label = "tls13 ".to_string() + label;
77 info.push(full_label.len() as u8);
78 info.extend_from_slice(full_label.as_bytes());
79 info.push(data.len() as u8);
80 info.extend_from_slice(data);
81
82 hkdf::hkdf_expand(prk, &info, N).try_into().unwrap()
83}
84
85fn digest_bytes(data: &[u8]) -> [u8; 32] {
86 let d = sha2::sha256(data);
87 let mut out = [0u8; 32];
88 out.copy_from_slice(d.as_ref());
89 out
90}
91
92fn nonce_xor(iv12: &[u8; 12], seq: u64) -> [u8; 12] {
94 let mut nonce_bytes = [[0, 0, 0, 0].as_ref(), &u64::to_be_bytes(seq)].concat();
96 nonce_bytes.iter_mut().zip(iv12.iter()).for_each(|(s, iv)| {
98 *s ^= *iv;
99 });
100 nonce_bytes[..12].try_into().unwrap()
101}
102
103pub struct TlsStream<T: Read + Write> {
105 io: T,
106 aead_enc: [u8; 16],
108 aead_dec: [u8; 16],
109 iv_enc: [u8; 12],
110 iv_dec: [u8; 12],
111 seq_enc: u64,
112 seq_dec: u64,
113 rbuf: Vec<u8>,
115 rpos: usize,
116}
117
118fn client_hello_body(sni_host: &str, client_pub: &[u8]) -> Vec<u8> {
119 let mut ch_body = Vec::with_capacity(512);
120
121 let mut random = [0u8; 32];
123 let got_bytes = unsafe { libc::getrandom(random.as_mut_ptr() as *mut c_void, 32, 0) };
124 debug_assert_eq!(got_bytes, 32);
125
126 let mut session_id = [0u8; 32];
127 let got_bytes = unsafe { libc::getrandom(session_id.as_mut_ptr() as *mut c_void, 32, 0) };
128 debug_assert_eq!(got_bytes, 32);
129
130 ch_body.extend_from_slice(&0x0303u16.to_be_bytes());
132 ch_body.extend_from_slice(&random);
134 ch_body.push(session_id.len() as u8);
136 ch_body.extend_from_slice(&session_id);
137 put_u16(&mut ch_body, 2);
139 put_u16(&mut ch_body, CIPHER_TLS_AES_128_GCM_SHA256);
140 ch_body.push(1);
142 ch_body.push(0);
143
144 let mut exts = Vec::with_capacity(512);
146
147 {
149 let host_bytes = sni_host.as_bytes();
150 let mut snl = Vec::with_capacity(3 + host_bytes.len());
151 snl.push(0); put_u16(&mut snl, host_bytes.len() as u16);
153 snl.extend_from_slice(host_bytes);
154
155 let mut sni = Vec::with_capacity(2 + snl.len());
156 put_u16(&mut sni, snl.len() as u16);
157 sni.extend_from_slice(&snl);
158
159 put_u16(&mut exts, EXT_SERVER_NAME);
160 put_u16(&mut exts, sni.len() as u16);
161 exts.extend_from_slice(&sni);
162 }
163
164 {
166 let mut sv = Vec::with_capacity(3);
167 sv.push(2); sv.extend_from_slice(&TLS13.to_be_bytes());
169 put_u16(&mut exts, EXT_SUPPORTED_VERSIONS);
170 put_u16(&mut exts, sv.len() as u16);
171 exts.extend_from_slice(&sv);
172 }
173
174 {
176 let mut sg = Vec::with_capacity(2 + 2);
177 put_u16(&mut sg, 2);
178 put_u16(&mut sg, GROUP_X25519);
179 put_u16(&mut exts, EXT_SUPPORTED_GROUPS);
180 put_u16(&mut exts, sg.len() as u16);
181 exts.extend_from_slice(&sg);
182 }
183
184 {
186 const ECDSA_SECP256R1_SHA256: u16 = 0x0403;
187 const RSA_PSS_RSAE_SHA256: u16 = 0x0804;
188 const RSA_PKCS1_SHA256: u16 = 0x0401;
189
190 let mut sa = Vec::with_capacity(2 + 6);
191 put_u16(&mut sa, 6);
192 put_u16(&mut sa, ECDSA_SECP256R1_SHA256);
193 put_u16(&mut sa, RSA_PSS_RSAE_SHA256);
194 put_u16(&mut sa, RSA_PKCS1_SHA256);
195
196 put_u16(&mut exts, EXT_SIGNATURE_ALGS);
197 put_u16(&mut exts, sa.len() as u16);
198 exts.extend_from_slice(&sa);
199 }
200
201 {
203 let mut ks = Vec::with_capacity(2 + 2 + 2 + 32);
204 let mut entry = Vec::with_capacity(2 + 2 + 32);
206 put_u16(&mut entry, GROUP_X25519);
207 put_u16(&mut entry, 32);
208 entry.extend_from_slice(client_pub);
209 put_u16(&mut ks, entry.len() as u16);
210 ks.extend_from_slice(&entry);
211
212 put_u16(&mut exts, EXT_KEY_SHARE);
213 put_u16(&mut exts, ks.len() as u16);
214 exts.extend_from_slice(&ks);
215 }
216
217 put_u16(&mut ch_body, exts.len() as u16);
219 ch_body.extend_from_slice(&exts);
220
221 ch_body
222}
223
224fn client_hello_msg(sni_host: &str, client_private_key: &[u8]) -> OrtResult<Vec<u8>> {
226 let client_pub_key = ecdh::x25519_public_key(client_private_key);
227 let client_pub_ref = &client_pub_key;
228 debug_print("Client public key", client_pub_ref);
229
230 let ch_body = client_hello_body(sni_host, client_pub_ref);
231
232 let mut ch_msg = Vec::with_capacity(4 + ch_body.len());
234 ch_msg.push(HS_CLIENT_HELLO);
235 put_u24(&mut ch_msg, ch_body.len());
236 ch_msg.extend_from_slice(&ch_body);
237
238 Ok(ch_msg)
239}
240
241fn read_server_hello<R: Read>(io: &mut R) -> OrtResult<(Vec<u8>, Vec<u8>)> {
243 let (typ, payload) = read_record_plain(io)
244 .map_err(|e| ort_from_err(ErrorKind::TlsRecordTooShort, "read_record_plain", e))?;
245 if typ != REC_TYPE_HANDSHAKE {
246 return Err(ort_error(ErrorKind::TlsExpectedHandshakeRecord, ""));
247 }
248 let sh_buf = payload;
249
250 let mut rd = &sh_buf[..];
252 let (sh_typ, sh_body, sh_full) = read_handshake_message(&mut rd).map_err(|e| {
253 ort_from_err(
254 ErrorKind::TlsHandshakeHeaderTooShort,
255 "read_handshake_message",
256 e,
257 )
258 })?;
259 if sh_typ != HS_SERVER_HELLO {
260 return Err(ort_error(ErrorKind::TlsExpectedServerHello, ""));
261 }
262
263 Ok((sh_body.to_vec(), sh_full.to_vec()))
265}
266
267struct HandshakeState {
268 handshake_secret: [u8; 32],
269 client_hs_ts: [u8; 32],
270 server_hs_ts: [u8; 32],
271 client_handshake_iv: [u8; 12],
272 server_handshake_iv: [u8; 12],
273 aead_enc_hs: [u8; 16],
274 aead_dec_hs: [u8; 16],
275 empty_hash: [u8; 32],
276}
277
278struct ApplicationKeys {
279 aead_app_enc: [u8; 16],
280 aead_app_dec: [u8; 16],
281 iv_enc: [u8; 12],
282 iv_dec: [u8; 12],
283}
284
285impl<T: Read + Write> TlsStream<T> {
286 pub fn connect(mut io: T, sni_host: &str) -> OrtResult<Self> {
287 let mut transcript = Vec::with_capacity(1024);
289
290 let mut client_private_key = [0u8; 32];
292 let _ = unsafe { libc::getrandom(client_private_key.as_mut_ptr() as *mut c_void, 32, 0) };
293 debug_print("Client private key", &client_private_key);
294
295 Self::send_client_hello(&mut io, sni_host, &mut transcript, &client_private_key)?;
296
297 let sh_body = Self::receive_server_hello(&mut io, &mut transcript)?;
298
299 let handshake = Self::derive_handshake_keys(&client_private_key, &sh_body, &transcript)?;
300
301 Self::receive_dummy_change_cipher_spec(&mut io)?;
302
303 let mut seq_dec_hs = 0u64;
304 let mut seq_enc_hs = 0u64;
305
306 Self::receive_server_encrypted_flight(
307 &mut io,
308 &mut seq_dec_hs,
309 &handshake,
310 &mut transcript,
311 )?;
312
313 let ApplicationKeys {
314 aead_app_enc,
315 aead_app_dec,
316 iv_enc: caiv,
317 iv_dec: saiv,
318 } = Self::derive_application_keys(
319 &handshake.handshake_secret,
320 &handshake.empty_hash,
321 &transcript,
322 );
323
324 let seq_app_enc = 0u64;
325 let seq_app_dec = 0u64;
326
327 Self::send_client_finished(&mut io, &handshake, &mut transcript, &mut seq_enc_hs)?;
332
333 Ok(TlsStream {
334 io,
335 aead_enc: aead_app_enc,
336 aead_dec: aead_app_dec,
337 iv_enc: caiv,
338 iv_dec: saiv,
339 seq_enc: seq_app_enc,
340 seq_dec: seq_app_dec,
341 rbuf: Vec::with_capacity(16 * 1024),
342 rpos: 0,
343 })
344 }
345
346 fn send_client_hello<W: Write>(
347 io: &mut W,
348 sni_host: &str,
349 transcript: &mut Vec<u8>,
350 client_private_key: &[u8; 32],
351 ) -> OrtResult<()> {
352 let ch_msg = client_hello_msg(sni_host, client_private_key)?;
353 write_record_plain(io, REC_TYPE_HANDSHAKE, &ch_msg)
354 .map_err(|e| ort_from_err(ErrorKind::SocketWriteFailed, "write ClientHello", e))?;
355 transcript.extend_from_slice(&ch_msg);
356 Ok(())
357 }
358
359 fn receive_server_hello<R: Read>(io: &mut R, transcript: &mut Vec<u8>) -> OrtResult<Vec<u8>> {
360 let (sh_body, sh_full) = read_server_hello(io)?;
361 transcript.extend_from_slice(&sh_full);
362 Ok(sh_body)
363 }
364
365 fn receive_dummy_change_cipher_spec<R: Read>(io: &mut R) -> OrtResult<()> {
366 let (typ, _) = read_record_plain(io)
368 .map_err(|e| ort_from_err(ErrorKind::TlsRecordTooShort, "read_record_plain", e))?;
369 if typ != REC_TYPE_CHANGE_CIPHER_SPEC {
370 return Err(ort_error(ErrorKind::TlsExpectedChangeCipherSpec, ""));
371 }
372 Ok(())
373 }
374
375 fn receive_server_encrypted_flight<R: Read>(
376 io: &mut R,
377 seq_dec_hs: &mut u64,
378 handshake: &HandshakeState,
379 transcript: &mut Vec<u8>,
380 ) -> OrtResult<()> {
381 let (typ, ct, _inner_type) = read_record_cipher(
382 io,
383 &handshake.aead_dec_hs,
384 &handshake.server_handshake_iv,
385 seq_dec_hs,
386 )
387 .map_err(|e| ort_from_err(ErrorKind::TlsRecordTooShort, "read_record_cipher", e))?;
388 if typ != REC_TYPE_APPDATA {
389 return Err(ort_error(ErrorKind::TlsExpectedEncryptedRecords, ""));
390 }
391
392 let mut p = &ct[..];
395 while !p.is_empty() {
396 let (mtyp, body, full) = match read_handshake_message(&mut p) {
399 Ok(x) => x,
400 Err(_) => {
401 return Err(ort_error(ErrorKind::TlsBadHandshakeFragment, ""));
402 }
403 };
404 transcript.extend_from_slice(full);
405
406 if mtyp == HS_FINISHED {
407 let s_finished_key =
409 hkdf_expand_label::<32>(&handshake.server_hs_ts, "finished", &[]);
410
411 let thash = digest_bytes(&transcript[..transcript.len() - full.len()]);
412 let expected = hmac::sign(&s_finished_key, &thash);
413 if expected.as_slice() != body {
414 return Err(ort_error(ErrorKind::TlsFinishedVerifyFailed, ""));
415 }
416 break;
418 }
419 }
421 Ok(())
422 }
423
424 fn derive_handshake_keys(
425 client_private_key: &[u8; 32],
426 sh_body: &[u8],
427 transcript: &[u8],
428 ) -> OrtResult<HandshakeState> {
429 let (cipher, server_public_key_bytes) =
431 parse_server_hello_for_keys(sh_body).map_err(|e| {
432 ort_from_err(
433 ErrorKind::TlsServerHelloTooShort,
434 "parse_server_hello_for_keys",
435 e,
436 )
437 })?;
438 debug_print("Server public key", &server_public_key_bytes);
439 if cipher != CIPHER_TLS_AES_128_GCM_SHA256 {
440 return Err(ort_error(
441 ErrorKind::TlsUnsupportedCipher,
442 "server picked unsupported cipher",
443 ));
444 }
445
446 let hs_shared_secret = ecdh::x25519_agreement(client_private_key, &server_public_key_bytes);
448 debug_print("hs shared secret", &hs_shared_secret);
449
450 let empty_hash = digest_bytes(&[]);
452 debug_print("empty_hash", &empty_hash);
453
454 let zero: [u8; 32] = [0u8; 32];
455 let early_secret = hkdf::hkdf_extract(&zero, &zero);
456
457 let derived_secret_bytes = hkdf_expand_label::<32>(&early_secret, "derived", &empty_hash);
458 debug_print("derived", &derived_secret_bytes);
459
460 let handshake_secret = hkdf::hkdf_extract(&derived_secret_bytes, &hs_shared_secret);
461 debug_print("handshake_secret", &handshake_secret);
462
463 let ch_sh_hash = digest_bytes(transcript);
464 debug_print("digest bytes", &ch_sh_hash);
465
466 let c_hs_ts = hkdf_expand_label(&handshake_secret, "c hs traffic", &ch_sh_hash);
467 let s_hs_ts = hkdf_expand_label(&handshake_secret, "s hs traffic", &ch_sh_hash);
468
469 debug_print("c hs traffic", &c_hs_ts);
470 debug_print("s hs traffic", &s_hs_ts);
471
472 let client_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&c_hs_ts, "key", &[])
474 .as_slice()[..16]
475 .try_into()
476 .unwrap();
477 debug_print("client_handshake_key", &client_handshake_key);
478 let client_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&c_hs_ts, "iv", &[]).as_slice()
479 [..12]
480 .try_into()
481 .unwrap();
482 debug_print("client_handshake_iv", &client_handshake_iv);
483
484 let server_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&s_hs_ts, "key", &[])
485 .as_slice()[..16]
486 .try_into()
487 .unwrap();
488 debug_print("server_handshake_key", &server_handshake_key);
489 let server_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&s_hs_ts, "iv", &[]).as_slice()
490 [..12]
491 .try_into()
492 .unwrap();
493 debug_print("server_handshake_iv", &server_handshake_iv);
494
495 Ok(HandshakeState {
496 handshake_secret,
497 client_hs_ts: c_hs_ts,
498 server_hs_ts: s_hs_ts,
499 client_handshake_iv,
500 server_handshake_iv,
501 aead_enc_hs: client_handshake_key,
502 aead_dec_hs: server_handshake_key,
503 empty_hash,
504 })
505 }
506
507 fn derive_application_keys(
508 handshake_secret: &[u8; 32],
509 empty_hash: &[u8; 32],
510 transcript: &[u8],
511 ) -> ApplicationKeys {
512 let derived2_bytes = hkdf_expand_label::<32>(handshake_secret, "derived", empty_hash);
513 debug_print("derived2_bytes", &derived2_bytes);
514
515 let zero: [u8; 32] = [0u8; 32];
516 let master_secret = hkdf::hkdf_extract(&derived2_bytes, &zero);
517 let thash_srv_fin = digest_bytes(transcript);
518
519 let c_ap_ts = hkdf_expand_label::<32>(&master_secret, "c ap traffic", &thash_srv_fin);
520 let s_ap_ts = hkdf_expand_label::<32>(&master_secret, "s ap traffic", &thash_srv_fin);
521 debug_print("c_ap_ts", &c_ap_ts);
522 debug_print("s_ap_ts", &s_ap_ts);
523
524 let cak: [u8; 16] = hkdf_expand_label::<16>(&c_ap_ts, "key", &[]).as_slice()[..16]
525 .try_into()
526 .unwrap();
527 let caiv: [u8; 12] = hkdf_expand_label::<12>(&c_ap_ts, "iv", &[]).as_slice()[..12]
528 .try_into()
529 .unwrap();
530 debug_print("cak", &cak);
531 debug_print("caiv", &caiv);
532
533 let sak: [u8; 16] = hkdf_expand_label::<16>(&s_ap_ts, "key", &[]).as_slice()[..16]
534 .try_into()
535 .unwrap();
536 let saiv: [u8; 12] = hkdf_expand_label::<12>(&s_ap_ts, "iv", &[]).as_slice()[..12]
537 .try_into()
538 .unwrap();
539 debug_print("sak", &sak);
540 debug_print("saiv", &saiv);
541
542 ApplicationKeys {
543 aead_app_enc: cak,
544 aead_app_dec: sak,
545 iv_enc: caiv,
546 iv_dec: saiv,
547 }
548 }
549
550 fn send_client_finished<W: Write>(
551 io: &mut W,
552 handshake: &HandshakeState,
553 transcript: &mut Vec<u8>,
554 seq_enc_hs: &mut u64,
555 ) -> OrtResult<()> {
556 let c_finished_key = hkdf_expand_label::<32>(&handshake.client_hs_ts, "finished", &[]);
557 debug_print("c_finished", &c_finished_key);
558
559 let thash_client_fin = digest_bytes(transcript.as_slice());
560 let verify_data = hmac::sign(&c_finished_key, &thash_client_fin);
561 debug_print("verify_data", &verify_data);
562
563 let mut fin = Vec::with_capacity(4 + verify_data.as_ref().len());
564 fin.push(HS_FINISHED);
565 put_u24(&mut fin, verify_data.as_ref().len());
566 fin.extend_from_slice(verify_data.as_ref());
567
568 transcript.extend_from_slice(&fin);
570
571 write_record_cipher(
572 io,
573 REC_TYPE_HANDSHAKE,
574 &fin,
575 &handshake.aead_enc_hs,
576 &handshake.client_handshake_iv,
577 seq_enc_hs,
578 )
579 .map_err(|e| ort_from_err(ErrorKind::TlsRecordTooShort, "read_record_cipher", e))?;
580
581 Ok(())
582 }
583}
584
585impl<T: Read + Write> Write for TlsStream<T> {
586 fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
587 write_record_cipher(
588 &mut self.io,
589 REC_TYPE_APPDATA,
590 buf,
591 &self.aead_enc,
592 &self.iv_enc,
593 &mut self.seq_enc,
594 )
595 .map(|_| buf.len())
596 }
597 fn flush(&mut self) -> OrtResult<()> {
598 self.io.flush()
599 }
600}
601
602impl<T: Read + Write> Read for TlsStream<T> {
603 fn read(&mut self, out: &mut [u8]) -> OrtResult<usize> {
604 if self.rpos < self.rbuf.len() {
605 let n = cmp::min(out.len(), self.rbuf.len() - self.rpos);
606 out[..n].copy_from_slice(&self.rbuf[self.rpos..self.rpos + n]);
607 self.rpos += n;
608 if self.rpos == self.rbuf.len() {
609 self.rbuf.clear();
610 self.rpos = 0;
611 }
612 return Ok(n);
613 }
614 loop {
615 let (typ, plaintext, inner_type) = read_record_cipher(
616 &mut self.io,
617 &self.aead_dec,
618 &self.iv_dec,
619 &mut self.seq_dec,
620 )?;
621 if typ != REC_TYPE_APPDATA {
622 continue;
624 }
625 if plaintext.is_empty() {
627 continue;
628 }
629 if inner_type == REC_TYPE_HANDSHAKE {
630 continue;
632 }
633 if inner_type == REC_TYPE_ALERT {
634 let level = match plaintext[0] {
635 1 => "warning",
636 2 => "fatal",
637 _ => "unknown",
638 };
639 let err_level = CString::new(level.to_string() + " alert: ").unwrap();
640
641 let mut err_code_buf: [u8; 5] = [0u8; 5];
644 let len = to_ascii(plaintext[1] as usize, &mut err_code_buf);
645 let err_code = unsafe { CStr::from_bytes_with_nul_unchecked(&err_code_buf[..len]) };
646 unsafe {
647 libc::write(2, err_level.as_ptr().cast(), err_level.count_bytes());
648 libc::write(2, err_code.as_ptr().cast(), err_code.count_bytes());
649 }
650
651 return Err(ort_error(ErrorKind::TlsAlertReceived, ""));
652 }
653 if inner_type != REC_TYPE_APPDATA {
654 }
657 if plaintext.is_empty() {
658 continue;
659 }
660
661 self.rbuf.extend_from_slice(&plaintext);
662 self.rpos = 0;
663 let n = cmp::min(out.len(), self.rbuf.len());
665 out[..n].copy_from_slice(&self.rbuf[..n]);
666 self.rpos = n;
667 if n == self.rbuf.len() {
668 self.rbuf.clear();
669 self.rpos = 0;
670 }
671 return Ok(n);
672 }
673 }
674}
675
676fn write_record_plain<W: Write>(w: &mut W, typ: u8, body: &[u8]) -> OrtResult<()> {
679 let mut hdr = [0u8; 5];
680 hdr[0] = typ;
681 hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
682 hdr[3..5].copy_from_slice(&(body.len() as u16).to_be_bytes());
683 w.write_all(&hdr)?;
684 w.write_all(body)?;
685 Ok(())
686}
687
688fn read_exact_n<R: Read>(r: &mut R, n: usize) -> OrtResult<Vec<u8>> {
689 let mut buf = vec![0u8; n];
690 r.read_exact(&mut buf)?;
691 Ok(buf)
692}
693
694fn read_record_plain<R: Read>(r: &mut R) -> OrtResult<(u8, Vec<u8>)> {
695 let hdr = read_exact_n(r, 5)?; let typ = hdr[0];
697 let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
698 let body = read_exact_n(r, len)?;
699 Ok((typ, body))
701}
702
703fn write_record_cipher<W: Write>(
704 w: &mut W,
705 outer_type: u8,
706 inner: &[u8],
707 key: &[u8; 16],
708 iv12: &[u8; 12],
709 seq: &mut u64,
710) -> OrtResult<()> {
711 let total_len = inner.len() + 1 + AEAD_TAG_LEN;
713 let mut plain = Vec::with_capacity(total_len);
714 plain.extend_from_slice(inner);
715 plain.push(outer_type);
716
717 debug_print("write_record_cipher plaintext", &plain);
718
719 let nonce = nonce_xor(iv12, *seq);
720 *seq = seq.wrapping_add(1);
721
722 let mut hdr = [0u8; 5];
723 hdr[0] = REC_TYPE_APPDATA;
724 hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
725 hdr[3..5].copy_from_slice(&(total_len as u16).to_be_bytes());
726
727 let out = aead::aes_128_gcm_encrypt(key, &nonce, &hdr, &plain).unwrap();
728
729 debug_print("write_record_cipher header", &hdr);
730 let final_label = format!("write_record_cipher final {total_len}");
731 debug_print(final_label.as_str(), &out);
732
733 w.write_all(&hdr)?;
734 w.write_all(&out)?;
735 Ok(())
736}
737
738fn read_record_cipher<R: Read>(
739 r: &mut R,
740 key: &[u8; 16],
741 iv12: &[u8; 12],
742 seq: &mut u64,
743) -> OrtResult<(u8, Vec<u8>, u8)> {
744 let hdr = read_exact_n(r, 5)?;
745 let typ = hdr[0];
746 let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
747 let ciphertext = read_exact_n(r, len)?;
748 if len < AEAD_TAG_LEN {
749 return Err(ort_error(ErrorKind::TlsRecordTooShort, "short record"));
750 }
751 debug_print("read_record_cipher hdr", &hdr);
752 debug_print("read_record_cipher ct", &ciphertext);
753
754 let nonce = nonce_xor(iv12, *seq);
757 *seq = seq.wrapping_add(1);
758
759 let mut out = aead::aes_128_gcm_decrypt(key, &nonce, &hdr, &ciphertext).unwrap();
760
761 debug_print("read_record_cipher plaintext hdr", &hdr);
762 debug_print("read_record_cipher plaintext", &out);
763
764 if out.is_empty() {
765 return Ok((typ, ciphertext, 0));
766 }
767 let inner_type = *out.last().unwrap();
769 out.truncate(out.len() - 1);
770 Ok((typ, out, inner_type))
771}
772
773fn read_handshake_message<'a>(rd: &mut &'a [u8]) -> OrtResult<(u8, &'a [u8], &'a [u8])> {
776 if rd.len() < 4 {
777 return Err(ort_error(ErrorKind::TlsHandshakeHeaderTooShort, ""));
778 }
779 let typ = rd[0];
780 let len = ((rd[1] as usize) << 16) | ((rd[2] as usize) << 8) | rd[3] as usize;
781 if rd.len() < 4 + len {
782 return Err(ort_error(ErrorKind::TlsHandshakeBodyTooShort, ""));
783 }
784 let full = &rd[..4 + len];
785 let body = &rd[4..4 + len];
786 *rd = &rd[4 + len..];
787 Ok((typ, body, full))
788}
789
790fn parse_server_hello_for_keys(sh: &[u8]) -> OrtResult<(u16, [u8; 32])> {
791 if sh.len() < 2 + 32 + 1 + 2 + 1 + 2 {
793 return Err(ort_error(ErrorKind::TlsServerHelloTooShort, ""));
794 }
795 let mut p = sh;
796
797 p = &p[2..]; p = &p[32..]; let sid_len = p[0] as usize;
800 p = &p[1..];
801 if p.len() < sid_len + 2 + 1 + 2 {
802 return Err(ort_error(ErrorKind::TlsServerHelloSessionIdInvalid, ""));
803 }
804 p = &p[sid_len..];
805 let cipher = u16::from_be_bytes([p[0], p[1]]);
806 p = &p[2..];
807 let _comp = p[0];
808 p = &p[1..];
809 let ext_len = u16::from_be_bytes([p[0], p[1]]) as usize;
810 p = &p[2..];
811 if p.len() < ext_len {
812 return Err(ort_error(ErrorKind::TlsServerHelloExtTooShort, ""));
813 }
814 let mut ex = &p[..ext_len];
815
816 let mut server_pub = None;
817
818 while !ex.is_empty() {
819 if ex.len() < 4 {
820 return Err(ort_error(ErrorKind::TlsExtensionHeaderTooShort, ""));
821 }
822 let et = u16::from_be_bytes([ex[0], ex[1]]);
823 let el = u16::from_be_bytes([ex[2], ex[3]]) as usize;
824 ex = &ex[4..];
825 if ex.len() < el {
826 return Err(ort_error(ErrorKind::TlsExtensionLengthInvalid, ""));
827 }
828 let ed = &ex[..el];
829 ex = &ex[el..];
830
831 match et {
832 EXT_KEY_SHARE => {
833 if ed.len() < 2 + 2 + 32 {
835 return Err(ort_error(ErrorKind::TlsKeyShareServerHelloInvalid, ""));
836 }
837 let grp = u16::from_be_bytes([ed[0], ed[1]]);
838 if grp != GROUP_X25519 {
839 return Err(ort_error(
840 ErrorKind::TlsServerGroupUnsupported,
841 "server group != x25519",
842 ));
843 }
844 let kx_len = u16::from_be_bytes([ed[2], ed[3]]) as usize;
845 if ed.len() < 4 + kx_len || kx_len != 32 {
846 return Err(ort_error(ErrorKind::TlsKeyShareLengthInvalid, ""));
847 }
848 let mut pk = [0u8; 32];
849 pk.copy_from_slice(&ed[4..4 + 32]);
850 server_pub = Some(pk);
851 }
852 EXT_SUPPORTED_VERSIONS => {
853 if ed.len() != 2 || u16::from_be_bytes([ed[0], ed[1]]) != TLS13 {
854 return Err(ort_error(ErrorKind::TlsServerNotTls13, ""));
855 }
856 }
857 _ => {}
858 }
859 }
860
861 let sp = server_pub.ok_or_else(|| ort_error(ErrorKind::TlsMissingServerKey, ""))?;
862 Ok((cipher, sp))
863}
864
865fn debug_print(_name: &str, _value: &[u8]) {}
866
867#[cfg(test)]
897pub mod tests {
898 extern crate alloc;
899 use alloc::vec::Vec;
900
901 pub fn string_to_bytes(s: &str) -> [u8; 32] {
902 let mut bytes = s.as_bytes();
903 if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'x' || bytes[1] == b'X') {
904 bytes = &bytes[2..];
905 }
906 assert!(
907 bytes.len() == 64,
908 "hex string must be exactly 64 hex chars (32 bytes)"
909 );
910
911 let mut out = [0u8; 32];
912 for i in 0..32 {
913 let hi = hex_val(bytes[2 * i]);
914 let lo = hex_val(bytes[2 * i + 1]);
915 out[i] = (hi << 4) | lo;
916 }
917 out
918 }
919
920 pub fn hex_to_vec(s: &str) -> Vec<u8> {
921 let mut bytes = s.as_bytes();
922 if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'X' || bytes[1] == b'x') {
923 bytes = &bytes[2..];
924 }
925 assert_eq!(bytes.len() % 2, 0, "hex string must have even length");
926 let mut out = Vec::with_capacity(bytes.len() / 2);
927 for chunk in bytes.chunks_exact(2) {
928 let hi = hex_val(chunk[0]);
929 let lo = hex_val(chunk[1]);
930 out.push((hi << 4) | lo);
931 }
932 out
933 }
934
935 fn hex_val(b: u8) -> u8 {
936 match b {
937 b'0'..=b'9' => b - b'0',
938 b'a'..=b'f' => b - b'a' + 10,
939 b'A'..=b'F' => b - b'A' + 10,
940 _ => panic!("invalid hex character"),
941 }
942 }
943}