ort_openrouter_cli/net/
tls.rs

1//! ort: Open Router CLI
2//! https://github.com/grahamking/ort
3//!
4//! MIT License
5//! Copyright (c) 2025 Graham King
6//
7//! ---------------------- Minimal TLS 1.3 client (AES-128-GCM + X25519) -------
8
9use core::ffi::c_void;
10use core::{cmp, ffi::CStr};
11
12extern crate alloc;
13use alloc::ffi::CString;
14use alloc::format;
15use alloc::string::ToString;
16use alloc::vec;
17use alloc::vec::Vec;
18
19use crate::{
20    ErrorKind, OrtResult, Read, Write, common::utils::to_ascii, libc, ort_error, ort_from_err,
21};
22
23mod aead;
24mod ecdh;
25mod hkdf;
26mod hmac;
27mod sha2;
28
29//const DEBUG_LOG: bool = false;
30
31const REC_TYPE_CHANGE_CIPHER_SPEC: u8 = 20; // 0x14
32const REC_TYPE_ALERT: u8 = 21; // 0x15
33const REC_TYPE_HANDSHAKE: u8 = 22; // 0x16
34const REC_TYPE_APPDATA: u8 = 23; // 0x17
35const LEGACY_REC_VER: u16 = 0x0303;
36
37const HS_CLIENT_HELLO: u8 = 1;
38const HS_SERVER_HELLO: u8 = 2;
39//const HS_NEW_SESSION_TICKET: u8 = 4;
40//const HS_ENCRYPTED_EXTENSIONS: u8 = 8;
41//const HS_CERTIFICATE: u8 = 11;
42//const HS_CERT_VERIFY: u8 = 15;
43const HS_FINISHED: u8 = 20; // 0x14
44
45// TLS_AES_128_GCM_SHA256
46const CIPHER_TLS_AES_128_GCM_SHA256: u16 = 0x1301;
47// supported_versions (TLS 1.3)
48const TLS13: u16 = 0x0304;
49// supported group: x25519
50const GROUP_X25519: u16 = 0x001d;
51
52// Extensions
53const EXT_SERVER_NAME: u16 = 0x0000;
54const EXT_SUPPORTED_GROUPS: u16 = 0x000a;
55const EXT_SIGNATURE_ALGS: u16 = 0x000d;
56//const EXT_ALPN: u16 = 0x0010;
57const EXT_SUPPORTED_VERSIONS: u16 = 0x002b;
58//const EXT_PSK_MODES: u16 = 0x002d;
59const EXT_KEY_SHARE: u16 = 0x0033;
60
61// AEAD tag length (GCM)
62const AEAD_TAG_LEN: usize = 16;
63
64// Tiny helper to write BE ints
65fn put_u16(buf: &mut Vec<u8>, v: u16) {
66    buf.extend_from_slice(&v.to_be_bytes());
67}
68fn put_u24(buf: &mut Vec<u8>, v: usize) {
69    let v = v as u32;
70    buf.extend_from_slice(&[(v >> 16) as u8, (v >> 8) as u8, v as u8]);
71}
72
73fn hkdf_expand_label<const N: usize>(prk: &[u8], label: &str, data: &[u8]) -> [u8; N] {
74    let mut info = Vec::with_capacity(2 + 1 + 6 + label.len() + 1 + data.len());
75    put_u16(&mut info, N as u16);
76    let full_label = "tls13 ".to_string() + label;
77    info.push(full_label.len() as u8);
78    info.extend_from_slice(full_label.as_bytes());
79    info.push(data.len() as u8);
80    info.extend_from_slice(data);
81
82    hkdf::hkdf_expand(prk, &info, N).try_into().unwrap()
83}
84
85fn digest_bytes(data: &[u8]) -> [u8; 32] {
86    let d = sha2::sha256(data);
87    let mut out = [0u8; 32];
88    out.copy_from_slice(d.as_ref());
89    out
90}
91
92// AEAD nonce = iv XOR seq (seq in BE on the rightmost 8 bytes)
93fn nonce_xor(iv12: &[u8; 12], seq: u64) -> [u8; 12] {
94    // seq number in big endian on rightmost 8 bytes
95    let mut nonce_bytes = [[0, 0, 0, 0].as_ref(), &u64::to_be_bytes(seq)].concat();
96    // xor them
97    nonce_bytes.iter_mut().zip(iv12.iter()).for_each(|(s, iv)| {
98        *s ^= *iv;
99    });
100    nonce_bytes[..12].try_into().unwrap()
101}
102
103// Very small record writer/reader after handshake
104pub struct TlsStream<T: Read + Write> {
105    io: T,
106    // Application traffic
107    aead_enc: [u8; 16],
108    aead_dec: [u8; 16],
109    iv_enc: [u8; 12],
110    iv_dec: [u8; 12],
111    seq_enc: u64,
112    seq_dec: u64,
113    // read buffer for decrypted application data
114    rbuf: Vec<u8>,
115    rpos: usize,
116}
117
118fn client_hello_body(sni_host: &str, client_pub: &[u8]) -> Vec<u8> {
119    let mut ch_body = Vec::with_capacity(512);
120
121    // X25519
122    let mut random = [0u8; 32];
123    let got_bytes = unsafe { libc::getrandom(random.as_mut_ptr() as *mut c_void, 32, 0) };
124    debug_assert_eq!(got_bytes, 32);
125
126    let mut session_id = [0u8; 32];
127    let got_bytes = unsafe { libc::getrandom(session_id.as_mut_ptr() as *mut c_void, 32, 0) };
128    debug_assert_eq!(got_bytes, 32);
129
130    // legacy_version
131    ch_body.extend_from_slice(&0x0303u16.to_be_bytes());
132    // random
133    ch_body.extend_from_slice(&random);
134    // legacy_session_id
135    ch_body.push(session_id.len() as u8);
136    ch_body.extend_from_slice(&session_id);
137    // cipher_suites: only TLS_AES_128_GCM_SHA256
138    put_u16(&mut ch_body, 2);
139    put_u16(&mut ch_body, CIPHER_TLS_AES_128_GCM_SHA256);
140    // legacy_compression_methods: null
141    ch_body.push(1);
142    ch_body.push(0);
143
144    // --- extensions ---
145    let mut exts = Vec::with_capacity(512);
146
147    // server_name
148    {
149        let host_bytes = sni_host.as_bytes();
150        let mut snl = Vec::with_capacity(3 + host_bytes.len());
151        snl.push(0); // host_name
152        put_u16(&mut snl, host_bytes.len() as u16);
153        snl.extend_from_slice(host_bytes);
154
155        let mut sni = Vec::with_capacity(2 + snl.len());
156        put_u16(&mut sni, snl.len() as u16);
157        sni.extend_from_slice(&snl);
158
159        put_u16(&mut exts, EXT_SERVER_NAME);
160        put_u16(&mut exts, sni.len() as u16);
161        exts.extend_from_slice(&sni);
162    }
163
164    // supported_versions: TLS 1.3
165    {
166        let mut sv = Vec::with_capacity(3);
167        sv.push(2); // length in bytes
168        sv.extend_from_slice(&TLS13.to_be_bytes());
169        put_u16(&mut exts, EXT_SUPPORTED_VERSIONS);
170        put_u16(&mut exts, sv.len() as u16);
171        exts.extend_from_slice(&sv);
172    }
173
174    // supported_groups: x25519
175    {
176        let mut sg = Vec::with_capacity(2 + 2);
177        put_u16(&mut sg, 2);
178        put_u16(&mut sg, GROUP_X25519);
179        put_u16(&mut exts, EXT_SUPPORTED_GROUPS);
180        put_u16(&mut exts, sg.len() as u16);
181        exts.extend_from_slice(&sg);
182    }
183
184    // signature_algorithms: minimal list
185    {
186        const ECDSA_SECP256R1_SHA256: u16 = 0x0403;
187        const RSA_PSS_RSAE_SHA256: u16 = 0x0804;
188        const RSA_PKCS1_SHA256: u16 = 0x0401;
189
190        let mut sa = Vec::with_capacity(2 + 6);
191        put_u16(&mut sa, 6);
192        put_u16(&mut sa, ECDSA_SECP256R1_SHA256);
193        put_u16(&mut sa, RSA_PSS_RSAE_SHA256);
194        put_u16(&mut sa, RSA_PKCS1_SHA256);
195
196        put_u16(&mut exts, EXT_SIGNATURE_ALGS);
197        put_u16(&mut exts, sa.len() as u16);
198        exts.extend_from_slice(&sa);
199    }
200
201    // key_share: x25519
202    {
203        let mut ks = Vec::with_capacity(2 + 2 + 2 + 32);
204        // client_shares vector
205        let mut entry = Vec::with_capacity(2 + 2 + 32);
206        put_u16(&mut entry, GROUP_X25519);
207        put_u16(&mut entry, 32);
208        entry.extend_from_slice(client_pub);
209        put_u16(&mut ks, entry.len() as u16);
210        ks.extend_from_slice(&entry);
211
212        put_u16(&mut exts, EXT_KEY_SHARE);
213        put_u16(&mut exts, ks.len() as u16);
214        exts.extend_from_slice(&ks);
215    }
216
217    // add extensions to CH
218    put_u16(&mut ch_body, exts.len() as u16);
219    ch_body.extend_from_slice(&exts);
220
221    ch_body
222}
223
224/// --- Build ClientHello (single cipher: TLS_AES_128_GCM_SHA256) ---
225fn client_hello_msg(sni_host: &str, client_private_key: &[u8]) -> OrtResult<Vec<u8>> {
226    let client_pub_key = ecdh::x25519_public_key(client_private_key);
227    let client_pub_ref = &client_pub_key;
228    debug_print("Client public key", client_pub_ref);
229
230    let ch_body = client_hello_body(sni_host, client_pub_ref);
231
232    // Handshake framing: ClientHello
233    let mut ch_msg = Vec::with_capacity(4 + ch_body.len());
234    ch_msg.push(HS_CLIENT_HELLO);
235    put_u24(&mut ch_msg, ch_body.len());
236    ch_msg.extend_from_slice(&ch_body);
237
238    Ok(ch_msg)
239}
240
241/// Read ServerHello (plaintext Handshake record)
242fn read_server_hello<R: Read>(io: &mut R) -> OrtResult<(Vec<u8>, Vec<u8>)> {
243    let (typ, payload) = read_record_plain(io)
244        .map_err(|e| ort_from_err(ErrorKind::TlsRecordTooShort, "read_record_plain", e))?;
245    if typ != REC_TYPE_HANDSHAKE {
246        return Err(ort_error(ErrorKind::TlsExpectedHandshakeRecord, ""));
247    }
248    let sh_buf = payload;
249
250    // There can be multiple handshake messages; we need the ServerHello bytes specifically
251    let mut rd = &sh_buf[..];
252    let (sh_typ, sh_body, sh_full) = read_handshake_message(&mut rd).map_err(|e| {
253        ort_from_err(
254            ErrorKind::TlsHandshakeHeaderTooShort,
255            "read_handshake_message",
256            e,
257        )
258    })?;
259    if sh_typ != HS_SERVER_HELLO {
260        return Err(ort_error(ErrorKind::TlsExpectedServerHello, ""));
261    }
262
263    // TODO: later remove the copy. The slices are into sh_buf
264    Ok((sh_body.to_vec(), sh_full.to_vec()))
265}
266
267struct HandshakeState {
268    handshake_secret: [u8; 32],
269    client_hs_ts: [u8; 32],
270    server_hs_ts: [u8; 32],
271    client_handshake_iv: [u8; 12],
272    server_handshake_iv: [u8; 12],
273    aead_enc_hs: [u8; 16],
274    aead_dec_hs: [u8; 16],
275    empty_hash: [u8; 32],
276}
277
278struct ApplicationKeys {
279    aead_app_enc: [u8; 16],
280    aead_app_dec: [u8; 16],
281    iv_enc: [u8; 12],
282    iv_dec: [u8; 12],
283}
284
285impl<T: Read + Write> TlsStream<T> {
286    pub fn connect(mut io: T, sni_host: &str) -> OrtResult<Self> {
287        // transcript = full Handshake message encodings (headers + bodies)
288        let mut transcript = Vec::with_capacity(1024);
289
290        // A private key is simply random bytes. /dev/urandom is cryptographically secure.
291        let mut client_private_key = [0u8; 32];
292        let _ = unsafe { libc::getrandom(client_private_key.as_mut_ptr() as *mut c_void, 32, 0) };
293        debug_print("Client private key", &client_private_key);
294
295        Self::send_client_hello(&mut io, sni_host, &mut transcript, &client_private_key)?;
296
297        let sh_body = Self::receive_server_hello(&mut io, &mut transcript)?;
298
299        let handshake = Self::derive_handshake_keys(&client_private_key, &sh_body, &transcript)?;
300
301        Self::receive_dummy_change_cipher_spec(&mut io)?;
302
303        let mut seq_dec_hs = 0u64;
304        let mut seq_enc_hs = 0u64;
305
306        Self::receive_server_encrypted_flight(
307            &mut io,
308            &mut seq_dec_hs,
309            &handshake,
310            &mut transcript,
311        )?;
312
313        let ApplicationKeys {
314            aead_app_enc,
315            aead_app_dec,
316            iv_enc: caiv,
317            iv_dec: saiv,
318        } = Self::derive_application_keys(
319            &handshake.handshake_secret,
320            &handshake.empty_hash,
321            &transcript,
322        );
323
324        let seq_app_enc = 0u64;
325        let seq_app_dec = 0u64;
326
327        // Client Change Cipher Spec
328        // This is optional, to "confuse middleboxes" which expect TLS 1.2. Works without.
329        //write_record_plain(&mut io, REC_TYPE_CHANGE_CIPHER_SPEC, &[0x01])?;
330
331        Self::send_client_finished(&mut io, &handshake, &mut transcript, &mut seq_enc_hs)?;
332
333        Ok(TlsStream {
334            io,
335            aead_enc: aead_app_enc,
336            aead_dec: aead_app_dec,
337            iv_enc: caiv,
338            iv_dec: saiv,
339            seq_enc: seq_app_enc,
340            seq_dec: seq_app_dec,
341            rbuf: Vec::with_capacity(16 * 1024),
342            rpos: 0,
343        })
344    }
345
346    fn send_client_hello<W: Write>(
347        io: &mut W,
348        sni_host: &str,
349        transcript: &mut Vec<u8>,
350        client_private_key: &[u8; 32],
351    ) -> OrtResult<()> {
352        let ch_msg = client_hello_msg(sni_host, client_private_key)?;
353        write_record_plain(io, REC_TYPE_HANDSHAKE, &ch_msg)
354            .map_err(|e| ort_from_err(ErrorKind::SocketWriteFailed, "write ClientHello", e))?;
355        transcript.extend_from_slice(&ch_msg);
356        Ok(())
357    }
358
359    fn receive_server_hello<R: Read>(io: &mut R, transcript: &mut Vec<u8>) -> OrtResult<Vec<u8>> {
360        let (sh_body, sh_full) = read_server_hello(io)?;
361        transcript.extend_from_slice(&sh_full);
362        Ok(sh_body)
363    }
364
365    fn receive_dummy_change_cipher_spec<R: Read>(io: &mut R) -> OrtResult<()> {
366        // Some servers send TLS 1.2-style ChangeCipherSpec for middlebox compatibility.
367        let (typ, _) = read_record_plain(io)
368            .map_err(|e| ort_from_err(ErrorKind::TlsRecordTooShort, "read_record_plain", e))?;
369        if typ != REC_TYPE_CHANGE_CIPHER_SPEC {
370            return Err(ort_error(ErrorKind::TlsExpectedChangeCipherSpec, ""));
371        }
372        Ok(())
373    }
374
375    fn receive_server_encrypted_flight<R: Read>(
376        io: &mut R,
377        seq_dec_hs: &mut u64,
378        handshake: &HandshakeState,
379        transcript: &mut Vec<u8>,
380    ) -> OrtResult<()> {
381        let (typ, ct, _inner_type) = read_record_cipher(
382            io,
383            &handshake.aead_dec_hs,
384            &handshake.server_handshake_iv,
385            seq_dec_hs,
386        )
387        .map_err(|e| ort_from_err(ErrorKind::TlsRecordTooShort, "read_record_cipher", e))?;
388        if typ != REC_TYPE_APPDATA {
389            return Err(ort_error(ErrorKind::TlsExpectedEncryptedRecords, ""));
390        }
391
392        // Decrypted TLSInnerPlaintext: ... | content_type
393        // May contain multiple handshake messages; parse & append to transcript.
394        let mut p = &ct[..];
395        while !p.is_empty() {
396            // On TLS 1.3: content_type is last byte; but ring decrypt gives only plaintext,
397            // here ct already stripped of content-type 0x16 by read_record_cipher().
398            let (mtyp, body, full) = match read_handshake_message(&mut p) {
399                Ok(x) => x,
400                Err(_) => {
401                    return Err(ort_error(ErrorKind::TlsBadHandshakeFragment, ""));
402                }
403            };
404            transcript.extend_from_slice(full);
405
406            if mtyp == HS_FINISHED {
407                // verify server Finished
408                let s_finished_key =
409                    hkdf_expand_label::<32>(&handshake.server_hs_ts, "finished", &[]);
410
411                let thash = digest_bytes(&transcript[..transcript.len() - full.len()]);
412                let expected = hmac::sign(&s_finished_key, &thash);
413                if expected.as_slice() != body {
414                    return Err(ort_error(ErrorKind::TlsFinishedVerifyFailed, ""));
415                }
416                // Done collecting server handshake.
417                break;
418            }
419            // Ignore other handshake types’ contents (no cert validation).
420        }
421        Ok(())
422    }
423
424    fn derive_handshake_keys(
425        client_private_key: &[u8; 32],
426        sh_body: &[u8],
427        transcript: &[u8],
428    ) -> OrtResult<HandshakeState> {
429        // Parse minimal ServerHello to get cipher & key_share
430        let (cipher, server_public_key_bytes) =
431            parse_server_hello_for_keys(sh_body).map_err(|e| {
432                ort_from_err(
433                    ErrorKind::TlsServerHelloTooShort,
434                    "parse_server_hello_for_keys",
435                    e,
436                )
437            })?;
438        debug_print("Server public key", &server_public_key_bytes);
439        if cipher != CIPHER_TLS_AES_128_GCM_SHA256 {
440            return Err(ort_error(
441                ErrorKind::TlsUnsupportedCipher,
442                "server picked unsupported cipher",
443            ));
444        }
445
446        // ECDH(X25519) shared secret
447        let hs_shared_secret = ecdh::x25519_agreement(client_private_key, &server_public_key_bytes);
448        debug_print("hs shared secret", &hs_shared_secret);
449
450        // Same as: `echo -n "" | openssl sha256`
451        let empty_hash = digest_bytes(&[]);
452        debug_print("empty_hash", &empty_hash);
453
454        let zero: [u8; 32] = [0u8; 32];
455        let early_secret = hkdf::hkdf_extract(&zero, &zero);
456
457        let derived_secret_bytes = hkdf_expand_label::<32>(&early_secret, "derived", &empty_hash);
458        debug_print("derived", &derived_secret_bytes);
459
460        let handshake_secret = hkdf::hkdf_extract(&derived_secret_bytes, &hs_shared_secret);
461        debug_print("handshake_secret", &handshake_secret);
462
463        let ch_sh_hash = digest_bytes(transcript);
464        debug_print("digest bytes", &ch_sh_hash);
465
466        let c_hs_ts = hkdf_expand_label(&handshake_secret, "c hs traffic", &ch_sh_hash);
467        let s_hs_ts = hkdf_expand_label(&handshake_secret, "s hs traffic", &ch_sh_hash);
468
469        debug_print("c hs traffic", &c_hs_ts);
470        debug_print("s hs traffic", &s_hs_ts);
471
472        // handshake AEAD keys/IVs
473        let client_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&c_hs_ts, "key", &[])
474            .as_slice()[..16]
475            .try_into()
476            .unwrap();
477        debug_print("client_handshake_key", &client_handshake_key);
478        let client_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&c_hs_ts, "iv", &[]).as_slice()
479            [..12]
480            .try_into()
481            .unwrap();
482        debug_print("client_handshake_iv", &client_handshake_iv);
483
484        let server_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&s_hs_ts, "key", &[])
485            .as_slice()[..16]
486            .try_into()
487            .unwrap();
488        debug_print("server_handshake_key", &server_handshake_key);
489        let server_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&s_hs_ts, "iv", &[]).as_slice()
490            [..12]
491            .try_into()
492            .unwrap();
493        debug_print("server_handshake_iv", &server_handshake_iv);
494
495        Ok(HandshakeState {
496            handshake_secret,
497            client_hs_ts: c_hs_ts,
498            server_hs_ts: s_hs_ts,
499            client_handshake_iv,
500            server_handshake_iv,
501            aead_enc_hs: client_handshake_key,
502            aead_dec_hs: server_handshake_key,
503            empty_hash,
504        })
505    }
506
507    fn derive_application_keys(
508        handshake_secret: &[u8; 32],
509        empty_hash: &[u8; 32],
510        transcript: &[u8],
511    ) -> ApplicationKeys {
512        let derived2_bytes = hkdf_expand_label::<32>(handshake_secret, "derived", empty_hash);
513        debug_print("derived2_bytes", &derived2_bytes);
514
515        let zero: [u8; 32] = [0u8; 32];
516        let master_secret = hkdf::hkdf_extract(&derived2_bytes, &zero);
517        let thash_srv_fin = digest_bytes(transcript);
518
519        let c_ap_ts = hkdf_expand_label::<32>(&master_secret, "c ap traffic", &thash_srv_fin);
520        let s_ap_ts = hkdf_expand_label::<32>(&master_secret, "s ap traffic", &thash_srv_fin);
521        debug_print("c_ap_ts", &c_ap_ts);
522        debug_print("s_ap_ts", &s_ap_ts);
523
524        let cak: [u8; 16] = hkdf_expand_label::<16>(&c_ap_ts, "key", &[]).as_slice()[..16]
525            .try_into()
526            .unwrap();
527        let caiv: [u8; 12] = hkdf_expand_label::<12>(&c_ap_ts, "iv", &[]).as_slice()[..12]
528            .try_into()
529            .unwrap();
530        debug_print("cak", &cak);
531        debug_print("caiv", &caiv);
532
533        let sak: [u8; 16] = hkdf_expand_label::<16>(&s_ap_ts, "key", &[]).as_slice()[..16]
534            .try_into()
535            .unwrap();
536        let saiv: [u8; 12] = hkdf_expand_label::<12>(&s_ap_ts, "iv", &[]).as_slice()[..12]
537            .try_into()
538            .unwrap();
539        debug_print("sak", &sak);
540        debug_print("saiv", &saiv);
541
542        ApplicationKeys {
543            aead_app_enc: cak,
544            aead_app_dec: sak,
545            iv_enc: caiv,
546            iv_dec: saiv,
547        }
548    }
549
550    fn send_client_finished<W: Write>(
551        io: &mut W,
552        handshake: &HandshakeState,
553        transcript: &mut Vec<u8>,
554        seq_enc_hs: &mut u64,
555    ) -> OrtResult<()> {
556        let c_finished_key = hkdf_expand_label::<32>(&handshake.client_hs_ts, "finished", &[]);
557        debug_print("c_finished", &c_finished_key);
558
559        let thash_client_fin = digest_bytes(transcript.as_slice());
560        let verify_data = hmac::sign(&c_finished_key, &thash_client_fin);
561        debug_print("verify_data", &verify_data);
562
563        let mut fin = Vec::with_capacity(4 + verify_data.as_ref().len());
564        fin.push(HS_FINISHED);
565        put_u24(&mut fin, verify_data.as_ref().len());
566        fin.extend_from_slice(verify_data.as_ref());
567
568        // append to transcript before switching keys
569        transcript.extend_from_slice(&fin);
570
571        write_record_cipher(
572            io,
573            REC_TYPE_HANDSHAKE,
574            &fin,
575            &handshake.aead_enc_hs,
576            &handshake.client_handshake_iv,
577            seq_enc_hs,
578        )
579        .map_err(|e| ort_from_err(ErrorKind::TlsRecordTooShort, "read_record_cipher", e))?;
580
581        Ok(())
582    }
583}
584
585impl<T: Read + Write> Write for TlsStream<T> {
586    fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
587        write_record_cipher(
588            &mut self.io,
589            REC_TYPE_APPDATA,
590            buf,
591            &self.aead_enc,
592            &self.iv_enc,
593            &mut self.seq_enc,
594        )
595        .map(|_| buf.len())
596    }
597    fn flush(&mut self) -> OrtResult<()> {
598        self.io.flush()
599    }
600}
601
602impl<T: Read + Write> Read for TlsStream<T> {
603    fn read(&mut self, out: &mut [u8]) -> OrtResult<usize> {
604        if self.rpos < self.rbuf.len() {
605            let n = cmp::min(out.len(), self.rbuf.len() - self.rpos);
606            out[..n].copy_from_slice(&self.rbuf[self.rpos..self.rpos + n]);
607            self.rpos += n;
608            if self.rpos == self.rbuf.len() {
609                self.rbuf.clear();
610                self.rpos = 0;
611            }
612            return Ok(n);
613        }
614        loop {
615            let (typ, plaintext, inner_type) = read_record_cipher(
616                &mut self.io,
617                &self.aead_dec,
618                &self.iv_dec,
619                &mut self.seq_dec,
620            )?;
621            if typ != REC_TYPE_APPDATA {
622                // Ignore unexpected (e.g., post-handshake Handshake like NewSessionTicket)
623                continue;
624            }
625            // plaintext ends with inner content type byte; for app data it is 0x17.
626            if plaintext.is_empty() {
627                continue;
628            }
629            if inner_type == REC_TYPE_HANDSHAKE {
630                // Drop post-handshake messages (tickets, etc.)
631                continue;
632            }
633            if inner_type == REC_TYPE_ALERT {
634                let level = match plaintext[0] {
635                    1 => "warning",
636                    2 => "fatal",
637                    _ => "unknown",
638                };
639                let err_level = CString::new(level.to_string() + " alert: ").unwrap();
640
641                // See https://www.rfc-editor.org/rfc/rfc8446#appendix-B search for
642                // "unexpected_message" for all types
643                let mut err_code_buf: [u8; 5] = [0u8; 5];
644                let len = to_ascii(plaintext[1] as usize, &mut err_code_buf);
645                let err_code = unsafe { CStr::from_bytes_with_nul_unchecked(&err_code_buf[..len]) };
646                unsafe {
647                    libc::write(2, err_level.as_ptr().cast(), err_level.count_bytes());
648                    libc::write(2, err_code.as_ptr().cast(), err_code.count_bytes());
649                }
650
651                return Err(ort_error(ErrorKind::TlsAlertReceived, ""));
652            }
653            if inner_type != REC_TYPE_APPDATA {
654                // Some servers pad with 0x00.. then type; we already consumed type.
655                // If not 0x17, treat preceding bytes (if any) as app anyway.
656            }
657            if plaintext.is_empty() {
658                continue;
659            }
660
661            self.rbuf.extend_from_slice(&plaintext);
662            self.rpos = 0;
663            // Now serve from buffer
664            let n = cmp::min(out.len(), self.rbuf.len());
665            out[..n].copy_from_slice(&self.rbuf[..n]);
666            self.rpos = n;
667            if n == self.rbuf.len() {
668                self.rbuf.clear();
669                self.rpos = 0;
670            }
671            return Ok(n);
672        }
673    }
674}
675
676// ---------------------- Record I/O helpers ----------------------------------
677
678fn write_record_plain<W: Write>(w: &mut W, typ: u8, body: &[u8]) -> OrtResult<()> {
679    let mut hdr = [0u8; 5];
680    hdr[0] = typ;
681    hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
682    hdr[3..5].copy_from_slice(&(body.len() as u16).to_be_bytes());
683    w.write_all(&hdr)?;
684    w.write_all(body)?;
685    Ok(())
686}
687
688fn read_exact_n<R: Read>(r: &mut R, n: usize) -> OrtResult<Vec<u8>> {
689    let mut buf = vec![0u8; n];
690    r.read_exact(&mut buf)?;
691    Ok(buf)
692}
693
694fn read_record_plain<R: Read>(r: &mut R) -> OrtResult<(u8, Vec<u8>)> {
695    let hdr = read_exact_n(r, 5)?; // Record Header, e.g. 16 03 03 len
696    let typ = hdr[0];
697    let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
698    let body = read_exact_n(r, len)?;
699    //let _ = write_bytes_to_file(&[&hdr[..], &body].concat(), debug_filename);
700    Ok((typ, body))
701}
702
703fn write_record_cipher<W: Write>(
704    w: &mut W,
705    outer_type: u8,
706    inner: &[u8],
707    key: &[u8; 16],
708    iv12: &[u8; 12],
709    seq: &mut u64,
710) -> OrtResult<()> {
711    // AES / GCM plaintext and ciphertext have the same length
712    let total_len = inner.len() + 1 + AEAD_TAG_LEN;
713    let mut plain = Vec::with_capacity(total_len);
714    plain.extend_from_slice(inner);
715    plain.push(outer_type);
716
717    debug_print("write_record_cipher plaintext", &plain);
718
719    let nonce = nonce_xor(iv12, *seq);
720    *seq = seq.wrapping_add(1);
721
722    let mut hdr = [0u8; 5];
723    hdr[0] = REC_TYPE_APPDATA;
724    hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
725    hdr[3..5].copy_from_slice(&(total_len as u16).to_be_bytes());
726
727    let out = aead::aes_128_gcm_encrypt(key, &nonce, &hdr, &plain).unwrap();
728
729    debug_print("write_record_cipher header", &hdr);
730    let final_label = format!("write_record_cipher final {total_len}");
731    debug_print(final_label.as_str(), &out);
732
733    w.write_all(&hdr)?;
734    w.write_all(&out)?;
735    Ok(())
736}
737
738fn read_record_cipher<R: Read>(
739    r: &mut R,
740    key: &[u8; 16],
741    iv12: &[u8; 12],
742    seq: &mut u64,
743) -> OrtResult<(u8, Vec<u8>, u8)> {
744    let hdr = read_exact_n(r, 5)?;
745    let typ = hdr[0];
746    let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
747    let ciphertext = read_exact_n(r, len)?;
748    if len < AEAD_TAG_LEN {
749        return Err(ort_error(ErrorKind::TlsRecordTooShort, "short record"));
750    }
751    debug_print("read_record_cipher hdr", &hdr);
752    debug_print("read_record_cipher ct", &ciphertext);
753
754    // Decrypt ciphertext
755
756    let nonce = nonce_xor(iv12, *seq);
757    *seq = seq.wrapping_add(1);
758
759    let mut out = aead::aes_128_gcm_decrypt(key, &nonce, &hdr, &ciphertext).unwrap();
760
761    debug_print("read_record_cipher plaintext hdr", &hdr);
762    debug_print("read_record_cipher plaintext", &out);
763
764    if out.is_empty() {
765        return Ok((typ, ciphertext, 0));
766    }
767    // Strip inner content-type byte
768    let inner_type = *out.last().unwrap();
769    out.truncate(out.len() - 1);
770    Ok((typ, out, inner_type))
771}
772
773// ---------------------- Handshake parsing helpers ---------------------------
774
775fn read_handshake_message<'a>(rd: &mut &'a [u8]) -> OrtResult<(u8, &'a [u8], &'a [u8])> {
776    if rd.len() < 4 {
777        return Err(ort_error(ErrorKind::TlsHandshakeHeaderTooShort, ""));
778    }
779    let typ = rd[0];
780    let len = ((rd[1] as usize) << 16) | ((rd[2] as usize) << 8) | rd[3] as usize;
781    if rd.len() < 4 + len {
782        return Err(ort_error(ErrorKind::TlsHandshakeBodyTooShort, ""));
783    }
784    let full = &rd[..4 + len];
785    let body = &rd[4..4 + len];
786    *rd = &rd[4 + len..];
787    Ok((typ, body, full))
788}
789
790fn parse_server_hello_for_keys(sh: &[u8]) -> OrtResult<(u16, [u8; 32])> {
791    // minimal parse: skip legacy_version(2), random(32), sid, cipher(2), comp(1), exts
792    if sh.len() < 2 + 32 + 1 + 2 + 1 + 2 {
793        return Err(ort_error(ErrorKind::TlsServerHelloTooShort, ""));
794    }
795    let mut p = sh;
796
797    p = &p[2..]; // legacy_version
798    p = &p[32..]; // random
799    let sid_len = p[0] as usize;
800    p = &p[1..];
801    if p.len() < sid_len + 2 + 1 + 2 {
802        return Err(ort_error(ErrorKind::TlsServerHelloSessionIdInvalid, ""));
803    }
804    p = &p[sid_len..];
805    let cipher = u16::from_be_bytes([p[0], p[1]]);
806    p = &p[2..];
807    let _comp = p[0];
808    p = &p[1..];
809    let ext_len = u16::from_be_bytes([p[0], p[1]]) as usize;
810    p = &p[2..];
811    if p.len() < ext_len {
812        return Err(ort_error(ErrorKind::TlsServerHelloExtTooShort, ""));
813    }
814    let mut ex = &p[..ext_len];
815
816    let mut server_pub = None;
817
818    while !ex.is_empty() {
819        if ex.len() < 4 {
820            return Err(ort_error(ErrorKind::TlsExtensionHeaderTooShort, ""));
821        }
822        let et = u16::from_be_bytes([ex[0], ex[1]]);
823        let el = u16::from_be_bytes([ex[2], ex[3]]) as usize;
824        ex = &ex[4..];
825        if ex.len() < el {
826            return Err(ort_error(ErrorKind::TlsExtensionLengthInvalid, ""));
827        }
828        let ed = &ex[..el];
829        ex = &ex[el..];
830
831        match et {
832            EXT_KEY_SHARE => {
833                // KeyShareServerHello: group(2) kx_len(2) kx
834                if ed.len() < 2 + 2 + 32 {
835                    return Err(ort_error(ErrorKind::TlsKeyShareServerHelloInvalid, ""));
836                }
837                let grp = u16::from_be_bytes([ed[0], ed[1]]);
838                if grp != GROUP_X25519 {
839                    return Err(ort_error(
840                        ErrorKind::TlsServerGroupUnsupported,
841                        "server group != x25519",
842                    ));
843                }
844                let kx_len = u16::from_be_bytes([ed[2], ed[3]]) as usize;
845                if ed.len() < 4 + kx_len || kx_len != 32 {
846                    return Err(ort_error(ErrorKind::TlsKeyShareLengthInvalid, ""));
847                }
848                let mut pk = [0u8; 32];
849                pk.copy_from_slice(&ed[4..4 + 32]);
850                server_pub = Some(pk);
851            }
852            EXT_SUPPORTED_VERSIONS => {
853                if ed.len() != 2 || u16::from_be_bytes([ed[0], ed[1]]) != TLS13 {
854                    return Err(ort_error(ErrorKind::TlsServerNotTls13, ""));
855                }
856            }
857            _ => {}
858        }
859    }
860
861    let sp = server_pub.ok_or_else(|| ort_error(ErrorKind::TlsMissingServerKey, ""))?;
862    Ok((cipher, sp))
863}
864
865fn debug_print(_name: &str, _value: &[u8]) {}
866
867/* Debug features, need STDLIB
868
869fn debug_print(name: &str, value: &[u8]) {
870    if !DEBUG_LOG {
871        return;
872    }
873    eprintln!("\n{name} {}:", value.len());
874    print_hex(value);
875}
876
877fn print_hex(v: &[u8]) {
878    let hex: String = v
879        .iter()
880        .map(|b| format!("{:02x}", b))
881        .collect::<Vec<_>>()
882        .join("");
883    eprintln!("{hex}");
884}
885*/
886
887/*
888#[allow(dead_code)]
889fn write_bytes_to_file(bytes: &[u8], file_path: &str) -> std::io::Result<()> {
890    let mut file = File::create(file_path)?;
891    file.write_all(bytes)?;
892    Ok(())
893}
894*/
895
896#[cfg(test)]
897pub mod tests {
898    extern crate alloc;
899    use alloc::vec::Vec;
900
901    pub fn string_to_bytes(s: &str) -> [u8; 32] {
902        let mut bytes = s.as_bytes();
903        if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'x' || bytes[1] == b'X') {
904            bytes = &bytes[2..];
905        }
906        assert!(
907            bytes.len() == 64,
908            "hex string must be exactly 64 hex chars (32 bytes)"
909        );
910
911        let mut out = [0u8; 32];
912        for i in 0..32 {
913            let hi = hex_val(bytes[2 * i]);
914            let lo = hex_val(bytes[2 * i + 1]);
915            out[i] = (hi << 4) | lo;
916        }
917        out
918    }
919
920    pub fn hex_to_vec(s: &str) -> Vec<u8> {
921        let mut bytes = s.as_bytes();
922        if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'X' || bytes[1] == b'x') {
923            bytes = &bytes[2..];
924        }
925        assert_eq!(bytes.len() % 2, 0, "hex string must have even length");
926        let mut out = Vec::with_capacity(bytes.len() / 2);
927        for chunk in bytes.chunks_exact(2) {
928            let hi = hex_val(chunk[0]);
929            let lo = hex_val(chunk[1]);
930            out.push((hi << 4) | lo);
931        }
932        out
933    }
934
935    fn hex_val(b: u8) -> u8 {
936        match b {
937            b'0'..=b'9' => b - b'0',
938            b'a'..=b'f' => b - b'a' + 10,
939            b'A'..=b'F' => b - b'A' + 10,
940            _ => panic!("invalid hex character"),
941        }
942    }
943}