Skip to main content

ort_openrouter_cli/net/
tls.rs

1//! ort: Open Router CLI
2//! https://github.com/grahamking/ort
3//!
4//! MIT License
5//! Copyright (c) 2025 Graham King
6//
7//! ---------------------- Minimal TLS 1.3 client (AES-128-GCM + X25519) -------
8
9use core::{cmp, ffi::CStr};
10
11extern crate alloc;
12use alloc::ffi::CString;
13use alloc::string::ToString;
14use alloc::vec;
15use alloc::vec::Vec;
16
17use crate::{
18    Context, ErrorKind, OrtResult, Read, Write, common::utils::to_ascii, net::AsFd, ort_error,
19    syscall,
20};
21
22mod aead;
23mod ecdh;
24mod hkdf;
25mod hmac;
26mod sha2;
27
28#[allow(unused)]
29const DEBUG_LOG: bool = false;
30
31const REC_TYPE_CHANGE_CIPHER_SPEC: u8 = 20; // 0x14
32const REC_TYPE_ALERT: u8 = 21; // 0x15
33const REC_TYPE_HANDSHAKE: u8 = 22; // 0x16
34const REC_TYPE_APPDATA: u8 = 23; // 0x17
35const LEGACY_REC_VER: u16 = 0x0303;
36
37const HS_CLIENT_HELLO: u8 = 1;
38const HS_SERVER_HELLO: u8 = 2;
39//const HS_NEW_SESSION_TICKET: u8 = 4;
40//const HS_ENCRYPTED_EXTENSIONS: u8 = 8;
41//const HS_CERTIFICATE: u8 = 11;
42//const HS_CERT_VERIFY: u8 = 15;
43const HS_FINISHED: u8 = 20; // 0x14
44
45// TLS_AES_128_GCM_SHA256
46const CIPHER_TLS_AES_128_GCM_SHA256: u16 = 0x1301;
47// supported_versions (TLS 1.3)
48const TLS13: u16 = 0x0304;
49// supported group: x25519
50const GROUP_X25519: u16 = 0x001d;
51
52// Extensions
53const EXT_SERVER_NAME: u16 = 0x0000;
54const EXT_SUPPORTED_GROUPS: u16 = 0x000a;
55const EXT_SIGNATURE_ALGS: u16 = 0x000d;
56//const EXT_ALPN: u16 = 0x0010;
57const EXT_SUPPORTED_VERSIONS: u16 = 0x002b;
58//const EXT_PSK_MODES: u16 = 0x002d;
59const EXT_KEY_SHARE: u16 = 0x0033;
60
61// AEAD tag length (GCM)
62const AEAD_TAG_LEN: usize = 16;
63
64// Tiny helper to write BE ints
65fn put_u16(buf: &mut Vec<u8>, v: u16) {
66    buf.extend_from_slice(&v.to_be_bytes());
67}
68fn put_u24(buf: &mut Vec<u8>, v: usize) {
69    let v = v as u32;
70    buf.extend_from_slice(&[(v >> 16) as u8, (v >> 8) as u8, v as u8]);
71}
72
73fn hkdf_expand_label<const N: usize>(prk: &[u8], label: &str, data: &[u8]) -> [u8; N] {
74    let mut info = Vec::with_capacity(2 + 1 + 6 + label.len() + 1 + data.len());
75    put_u16(&mut info, N as u16);
76    info.push(("tls13 ".len() + label.len()) as u8);
77    info.extend_from_slice("tls13 ".as_bytes());
78    info.extend_from_slice(label.as_bytes());
79    info.push(data.len() as u8);
80    info.extend_from_slice(data);
81
82    hkdf::hkdf_expand(prk, &info, N).try_into().unwrap()
83}
84
85fn digest_bytes(data: &[u8]) -> [u8; 32] {
86    let d = sha2::sha256(data);
87    let mut out = [0u8; 32];
88    out.copy_from_slice(d.as_ref());
89    out
90}
91
92// AEAD nonce = iv XOR seq (seq in BE on the rightmost 8 bytes)
93fn nonce_xor(iv12: &[u8; 12], seq: u64) -> [u8; 12] {
94    // seq number in big endian on rightmost 8 bytes
95    let mut nonce_bytes = [[0, 0, 0, 0].as_ref(), &u64::to_be_bytes(seq)].concat();
96    // xor them
97    nonce_bytes.iter_mut().zip(iv12.iter()).for_each(|(s, iv)| {
98        *s ^= *iv;
99    });
100    nonce_bytes[..12].try_into().unwrap()
101}
102
103// Very small record writer/reader after handshake
104pub struct TlsStream<T: Read + Write> {
105    io: T,
106    // Application traffic
107    aead_enc: [u8; 16],
108    aead_dec: [u8; 16],
109    iv_enc: [u8; 12],
110    iv_dec: [u8; 12],
111    seq_enc: u64,
112    seq_dec: u64,
113    // read buffer for decrypted application data
114    rbuf: Vec<u8>,
115    rpos: usize,
116}
117
118fn client_hello_body(sni_host: &str, client_pub: &[u8]) -> Vec<u8> {
119    let mut ch_body = Vec::with_capacity(512);
120
121    // X25519
122    let mut random = [0u8; 32];
123    syscall::getrandom(&mut random);
124
125    let mut session_id = [0u8; 32];
126    syscall::getrandom(&mut session_id);
127
128    // legacy_version
129    ch_body.extend_from_slice(&0x0303u16.to_be_bytes());
130    // random
131    ch_body.extend_from_slice(&random);
132    // legacy_session_id
133    ch_body.push(session_id.len() as u8);
134    ch_body.extend_from_slice(&session_id);
135    // cipher_suites: only TLS_AES_128_GCM_SHA256
136    put_u16(&mut ch_body, 2);
137    put_u16(&mut ch_body, CIPHER_TLS_AES_128_GCM_SHA256);
138    // legacy_compression_methods: null
139    ch_body.push(1);
140    ch_body.push(0);
141
142    // --- extensions ---
143    let mut exts = Vec::with_capacity(512);
144
145    // server_name
146    {
147        let host_bytes = sni_host.as_bytes();
148        let mut snl = Vec::with_capacity(3 + host_bytes.len());
149        snl.push(0); // host_name
150        put_u16(&mut snl, host_bytes.len() as u16);
151        snl.extend_from_slice(host_bytes);
152
153        let mut sni = Vec::with_capacity(2 + snl.len());
154        put_u16(&mut sni, snl.len() as u16);
155        sni.extend_from_slice(&snl);
156
157        put_u16(&mut exts, EXT_SERVER_NAME);
158        put_u16(&mut exts, sni.len() as u16);
159        exts.extend_from_slice(&sni);
160    }
161
162    // supported_versions: TLS 1.3
163    {
164        let mut sv = Vec::with_capacity(3);
165        sv.push(2); // length in bytes
166        sv.extend_from_slice(&TLS13.to_be_bytes());
167        put_u16(&mut exts, EXT_SUPPORTED_VERSIONS);
168        put_u16(&mut exts, sv.len() as u16);
169        exts.extend_from_slice(&sv);
170    }
171
172    // supported_groups: x25519
173    {
174        let mut sg = Vec::with_capacity(2 + 2);
175        put_u16(&mut sg, 2);
176        put_u16(&mut sg, GROUP_X25519);
177        put_u16(&mut exts, EXT_SUPPORTED_GROUPS);
178        put_u16(&mut exts, sg.len() as u16);
179        exts.extend_from_slice(&sg);
180    }
181
182    // signature_algorithms: minimal list
183    {
184        const ECDSA_SECP256R1_SHA256: u16 = 0x0403;
185        const RSA_PSS_RSAE_SHA256: u16 = 0x0804;
186        const RSA_PKCS1_SHA256: u16 = 0x0401;
187
188        let mut sa = Vec::with_capacity(2 + 6);
189        put_u16(&mut sa, 6);
190        put_u16(&mut sa, ECDSA_SECP256R1_SHA256);
191        put_u16(&mut sa, RSA_PSS_RSAE_SHA256);
192        put_u16(&mut sa, RSA_PKCS1_SHA256);
193
194        put_u16(&mut exts, EXT_SIGNATURE_ALGS);
195        put_u16(&mut exts, sa.len() as u16);
196        exts.extend_from_slice(&sa);
197    }
198
199    // key_share: x25519
200    {
201        let mut ks = Vec::with_capacity(2 + 2 + 2 + 32);
202        // client_shares vector
203        let mut entry = Vec::with_capacity(2 + 2 + 32);
204        put_u16(&mut entry, GROUP_X25519);
205        put_u16(&mut entry, 32);
206        entry.extend_from_slice(client_pub);
207        put_u16(&mut ks, entry.len() as u16);
208        ks.extend_from_slice(&entry);
209
210        put_u16(&mut exts, EXT_KEY_SHARE);
211        put_u16(&mut exts, ks.len() as u16);
212        exts.extend_from_slice(&ks);
213    }
214
215    // add extensions to CH
216    put_u16(&mut ch_body, exts.len() as u16);
217    ch_body.extend_from_slice(&exts);
218
219    ch_body
220}
221
222/// --- Build ClientHello (single cipher: TLS_AES_128_GCM_SHA256) ---
223fn client_hello_msg(sni_host: &str, client_private_key: &[u8]) -> OrtResult<Vec<u8>> {
224    let client_pub_key = ecdh::x25519_public_key(client_private_key);
225    let client_pub_ref = &client_pub_key;
226    debug_print("Client public key", client_pub_ref);
227
228    let ch_body = client_hello_body(sni_host, client_pub_ref);
229
230    // Handshake framing: ClientHello
231    let mut ch_msg = Vec::with_capacity(4 + ch_body.len());
232    ch_msg.push(HS_CLIENT_HELLO);
233    put_u24(&mut ch_msg, ch_body.len());
234    ch_msg.extend_from_slice(&ch_body);
235
236    Ok(ch_msg)
237}
238
239/// Read ServerHello (plaintext Handshake record)
240fn read_server_hello<R: Read>(io: &mut R) -> OrtResult<(Vec<u8>, Vec<u8>)> {
241    let (typ, payload) = read_record_plain(io).context("read_record_plain in read_server_hello")?;
242    if typ != REC_TYPE_HANDSHAKE {
243        return Err(ort_error(ErrorKind::TlsExpectedHandshakeRecord, ""));
244    }
245    let sh_buf = payload;
246
247    // There can be multiple handshake messages; we need the ServerHello bytes specifically
248    let mut rd = &sh_buf[..];
249    let (sh_typ, sh_body, sh_full) =
250        read_handshake_message(&mut rd).context("read_handshake_message")?;
251    if sh_typ != HS_SERVER_HELLO {
252        return Err(ort_error(ErrorKind::TlsExpectedServerHello, ""));
253    }
254
255    // TODO: later remove the copy. The slices are into sh_buf
256    Ok((sh_body.to_vec(), sh_full.to_vec()))
257}
258
259struct HandshakeState {
260    handshake_secret: [u8; 32],
261    client_hs_ts: [u8; 32],
262    server_hs_ts: [u8; 32],
263    client_handshake_iv: [u8; 12],
264    server_handshake_iv: [u8; 12],
265    aead_enc_hs: [u8; 16],
266    aead_dec_hs: [u8; 16],
267    empty_hash: [u8; 32],
268}
269
270struct ApplicationKeys {
271    aead_app_enc: [u8; 16],
272    aead_app_dec: [u8; 16],
273    iv_enc: [u8; 12],
274    iv_dec: [u8; 12],
275}
276
277impl<T: Read + Write> TlsStream<T> {
278    pub fn connect(mut io: T, sni_host: &str) -> OrtResult<Self> {
279        // transcript = full Handshake message encodings (headers + bodies)
280        // Feb 18 2026 full transcript is 5674 bytes
281        let mut transcript = Vec::with_capacity(8192);
282
283        // A private key is simply random bytes. /dev/urandom is cryptographically secure.
284        let mut client_private_key = [0u8; 32];
285        syscall::getrandom(&mut client_private_key);
286        debug_print("Client private key", &client_private_key);
287
288        debug_print("MSG -> ClientHello", &[]);
289        Self::send_client_hello(&mut io, sni_host, &mut transcript, &client_private_key)?;
290
291        debug_print("MSG <- ServerHello", &[]);
292        let sh_body = Self::receive_server_hello(&mut io, &mut transcript)?;
293
294        let handshake = Self::derive_handshake_keys(&client_private_key, &sh_body, &transcript)?;
295
296        debug_print("MSG <- ChangeCipherSpec (dummy)", &[]);
297        Self::receive_dummy_change_cipher_spec(&mut io)?;
298
299        let mut seq_dec_hs = 0u64;
300        let mut seq_enc_hs = 0u64;
301
302        let mut is_finished: bool = false;
303        while !is_finished {
304            debug_print("MSG <- Server flight", &[]);
305            is_finished = Self::receive_server_encrypted_flight(
306                &mut io,
307                &mut seq_dec_hs,
308                &handshake,
309                &mut transcript,
310            )?;
311        }
312
313        let ApplicationKeys {
314            aead_app_enc,
315            aead_app_dec,
316            iv_enc: caiv,
317            iv_dec: saiv,
318        } = Self::derive_application_keys(
319            &handshake.handshake_secret,
320            &handshake.empty_hash,
321            &transcript,
322        );
323
324        let seq_app_enc = 0u64;
325        let seq_app_dec = 0u64;
326
327        // Client Change Cipher Spec
328        // This is optional, to "confuse middleboxes" which expect TLS 1.2. Works without.
329        //write_record_plain(&mut io, REC_TYPE_CHANGE_CIPHER_SPEC, &[0x01])?;
330
331        debug_print("MSG -> ClientFinished", &[]);
332        Self::send_client_finished(&mut io, &handshake, &mut transcript, &mut seq_enc_hs)?;
333
334        debug_print("TLS connect done", &[]);
335        Ok(TlsStream {
336            io,
337            aead_enc: aead_app_enc,
338            aead_dec: aead_app_dec,
339            iv_enc: caiv,
340            iv_dec: saiv,
341            seq_enc: seq_app_enc,
342            seq_dec: seq_app_dec,
343            rbuf: Vec::with_capacity(16 * 1024),
344            rpos: 0,
345        })
346    }
347
348    pub fn has_buffered_data(&self) -> bool {
349        /*
350        let out = self.rpos < self.rbuf.len();
351        let msg = alloc::string::ToString::to_string(&"rpos = ")
352            + &utils::num_to_string(self.rpos)
353            + ", rbuf.len() = "
354            + &utils::num_to_string(self.rbuf.len())
355            + " . "
356            + if out { "true" } else { "false" };
357        utils::print_string(c"tls has_buffered_data: ", &msg);
358        */
359        self.rpos < self.rbuf.len()
360    }
361
362    fn send_client_hello<W: Write>(
363        io: &mut W,
364        sni_host: &str,
365        transcript: &mut Vec<u8>,
366        client_private_key: &[u8; 32],
367    ) -> OrtResult<()> {
368        let ch_msg = client_hello_msg(sni_host, client_private_key)?;
369        write_record_plain(io, REC_TYPE_HANDSHAKE, &ch_msg).context("write ClientHello")?;
370        transcript.extend_from_slice(&ch_msg);
371        Ok(())
372    }
373
374    fn receive_server_hello<R: Read>(io: &mut R, transcript: &mut Vec<u8>) -> OrtResult<Vec<u8>> {
375        let (sh_body, sh_full) = read_server_hello(io)?;
376        transcript.extend_from_slice(&sh_full);
377        Ok(sh_body)
378    }
379
380    fn receive_dummy_change_cipher_spec<R: Read>(io: &mut R) -> OrtResult<()> {
381        // Some servers send TLS 1.2-style ChangeCipherSpec for middlebox compatibility.
382        let (typ, _) =
383            read_record_plain(io).context("read_record_plain for dummy change cipher")?;
384        if typ != REC_TYPE_CHANGE_CIPHER_SPEC {
385            return Err(ort_error(ErrorKind::TlsExpectedChangeCipherSpec, ""));
386        }
387        Ok(())
388    }
389
390    /// Should be called multiple times until it returns true.
391    /// The TLS messages for this stage might come as separate packets, or all in one.
392    fn receive_server_encrypted_flight<R: Read>(
393        io: &mut R,
394        seq_dec_hs: &mut u64,
395        handshake: &HandshakeState,
396        transcript: &mut Vec<u8>,
397    ) -> OrtResult<bool> {
398        let (typ, ct, _inner_type) = read_record_cipher(
399            io,
400            &handshake.aead_dec_hs,
401            &handshake.server_handshake_iv,
402            seq_dec_hs,
403        )?;
404        if typ != REC_TYPE_APPDATA {
405            return Err(ort_error(ErrorKind::TlsExpectedEncryptedRecords, ""));
406        }
407
408        // Decrypted TLSInnerPlaintext: ... | content_type
409        // May contain multiple handshake messages; parse & append to transcript.
410        let mut p = &ct[..];
411        while !p.is_empty() {
412            let (mtyp, body, full) = match read_handshake_message(&mut p) {
413                Ok(x) => x,
414                Err(_) => {
415                    return Err(ort_error(ErrorKind::TlsBadHandshakeFragment, ""));
416                }
417            };
418            transcript.extend_from_slice(full);
419            debug_print("handshake message (type is first byte)", full);
420
421            if mtyp == HS_FINISHED {
422                // verify server Finished
423                let s_finished_key =
424                    hkdf_expand_label::<32>(&handshake.server_hs_ts, "finished", &[]);
425
426                let thash = digest_bytes(&transcript[..transcript.len() - full.len()]);
427                let expected = hmac::sign(&s_finished_key, &thash);
428                if expected.as_slice() != body {
429                    return Err(ort_error(ErrorKind::TlsFinishedVerifyFailed, ""));
430                }
431                // Done collecting server handshake.
432                return Ok(true);
433            }
434            // Ignore other handshake types’ contents (no cert validation).
435        }
436        Ok(false)
437    }
438
439    fn derive_handshake_keys(
440        client_private_key: &[u8; 32],
441        sh_body: &[u8],
442        transcript: &[u8],
443    ) -> OrtResult<HandshakeState> {
444        // Parse minimal ServerHello to get cipher & key_share
445        let (cipher, server_public_key_bytes) = parse_server_hello_for_keys(sh_body)?;
446        debug_print("Server public key", &server_public_key_bytes);
447        if cipher != CIPHER_TLS_AES_128_GCM_SHA256 {
448            return Err(ort_error(
449                ErrorKind::TlsUnsupportedCipher,
450                "server picked unsupported cipher",
451            ));
452        }
453
454        // ECDH(X25519) shared secret
455        let hs_shared_secret = ecdh::x25519_agreement(client_private_key, &server_public_key_bytes);
456        debug_print("hs shared secret", &hs_shared_secret);
457
458        // Same as: `echo -n "" | openssl sha256`
459        let empty_hash = digest_bytes(&[]);
460        debug_print("empty_hash", &empty_hash);
461
462        let zero: [u8; 32] = [0u8; 32];
463        let early_secret = hkdf::hkdf_extract(&zero, &zero);
464
465        let derived_secret_bytes = hkdf_expand_label::<32>(&early_secret, "derived", &empty_hash);
466        debug_print("derived", &derived_secret_bytes);
467
468        let handshake_secret = hkdf::hkdf_extract(&derived_secret_bytes, &hs_shared_secret);
469        debug_print("handshake_secret", &handshake_secret);
470
471        let ch_sh_hash = digest_bytes(transcript);
472        debug_print("digest bytes", &ch_sh_hash);
473
474        let c_hs_ts = hkdf_expand_label(&handshake_secret, "c hs traffic", &ch_sh_hash);
475        let s_hs_ts = hkdf_expand_label(&handshake_secret, "s hs traffic", &ch_sh_hash);
476
477        debug_print("c hs traffic", &c_hs_ts);
478        debug_print("s hs traffic", &s_hs_ts);
479
480        // handshake AEAD keys/IVs
481        let client_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&c_hs_ts, "key", &[])
482            .as_slice()[..16]
483            .try_into()
484            .unwrap();
485        debug_print("client_handshake_key", &client_handshake_key);
486        let client_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&c_hs_ts, "iv", &[]).as_slice()
487            [..12]
488            .try_into()
489            .unwrap();
490        debug_print("client_handshake_iv", &client_handshake_iv);
491
492        let server_handshake_key: [u8; 16] = hkdf_expand_label::<16>(&s_hs_ts, "key", &[])
493            .as_slice()[..16]
494            .try_into()
495            .unwrap();
496        debug_print("server_handshake_key", &server_handshake_key);
497        let server_handshake_iv: [u8; 12] = hkdf_expand_label::<12>(&s_hs_ts, "iv", &[]).as_slice()
498            [..12]
499            .try_into()
500            .unwrap();
501        debug_print("server_handshake_iv", &server_handshake_iv);
502
503        Ok(HandshakeState {
504            handshake_secret,
505            client_hs_ts: c_hs_ts,
506            server_hs_ts: s_hs_ts,
507            client_handshake_iv,
508            server_handshake_iv,
509            aead_enc_hs: client_handshake_key,
510            aead_dec_hs: server_handshake_key,
511            empty_hash,
512        })
513    }
514
515    fn derive_application_keys(
516        handshake_secret: &[u8; 32],
517        empty_hash: &[u8; 32],
518        transcript: &[u8],
519    ) -> ApplicationKeys {
520        let derived2_bytes = hkdf_expand_label::<32>(handshake_secret, "derived", empty_hash);
521        debug_print("derived2_bytes", &derived2_bytes);
522
523        let zero: [u8; 32] = [0u8; 32];
524        let master_secret = hkdf::hkdf_extract(&derived2_bytes, &zero);
525        let thash_srv_fin = digest_bytes(transcript);
526
527        let c_ap_ts = hkdf_expand_label::<32>(&master_secret, "c ap traffic", &thash_srv_fin);
528        let s_ap_ts = hkdf_expand_label::<32>(&master_secret, "s ap traffic", &thash_srv_fin);
529        debug_print("c_ap_ts", &c_ap_ts);
530        debug_print("s_ap_ts", &s_ap_ts);
531
532        let cak: [u8; 16] = hkdf_expand_label::<16>(&c_ap_ts, "key", &[]).as_slice()[..16]
533            .try_into()
534            .unwrap();
535        let caiv: [u8; 12] = hkdf_expand_label::<12>(&c_ap_ts, "iv", &[]).as_slice()[..12]
536            .try_into()
537            .unwrap();
538        debug_print("cak", &cak);
539        debug_print("caiv", &caiv);
540
541        let sak: [u8; 16] = hkdf_expand_label::<16>(&s_ap_ts, "key", &[]).as_slice()[..16]
542            .try_into()
543            .unwrap();
544        let saiv: [u8; 12] = hkdf_expand_label::<12>(&s_ap_ts, "iv", &[]).as_slice()[..12]
545            .try_into()
546            .unwrap();
547        debug_print("sak", &sak);
548        debug_print("saiv", &saiv);
549
550        ApplicationKeys {
551            aead_app_enc: cak,
552            aead_app_dec: sak,
553            iv_enc: caiv,
554            iv_dec: saiv,
555        }
556    }
557
558    fn send_client_finished<W: Write>(
559        io: &mut W,
560        handshake: &HandshakeState,
561        transcript: &mut Vec<u8>,
562        seq_enc_hs: &mut u64,
563    ) -> OrtResult<()> {
564        let c_finished_key = hkdf_expand_label::<32>(&handshake.client_hs_ts, "finished", &[]);
565        debug_print("c_finished", &c_finished_key);
566
567        let thash_client_fin = digest_bytes(transcript.as_slice());
568        let verify_data = hmac::sign(&c_finished_key, &thash_client_fin);
569        debug_print("verify_data", &verify_data);
570
571        let mut fin = Vec::with_capacity(4 + verify_data.as_ref().len());
572        fin.push(HS_FINISHED);
573        put_u24(&mut fin, verify_data.as_ref().len());
574        fin.extend_from_slice(verify_data.as_ref());
575
576        // append to transcript before switching keys
577        transcript.extend_from_slice(&fin);
578
579        write_record_cipher(
580            io,
581            REC_TYPE_HANDSHAKE,
582            &fin,
583            &handshake.aead_enc_hs,
584            &handshake.client_handshake_iv,
585            seq_enc_hs,
586        )
587        .context("write_record_cipher write_all failed")?;
588
589        Ok(())
590    }
591}
592
593impl<T: Read + Write> Write for TlsStream<T> {
594    fn write(&mut self, buf: &[u8]) -> OrtResult<usize> {
595        write_record_cipher(
596            &mut self.io,
597            REC_TYPE_APPDATA,
598            buf,
599            &self.aead_enc,
600            &self.iv_enc,
601            &mut self.seq_enc,
602        )
603        .map(|_| buf.len())
604    }
605    fn flush(&mut self) -> OrtResult<()> {
606        self.io.flush()
607    }
608}
609
610impl<T: Read + Write> Read for TlsStream<T> {
611    fn read(&mut self, out: &mut [u8]) -> OrtResult<usize> {
612        if self.rpos < self.rbuf.len() {
613            debug_print("TlsStream.read using buf", &[]);
614
615            let n = cmp::min(out.len(), self.rbuf.len() - self.rpos);
616            out[..n].copy_from_slice(&self.rbuf[self.rpos..self.rpos + n]);
617            self.rpos += n;
618            if self.rpos == self.rbuf.len() {
619                self.rbuf.clear();
620                self.rpos = 0;
621            }
622            return Ok(n);
623        }
624        loop {
625            let (typ, plaintext, inner_type) = read_record_cipher(
626                &mut self.io,
627                &self.aead_dec,
628                &self.iv_dec,
629                &mut self.seq_dec,
630            )?;
631            if typ != REC_TYPE_APPDATA {
632                // Ignore unexpected (e.g., post-handshake Handshake like NewSessionTicket)
633                continue;
634            }
635            // plaintext ends with inner content type byte; for app data it is 0x17.
636            if plaintext.is_empty() {
637                continue;
638            }
639            if inner_type == REC_TYPE_HANDSHAKE {
640                // Drop post-handshake messages (tickets, etc.)
641                continue;
642            }
643            if inner_type == REC_TYPE_ALERT {
644                let level = match plaintext[0] {
645                    1 => "warning",
646                    2 => "fatal",
647                    _ => "unknown",
648                };
649                let err_level = CString::new(level.to_string() + " alert: ").unwrap();
650
651                // See https://www.rfc-editor.org/rfc/rfc8446#appendix-B search for
652                // "unexpected_message" for all types
653                let mut err_code_buf: [u8; 5] = [0u8; 5];
654                let len = to_ascii(plaintext[1] as usize, &mut err_code_buf);
655                let err_code = unsafe { CStr::from_bytes_with_nul_unchecked(&err_code_buf[..len]) };
656                syscall::write(2, err_level.as_ptr().cast(), err_level.count_bytes());
657                syscall::write(2, err_code.as_ptr().cast(), err_code.count_bytes());
658
659                return Err(ort_error(ErrorKind::TlsAlertReceived, ""));
660            }
661            if inner_type != REC_TYPE_APPDATA {
662                // Some servers pad with 0x00.. then type; we already consumed type.
663                // If not 0x17, treat preceding bytes (if any) as app anyway.
664            }
665            if plaintext.is_empty() {
666                continue;
667            }
668
669            self.rbuf.extend_from_slice(&plaintext);
670            self.rpos = 0;
671            // Now serve from buffer
672            let n = cmp::min(out.len(), self.rbuf.len());
673            out[..n].copy_from_slice(&self.rbuf[..n]);
674            self.rpos = n;
675            if n == self.rbuf.len() {
676                self.rbuf.clear();
677                self.rpos = 0;
678            }
679            return Ok(n);
680        }
681    }
682}
683
684impl<T: Read + Write + AsFd> AsFd for TlsStream<T> {
685    fn as_fd(&self) -> i32 {
686        self.io.as_fd()
687    }
688}
689
690// ---------------------- Record I/O helpers ----------------------------------
691
692fn write_record_plain<W: Write>(w: &mut W, typ: u8, body: &[u8]) -> OrtResult<()> {
693    let mut hdr = [0u8; 5];
694    hdr[0] = typ;
695    hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
696    hdr[3..5].copy_from_slice(&(body.len() as u16).to_be_bytes());
697    w.write_all(&hdr)?;
698    w.write_all(body)?;
699    Ok(())
700}
701
702fn read_exact_n<R: Read>(r: &mut R, n: usize) -> OrtResult<Vec<u8>> {
703    let mut buf = vec![0u8; n];
704    r.read_exact(&mut buf)?;
705    Ok(buf)
706}
707
708fn read_record_plain<R: Read>(r: &mut R) -> OrtResult<(u8, Vec<u8>)> {
709    let hdr = read_exact_n(r, 5)?; // Record Header, e.g. 16 03 03 len
710    let typ = hdr[0];
711    let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
712    let body = read_exact_n(r, len)?;
713    debug_print("read_record_plain hdr", &hdr);
714    debug_print("read_record_plain body", &body);
715    //let _ = write_bytes_to_file(&[&hdr[..], &body].concat(), debug_filename);
716    Ok((typ, body))
717}
718
719fn write_record_cipher<W: Write>(
720    w: &mut W,
721    outer_type: u8,
722    inner: &[u8],
723    key: &[u8; 16],
724    iv12: &[u8; 12],
725    seq: &mut u64,
726) -> OrtResult<()> {
727    // AES / GCM plaintext and ciphertext have the same length
728    let total_len = inner.len() + 1 + AEAD_TAG_LEN;
729    let mut plain = Vec::with_capacity(total_len);
730    plain.extend_from_slice(inner);
731    plain.push(outer_type);
732
733    debug_print("write_record_cipher plaintext", &plain);
734
735    let nonce = nonce_xor(iv12, *seq);
736    *seq = seq.wrapping_add(1);
737
738    let mut hdr = [0u8; 5];
739    hdr[0] = REC_TYPE_APPDATA;
740    hdr[1..3].copy_from_slice(&LEGACY_REC_VER.to_be_bytes());
741    hdr[3..5].copy_from_slice(&(total_len as u16).to_be_bytes());
742
743    let out = aead::aes_128_gcm_encrypt(key, &nonce, &hdr, &plain).unwrap();
744
745    debug_print("write_record_cipher header", &hdr);
746    //let final_label = format!("write_record_cipher final {total_len}");
747    //debug_print(final_label.as_str(), &out);
748
749    w.write_all(&hdr)?;
750    w.write_all(&out)?;
751    Ok(())
752}
753
754fn read_record_cipher<R: Read>(
755    r: &mut R,
756    key: &[u8; 16],
757    iv12: &[u8; 12],
758    seq: &mut u64,
759) -> OrtResult<(u8, Vec<u8>, u8)> {
760    let hdr = read_exact_n(r, 5)?;
761    let typ = hdr[0];
762    let len = u16::from_be_bytes([hdr[3], hdr[4]]) as usize;
763    let ciphertext = read_exact_n(r, len)?;
764    if len < AEAD_TAG_LEN {
765        return Err(ort_error(ErrorKind::TlsRecordTooShort, "short record"));
766    }
767    debug_print("read_record_cipher hdr", &hdr);
768    debug_print("read_record_cipher ct", &ciphertext);
769
770    //let size_expected = crate::utils::num_to_string(len);
771    //let size_read = crate::utils::num_to_string(ciphertext.len());
772    //crate::utils::print_string(c"size_expected ", &size_expected);
773    //crate::utils::print_string(c"size_read ", &size_read);
774
775    // Decrypt ciphertext
776
777    let nonce = nonce_xor(iv12, *seq);
778    *seq = seq.wrapping_add(1);
779
780    let mut out = match aead::aes_128_gcm_decrypt(key, &nonce, &hdr, &ciphertext) {
781        Ok(out) => out,
782        Err(s) => {
783            return Err(ort_error(ErrorKind::TlsAes128GcmDecryptFailed, s));
784        }
785    };
786
787    debug_print("read_record_cipher plaintext hdr", &hdr);
788    debug_print("read_record_cipher plaintext", &out);
789
790    if out.is_empty() {
791        return Ok((typ, ciphertext, 0));
792    }
793    // Strip inner content-type byte
794    let inner_type = *out.last().unwrap();
795    out.truncate(out.len() - 1);
796    Ok((typ, out, inner_type))
797}
798
799// ---------------------- Handshake parsing helpers ---------------------------
800
801fn read_handshake_message<'a>(rd: &mut &'a [u8]) -> OrtResult<(u8, &'a [u8], &'a [u8])> {
802    if rd.len() < 4 {
803        return Err(ort_error(ErrorKind::TlsHandshakeHeaderTooShort, ""));
804    }
805    let typ = rd[0];
806    let len = ((rd[1] as usize) << 16) | ((rd[2] as usize) << 8) | rd[3] as usize;
807    if rd.len() < 4 + len {
808        return Err(ort_error(ErrorKind::TlsHandshakeBodyTooShort, ""));
809    }
810    let full = &rd[..4 + len];
811    let body = &rd[4..4 + len];
812    *rd = &rd[4 + len..];
813    Ok((typ, body, full))
814}
815
816fn parse_server_hello_for_keys(sh: &[u8]) -> OrtResult<(u16, [u8; 32])> {
817    // minimal parse: skip legacy_version(2), random(32), sid, cipher(2), comp(1), exts
818    if sh.len() < 2 + 32 + 1 + 2 + 1 + 2 {
819        return Err(ort_error(ErrorKind::TlsServerHelloTooShort, ""));
820    }
821    let mut p = sh;
822
823    p = &p[2..]; // legacy_version
824    p = &p[32..]; // random
825    let sid_len = p[0] as usize;
826    p = &p[1..];
827    if p.len() < sid_len + 2 + 1 + 2 {
828        return Err(ort_error(ErrorKind::TlsServerHelloSessionIdInvalid, ""));
829    }
830    p = &p[sid_len..];
831    let cipher = u16::from_be_bytes([p[0], p[1]]);
832    p = &p[2..];
833    let _comp = p[0];
834    p = &p[1..];
835    let ext_len = u16::from_be_bytes([p[0], p[1]]) as usize;
836    p = &p[2..];
837    if p.len() < ext_len {
838        return Err(ort_error(ErrorKind::TlsServerHelloExtTooShort, ""));
839    }
840    let mut ex = &p[..ext_len];
841
842    let mut server_pub = None;
843
844    while !ex.is_empty() {
845        if ex.len() < 4 {
846            return Err(ort_error(ErrorKind::TlsExtensionHeaderTooShort, ""));
847        }
848        let et = u16::from_be_bytes([ex[0], ex[1]]);
849        let el = u16::from_be_bytes([ex[2], ex[3]]) as usize;
850        ex = &ex[4..];
851        if ex.len() < el {
852            return Err(ort_error(ErrorKind::TlsExtensionLengthInvalid, ""));
853        }
854        let ed = &ex[..el];
855        ex = &ex[el..];
856
857        match et {
858            EXT_KEY_SHARE => {
859                // KeyShareServerHello: group(2) kx_len(2) kx
860                if ed.len() < 2 + 2 + 32 {
861                    return Err(ort_error(ErrorKind::TlsKeyShareServerHelloInvalid, ""));
862                }
863                let grp = u16::from_be_bytes([ed[0], ed[1]]);
864                if grp != GROUP_X25519 {
865                    return Err(ort_error(
866                        ErrorKind::TlsServerGroupUnsupported,
867                        "server group != x25519",
868                    ));
869                }
870                let kx_len = u16::from_be_bytes([ed[2], ed[3]]) as usize;
871                if ed.len() < 4 + kx_len || kx_len != 32 {
872                    return Err(ort_error(ErrorKind::TlsKeyShareLengthInvalid, ""));
873                }
874                let mut pk = [0u8; 32];
875                pk.copy_from_slice(&ed[4..4 + 32]);
876                server_pub = Some(pk);
877            }
878            EXT_SUPPORTED_VERSIONS
879                if (ed.len() != 2 || u16::from_be_bytes([ed[0], ed[1]]) != TLS13) =>
880            {
881                return Err(ort_error(ErrorKind::TlsServerNotTls13, ""));
882            }
883            _ => {}
884        }
885    }
886
887    let sp = server_pub.ok_or_else(|| ort_error(ErrorKind::TlsMissingServerKey, ""))?;
888    Ok((cipher, sp))
889}
890
891#[allow(unused)]
892fn debug_print(name: &str, value: &[u8]) {
893    #[cfg(debug_assertions)]
894    {
895        if !DEBUG_LOG {
896            return;
897        }
898        let c_str = CString::new(name).unwrap();
899        if !value.is_empty() {
900            crate::utils::print_hex(c_str.as_c_str(), value);
901        } else {
902            crate::utils::print_string(c_str.as_c_str(), "");
903        }
904    }
905}
906
907/*
908#[allow(dead_code)]
909fn write_bytes_to_file(bytes: &[u8], file_path: &str) -> std::io::Result<()> {
910    let mut file = File::create(file_path)?;
911    file.write_all(bytes)?;
912    Ok(())
913}
914*/
915
916#[cfg(test)]
917pub mod tests {
918    extern crate alloc;
919    use alloc::vec::Vec;
920
921    pub fn string_to_bytes(s: &str) -> [u8; 32] {
922        let mut bytes = s.as_bytes();
923        if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'x' || bytes[1] == b'X') {
924            bytes = &bytes[2..];
925        }
926        assert!(
927            bytes.len() == 64,
928            "hex string must be exactly 64 hex chars (32 bytes)"
929        );
930
931        let mut out = [0u8; 32];
932        for i in 0..32 {
933            let hi = hex_val(bytes[2 * i]);
934            let lo = hex_val(bytes[2 * i + 1]);
935            out[i] = (hi << 4) | lo;
936        }
937        out
938    }
939
940    pub fn hex_to_vec(s: &str) -> Vec<u8> {
941        let mut bytes = s.as_bytes();
942        if bytes.len() >= 2 && bytes[0] == b'0' && (bytes[1] == b'X' || bytes[1] == b'x') {
943            bytes = &bytes[2..];
944        }
945        assert_eq!(bytes.len() % 2, 0, "hex string must have even length");
946        let mut out = Vec::with_capacity(bytes.len() / 2);
947        for chunk in bytes.chunks_exact(2) {
948            let hi = hex_val(chunk[0]);
949            let lo = hex_val(chunk[1]);
950            out.push((hi << 4) | lo);
951        }
952        out
953    }
954
955    fn hex_val(b: u8) -> u8 {
956        match b {
957            b'0'..=b'9' => b - b'0',
958            b'a'..=b'f' => b - b'a' + 10,
959            b'A'..=b'F' => b - b'A' + 10,
960            _ => panic!("invalid hex character"),
961        }
962    }
963}