1use core::{cmp, ffi::CStr};
10
11extern crate alloc;
12use alloc::ffi::CString;
13use alloc::string::ToString;
14use alloc::vec;
15use alloc::vec::Vec;
16
17use crate::{
18 Context, ErrorKind, OrtResult, Read, Write, common::utils::to_ascii, net::AsFd, ort_error,
19 syscall,
20};
21
22mod aead;
23mod ecdh;
24mod hkdf;
25mod hmac;
26mod sha2;
27
28#[allow(unused)]
29const DEBUG_LOG: bool = false;
30
31const MAX_PLAINTEXT_SIZE: usize = 16 * 1024;
33
34const REC_TYPE_CHANGE_CIPHER_SPEC: u8 = 20; const REC_TYPE_ALERT: u8 = 21; const REC_TYPE_HANDSHAKE: u8 = 22; const REC_TYPE_APPDATA: u8 = 23; const LEGACY_REC_VER: u16 = 0x0303;
39
40const HS_CLIENT_HELLO: u8 = 1;
41const HS_SERVER_HELLO: u8 = 2;
42const HS_FINISHED: u8 = 20; const CIPHER_TLS_AES_128_GCM_SHA256: u16 = 0x1301;
50const TLS13: u16 = 0x0304;
52const GROUP_X25519: u16 = 0x001d;
54
55const EXT_SERVER_NAME: u16 = 0x0000;
57const EXT_SUPPORTED_GROUPS: u16 = 0x000a;
58const EXT_SIGNATURE_ALGS: u16 = 0x000d;
59const EXT_SUPPORTED_VERSIONS: u16 = 0x002b;
61const EXT_KEY_SHARE: u16 = 0x0033;
63
64const AEAD_TAG_LEN: usize = 16;
66
67fn put_u16(buf: &mut Vec<u8>, v: u16) {
69 buf.extend_from_slice(&v.to_be_bytes());
70}
71fn put_u24(buf: &mut Vec<u8>, v: usize) {
72 let v = v as u32;
73 buf.extend_from_slice(&[(v >> 16) as u8, (v >> 8) as u8, v as u8]);
74}
75
76fn hkdf_expand_label<const N: usize>(prk: &[u8], label: &str, data: &[u8]) -> [u8; N] {
77 let mut info = Vec::with_capacity(2 + 1 + 6 + label.len() + 1 + data.len());
78 put_u16(&mut info, N as u16);
79 info.push(("tls13 ".len() + label.len()) as u8);
80 info.extend_from_slice("tls13 ".as_bytes());
81 info.extend_from_slice(label.as_bytes());
82 info.push(data.len() as u8);
83 info.extend_from_slice(data);
84
85 hkdf::hkdf_expand(prk, &info, N).try_into().unwrap()
86}
87
88fn digest_bytes(data: &[u8]) -> [u8; 32] {
89 let d = sha2::sha256(data);
90 let mut out = [0u8; 32];
91 out.copy_from_slice(d.as_ref());
92 out
93}
94
95fn nonce_xor(iv12: &[u8; 12], seq: u64) -> [u8; 12] {
97 let mut nonce_bytes = [[0, 0, 0, 0].as_ref(), &u64::to_be_bytes(seq)].concat();
99 nonce_bytes.iter_mut().zip(iv12.iter()).for_each(|(s, iv)| {
101 *s ^= *iv;
102 });
103 nonce_bytes[..12].try_into().unwrap()
104}
105
106pub struct TlsStream<T: Read + Write> {
108 io: T,
109 aead_enc: [u8; 16],
111 aead_dec: [u8; 16],
112 iv_enc: [u8; 12],
113 iv_dec: [u8; 12],
114 seq_enc: u64,
115 seq_dec: u64,
116 rbuf: Vec<u8>,
118 rpos: usize,
119}
120
121fn client_hello_body(sni_host: &str, client_pub: &[u8]) -> Vec<u8> {
122 let mut ch_body = Vec::with_capacity(512);
123
124 let mut random = [0u8; 32];
126 syscall::getrandom(&mut random);
127
128 let mut session_id = [0u8; 32];
129 syscall::getrandom(&mut session_id);
130
131 ch_body.extend_from_slice(&0x0303u16.to_be_bytes());
133 ch_body.extend_from_slice(&random);
135 ch_body.push(session_id.len() as u8);
137 ch_body.extend_from_slice(&session_id);
138 put_u16(&mut ch_body, 2);
140 put_u16(&mut ch_body, CIPHER_TLS_AES_128_GCM_SHA256);
141 ch_body.push(1);
143 ch_body.push(0);
144
145 let mut exts = Vec::with_capacity(512);
147
148 {
150 let host_bytes = sni_host.as_bytes();
151 let mut snl = Vec::with_capacity(3 + host_bytes.len());
152 snl.push(0); put_u16(&mut snl, host_bytes.len() as u16);
154 snl.extend_from_slice(host_bytes);
155
156 let mut sni = Vec::with_capacity(2 + snl.len());
157 put_u16(&mut sni, snl.len() as u16);
158 sni.extend_from_slice(&snl);
159
160 put_u16(&mut exts, EXT_SERVER_NAME);
161 put_u16(&mut exts, sni.len() as u16);
162 exts.extend_from_slice(&sni);
163 }
164
165 {
167 let mut sv = Vec::with_capacity(3);
168 sv.push(2); sv.extend_from_slice(&TLS13.to_be_bytes());
170 put_u16(&mut exts, EXT_SUPPORTED_VERSIONS);
171 put_u16(&mut exts, sv.len() as u16);
172 exts.extend_from_slice(&sv);
173 }
174
175 {
177 let mut sg = Vec::with_capacity(2 + 2);
178 put_u16(&mut sg, 2);
179 put_u16(&mut sg, GROUP_X25519);
180 put_u16(&mut exts, EXT_SUPPORTED_GROUPS);
181 put_u16(&mut exts, sg.len() as u16);
182 exts.extend_from_slice(&sg);
183 }
184
185 {
187 const ECDSA_SECP256R1_SHA256: u16 = 0x0403;
188 const RSA_PSS_RSAE_SHA256: u16 = 0x0804;
189 const RSA_PKCS1_SHA256: u16 = 0x0401;
190
191 let mut sa = Vec::with_capacity(2 + 6);
192 put_u16(&mut sa, 6);
193 put_u16(&mut sa, ECDSA_SECP256R1_SHA256);
194 put_u16(&mut sa, RSA_PSS_RSAE_SHA256);
195 put_u16(&mut sa, RSA_PKCS1_SHA256);
196
197 put_u16(&mut exts, EXT_SIGNATURE_ALGS);
198 put_u16(&mut exts, sa.len() as u16);
199 exts.extend_from_slice(&sa);
200 }
201
202 {
204 let mut ks = Vec::with_capacity(2 + 2 + 2 + 32);
205 let mut entry = Vec::with_capacity(2 + 2 + 32);
207 put_u16(&mut entry, GROUP_X25519);
208 put_u16(&mut entry, 32);
209 entry.extend_from_slice(client_pub);
210 put_u16(&mut ks, entry.len() as u16);
211 ks.extend_from_slice(&entry);
212
213 put_u16(&mut exts, EXT_KEY_SHARE);
214 put_u16(&mut exts, ks.len() as u16);
215 exts.extend_from_slice(&ks);
216 }
217
218 put_u16(&mut ch_body, exts.len() as u16);
220 ch_body.extend_from_slice(&exts);
221
222 ch_body
223}
224
225fn client_hello_msg(sni_host: &str, client_private_key: &[u8]) -> OrtResult<Vec<u8>> {
227 let client_pub_key = ecdh::x25519_public_key(client_private_key);
228 let client_pub_ref = &client_pub_key;
229 debug_print("Client public key", client_pub_ref);
230
231 let ch_body = client_hello_body(sni_host, client_pub_ref);
232
233 let mut ch_msg = Vec::with_capacity(4 + ch_body.len());
235 ch_msg.push(HS_CLIENT_HELLO);
236 put_u24(&mut ch_msg, ch_body.len());
237 ch_msg.extend_from_slice(&ch_body);
238
239 Ok(ch_msg)
240}
241
242fn read_server_hello<R: Read>(io: &mut R) -> OrtResult<(Vec<u8>, Vec<u8>)> {
244 let (typ, payload) = read_record_plain(io).context("read_record_plain in read_server_hello")?;
245 if typ != REC_TYPE_HANDSHAKE {
246 return Err(ort_error(ErrorKind::TlsExpectedHandshakeRecord, ""));
247 }
248 let sh_buf = payload;
249
250 let mut rd = &sh_buf[..];
252 let (sh_typ, sh_body, sh_full) =
253 read_handshake_message(&mut rd).context("read_handshake_message")?;
254 if sh_typ != HS_SERVER_HELLO {
255 return Err(ort_error(ErrorKind::TlsExpectedServerHello, ""));
256 }
257
258 Ok((sh_body.to_vec(), sh_full.to_vec()))
260}
261
262struct HandshakeState {
263 handshake_secret: [u8; 32],
264 client_hs_ts: [u8; 32],
265 server_hs_ts: [u8; 32],
266 client_handshake_iv: [u8; 12],
267 server_handshake_iv: [u8; 12],
268 aead_enc_hs: [u8; 16],
269 aead_dec_hs: [u8; 16],
270 empty_hash: [u8; 32],
271}
272
273struct ApplicationKeys {
274 aead_app_enc: [u8; 16],
275 aead_app_dec: [u8; 16],
276 iv_enc: [u8; 12],
277 iv_dec: [u8; 12],
278}
279
280impl<T: Read + Write> TlsStream<T> {
281 pub fn connect(mut io: T, sni_host: &str) -> OrtResult<Self> {
282 let mut transcript = Vec::with_capacity(8192);
285
286 let mut client_private_key = [0u8; 32];
288 syscall::getrandom(&mut client_private_key);
289 debug_print("Client private key", &client_private_key);
290
291 debug_print("MSG -> ClientHello", &[]);
292 Self::send_client_hello(&mut io, sni_host, &mut transcript, &client_private_key)?;
293
294 debug_print("MSG <- ServerHello", &[]);
295 let sh_body = Self::receive_server_hello(&mut io, &mut transcript)?;
296
297 let handshake = Self::derive_handshake_keys(&client_private_key, &sh_body, &transcript)?;
298
299 debug_print("MSG <- ChangeCipherSpec (dummy)", &[]);
300 Self::receive_dummy_change_cipher_spec(&mut io)?;
301
302 let mut seq_dec_hs = 0u64;
303 let mut seq_enc_hs = 0u64;
304
305 let mut is_finished: bool = false;
306 while !is_finished {
307 debug_print("MSG <- Server flight", &[]);
308 is_finished = Self::receive_server_encrypted_flight(
309 &mut io,
310 &mut seq_dec_hs,
311 &handshake,
312 &mut transcript,
313 )?;
314 }
315
316 let ApplicationKeys {
317 aead_app_enc,
318 aead_app_dec,
319 iv_enc: caiv,
320 iv_dec: saiv,
321 } = Self::derive_application_keys(
322 &handshake.handshake_secret,
323 &handshake.empty_hash,
324 &transcript,
325 );
326
327 let seq_app_enc = 0u64;
328 let seq_app_dec = 0u64;
329
330 debug_print("MSG -> ClientFinished", &[]);
335 Self::send_client_finished(&mut io, &handshake, &mut transcript, &mut seq_enc_hs)?;
336
337 debug_print("TLS connect done", &[]);
338 Ok(TlsStream {
339 io,
340 aead_enc: aead_app_enc,
341 aead_dec: aead_app_dec,
342 iv_enc: caiv,
343 iv_dec: saiv,
344 seq_enc: seq_app_enc,
345 seq_dec: seq_app_dec,
346 rbuf: Vec::with_capacity(16 * 1024),
347 rpos: 0,
348 })
349 }
350
351 pub fn has_buffered_data(&self) -> bool {
352 self.rpos < self.rbuf.len()
363 }
364
365 fn send_client_hello<W: Write>(
366 io: &mut W,
367 sni_host: &str,
368 transcript: &mut Vec<u8>,
369 client_private_key: &[u8; 32],
370 ) -> OrtResult<()> {
371 let ch_msg = client_hello_msg(sni_host, client_private_key)?;
372 write_record_plain(io, REC_TYPE_HANDSHAKE, &ch_msg).context("write ClientHello")?;
373 transcript.extend_from_slice(&ch_msg);
374 Ok(())
375 }
376
377 fn receive_server_hello<R: Read>(io: &mut R, transcript: &mut Vec<u8>) -> OrtResult<Vec<u8>> {
378 let (sh_body, sh_full) = read_server_hello(io)?;
379 transcript.extend_from_slice(&sh_full);
380 Ok(sh_body)
381 }
382
383 fn receive_dummy_change_cipher_spec<R: Read>(io: &mut R) -> OrtResult<()> {
384 let (typ, _) =
386 read_record_plain(io).context("read_record_plain for dummy change cipher")?;
387 if typ != REC_TYPE_CHANGE_CIPHER_SPEC {
388 return Err(ort_error(ErrorKind::TlsExpectedChangeCipherSpec, ""));
389 }
390 Ok(())
391 }
392
393 fn receive_server_encrypted_flight<R: Read>(
396 io: &mut R,
397 seq_dec_hs: &mut u64,
398 handshake: &HandshakeState,
399 transcript: &mut Vec<u8>,
400 ) -> OrtResult<bool> {
401 let (typ, ct, _inner_type) = read_record_cipher(
402 io,
403 &handshake.aead_dec_hs,
404 &handshake.server_handshake_iv,
405 seq_dec_hs,
406 )?;
407 if typ != REC_TYPE_APPDATA {
408 return Err(ort_error(ErrorKind::TlsExpectedEncryptedRecords, ""));
409 }
410
411 let mut p = &ct[..];
414 while !p.is_empty() {
415 let (mtyp, body, full) = match read_handshake_message(&mut p) {
416 Ok(x) => x,
417 Err(_) => {
418 return Err(ort_error(ErrorKind::TlsBadHandshakeFragment, ""));
419 }
420 };
421 transcript.extend_from_slice(full);
422 debug_print("handshake message (type is first byte)", full);
423
424 if mtyp == HS_FINISHED {
425 let s_finished_key =
427 hkdf_expand_label::<32>(&handshake.server_hs_ts, "finished", &[]);
428
429 let thash = digest_bytes(&transcript[..transcript.len() - full.len()]);
430 let expected = hmac::sign(&s_finished_key, &thash);
431 if expected.as_slice() != body {
432 return Err(ort_error(ErrorKind::TlsFinishedVerifyFailed, ""));
433 }
434 return Ok(true);
436 }
437 }
439 Ok(false)
440 }
441
442 fn derive_handshake_keys(
443 client_private_key: &[u8; 32],
444 sh_body: &[u8],
445 transcript: &[u8],
446 ) -> OrtResult<HandshakeState> {
447 let (cipher, server_public_key_bytes) = parse_server_hello_for_keys(sh_body)?;
449 debug_print("Server public key", &server_public_key_bytes);
450 if cipher != CIPHER_TLS_AES_128_GCM_SHA256 {
451 return Err(ort_error(
452 ErrorKind::TlsUnsupportedCipher,
453 "server picked unsupported cipher",
454 ));
455 }
456
457 let hs_shared_secret = ecdh::x25519_agreement(client_private_key, &server_public_key_bytes);
459 debug_print("hs shared secret", &hs_shared_secret);
460
461 let empty_hash = digest_bytes(&[]);
463 debug_print("empty_hash", &empty_hash);
464
465 let zero: [u8; 32] = [0u8; 32];
466 let early_secret = hkdf::hkdf_extract(&zero, &zero);
467
468 let derived_secret_bytes = hkdf_expand_label::<32>(&early_secret, "derived", &empty_hash);
469 debug_print("derived", &derived_secret_bytes);
470
471 let handshake_secret = hkdf::hkdf_extract(&derived_secret_bytes, &hs_shared_secret);
472 debug_print("handshake_secret", &handshake_secret);
473
474 let ch_sh_hash = digest_bytes(transcript);
475 debug_print("digest bytes", &ch_sh_hash);
476
477 let c_hs_ts = hkdf_expand_label(&handshake_secret, "c hs traffic", &ch_sh_hash);
478 let s_hs_ts = hkdf_expand_label(&handshake_secret, "s hs traffic", &ch_sh_hash);
479
480 debug_print("c hs traffic", &c_hs_ts);
481 debug_print("s hs traffic", &s_hs_ts);
482
483 let client_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&c_hs_ts, "key", &[])
485 .as_slice()[..16]
486 .try_into()
487 .unwrap();
488 debug_print("client_handshake_key", &client_handshake_key);
489 let client_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&c_hs_ts, "iv", &[]).as_slice()
490 [..12]
491 .try_into()
492 .unwrap();
493 debug_print("client_handshake_iv", &client_handshake_iv);
494
495 let server_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&s_hs_ts, "key", &[])
496 .as_slice()[..16]
497 .try_into()
498 .unwrap();
499 debug_print("server_handshake_key", &server_handshake_key);
500 let server_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&s_hs_ts, "iv", &[]).as_slice()
501 [..12]
502 .try_into()
503 .unwrap();
504 debug_print("server_handshake_iv", &server_handshake_iv);
505
506 Ok(HandshakeState {
507 handshake_secret,
508 client_hs_ts: c_hs_ts,
509 server_hs_ts: s_hs_ts,
510 client_handshake_iv,
511 server_handshake_iv,
512 aead_enc_hs: client_handshake_key,
513 aead_dec_hs: server_handshake_key,
514 empty_hash,
515 })
516 }
517
518 fn derive_application_keys(
519 handshake_secret: &[u8; 32],
520 empty_hash: &[u8; 32],
521 transcript: &[u8],
522 ) -> ApplicationKeys {
523 let derived2_bytes = hkdf_expand_label::<32>(handshake_secret, "derived", empty_hash);
524 debug_print("derived2_bytes", &derived2_bytes);
525
526 let zero: [u8; 32] = [0u8; 32];
527 let master_secret = hkdf::hkdf_extract(&derived2_bytes, &zero);
528 let thash_srv_fin = digest_bytes(transcript);
529
530 let c_ap_ts = hkdf_expand_label::<32>(&master_secret, "c ap traffic", &thash_srv_fin);
531 let s_ap_ts = hkdf_expand_label::<32>(&master_secret, "s ap traffic", &thash_srv_fin);
532 debug_print("c_ap_ts", &c_ap_ts);
533 debug_print("s_ap_ts", &s_ap_ts);
534
535 let cak: [u8; 16] = hkdf_expand_label::<16>(&c_ap_ts, "key", &[]).as_slice()[..16]
536 .try_into()
537 .unwrap();
538 let caiv: [u8; 12] = hkdf_expand_label::<12>(&c_ap_ts, "iv", &[]).as_slice()[..12]
539 .try_into()
540 .unwrap();
541 debug_print("cak", &cak);
542 debug_print("caiv", &caiv);
543
544 let sak: [u8; 16] = hkdf_expand_label::<16>(&s_ap_ts, "key", &[]).as_slice()[..16]
545 .try_into()
546 .unwrap();
547 let saiv: [u8; 12] = hkdf_expand_label::<12>(&s_ap_ts, "iv", &[]).as_slice()[..12]
548 .try_into()
549 .unwrap();
550 debug_print("sak", &sak);
551 debug_print("saiv", &saiv);
552
553 ApplicationKeys {
554 aead_app_enc: cak,
555 aead_app_dec: sak,
556 iv_enc: caiv,
557 iv_dec: saiv,
558 }
559 }
560
561 fn send_client_finished<W: Write>(
562 io: &mut W,
563 handshake: &HandshakeState,
564 transcript: &mut Vec<u8>,
565 seq_enc_hs: &mut u64,
566 ) -> OrtResult<()> {
567 let c_finished_key = hkdf_expand_label::<32>(&handshake.client_hs_ts, "finished", &[]);
568 debug_print("c_finished", &c_finished_key);
569
570 let thash_client_fin = digest_bytes(transcript.as_slice());
571 let verify_data = hmac::sign(&c_finished_key, &thash_client_fin);
572 debug_print("verify_data", &verify_data);
573
574 let mut fin = Vec::with_capacity(4 + verify_data.as_ref().len());
575 fin.push(HS_FINISHED);
576 put_u24(&mut fin, verify_data.as_ref().len());
577 fin.extend_from_slice(verify_data.as_ref());
578
579 transcript.extend_from_slice(&fin);
581
582 write_record_cipher(
583 io,
584 REC_TYPE_HANDSHAKE,
585 &fin,
586 &handshake.aead_enc_hs,
587 &handshake.client_handshake_iv,
588 seq_enc_hs,
589 )
590 .context("write_record_cipher write_all failed")?;
591
592 Ok(())
593 }
594}
595
596impl<T: Read + Write> Write for TlsStream<T> {
597 fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
598 let mut bytes_sent = 0;
599 for chunk in buf.chunks(MAX_PLAINTEXT_SIZE) {
600 write_record_cipher(
601 &mut self.io,
602 REC_TYPE_APPDATA,
603 chunk,
604 &self.aead_enc,
605 &self.iv_enc,
606 &mut self.seq_enc,
607 )?;
608 bytes_sent += chunk.len();
609 }
610 Ok(bytes_sent)
611 }
612 fn flush(&mut self) -> OrtResult<()> {
613 self.io.flush()
614 }
615}
616
617impl<T: Read + Write> Read for TlsStream<T> {
618 fn read(&mut self, out: &mut [u8]) -> OrtResult<usize> {
619 if self.rpos < self.rbuf.len() {
620 debug_print("TlsStream.read using buf", &[]);
621
622 let n = cmp::min(out.len(), self.rbuf.len() - self.rpos);
623 out[..n].copy_from_slice(&self.rbuf[self.rpos..self.rpos + n]);
624 self.rpos += n;
625 if self.rpos == self.rbuf.len() {
626 self.rbuf.clear();
627 self.rpos = 0;
628 }
629 return Ok(n);
630 }
631 loop {
632 let (typ, plaintext, inner_type) = read_record_cipher(
633 &mut self.io,
634 &self.aead_dec,
635 &self.iv_dec,
636 &mut self.seq_dec,
637 )?;
638 if typ != REC_TYPE_APPDATA {
639 continue;
641 }
642 if plaintext.is_empty() {
644 continue;
645 }
646 if inner_type == REC_TYPE_HANDSHAKE {
647 continue;
649 }
650 if inner_type == REC_TYPE_ALERT {
651 let level = match plaintext[0] {
652 1 => "warning",
653 2 => "fatal",
654 _ => "unknown",
655 };
656 let err_level = CString::new(level.to_string() + " alert: ").unwrap();
657
658 let mut err_code_buf: [u8; 5] = [0u8; 5];
661 let len = to_ascii(plaintext[1] as usize, &mut err_code_buf);
662 let err_code = unsafe { CStr::from_bytes_with_nul_unchecked(&err_code_buf[..len]) };
663 syscall::write(2, err_level.as_ptr().cast(), err_level.count_bytes());
664 syscall::write(2, err_code.as_ptr().cast(), err_code.count_bytes());
665
666 return Err(ort_error(ErrorKind::TlsAlertReceived, ""));
667 }
668 if inner_type != REC_TYPE_APPDATA {
669 }
672 if plaintext.is_empty() {
673 continue;
674 }
675
676 self.rbuf.extend_from_slice(&plaintext);
677 self.rpos = 0;
678 let n = cmp::min(out.len(), self.rbuf.len());
680 out[..n].copy_from_slice(&self.rbuf[..n]);
681 self.rpos = n;
682 if n == self.rbuf.len() {
683 self.rbuf.clear();
684 self.rpos = 0;
685 }
686 return Ok(n);
687 }
688 }
689}
690
691impl<T: Read + Write + AsFd> AsFd for TlsStream<T> {
692 fn as_fd(&self) -> i32 {
693 self.io.as_fd()
694 }
695}
696
697fn write_record_plain<W: Write>(w: &mut W, typ: u8, body: &[u8]) -> OrtResult<()> {
700 let mut hdr = [0u8; 5];
701 hdr[0] = typ;
702 hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
703 hdr[3..5].copy_from_slice(&(body.len() as u16).to_be_bytes());
704 w.write_all(&hdr)?;
705 w.write_all(body)?;
706 Ok(())
707}
708
709fn read_exact_n<R: Read>(r: &mut R, n: usize) -> OrtResult<Vec<u8>> {
710 let mut buf = vec![0u8; n];
711 r.read_exact(&mut buf)?;
712 Ok(buf)
713}
714
715fn read_record_plain<R: Read>(r: &mut R) -> OrtResult<(u8, Vec<u8>)> {
716 let hdr = read_exact_n(r, 5)?; let typ = hdr[0];
718 let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
719 let body = read_exact_n(r, len)?;
720 debug_print("read_record_plain hdr", &hdr);
721 debug_print("read_record_plain body", &body);
722 Ok((typ, body))
724}
725
726fn write_record_cipher<W: Write>(
727 w: &mut W,
728 outer_type: u8,
729 inner: &[u8],
730 key: &[u8; 16],
731 iv12: &[u8; 12],
732 seq: &mut u64,
733) -> OrtResult<()> {
734 let total_len = inner.len() + 1 + AEAD_TAG_LEN;
736 let mut plain = Vec::with_capacity(total_len);
737 plain.extend_from_slice(inner);
738 plain.push(outer_type);
739
740 debug_print("write_record_cipher plaintext", &plain);
741
742 let nonce = nonce_xor(iv12, *seq);
743 *seq = seq.wrapping_add(1);
744
745 let mut hdr = [0u8; 5];
746 hdr[0] = REC_TYPE_APPDATA;
747 hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
748 hdr[3..5].copy_from_slice(&(total_len as u16).to_be_bytes());
749
750 let out = aead::aes_128_gcm_encrypt(key, &nonce, &hdr, &plain).unwrap();
751
752 debug_print("write_record_cipher header", &hdr);
753 w.write_all(&hdr)?;
757 w.write_all(&out)?;
758 Ok(())
759}
760
761fn read_record_cipher<R: Read>(
762 r: &mut R,
763 key: &[u8; 16],
764 iv12: &[u8; 12],
765 seq: &mut u64,
766) -> OrtResult<(u8, Vec<u8>, u8)> {
767 let hdr = read_exact_n(r, 5)?;
768 let typ = hdr[0];
769 let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
770 let ciphertext = read_exact_n(r, len)?;
771 if len < AEAD_TAG_LEN {
772 return Err(ort_error(ErrorKind::TlsRecordTooShort, "short record"));
773 }
774 debug_print("read_record_cipher hdr", &hdr);
775 debug_print("read_record_cipher ct", &ciphertext);
776
777 let nonce = nonce_xor(iv12, *seq);
785 *seq = seq.wrapping_add(1);
786
787 let mut out = match aead::aes_128_gcm_decrypt(key, &nonce, &hdr, &ciphertext) {
788 Ok(out) => out,
789 Err(s) => {
790 return Err(ort_error(ErrorKind::TlsAes128GcmDecryptFailed, s));
791 }
792 };
793
794 debug_print("read_record_cipher plaintext hdr", &hdr);
795 debug_print("read_record_cipher plaintext", &out);
796
797 if out.is_empty() {
798 return Ok((typ, ciphertext, 0));
799 }
800 let inner_type = *out.last().unwrap();
802 out.truncate(out.len() - 1);
803 Ok((typ, out, inner_type))
804}
805
806fn read_handshake_message<'a>(rd: &mut &'a [u8]) -> OrtResult<(u8, &'a [u8], &'a [u8])> {
809 if rd.len() < 4 {
810 return Err(ort_error(ErrorKind::TlsHandshakeHeaderTooShort, ""));
811 }
812 let typ = rd[0];
813 let len = ((rd[1] as usize) << 16) | ((rd[2] as usize) << 8) | rd[3] as usize;
814 if rd.len() < 4 + len {
815 return Err(ort_error(ErrorKind::TlsHandshakeBodyTooShort, ""));
816 }
817 let full = &rd[..4 + len];
818 let body = &rd[4..4 + len];
819 *rd = &rd[4 + len..];
820 Ok((typ, body, full))
821}
822
823fn parse_server_hello_for_keys(sh: &[u8]) -> OrtResult<(u16, [u8; 32])> {
824 if sh.len() < 2 + 32 + 1 + 2 + 1 + 2 {
826 return Err(ort_error(ErrorKind::TlsServerHelloTooShort, ""));
827 }
828 let mut p = sh;
829
830 p = &p[2..]; p = &p[32..]; let sid_len = p[0] as usize;
833 p = &p[1..];
834 if p.len() < sid_len + 2 + 1 + 2 {
835 return Err(ort_error(ErrorKind::TlsServerHelloSessionIdInvalid, ""));
836 }
837 p = &p[sid_len..];
838 let cipher = u16::from_be_bytes([p[0], p[1]]);
839 p = &p[2..];
840 let _comp = p[0];
841 p = &p[1..];
842 let ext_len = u16::from_be_bytes([p[0], p[1]]) as usize;
843 p = &p[2..];
844 if p.len() < ext_len {
845 return Err(ort_error(ErrorKind::TlsServerHelloExtTooShort, ""));
846 }
847 let mut ex = &p[..ext_len];
848
849 let mut server_pub = None;
850
851 while !ex.is_empty() {
852 if ex.len() < 4 {
853 return Err(ort_error(ErrorKind::TlsExtensionHeaderTooShort, ""));
854 }
855 let et = u16::from_be_bytes([ex[0], ex[1]]);
856 let el = u16::from_be_bytes([ex[2], ex[3]]) as usize;
857 ex = &ex[4..];
858 if ex.len() < el {
859 return Err(ort_error(ErrorKind::TlsExtensionLengthInvalid, ""));
860 }
861 let ed = &ex[..el];
862 ex = &ex[el..];
863
864 match et {
865 EXT_KEY_SHARE => {
866 if ed.len() < 2 + 2 + 32 {
868 return Err(ort_error(ErrorKind::TlsKeyShareServerHelloInvalid, ""));
869 }
870 let grp = u16::from_be_bytes([ed[0], ed[1]]);
871 if grp != GROUP_X25519 {
872 return Err(ort_error(
873 ErrorKind::TlsServerGroupUnsupported,
874 "server group != x25519",
875 ));
876 }
877 let kx_len = u16::from_be_bytes([ed[2], ed[3]]) as usize;
878 if ed.len() < 4 + kx_len || kx_len != 32 {
879 return Err(ort_error(ErrorKind::TlsKeyShareLengthInvalid, ""));
880 }
881 let mut pk = [0u8; 32];
882 pk.copy_from_slice(&ed[4..4 + 32]);
883 server_pub = Some(pk);
884 }
885 EXT_SUPPORTED_VERSIONS
886 if (ed.len() != 2 || u16::from_be_bytes([ed[0], ed[1]]) != TLS13) =>
887 {
888 return Err(ort_error(ErrorKind::TlsServerNotTls13, ""));
889 }
890 _ => {}
891 }
892 }
893
894 let sp = server_pub.ok_or_else(|| ort_error(ErrorKind::TlsMissingServerKey, ""))?;
895 Ok((cipher, sp))
896}
897
898#[allow(unused)]
899fn debug_print(name: &str, value: &[u8]) {
900 #[cfg(debug_assertions)]
901 {
902 if !DEBUG_LOG {
903 return;
904 }
905 let c_str = CString::new(name).unwrap();
906 if !value.is_empty() {
907 crate::utils::print_hex(c_str.as_c_str(), value);
908 } else {
909 crate::utils::print_string(c_str.as_c_str(), "");
910 }
911 }
912}
913
914#[cfg(test)]
924pub mod tests {
925 extern crate alloc;
926 use alloc::vec::Vec;
927
928 pub fn string_to_bytes(s: &str) -> [u8; 32] {
929 let mut bytes = s.as_bytes();
930 if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'x' || bytes[1] == b'X') {
931 bytes = &bytes[2..];
932 }
933 assert!(
934 bytes.len() == 64,
935 "hex string must be exactly 64 hex chars (32 bytes)"
936 );
937
938 let mut out = [0u8; 32];
939 for i in 0..32 {
940 let hi = hex_val(bytes[2 * i]);
941 let lo = hex_val(bytes[2 * i + 1]);
942 out[i] = (hi << 4) | lo;
943 }
944 out
945 }
946
947 pub fn hex_to_vec(s: &str) -> Vec<u8> {
948 let mut bytes = s.as_bytes();
949 if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'X' || bytes[1] == b'x') {
950 bytes = &bytes[2..];
951 }
952 assert_eq!(bytes.len() % 2, 0, "hex string must have even length");
953 let mut out = Vec::with_capacity(bytes.len() / 2);
954 for chunk in bytes.chunks_exact(2) {
955 let hi = hex_val(chunk[0]);
956 let lo = hex_val(chunk[1]);
957 out.push((hi << 4) | lo);
958 }
959 out
960 }
961
962 fn hex_val(b: u8) -> u8 {
963 match b {
964 b'0'..=b'9' => b - b'0',
965 b'a'..=b'f' => b - b'a' + 10,
966 b'A'..=b'F' => b - b'A' + 10,
967 _ => panic!("invalid hex character"),
968 }
969 }
970}