1use core::ffi::c_void;
10use core::{cmp, ffi::CStr};
11
12extern crate alloc;
13use alloc::ffi::CString;
14use alloc::string::ToString;
15use alloc::vec;
16use alloc::vec::Vec;
17
18use crate::{Context, ErrorKind, OrtResult, Read, Write, common::utils::to_ascii, libc, ort_error};
19
20mod aead;
21mod ecdh;
22mod hkdf;
23mod hmac;
24mod sha2;
25
26#[allow(unused)]
27const DEBUG_LOG: bool = false;
28
29const REC_TYPE_CHANGE_CIPHER_SPEC: u8 = 20; const REC_TYPE_ALERT: u8 = 21; const REC_TYPE_HANDSHAKE: u8 = 22; const REC_TYPE_APPDATA: u8 = 23; const LEGACY_REC_VER: u16 = 0x0303;
34
35const HS_CLIENT_HELLO: u8 = 1;
36const HS_SERVER_HELLO: u8 = 2;
37const HS_FINISHED: u8 = 20; const CIPHER_TLS_AES_128_GCM_SHA256: u16 = 0x1301;
45const TLS13: u16 = 0x0304;
47const GROUP_X25519: u16 = 0x001d;
49
50const EXT_SERVER_NAME: u16 = 0x0000;
52const EXT_SUPPORTED_GROUPS: u16 = 0x000a;
53const EXT_SIGNATURE_ALGS: u16 = 0x000d;
54const EXT_SUPPORTED_VERSIONS: u16 = 0x002b;
56const EXT_KEY_SHARE: u16 = 0x0033;
58
59const AEAD_TAG_LEN: usize = 16;
61
62fn put_u16(buf: &mut Vec<u8>, v: u16) {
64 buf.extend_from_slice(&v.to_be_bytes());
65}
66fn put_u24(buf: &mut Vec<u8>, v: usize) {
67 let v = v as u32;
68 buf.extend_from_slice(&[(v >> 16) as u8, (v >> 8) as u8, v as u8]);
69}
70
71fn hkdf_expand_label<const N: usize>(prk: &[u8], label: &str, data: &[u8]) -> [u8; N] {
72 let mut info = Vec::with_capacity(2 + 1 + 6 + label.len() + 1 + data.len());
73 put_u16(&mut info, N as u16);
74 info.push(("tls13 ".len() + label.len()) as u8);
75 info.extend_from_slice("tls13 ".as_bytes());
76 info.extend_from_slice(label.as_bytes());
77 info.push(data.len() as u8);
78 info.extend_from_slice(data);
79
80 hkdf::hkdf_expand(prk, &info, N).try_into().unwrap()
81}
82
83fn digest_bytes(data: &[u8]) -> [u8; 32] {
84 let d = sha2::sha256(data);
85 let mut out = [0u8; 32];
86 out.copy_from_slice(d.as_ref());
87 out
88}
89
90fn nonce_xor(iv12: &[u8; 12], seq: u64) -> [u8; 12] {
92 let mut nonce_bytes = [[0, 0, 0, 0].as_ref(), &u64::to_be_bytes(seq)].concat();
94 nonce_bytes.iter_mut().zip(iv12.iter()).for_each(|(s, iv)| {
96 *s ^= *iv;
97 });
98 nonce_bytes[..12].try_into().unwrap()
99}
100
101pub struct TlsStream<T: Read + Write> {
103 io: T,
104 aead_enc: [u8; 16],
106 aead_dec: [u8; 16],
107 iv_enc: [u8; 12],
108 iv_dec: [u8; 12],
109 seq_enc: u64,
110 seq_dec: u64,
111 rbuf: Vec<u8>,
113 rpos: usize,
114}
115
116fn client_hello_body(sni_host: &str, client_pub: &[u8]) -> Vec<u8> {
117 let mut ch_body = Vec::with_capacity(512);
118
119 let mut random = [0u8; 32];
121 let got_bytes = unsafe { libc::getrandom(random.as_mut_ptr() as *mut c_void, 32, 0) };
122 debug_assert_eq!(got_bytes, 32);
123
124 let mut session_id = [0u8; 32];
125 let got_bytes = unsafe { libc::getrandom(session_id.as_mut_ptr() as *mut c_void, 32, 0) };
126 debug_assert_eq!(got_bytes, 32);
127
128 ch_body.extend_from_slice(&0x0303u16.to_be_bytes());
130 ch_body.extend_from_slice(&random);
132 ch_body.push(session_id.len() as u8);
134 ch_body.extend_from_slice(&session_id);
135 put_u16(&mut ch_body, 2);
137 put_u16(&mut ch_body, CIPHER_TLS_AES_128_GCM_SHA256);
138 ch_body.push(1);
140 ch_body.push(0);
141
142 let mut exts = Vec::with_capacity(512);
144
145 {
147 let host_bytes = sni_host.as_bytes();
148 let mut snl = Vec::with_capacity(3 + host_bytes.len());
149 snl.push(0); put_u16(&mut snl, host_bytes.len() as u16);
151 snl.extend_from_slice(host_bytes);
152
153 let mut sni = Vec::with_capacity(2 + snl.len());
154 put_u16(&mut sni, snl.len() as u16);
155 sni.extend_from_slice(&snl);
156
157 put_u16(&mut exts, EXT_SERVER_NAME);
158 put_u16(&mut exts, sni.len() as u16);
159 exts.extend_from_slice(&sni);
160 }
161
162 {
164 let mut sv = Vec::with_capacity(3);
165 sv.push(2); sv.extend_from_slice(&TLS13.to_be_bytes());
167 put_u16(&mut exts, EXT_SUPPORTED_VERSIONS);
168 put_u16(&mut exts, sv.len() as u16);
169 exts.extend_from_slice(&sv);
170 }
171
172 {
174 let mut sg = Vec::with_capacity(2 + 2);
175 put_u16(&mut sg, 2);
176 put_u16(&mut sg, GROUP_X25519);
177 put_u16(&mut exts, EXT_SUPPORTED_GROUPS);
178 put_u16(&mut exts, sg.len() as u16);
179 exts.extend_from_slice(&sg);
180 }
181
182 {
184 const ECDSA_SECP256R1_SHA256: u16 = 0x0403;
185 const RSA_PSS_RSAE_SHA256: u16 = 0x0804;
186 const RSA_PKCS1_SHA256: u16 = 0x0401;
187
188 let mut sa = Vec::with_capacity(2 + 6);
189 put_u16(&mut sa, 6);
190 put_u16(&mut sa, ECDSA_SECP256R1_SHA256);
191 put_u16(&mut sa, RSA_PSS_RSAE_SHA256);
192 put_u16(&mut sa, RSA_PKCS1_SHA256);
193
194 put_u16(&mut exts, EXT_SIGNATURE_ALGS);
195 put_u16(&mut exts, sa.len() as u16);
196 exts.extend_from_slice(&sa);
197 }
198
199 {
201 let mut ks = Vec::with_capacity(2 + 2 + 2 + 32);
202 let mut entry = Vec::with_capacity(2 + 2 + 32);
204 put_u16(&mut entry, GROUP_X25519);
205 put_u16(&mut entry, 32);
206 entry.extend_from_slice(client_pub);
207 put_u16(&mut ks, entry.len() as u16);
208 ks.extend_from_slice(&entry);
209
210 put_u16(&mut exts, EXT_KEY_SHARE);
211 put_u16(&mut exts, ks.len() as u16);
212 exts.extend_from_slice(&ks);
213 }
214
215 put_u16(&mut ch_body, exts.len() as u16);
217 ch_body.extend_from_slice(&exts);
218
219 ch_body
220}
221
222fn client_hello_msg(sni_host: &str, client_private_key: &[u8]) -> OrtResult<Vec<u8>> {
224 let client_pub_key = ecdh::x25519_public_key(client_private_key);
225 let client_pub_ref = &client_pub_key;
226 debug_print("Client public key", client_pub_ref);
227
228 let ch_body = client_hello_body(sni_host, client_pub_ref);
229
230 let mut ch_msg = Vec::with_capacity(4 + ch_body.len());
232 ch_msg.push(HS_CLIENT_HELLO);
233 put_u24(&mut ch_msg, ch_body.len());
234 ch_msg.extend_from_slice(&ch_body);
235
236 Ok(ch_msg)
237}
238
239fn read_server_hello<R: Read>(io: &mut R) -> OrtResult<(Vec<u8>, Vec<u8>)> {
241 let (typ, payload) = read_record_plain(io).context("read_record_plain in read_server_hello")?;
242 if typ != REC_TYPE_HANDSHAKE {
243 return Err(ort_error(ErrorKind::TlsExpectedHandshakeRecord, ""));
244 }
245 let sh_buf = payload;
246
247 let mut rd = &sh_buf[..];
249 let (sh_typ, sh_body, sh_full) =
250 read_handshake_message(&mut rd).context("read_handshake_message")?;
251 if sh_typ != HS_SERVER_HELLO {
252 return Err(ort_error(ErrorKind::TlsExpectedServerHello, ""));
253 }
254
255 Ok((sh_body.to_vec(), sh_full.to_vec()))
257}
258
259struct HandshakeState {
260 handshake_secret: [u8; 32],
261 client_hs_ts: [u8; 32],
262 server_hs_ts: [u8; 32],
263 client_handshake_iv: [u8; 12],
264 server_handshake_iv: [u8; 12],
265 aead_enc_hs: [u8; 16],
266 aead_dec_hs: [u8; 16],
267 empty_hash: [u8; 32],
268}
269
270struct ApplicationKeys {
271 aead_app_enc: [u8; 16],
272 aead_app_dec: [u8; 16],
273 iv_enc: [u8; 12],
274 iv_dec: [u8; 12],
275}
276
277impl<T: Read + Write> TlsStream<T> {
278 pub fn connect(mut io: T, sni_host: &str) -> OrtResult<Self> {
279 let mut transcript = Vec::with_capacity(8192);
282
283 let mut client_private_key = [0u8; 32];
285 let _ = unsafe { libc::getrandom(client_private_key.as_mut_ptr() as *mut c_void, 32, 0) };
286 debug_print("Client private key", &client_private_key);
287
288 debug_print("MSG -> ClientHello", &[]);
289 Self::send_client_hello(&mut io, sni_host, &mut transcript, &client_private_key)?;
290
291 debug_print("MSG <- ServerHello", &[]);
292 let sh_body = Self::receive_server_hello(&mut io, &mut transcript)?;
293
294 let handshake = Self::derive_handshake_keys(&client_private_key, &sh_body, &transcript)?;
295
296 debug_print("MSG <- ChangeCipherSpec (dummy)", &[]);
297 Self::receive_dummy_change_cipher_spec(&mut io)?;
298
299 let mut seq_dec_hs = 0u64;
300 let mut seq_enc_hs = 0u64;
301
302 let mut is_finished: bool = false;
303 while !is_finished {
304 debug_print("MSG <- Server flight", &[]);
305 is_finished = Self::receive_server_encrypted_flight(
306 &mut io,
307 &mut seq_dec_hs,
308 &handshake,
309 &mut transcript,
310 )?;
311 }
312
313 let ApplicationKeys {
314 aead_app_enc,
315 aead_app_dec,
316 iv_enc: caiv,
317 iv_dec: saiv,
318 } = Self::derive_application_keys(
319 &handshake.handshake_secret,
320 &handshake.empty_hash,
321 &transcript,
322 );
323
324 let seq_app_enc = 0u64;
325 let seq_app_dec = 0u64;
326
327 debug_print("MSG -> ClientFinished", &[]);
332 Self::send_client_finished(&mut io, &handshake, &mut transcript, &mut seq_enc_hs)?;
333
334 debug_print("TLS connect done", &[]);
335 Ok(TlsStream {
336 io,
337 aead_enc: aead_app_enc,
338 aead_dec: aead_app_dec,
339 iv_enc: caiv,
340 iv_dec: saiv,
341 seq_enc: seq_app_enc,
342 seq_dec: seq_app_dec,
343 rbuf: Vec::with_capacity(16 * 1024),
344 rpos: 0,
345 })
346 }
347
348 fn send_client_hello<W: Write>(
349 io: &mut W,
350 sni_host: &str,
351 transcript: &mut Vec<u8>,
352 client_private_key: &[u8; 32],
353 ) -> OrtResult<()> {
354 let ch_msg = client_hello_msg(sni_host, client_private_key)?;
355 write_record_plain(io, REC_TYPE_HANDSHAKE, &ch_msg).context("write ClientHello")?;
356 transcript.extend_from_slice(&ch_msg);
357 Ok(())
358 }
359
360 fn receive_server_hello<R: Read>(io: &mut R, transcript: &mut Vec<u8>) -> OrtResult<Vec<u8>> {
361 let (sh_body, sh_full) = read_server_hello(io)?;
362 transcript.extend_from_slice(&sh_full);
363 Ok(sh_body)
364 }
365
366 fn receive_dummy_change_cipher_spec<R: Read>(io: &mut R) -> OrtResult<()> {
367 let (typ, _) =
369 read_record_plain(io).context("read_record_plain for dummy change cipher")?;
370 if typ != REC_TYPE_CHANGE_CIPHER_SPEC {
371 return Err(ort_error(ErrorKind::TlsExpectedChangeCipherSpec, ""));
372 }
373 Ok(())
374 }
375
376 fn receive_server_encrypted_flight<R: Read>(
379 io: &mut R,
380 seq_dec_hs: &mut u64,
381 handshake: &HandshakeState,
382 transcript: &mut Vec<u8>,
383 ) -> OrtResult<bool> {
384 let (typ, ct, _inner_type) = read_record_cipher(
385 io,
386 &handshake.aead_dec_hs,
387 &handshake.server_handshake_iv,
388 seq_dec_hs,
389 )?;
390 if typ != REC_TYPE_APPDATA {
391 return Err(ort_error(ErrorKind::TlsExpectedEncryptedRecords, ""));
392 }
393
394 let mut p = &ct[..];
397 while !p.is_empty() {
398 let (mtyp, body, full) = match read_handshake_message(&mut p) {
399 Ok(x) => x,
400 Err(_) => {
401 return Err(ort_error(ErrorKind::TlsBadHandshakeFragment, ""));
402 }
403 };
404 transcript.extend_from_slice(full);
405 debug_print("handshake message (type is first byte)", full);
406
407 if mtyp == HS_FINISHED {
408 let s_finished_key =
410 hkdf_expand_label::<32>(&handshake.server_hs_ts, "finished", &[]);
411
412 let thash = digest_bytes(&transcript[..transcript.len() - full.len()]);
413 let expected = hmac::sign(&s_finished_key, &thash);
414 if expected.as_slice() != body {
415 return Err(ort_error(ErrorKind::TlsFinishedVerifyFailed, ""));
416 }
417 return Ok(true);
419 }
420 }
422 Ok(false)
423 }
424
425 fn derive_handshake_keys(
426 client_private_key: &[u8; 32],
427 sh_body: &[u8],
428 transcript: &[u8],
429 ) -> OrtResult<HandshakeState> {
430 let (cipher, server_public_key_bytes) = parse_server_hello_for_keys(sh_body)?;
432 debug_print("Server public key", &server_public_key_bytes);
433 if cipher != CIPHER_TLS_AES_128_GCM_SHA256 {
434 return Err(ort_error(
435 ErrorKind::TlsUnsupportedCipher,
436 "server picked unsupported cipher",
437 ));
438 }
439
440 let hs_shared_secret = ecdh::x25519_agreement(client_private_key, &server_public_key_bytes);
442 debug_print("hs shared secret", &hs_shared_secret);
443
444 let empty_hash = digest_bytes(&[]);
446 debug_print("empty_hash", &empty_hash);
447
448 let zero: [u8; 32] = [0u8; 32];
449 let early_secret = hkdf::hkdf_extract(&zero, &zero);
450
451 let derived_secret_bytes = hkdf_expand_label::<32>(&early_secret, "derived", &empty_hash);
452 debug_print("derived", &derived_secret_bytes);
453
454 let handshake_secret = hkdf::hkdf_extract(&derived_secret_bytes, &hs_shared_secret);
455 debug_print("handshake_secret", &handshake_secret);
456
457 let ch_sh_hash = digest_bytes(transcript);
458 debug_print("digest bytes", &ch_sh_hash);
459
460 let c_hs_ts = hkdf_expand_label(&handshake_secret, "c hs traffic", &ch_sh_hash);
461 let s_hs_ts = hkdf_expand_label(&handshake_secret, "s hs traffic", &ch_sh_hash);
462
463 debug_print("c hs traffic", &c_hs_ts);
464 debug_print("s hs traffic", &s_hs_ts);
465
466 let client_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&c_hs_ts, "key", &[])
468 .as_slice()[..16]
469 .try_into()
470 .unwrap();
471 debug_print("client_handshake_key", &client_handshake_key);
472 let client_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&c_hs_ts, "iv", &[]).as_slice()
473 [..12]
474 .try_into()
475 .unwrap();
476 debug_print("client_handshake_iv", &client_handshake_iv);
477
478 let server_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&s_hs_ts, "key", &[])
479 .as_slice()[..16]
480 .try_into()
481 .unwrap();
482 debug_print("server_handshake_key", &server_handshake_key);
483 let server_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&s_hs_ts, "iv", &[]).as_slice()
484 [..12]
485 .try_into()
486 .unwrap();
487 debug_print("server_handshake_iv", &server_handshake_iv);
488
489 Ok(HandshakeState {
490 handshake_secret,
491 client_hs_ts: c_hs_ts,
492 server_hs_ts: s_hs_ts,
493 client_handshake_iv,
494 server_handshake_iv,
495 aead_enc_hs: client_handshake_key,
496 aead_dec_hs: server_handshake_key,
497 empty_hash,
498 })
499 }
500
501 fn derive_application_keys(
502 handshake_secret: &[u8; 32],
503 empty_hash: &[u8; 32],
504 transcript: &[u8],
505 ) -> ApplicationKeys {
506 let derived2_bytes = hkdf_expand_label::<32>(handshake_secret, "derived", empty_hash);
507 debug_print("derived2_bytes", &derived2_bytes);
508
509 let zero: [u8; 32] = [0u8; 32];
510 let master_secret = hkdf::hkdf_extract(&derived2_bytes, &zero);
511 let thash_srv_fin = digest_bytes(transcript);
512
513 let c_ap_ts = hkdf_expand_label::<32>(&master_secret, "c ap traffic", &thash_srv_fin);
514 let s_ap_ts = hkdf_expand_label::<32>(&master_secret, "s ap traffic", &thash_srv_fin);
515 debug_print("c_ap_ts", &c_ap_ts);
516 debug_print("s_ap_ts", &s_ap_ts);
517
518 let cak: [u8; 16] = hkdf_expand_label::<16>(&c_ap_ts, "key", &[]).as_slice()[..16]
519 .try_into()
520 .unwrap();
521 let caiv: [u8; 12] = hkdf_expand_label::<12>(&c_ap_ts, "iv", &[]).as_slice()[..12]
522 .try_into()
523 .unwrap();
524 debug_print("cak", &cak);
525 debug_print("caiv", &caiv);
526
527 let sak: [u8; 16] = hkdf_expand_label::<16>(&s_ap_ts, "key", &[]).as_slice()[..16]
528 .try_into()
529 .unwrap();
530 let saiv: [u8; 12] = hkdf_expand_label::<12>(&s_ap_ts, "iv", &[]).as_slice()[..12]
531 .try_into()
532 .unwrap();
533 debug_print("sak", &sak);
534 debug_print("saiv", &saiv);
535
536 ApplicationKeys {
537 aead_app_enc: cak,
538 aead_app_dec: sak,
539 iv_enc: caiv,
540 iv_dec: saiv,
541 }
542 }
543
544 fn send_client_finished<W: Write>(
545 io: &mut W,
546 handshake: &HandshakeState,
547 transcript: &mut Vec<u8>,
548 seq_enc_hs: &mut u64,
549 ) -> OrtResult<()> {
550 let c_finished_key = hkdf_expand_label::<32>(&handshake.client_hs_ts, "finished", &[]);
551 debug_print("c_finished", &c_finished_key);
552
553 let thash_client_fin = digest_bytes(transcript.as_slice());
554 let verify_data = hmac::sign(&c_finished_key, &thash_client_fin);
555 debug_print("verify_data", &verify_data);
556
557 let mut fin = Vec::with_capacity(4 + verify_data.as_ref().len());
558 fin.push(HS_FINISHED);
559 put_u24(&mut fin, verify_data.as_ref().len());
560 fin.extend_from_slice(verify_data.as_ref());
561
562 transcript.extend_from_slice(&fin);
564
565 write_record_cipher(
566 io,
567 REC_TYPE_HANDSHAKE,
568 &fin,
569 &handshake.aead_enc_hs,
570 &handshake.client_handshake_iv,
571 seq_enc_hs,
572 )
573 .context("write_record_cipher write_all failed")?;
574
575 Ok(())
576 }
577}
578
579impl<T: Read + Write> Write for TlsStream<T> {
580 fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
581 write_record_cipher(
582 &mut self.io,
583 REC_TYPE_APPDATA,
584 buf,
585 &self.aead_enc,
586 &self.iv_enc,
587 &mut self.seq_enc,
588 )
589 .map(|_| buf.len())
590 }
591 fn flush(&mut self) -> OrtResult<()> {
592 self.io.flush()
593 }
594}
595
596impl<T: Read + Write> Read for TlsStream<T> {
597 fn read(&mut self, out: &mut [u8]) -> OrtResult<usize> {
598 if self.rpos < self.rbuf.len() {
599 debug_print("TlsStream.read using buf", &[]);
600
601 let n = cmp::min(out.len(), self.rbuf.len() - self.rpos);
602 out[..n].copy_from_slice(&self.rbuf[self.rpos..self.rpos + n]);
603 self.rpos += n;
604 if self.rpos == self.rbuf.len() {
605 self.rbuf.clear();
606 self.rpos = 0;
607 }
608 return Ok(n);
609 }
610 loop {
611 let (typ, plaintext, inner_type) = read_record_cipher(
612 &mut self.io,
613 &self.aead_dec,
614 &self.iv_dec,
615 &mut self.seq_dec,
616 )?;
617 if typ != REC_TYPE_APPDATA {
618 continue;
620 }
621 if plaintext.is_empty() {
623 continue;
624 }
625 if inner_type == REC_TYPE_HANDSHAKE {
626 continue;
628 }
629 if inner_type == REC_TYPE_ALERT {
630 let level = match plaintext[0] {
631 1 => "warning",
632 2 => "fatal",
633 _ => "unknown",
634 };
635 let err_level = CString::new(level.to_string() + " alert: ").unwrap();
636
637 let mut err_code_buf: [u8; 5] = [0u8; 5];
640 let len = to_ascii(plaintext[1] as usize, &mut err_code_buf);
641 let err_code = unsafe { CStr::from_bytes_with_nul_unchecked(&err_code_buf[..len]) };
642 unsafe {
643 libc::write(2, err_level.as_ptr().cast(), err_level.count_bytes());
644 libc::write(2, err_code.as_ptr().cast(), err_code.count_bytes());
645 }
646
647 return Err(ort_error(ErrorKind::TlsAlertReceived, ""));
648 }
649 if inner_type != REC_TYPE_APPDATA {
650 }
653 if plaintext.is_empty() {
654 continue;
655 }
656
657 self.rbuf.extend_from_slice(&plaintext);
658 self.rpos = 0;
659 let n = cmp::min(out.len(), self.rbuf.len());
661 out[..n].copy_from_slice(&self.rbuf[..n]);
662 self.rpos = n;
663 if n == self.rbuf.len() {
664 self.rbuf.clear();
665 self.rpos = 0;
666 }
667 return Ok(n);
668 }
669 }
670}
671
672fn write_record_plain<W: Write>(w: &mut W, typ: u8, body: &[u8]) -> OrtResult<()> {
675 let mut hdr = [0u8; 5];
676 hdr[0] = typ;
677 hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
678 hdr[3..5].copy_from_slice(&(body.len() as u16).to_be_bytes());
679 w.write_all(&hdr)?;
680 w.write_all(body)?;
681 Ok(())
682}
683
684fn read_exact_n<R: Read>(r: &mut R, n: usize) -> OrtResult<Vec<u8>> {
685 let mut buf = vec![0u8; n];
686 r.read_exact(&mut buf)?;
687 Ok(buf)
688}
689
690fn read_record_plain<R: Read>(r: &mut R) -> OrtResult<(u8, Vec<u8>)> {
691 let hdr = read_exact_n(r, 5)?; let typ = hdr[0];
693 let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
694 let body = read_exact_n(r, len)?;
695 debug_print("read_record_plain hdr", &hdr);
696 debug_print("read_record_plain body", &body);
697 Ok((typ, body))
699}
700
701fn write_record_cipher<W: Write>(
702 w: &mut W,
703 outer_type: u8,
704 inner: &[u8],
705 key: &[u8; 16],
706 iv12: &[u8; 12],
707 seq: &mut u64,
708) -> OrtResult<()> {
709 let total_len = inner.len() + 1 + AEAD_TAG_LEN;
711 let mut plain = Vec::with_capacity(total_len);
712 plain.extend_from_slice(inner);
713 plain.push(outer_type);
714
715 debug_print("write_record_cipher plaintext", &plain);
716
717 let nonce = nonce_xor(iv12, *seq);
718 *seq = seq.wrapping_add(1);
719
720 let mut hdr = [0u8; 5];
721 hdr[0] = REC_TYPE_APPDATA;
722 hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
723 hdr[3..5].copy_from_slice(&(total_len as u16).to_be_bytes());
724
725 let out = aead::aes_128_gcm_encrypt(key, &nonce, &hdr, &plain).unwrap();
726
727 debug_print("write_record_cipher header", &hdr);
728 w.write_all(&hdr)?;
732 w.write_all(&out)?;
733 Ok(())
734}
735
736fn read_record_cipher<R: Read>(
737 r: &mut R,
738 key: &[u8; 16],
739 iv12: &[u8; 12],
740 seq: &mut u64,
741) -> OrtResult<(u8, Vec<u8>, u8)> {
742 let hdr = read_exact_n(r, 5)?;
743 let typ = hdr[0];
744 let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
745 let ciphertext = read_exact_n(r, len)?;
746 if len < AEAD_TAG_LEN {
747 return Err(ort_error(ErrorKind::TlsRecordTooShort, "short record"));
748 }
749 debug_print("read_record_cipher hdr", &hdr);
750 debug_print("read_record_cipher ct", &ciphertext);
751
752 let nonce = nonce_xor(iv12, *seq);
760 *seq = seq.wrapping_add(1);
761
762 let mut out = match aead::aes_128_gcm_decrypt(key, &nonce, &hdr, &ciphertext) {
763 Ok(out) => out,
764 Err(s) => {
765 return Err(ort_error(ErrorKind::TlsAes128GcmDecryptFailed, s));
766 }
767 };
768
769 debug_print("read_record_cipher plaintext hdr", &hdr);
770 debug_print("read_record_cipher plaintext", &out);
771
772 if out.is_empty() {
773 return Ok((typ, ciphertext, 0));
774 }
775 let inner_type = *out.last().unwrap();
777 out.truncate(out.len() - 1);
778 Ok((typ, out, inner_type))
779}
780
781fn read_handshake_message<'a>(rd: &mut &'a [u8]) -> OrtResult<(u8, &'a [u8], &'a [u8])> {
784 if rd.len() < 4 {
785 return Err(ort_error(ErrorKind::TlsHandshakeHeaderTooShort, ""));
786 }
787 let typ = rd[0];
788 let len = ((rd[1] as usize) << 16) | ((rd[2] as usize) << 8) | rd[3] as usize;
789 if rd.len() < 4 + len {
790 return Err(ort_error(ErrorKind::TlsHandshakeBodyTooShort, ""));
791 }
792 let full = &rd[..4 + len];
793 let body = &rd[4..4 + len];
794 *rd = &rd[4 + len..];
795 Ok((typ, body, full))
796}
797
798fn parse_server_hello_for_keys(sh: &[u8]) -> OrtResult<(u16, [u8; 32])> {
799 if sh.len() < 2 + 32 + 1 + 2 + 1 + 2 {
801 return Err(ort_error(ErrorKind::TlsServerHelloTooShort, ""));
802 }
803 let mut p = sh;
804
805 p = &p[2..]; p = &p[32..]; let sid_len = p[0] as usize;
808 p = &p[1..];
809 if p.len() < sid_len + 2 + 1 + 2 {
810 return Err(ort_error(ErrorKind::TlsServerHelloSessionIdInvalid, ""));
811 }
812 p = &p[sid_len..];
813 let cipher = u16::from_be_bytes([p[0], p[1]]);
814 p = &p[2..];
815 let _comp = p[0];
816 p = &p[1..];
817 let ext_len = u16::from_be_bytes([p[0], p[1]]) as usize;
818 p = &p[2..];
819 if p.len() < ext_len {
820 return Err(ort_error(ErrorKind::TlsServerHelloExtTooShort, ""));
821 }
822 let mut ex = &p[..ext_len];
823
824 let mut server_pub = None;
825
826 while !ex.is_empty() {
827 if ex.len() < 4 {
828 return Err(ort_error(ErrorKind::TlsExtensionHeaderTooShort, ""));
829 }
830 let et = u16::from_be_bytes([ex[0], ex[1]]);
831 let el = u16::from_be_bytes([ex[2], ex[3]]) as usize;
832 ex = &ex[4..];
833 if ex.len() < el {
834 return Err(ort_error(ErrorKind::TlsExtensionLengthInvalid, ""));
835 }
836 let ed = &ex[..el];
837 ex = &ex[el..];
838
839 match et {
840 EXT_KEY_SHARE => {
841 if ed.len() < 2 + 2 + 32 {
843 return Err(ort_error(ErrorKind::TlsKeyShareServerHelloInvalid, ""));
844 }
845 let grp = u16::from_be_bytes([ed[0], ed[1]]);
846 if grp != GROUP_X25519 {
847 return Err(ort_error(
848 ErrorKind::TlsServerGroupUnsupported,
849 "server group != x25519",
850 ));
851 }
852 let kx_len = u16::from_be_bytes([ed[2], ed[3]]) as usize;
853 if ed.len() < 4 + kx_len || kx_len != 32 {
854 return Err(ort_error(ErrorKind::TlsKeyShareLengthInvalid, ""));
855 }
856 let mut pk = [0u8; 32];
857 pk.copy_from_slice(&ed[4..4 + 32]);
858 server_pub = Some(pk);
859 }
860 EXT_SUPPORTED_VERSIONS => {
861 if ed.len() != 2 || u16::from_be_bytes([ed[0], ed[1]]) != TLS13 {
862 return Err(ort_error(ErrorKind::TlsServerNotTls13, ""));
863 }
864 }
865 _ => {}
866 }
867 }
868
869 let sp = server_pub.ok_or_else(|| ort_error(ErrorKind::TlsMissingServerKey, ""))?;
870 Ok((cipher, sp))
871}
872
873#[allow(unused)]
874fn debug_print(name: &str, value: &[u8]) {
875 #[cfg(debug_assertions)]
876 {
877 if !DEBUG_LOG {
878 return;
879 }
880 let c_str = CString::new(name).unwrap();
881 if !value.is_empty() {
882 crate::utils::print_hex(c_str.as_c_str(), value);
883 } else {
884 crate::utils::print_string(c_str.as_c_str(), "");
885 }
886 }
887}
888
889#[cfg(test)]
899pub mod tests {
900 extern crate alloc;
901 use alloc::vec::Vec;
902
903 pub fn string_to_bytes(s: &str) -> [u8; 32] {
904 let mut bytes = s.as_bytes();
905 if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'x' || bytes[1] == b'X') {
906 bytes = &bytes[2..];
907 }
908 assert!(
909 bytes.len() == 64,
910 "hex string must be exactly 64 hex chars (32 bytes)"
911 );
912
913 let mut out = [0u8; 32];
914 for i in 0..32 {
915 let hi = hex_val(bytes[2 * i]);
916 let lo = hex_val(bytes[2 * i + 1]);
917 out[i] = (hi << 4) | lo;
918 }
919 out
920 }
921
922 pub fn hex_to_vec(s: &str) -> Vec<u8> {
923 let mut bytes = s.as_bytes();
924 if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'X' || bytes[1] == b'x') {
925 bytes = &bytes[2..];
926 }
927 assert_eq!(bytes.len() % 2, 0, "hex string must have even length");
928 let mut out = Vec::with_capacity(bytes.len() / 2);
929 for chunk in bytes.chunks_exact(2) {
930 let hi = hex_val(chunk[0]);
931 let lo = hex_val(chunk[1]);
932 out.push((hi << 4) | lo);
933 }
934 out
935 }
936
937 fn hex_val(b: u8) -> u8 {
938 match b {
939 b'0'..=b'9' => b - b'0',
940 b'a'..=b'f' => b - b'a' + 10,
941 b'A'..=b'F' => b - b'A' + 10,
942 _ => panic!("invalid hex character"),
943 }
944 }
945}