Skip to main content

ort_openrouter_cli/net/
tls.rs

1//! ort: Open Router CLI
2//! https://github.com/grahamking/ort
3//!
4//! MIT License
5//! Copyright (c) 2025 Graham King
6//
7//! ---------------------- Minimal TLS 1.3 client (AES-128-GCM + X25519) -------
8
9use core::{cmp, ffi::CStr};
10
11extern crate alloc;
12use alloc::ffi::CString;
13use alloc::string::ToString;
14use alloc::vec;
15use alloc::vec::Vec;
16
17use crate::{
18    Context, ErrorKind, OrtResult, Read, Write, common::utils::to_ascii, net::AsFd, ort_error,
19    syscall,
20};
21
22mod aead;
23mod ecdh;
24mod hkdf;
25mod hmac;
26mod sha2;
27
28#[allow(unused)]
29const DEBUG_LOG: bool = false;
30
31/// RFC 8445 5.1, "carrying data in chunks of 2^14 bytes or less"
32const MAX_PLAINTEXT_SIZE: usize = 16 * 1024;
33
34const REC_TYPE_CHANGE_CIPHER_SPEC: u8 = 20; // 0x14
35const REC_TYPE_ALERT: u8 = 21; // 0x15
36const REC_TYPE_HANDSHAKE: u8 = 22; // 0x16
37const REC_TYPE_APPDATA: u8 = 23; // 0x17
38const LEGACY_REC_VER: u16 = 0x0303;
39
40const HS_CLIENT_HELLO: u8 = 1;
41const HS_SERVER_HELLO: u8 = 2;
42//const HS_NEW_SESSION_TICKET: u8 = 4;
43//const HS_ENCRYPTED_EXTENSIONS: u8 = 8;
44//const HS_CERTIFICATE: u8 = 11;
45//const HS_CERT_VERIFY: u8 = 15;
46const HS_FINISHED: u8 = 20; // 0x14
47
48// TLS_AES_128_GCM_SHA256
49const CIPHER_TLS_AES_128_GCM_SHA256: u16 = 0x1301;
50// supported_versions (TLS 1.3)
51const TLS13: u16 = 0x0304;
52// supported group: x25519
53const GROUP_X25519: u16 = 0x001d;
54
55// Extensions
56const EXT_SERVER_NAME: u16 = 0x0000;
57const EXT_SUPPORTED_GROUPS: u16 = 0x000a;
58const EXT_SIGNATURE_ALGS: u16 = 0x000d;
59//const EXT_ALPN: u16 = 0x0010;
60const EXT_SUPPORTED_VERSIONS: u16 = 0x002b;
61//const EXT_PSK_MODES: u16 = 0x002d;
62const EXT_KEY_SHARE: u16 = 0x0033;
63
64// AEAD tag length (GCM)
65const AEAD_TAG_LEN: usize = 16;
66
67// Tiny helper to write BE ints
68fn put_u16(buf: &mut Vec<u8>, v: u16) {
69    buf.extend_from_slice(&v.to_be_bytes());
70}
71fn put_u24(buf: &mut Vec<u8>, v: usize) {
72    let v = v as u32;
73    buf.extend_from_slice(&[(v >> 16) as u8, (v >> 8) as u8, v as u8]);
74}
75
76fn hkdf_expand_label<const N: usize>(prk: &[u8], label: &str, data: &[u8]) -> [u8; N] {
77    let mut info = Vec::with_capacity(2 + 1 + 6 + label.len() + 1 + data.len());
78    put_u16(&mut info, N as u16);
79    info.push(("tls13 ".len() + label.len()) as u8);
80    info.extend_from_slice("tls13 ".as_bytes());
81    info.extend_from_slice(label.as_bytes());
82    info.push(data.len() as u8);
83    info.extend_from_slice(data);
84
85    hkdf::hkdf_expand(prk, &info, N).try_into().unwrap()
86}
87
88fn digest_bytes(data: &[u8]) -> [u8; 32] {
89    let d = sha2::sha256(data);
90    let mut out = [0u8; 32];
91    out.copy_from_slice(d.as_ref());
92    out
93}
94
95// AEAD nonce = iv XOR seq (seq in BE on the rightmost 8 bytes)
96fn nonce_xor(iv12: &[u8; 12], seq: u64) -> [u8; 12] {
97    // seq number in big endian on rightmost 8 bytes
98    let mut nonce_bytes = [[0, 0, 0, 0].as_ref(), &u64::to_be_bytes(seq)].concat();
99    // xor them
100    nonce_bytes.iter_mut().zip(iv12.iter()).for_each(|(s, iv)| {
101        *s ^= *iv;
102    });
103    nonce_bytes[..12].try_into().unwrap()
104}
105
106// Very small record writer/reader after handshake
107pub struct TlsStream<T: Read + Write> {
108    io: T,
109    // Application traffic
110    aead_enc: [u8; 16],
111    aead_dec: [u8; 16],
112    iv_enc: [u8; 12],
113    iv_dec: [u8; 12],
114    seq_enc: u64,
115    seq_dec: u64,
116    // read buffer for decrypted application data
117    rbuf: Vec<u8>,
118    rpos: usize,
119}
120
121fn client_hello_body(sni_host: &str, client_pub: &[u8]) -> Vec<u8> {
122    let mut ch_body = Vec::with_capacity(512);
123
124    // X25519
125    let mut random = [0u8; 32];
126    syscall::getrandom(&mut random);
127
128    let mut session_id = [0u8; 32];
129    syscall::getrandom(&mut session_id);
130
131    // legacy_version
132    ch_body.extend_from_slice(&0x0303u16.to_be_bytes());
133    // random
134    ch_body.extend_from_slice(&random);
135    // legacy_session_id
136    ch_body.push(session_id.len() as u8);
137    ch_body.extend_from_slice(&session_id);
138    // cipher_suites: only TLS_AES_128_GCM_SHA256
139    put_u16(&mut ch_body, 2);
140    put_u16(&mut ch_body, CIPHER_TLS_AES_128_GCM_SHA256);
141    // legacy_compression_methods: null
142    ch_body.push(1);
143    ch_body.push(0);
144
145    // --- extensions ---
146    let mut exts = Vec::with_capacity(512);
147
148    // server_name
149    {
150        let host_bytes = sni_host.as_bytes();
151        let mut snl = Vec::with_capacity(3 + host_bytes.len());
152        snl.push(0); // host_name
153        put_u16(&mut snl, host_bytes.len() as u16);
154        snl.extend_from_slice(host_bytes);
155
156        let mut sni = Vec::with_capacity(2 + snl.len());
157        put_u16(&mut sni, snl.len() as u16);
158        sni.extend_from_slice(&snl);
159
160        put_u16(&mut exts, EXT_SERVER_NAME);
161        put_u16(&mut exts, sni.len() as u16);
162        exts.extend_from_slice(&sni);
163    }
164
165    // supported_versions: TLS 1.3
166    {
167        let mut sv = Vec::with_capacity(3);
168        sv.push(2); // length in bytes
169        sv.extend_from_slice(&TLS13.to_be_bytes());
170        put_u16(&mut exts, EXT_SUPPORTED_VERSIONS);
171        put_u16(&mut exts, sv.len() as u16);
172        exts.extend_from_slice(&sv);
173    }
174
175    // supported_groups: x25519
176    {
177        let mut sg = Vec::with_capacity(2 + 2);
178        put_u16(&mut sg, 2);
179        put_u16(&mut sg, GROUP_X25519);
180        put_u16(&mut exts, EXT_SUPPORTED_GROUPS);
181        put_u16(&mut exts, sg.len() as u16);
182        exts.extend_from_slice(&sg);
183    }
184
185    // signature_algorithms: minimal list
186    {
187        const ECDSA_SECP256R1_SHA256: u16 = 0x0403;
188        const RSA_PSS_RSAE_SHA256: u16 = 0x0804;
189        const RSA_PKCS1_SHA256: u16 = 0x0401;
190
191        let mut sa = Vec::with_capacity(2 + 6);
192        put_u16(&mut sa, 6);
193        put_u16(&mut sa, ECDSA_SECP256R1_SHA256);
194        put_u16(&mut sa, RSA_PSS_RSAE_SHA256);
195        put_u16(&mut sa, RSA_PKCS1_SHA256);
196
197        put_u16(&mut exts, EXT_SIGNATURE_ALGS);
198        put_u16(&mut exts, sa.len() as u16);
199        exts.extend_from_slice(&sa);
200    }
201
202    // key_share: x25519
203    {
204        let mut ks = Vec::with_capacity(2 + 2 + 2 + 32);
205        // client_shares vector
206        let mut entry = Vec::with_capacity(2 + 2 + 32);
207        put_u16(&mut entry, GROUP_X25519);
208        put_u16(&mut entry, 32);
209        entry.extend_from_slice(client_pub);
210        put_u16(&mut ks, entry.len() as u16);
211        ks.extend_from_slice(&entry);
212
213        put_u16(&mut exts, EXT_KEY_SHARE);
214        put_u16(&mut exts, ks.len() as u16);
215        exts.extend_from_slice(&ks);
216    }
217
218    // add extensions to CH
219    put_u16(&mut ch_body, exts.len() as u16);
220    ch_body.extend_from_slice(&exts);
221
222    ch_body
223}
224
225/// --- Build ClientHello (single cipher: TLS_AES_128_GCM_SHA256) ---
226fn client_hello_msg(sni_host: &str, client_private_key: &[u8]) -> OrtResult<Vec<u8>> {
227    let client_pub_key = ecdh::x25519_public_key(client_private_key);
228    let client_pub_ref = &client_pub_key;
229    debug_print("Client public key", client_pub_ref);
230
231    let ch_body = client_hello_body(sni_host, client_pub_ref);
232
233    // Handshake framing: ClientHello
234    let mut ch_msg = Vec::with_capacity(4 + ch_body.len());
235    ch_msg.push(HS_CLIENT_HELLO);
236    put_u24(&mut ch_msg, ch_body.len());
237    ch_msg.extend_from_slice(&ch_body);
238
239    Ok(ch_msg)
240}
241
242/// Read ServerHello (plaintext Handshake record)
243fn read_server_hello<R: Read>(io: &mut R) -> OrtResult<(Vec<u8>, Vec<u8>)> {
244    let (typ, payload) = read_record_plain(io).context("read_record_plain in read_server_hello")?;
245    if typ != REC_TYPE_HANDSHAKE {
246        return Err(ort_error(ErrorKind::TlsExpectedHandshakeRecord, ""));
247    }
248    let sh_buf = payload;
249
250    // There can be multiple handshake messages; we need the ServerHello bytes specifically
251    let mut rd = &sh_buf[..];
252    let (sh_typ, sh_body, sh_full) =
253        read_handshake_message(&mut rd).context("read_handshake_message")?;
254    if sh_typ != HS_SERVER_HELLO {
255        return Err(ort_error(ErrorKind::TlsExpectedServerHello, ""));
256    }
257
258    // TODO: later remove the copy. The slices are into sh_buf
259    Ok((sh_body.to_vec(), sh_full.to_vec()))
260}
261
262struct HandshakeState {
263    handshake_secret: [u8; 32],
264    client_hs_ts: [u8; 32],
265    server_hs_ts: [u8; 32],
266    client_handshake_iv: [u8; 12],
267    server_handshake_iv: [u8; 12],
268    aead_enc_hs: [u8; 16],
269    aead_dec_hs: [u8; 16],
270    empty_hash: [u8; 32],
271}
272
273struct ApplicationKeys {
274    aead_app_enc: [u8; 16],
275    aead_app_dec: [u8; 16],
276    iv_enc: [u8; 12],
277    iv_dec: [u8; 12],
278}
279
280impl<T: Read + Write> TlsStream<T> {
281    pub fn connect(mut io: T, sni_host: &str) -> OrtResult<Self> {
282        // transcript = full Handshake message encodings (headers + bodies)
283        // Feb 18 2026 full transcript is 5674 bytes
284        let mut transcript = Vec::with_capacity(8192);
285
286        // A private key is simply random bytes. /dev/urandom is cryptographically secure.
287        let mut client_private_key = [0u8; 32];
288        syscall::getrandom(&mut client_private_key);
289        debug_print("Client private key", &client_private_key);
290
291        debug_print("MSG -> ClientHello", &[]);
292        Self::send_client_hello(&mut io, sni_host, &mut transcript, &client_private_key)?;
293
294        debug_print("MSG <- ServerHello", &[]);
295        let sh_body = Self::receive_server_hello(&mut io, &mut transcript)?;
296
297        let handshake = Self::derive_handshake_keys(&client_private_key, &sh_body, &transcript)?;
298
299        debug_print("MSG <- ChangeCipherSpec (dummy)", &[]);
300        Self::receive_dummy_change_cipher_spec(&mut io)?;
301
302        let mut seq_dec_hs = 0u64;
303        let mut seq_enc_hs = 0u64;
304
305        let mut is_finished: bool = false;
306        while !is_finished {
307            debug_print("MSG <- Server flight", &[]);
308            is_finished = Self::receive_server_encrypted_flight(
309                &mut io,
310                &mut seq_dec_hs,
311                &handshake,
312                &mut transcript,
313            )?;
314        }
315
316        let ApplicationKeys {
317            aead_app_enc,
318            aead_app_dec,
319            iv_enc: caiv,
320            iv_dec: saiv,
321        } = Self::derive_application_keys(
322            &handshake.handshake_secret,
323            &handshake.empty_hash,
324            &transcript,
325        );
326
327        let seq_app_enc = 0u64;
328        let seq_app_dec = 0u64;
329
330        // Client Change Cipher Spec
331        // This is optional, to "confuse middleboxes" which expect TLS 1.2. Works without.
332        //write_record_plain(&mut io, REC_TYPE_CHANGE_CIPHER_SPEC, &[0x01])?;
333
334        debug_print("MSG -> ClientFinished", &[]);
335        Self::send_client_finished(&mut io, &handshake, &mut transcript, &mut seq_enc_hs)?;
336
337        debug_print("TLS connect done", &[]);
338        Ok(TlsStream {
339            io,
340            aead_enc: aead_app_enc,
341            aead_dec: aead_app_dec,
342            iv_enc: caiv,
343            iv_dec: saiv,
344            seq_enc: seq_app_enc,
345            seq_dec: seq_app_dec,
346            rbuf: Vec::with_capacity(16 * 1024),
347            rpos: 0,
348        })
349    }
350
351    pub fn has_buffered_data(&self) -> bool {
352        /*
353        let out = self.rpos < self.rbuf.len();
354        let msg = alloc::string::ToString::to_string(&"rpos = ")
355            + &utils::num_to_string(self.rpos)
356            + ", rbuf.len() = "
357            + &utils::num_to_string(self.rbuf.len())
358            + " . "
359            + if out { "true" } else { "false" };
360        utils::print_string(c"tls has_buffered_data: ", &msg);
361        */
362        self.rpos < self.rbuf.len()
363    }
364
365    fn send_client_hello<W: Write>(
366        io: &mut W,
367        sni_host: &str,
368        transcript: &mut Vec<u8>,
369        client_private_key: &[u8; 32],
370    ) -> OrtResult<()> {
371        let ch_msg = client_hello_msg(sni_host, client_private_key)?;
372        write_record_plain(io, REC_TYPE_HANDSHAKE, &ch_msg).context("write ClientHello")?;
373        transcript.extend_from_slice(&ch_msg);
374        Ok(())
375    }
376
377    fn receive_server_hello<R: Read>(io: &mut R, transcript: &mut Vec<u8>) -> OrtResult<Vec<u8>> {
378        let (sh_body, sh_full) = read_server_hello(io)?;
379        transcript.extend_from_slice(&sh_full);
380        Ok(sh_body)
381    }
382
383    fn receive_dummy_change_cipher_spec<R: Read>(io: &mut R) -> OrtResult<()> {
384        // Some servers send TLS 1.2-style ChangeCipherSpec for middlebox compatibility.
385        let (typ, _) =
386            read_record_plain(io).context("read_record_plain for dummy change cipher")?;
387        if typ != REC_TYPE_CHANGE_CIPHER_SPEC {
388            return Err(ort_error(ErrorKind::TlsExpectedChangeCipherSpec, ""));
389        }
390        Ok(())
391    }
392
393    /// Should be called multiple times until it returns true.
394    /// The TLS messages for this stage might come as separate packets, or all in one.
395    fn receive_server_encrypted_flight<R: Read>(
396        io: &mut R,
397        seq_dec_hs: &mut u64,
398        handshake: &HandshakeState,
399        transcript: &mut Vec<u8>,
400    ) -> OrtResult<bool> {
401        let (typ, ct, _inner_type) = read_record_cipher(
402            io,
403            &handshake.aead_dec_hs,
404            &handshake.server_handshake_iv,
405            seq_dec_hs,
406        )?;
407        if typ != REC_TYPE_APPDATA {
408            return Err(ort_error(ErrorKind::TlsExpectedEncryptedRecords, ""));
409        }
410
411        // Decrypted TLSInnerPlaintext: ... | content_type
412        // May contain multiple handshake messages; parse & append to transcript.
413        let mut p = &ct[..];
414        while !p.is_empty() {
415            let (mtyp, body, full) = match read_handshake_message(&mut p) {
416                Ok(x) => x,
417                Err(_) => {
418                    return Err(ort_error(ErrorKind::TlsBadHandshakeFragment, ""));
419                }
420            };
421            transcript.extend_from_slice(full);
422            debug_print("handshake message (type is first byte)", full);
423
424            if mtyp == HS_FINISHED {
425                // verify server Finished
426                let s_finished_key =
427                    hkdf_expand_label::<32>(&handshake.server_hs_ts, "finished", &[]);
428
429                let thash = digest_bytes(&transcript[..transcript.len() - full.len()]);
430                let expected = hmac::sign(&s_finished_key, &thash);
431                if expected.as_slice() != body {
432                    return Err(ort_error(ErrorKind::TlsFinishedVerifyFailed, ""));
433                }
434                // Done collecting server handshake.
435                return Ok(true);
436            }
437            // Ignore other handshake types’ contents (no cert validation).
438        }
439        Ok(false)
440    }
441
442    fn derive_handshake_keys(
443        client_private_key: &[u8; 32],
444        sh_body: &[u8],
445        transcript: &[u8],
446    ) -> OrtResult<HandshakeState> {
447        // Parse minimal ServerHello to get cipher & key_share
448        let (cipher, server_public_key_bytes) = parse_server_hello_for_keys(sh_body)?;
449        debug_print("Server public key", &server_public_key_bytes);
450        if cipher != CIPHER_TLS_AES_128_GCM_SHA256 {
451            return Err(ort_error(
452                ErrorKind::TlsUnsupportedCipher,
453                "server picked unsupported cipher",
454            ));
455        }
456
457        // ECDH(X25519) shared secret
458        let hs_shared_secret = ecdh::x25519_agreement(client_private_key, &server_public_key_bytes);
459        debug_print("hs shared secret", &hs_shared_secret);
460
461        // Same as: `echo -n "" | openssl sha256`
462        let empty_hash = digest_bytes(&[]);
463        debug_print("empty_hash", &empty_hash);
464
465        let zero: [u8; 32] = [0u8; 32];
466        let early_secret = hkdf::hkdf_extract(&zero, &zero);
467
468        let derived_secret_bytes = hkdf_expand_label::<32>(&early_secret, "derived", &empty_hash);
469        debug_print("derived", &derived_secret_bytes);
470
471        let handshake_secret = hkdf::hkdf_extract(&derived_secret_bytes, &hs_shared_secret);
472        debug_print("handshake_secret", &handshake_secret);
473
474        let ch_sh_hash = digest_bytes(transcript);
475        debug_print("digest bytes", &ch_sh_hash);
476
477        let c_hs_ts = hkdf_expand_label(&handshake_secret, "c hs traffic", &ch_sh_hash);
478        let s_hs_ts = hkdf_expand_label(&handshake_secret, "s hs traffic", &ch_sh_hash);
479
480        debug_print("c hs traffic", &c_hs_ts);
481        debug_print("s hs traffic", &s_hs_ts);
482
483        // handshake AEAD keys/IVs
484        let client_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&c_hs_ts, "key", &[])
485            .as_slice()[..16]
486            .try_into()
487            .unwrap();
488        debug_print("client_handshake_key", &client_handshake_key);
489        let client_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&c_hs_ts, "iv", &[]).as_slice()
490            [..12]
491            .try_into()
492            .unwrap();
493        debug_print("client_handshake_iv", &client_handshake_iv);
494
495        let server_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&s_hs_ts, "key", &[])
496            .as_slice()[..16]
497            .try_into()
498            .unwrap();
499        debug_print("server_handshake_key", &server_handshake_key);
500        let server_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&s_hs_ts, "iv", &[]).as_slice()
501            [..12]
502            .try_into()
503            .unwrap();
504        debug_print("server_handshake_iv", &server_handshake_iv);
505
506        Ok(HandshakeState {
507            handshake_secret,
508            client_hs_ts: c_hs_ts,
509            server_hs_ts: s_hs_ts,
510            client_handshake_iv,
511            server_handshake_iv,
512            aead_enc_hs: client_handshake_key,
513            aead_dec_hs: server_handshake_key,
514            empty_hash,
515        })
516    }
517
518    fn derive_application_keys(
519        handshake_secret: &[u8; 32],
520        empty_hash: &[u8; 32],
521        transcript: &[u8],
522    ) -> ApplicationKeys {
523        let derived2_bytes = hkdf_expand_label::<32>(handshake_secret, "derived", empty_hash);
524        debug_print("derived2_bytes", &derived2_bytes);
525
526        let zero: [u8; 32] = [0u8; 32];
527        let master_secret = hkdf::hkdf_extract(&derived2_bytes, &zero);
528        let thash_srv_fin = digest_bytes(transcript);
529
530        let c_ap_ts = hkdf_expand_label::<32>(&master_secret, "c ap traffic", &thash_srv_fin);
531        let s_ap_ts = hkdf_expand_label::<32>(&master_secret, "s ap traffic", &thash_srv_fin);
532        debug_print("c_ap_ts", &c_ap_ts);
533        debug_print("s_ap_ts", &s_ap_ts);
534
535        let cak: [u8; 16] = hkdf_expand_label::<16>(&c_ap_ts, "key", &[]).as_slice()[..16]
536            .try_into()
537            .unwrap();
538        let caiv: [u8; 12] = hkdf_expand_label::<12>(&c_ap_ts, "iv", &[]).as_slice()[..12]
539            .try_into()
540            .unwrap();
541        debug_print("cak", &cak);
542        debug_print("caiv", &caiv);
543
544        let sak: [u8; 16] = hkdf_expand_label::<16>(&s_ap_ts, "key", &[]).as_slice()[..16]
545            .try_into()
546            .unwrap();
547        let saiv: [u8; 12] = hkdf_expand_label::<12>(&s_ap_ts, "iv", &[]).as_slice()[..12]
548            .try_into()
549            .unwrap();
550        debug_print("sak", &sak);
551        debug_print("saiv", &saiv);
552
553        ApplicationKeys {
554            aead_app_enc: cak,
555            aead_app_dec: sak,
556            iv_enc: caiv,
557            iv_dec: saiv,
558        }
559    }
560
561    fn send_client_finished<W: Write>(
562        io: &mut W,
563        handshake: &HandshakeState,
564        transcript: &mut Vec<u8>,
565        seq_enc_hs: &mut u64,
566    ) -> OrtResult<()> {
567        let c_finished_key = hkdf_expand_label::<32>(&handshake.client_hs_ts, "finished", &[]);
568        debug_print("c_finished", &c_finished_key);
569
570        let thash_client_fin = digest_bytes(transcript.as_slice());
571        let verify_data = hmac::sign(&c_finished_key, &thash_client_fin);
572        debug_print("verify_data", &verify_data);
573
574        let mut fin = Vec::with_capacity(4 + verify_data.as_ref().len());
575        fin.push(HS_FINISHED);
576        put_u24(&mut fin, verify_data.as_ref().len());
577        fin.extend_from_slice(verify_data.as_ref());
578
579        // append to transcript before switching keys
580        transcript.extend_from_slice(&fin);
581
582        write_record_cipher(
583            io,
584            REC_TYPE_HANDSHAKE,
585            &fin,
586            &handshake.aead_enc_hs,
587            &handshake.client_handshake_iv,
588            seq_enc_hs,
589        )
590        .context("write_record_cipher write_all failed")?;
591
592        Ok(())
593    }
594}
595
596impl<T: Read + Write> Write for TlsStream<T> {
597    fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
598        let mut bytes_sent = 0;
599        for chunk in buf.chunks(MAX_PLAINTEXT_SIZE) {
600            write_record_cipher(
601                &mut self.io,
602                REC_TYPE_APPDATA,
603                chunk,
604                &self.aead_enc,
605                &self.iv_enc,
606                &mut self.seq_enc,
607            )?;
608            bytes_sent += chunk.len();
609        }
610        Ok(bytes_sent)
611    }
612    fn flush(&mut self) -> OrtResult<()> {
613        self.io.flush()
614    }
615}
616
617impl<T: Read + Write> Read for TlsStream<T> {
618    fn read(&mut self, out: &mut [u8]) -> OrtResult<usize> {
619        if self.rpos < self.rbuf.len() {
620            debug_print("TlsStream.read using buf", &[]);
621
622            let n = cmp::min(out.len(), self.rbuf.len() - self.rpos);
623            out[..n].copy_from_slice(&self.rbuf[self.rpos..self.rpos + n]);
624            self.rpos += n;
625            if self.rpos == self.rbuf.len() {
626                self.rbuf.clear();
627                self.rpos = 0;
628            }
629            return Ok(n);
630        }
631        loop {
632            let (typ, plaintext, inner_type) = read_record_cipher(
633                &mut self.io,
634                &self.aead_dec,
635                &self.iv_dec,
636                &mut self.seq_dec,
637            )?;
638            if typ != REC_TYPE_APPDATA {
639                // Ignore unexpected (e.g., post-handshake Handshake like NewSessionTicket)
640                continue;
641            }
642            // plaintext ends with inner content type byte; for app data it is 0x17.
643            if plaintext.is_empty() {
644                continue;
645            }
646            if inner_type == REC_TYPE_HANDSHAKE {
647                // Drop post-handshake messages (tickets, etc.)
648                continue;
649            }
650            if inner_type == REC_TYPE_ALERT {
651                let level = match plaintext[0] {
652                    1 => "warning",
653                    2 => "fatal",
654                    _ => "unknown",
655                };
656                let err_level = CString::new(level.to_string() + " alert: ").unwrap();
657
658                // See https://www.rfc-editor.org/rfc/rfc8446#appendix-B search for
659                // "unexpected_message" for all types
660                let mut err_code_buf: [u8; 5] = [0u8; 5];
661                let len = to_ascii(plaintext[1] as usize, &mut err_code_buf);
662                let err_code = unsafe { CStr::from_bytes_with_nul_unchecked(&err_code_buf[..len]) };
663                syscall::write(2, err_level.as_ptr().cast(), err_level.count_bytes());
664                syscall::write(2, err_code.as_ptr().cast(), err_code.count_bytes());
665
666                return Err(ort_error(ErrorKind::TlsAlertReceived, ""));
667            }
668            if inner_type != REC_TYPE_APPDATA {
669                // Some servers pad with 0x00.. then type; we already consumed type.
670                // If not 0x17, treat preceding bytes (if any) as app anyway.
671            }
672            if plaintext.is_empty() {
673                continue;
674            }
675
676            self.rbuf.extend_from_slice(&plaintext);
677            self.rpos = 0;
678            // Now serve from buffer
679            let n = cmp::min(out.len(), self.rbuf.len());
680            out[..n].copy_from_slice(&self.rbuf[..n]);
681            self.rpos = n;
682            if n == self.rbuf.len() {
683                self.rbuf.clear();
684                self.rpos = 0;
685            }
686            return Ok(n);
687        }
688    }
689}
690
691impl<T: Read + Write + AsFd> AsFd for TlsStream<T> {
692    fn as_fd(&self) -> i32 {
693        self.io.as_fd()
694    }
695}
696
697// ---------------------- Record I/O helpers ----------------------------------
698
699fn write_record_plain<W: Write>(w: &mut W, typ: u8, body: &[u8]) -> OrtResult<()> {
700    let mut hdr = [0u8; 5];
701    hdr[0] = typ;
702    hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
703    hdr[3..5].copy_from_slice(&(body.len() as u16).to_be_bytes());
704    w.write_all(&hdr)?;
705    w.write_all(body)?;
706    Ok(())
707}
708
709fn read_exact_n<R: Read>(r: &mut R, n: usize) -> OrtResult<Vec<u8>> {
710    let mut buf = vec![0u8; n];
711    r.read_exact(&mut buf)?;
712    Ok(buf)
713}
714
715fn read_record_plain<R: Read>(r: &mut R) -> OrtResult<(u8, Vec<u8>)> {
716    let hdr = read_exact_n(r, 5)?; // Record Header, e.g. 16 03 03 len
717    let typ = hdr[0];
718    let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
719    let body = read_exact_n(r, len)?;
720    debug_print("read_record_plain hdr", &hdr);
721    debug_print("read_record_plain body", &body);
722    //let _ = write_bytes_to_file(&[&hdr[..], &body].concat(), debug_filename);
723    Ok((typ, body))
724}
725
726fn write_record_cipher<W: Write>(
727    w: &mut W,
728    outer_type: u8,
729    inner: &[u8],
730    key: &[u8; 16],
731    iv12: &[u8; 12],
732    seq: &mut u64,
733) -> OrtResult<()> {
734    // AES / GCM plaintext and ciphertext have the same length
735    let total_len = inner.len() + 1 + AEAD_TAG_LEN;
736    let mut plain = Vec::with_capacity(total_len);
737    plain.extend_from_slice(inner);
738    plain.push(outer_type);
739
740    debug_print("write_record_cipher plaintext", &plain);
741
742    let nonce = nonce_xor(iv12, *seq);
743    *seq = seq.wrapping_add(1);
744
745    let mut hdr = [0u8; 5];
746    hdr[0] = REC_TYPE_APPDATA;
747    hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
748    hdr[3..5].copy_from_slice(&(total_len as u16).to_be_bytes());
749
750    let out = aead::aes_128_gcm_encrypt(key, &nonce, &hdr, &plain).unwrap();
751
752    debug_print("write_record_cipher header", &hdr);
753    //let final_label = format!("write_record_cipher final {total_len}");
754    //debug_print(final_label.as_str(), &out);
755
756    w.write_all(&hdr)?;
757    w.write_all(&out)?;
758    Ok(())
759}
760
761fn read_record_cipher<R: Read>(
762    r: &mut R,
763    key: &[u8; 16],
764    iv12: &[u8; 12],
765    seq: &mut u64,
766) -> OrtResult<(u8, Vec<u8>, u8)> {
767    let hdr = read_exact_n(r, 5)?;
768    let typ = hdr[0];
769    let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
770    let ciphertext = read_exact_n(r, len)?;
771    if len < AEAD_TAG_LEN {
772        return Err(ort_error(ErrorKind::TlsRecordTooShort, "short record"));
773    }
774    debug_print("read_record_cipher hdr", &hdr);
775    debug_print("read_record_cipher ct", &ciphertext);
776
777    //let size_expected = crate::utils::num_to_string(len);
778    //let size_read = crate::utils::num_to_string(ciphertext.len());
779    //crate::utils::print_string(c"size_expected ", &size_expected);
780    //crate::utils::print_string(c"size_read ", &size_read);
781
782    // Decrypt ciphertext
783
784    let nonce = nonce_xor(iv12, *seq);
785    *seq = seq.wrapping_add(1);
786
787    let mut out = match aead::aes_128_gcm_decrypt(key, &nonce, &hdr, &ciphertext) {
788        Ok(out) => out,
789        Err(s) => {
790            return Err(ort_error(ErrorKind::TlsAes128GcmDecryptFailed, s));
791        }
792    };
793
794    debug_print("read_record_cipher plaintext hdr", &hdr);
795    debug_print("read_record_cipher plaintext", &out);
796
797    if out.is_empty() {
798        return Ok((typ, ciphertext, 0));
799    }
800    // Strip inner content-type byte
801    let inner_type = *out.last().unwrap();
802    out.truncate(out.len() - 1);
803    Ok((typ, out, inner_type))
804}
805
806// ---------------------- Handshake parsing helpers ---------------------------
807
808fn read_handshake_message<'a>(rd: &mut &'a [u8]) -> OrtResult<(u8, &'a [u8], &'a [u8])> {
809    if rd.len() < 4 {
810        return Err(ort_error(ErrorKind::TlsHandshakeHeaderTooShort, ""));
811    }
812    let typ = rd[0];
813    let len = ((rd[1] as usize) << 16) | ((rd[2] as usize) << 8) | rd[3] as usize;
814    if rd.len() < 4 + len {
815        return Err(ort_error(ErrorKind::TlsHandshakeBodyTooShort, ""));
816    }
817    let full = &rd[..4 + len];
818    let body = &rd[4..4 + len];
819    *rd = &rd[4 + len..];
820    Ok((typ, body, full))
821}
822
823fn parse_server_hello_for_keys(sh: &[u8]) -> OrtResult<(u16, [u8; 32])> {
824    // minimal parse: skip legacy_version(2), random(32), sid, cipher(2), comp(1), exts
825    if sh.len() < 2 + 32 + 1 + 2 + 1 + 2 {
826        return Err(ort_error(ErrorKind::TlsServerHelloTooShort, ""));
827    }
828    let mut p = sh;
829
830    p = &p[2..]; // legacy_version
831    p = &p[32..]; // random
832    let sid_len = p[0] as usize;
833    p = &p[1..];
834    if p.len() < sid_len + 2 + 1 + 2 {
835        return Err(ort_error(ErrorKind::TlsServerHelloSessionIdInvalid, ""));
836    }
837    p = &p[sid_len..];
838    let cipher = u16::from_be_bytes([p[0], p[1]]);
839    p = &p[2..];
840    let _comp = p[0];
841    p = &p[1..];
842    let ext_len = u16::from_be_bytes([p[0], p[1]]) as usize;
843    p = &p[2..];
844    if p.len() < ext_len {
845        return Err(ort_error(ErrorKind::TlsServerHelloExtTooShort, ""));
846    }
847    let mut ex = &p[..ext_len];
848
849    let mut server_pub = None;
850
851    while !ex.is_empty() {
852        if ex.len() < 4 {
853            return Err(ort_error(ErrorKind::TlsExtensionHeaderTooShort, ""));
854        }
855        let et = u16::from_be_bytes([ex[0], ex[1]]);
856        let el = u16::from_be_bytes([ex[2], ex[3]]) as usize;
857        ex = &ex[4..];
858        if ex.len() < el {
859            return Err(ort_error(ErrorKind::TlsExtensionLengthInvalid, ""));
860        }
861        let ed = &ex[..el];
862        ex = &ex[el..];
863
864        match et {
865            EXT_KEY_SHARE => {
866                // KeyShareServerHello: group(2) kx_len(2) kx
867                if ed.len() < 2 + 2 + 32 {
868                    return Err(ort_error(ErrorKind::TlsKeyShareServerHelloInvalid, ""));
869                }
870                let grp = u16::from_be_bytes([ed[0], ed[1]]);
871                if grp != GROUP_X25519 {
872                    return Err(ort_error(
873                        ErrorKind::TlsServerGroupUnsupported,
874                        "server group != x25519",
875                    ));
876                }
877                let kx_len = u16::from_be_bytes([ed[2], ed[3]]) as usize;
878                if ed.len() < 4 + kx_len || kx_len != 32 {
879                    return Err(ort_error(ErrorKind::TlsKeyShareLengthInvalid, ""));
880                }
881                let mut pk = [0u8; 32];
882                pk.copy_from_slice(&ed[4..4 + 32]);
883                server_pub = Some(pk);
884            }
885            EXT_SUPPORTED_VERSIONS
886                if (ed.len() != 2 || u16::from_be_bytes([ed[0], ed[1]]) != TLS13) =>
887            {
888                return Err(ort_error(ErrorKind::TlsServerNotTls13, ""));
889            }
890            _ => {}
891        }
892    }
893
894    let sp = server_pub.ok_or_else(|| ort_error(ErrorKind::TlsMissingServerKey, ""))?;
895    Ok((cipher, sp))
896}
897
898#[allow(unused)]
899fn debug_print(name: &str, value: &[u8]) {
900    #[cfg(debug_assertions)]
901    {
902        if !DEBUG_LOG {
903            return;
904        }
905        let c_str = CString::new(name).unwrap();
906        if !value.is_empty() {
907            crate::utils::print_hex(c_str.as_c_str(), value);
908        } else {
909            crate::utils::print_string(c_str.as_c_str(), "");
910        }
911    }
912}
913
914/*
915#[allow(dead_code)]
916fn write_bytes_to_file(bytes: &[u8], file_path: &str) -> std::io::Result<()> {
917    let mut file = File::create(file_path)?;
918    file.write_all(bytes)?;
919    Ok(())
920}
921*/
922
923#[cfg(test)]
924pub mod tests {
925    extern crate alloc;
926    use alloc::vec::Vec;
927
928    pub fn string_to_bytes(s: &str) -> [u8; 32] {
929        let mut bytes = s.as_bytes();
930        if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'x' || bytes[1] == b'X') {
931            bytes = &bytes[2..];
932        }
933        assert!(
934            bytes.len() == 64,
935            "hex string must be exactly 64 hex chars (32 bytes)"
936        );
937
938        let mut out = [0u8; 32];
939        for i in 0..32 {
940            let hi = hex_val(bytes[2 * i]);
941            let lo = hex_val(bytes[2 * i + 1]);
942            out[i] = (hi << 4) | lo;
943        }
944        out
945    }
946
947    pub fn hex_to_vec(s: &str) -> Vec<u8> {
948        let mut bytes = s.as_bytes();
949        if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'X' || bytes[1] == b'x') {
950            bytes = &bytes[2..];
951        }
952        assert_eq!(bytes.len() % 2, 0, "hex string must have even length");
953        let mut out = Vec::with_capacity(bytes.len() / 2);
954        for chunk in bytes.chunks_exact(2) {
955            let hi = hex_val(chunk[0]);
956            let lo = hex_val(chunk[1]);
957            out.push((hi << 4) | lo);
958        }
959        out
960    }
961
962    fn hex_val(b: u8) -> u8 {
963        match b {
964            b'0'..=b'9' => b - b'0',
965            b'a'..=b'f' => b - b'a' + 10,
966            b'A'..=b'F' => b - b'A' + 10,
967            _ => panic!("invalid hex character"),
968        }
969    }
970}