Skip to main content

oxitls_core/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3//! `oxitls-core` — Pure-Rust TLS transport primitives.
4//!
5//! This crate provides the foundational types used across the OxiTLS ecosystem:
6//! error types, TLS version and cipher suite enumerations, connection
7//! information, and trait definitions for TLS connectors/acceptors.
8
9use std::fmt;
10use std::future::Future;
11use std::io;
12use std::pin::Pin;
13use tokio::io::{AsyncRead, AsyncWrite};
14
15// ── Sub-modules ───────────────────────────────────────────────────────────────
16
17/// OS-entropy CSPRNG adapter implementing the `rand_core` 0.6 traits required
18/// by `rsa`/`ed25519-dalek`/`x25519-dalek` (decoupled from the workspace
19/// `rand`/`rand_core` 0.10).
20pub mod os_rng;
21
22/// Key-logging policy for TLS session secret export.
23pub mod keylog;
24
25/// TLS alert description codes (RFC 8446 §6).
26pub mod alert;
27
28/// Generic TLS configuration introspection trait.
29pub mod config;
30
31/// Helpers for extracting [`ConnectionInfo`] from a rustls connection state.
32pub mod stream_info;
33
34// Re-export top-level for backward compatibility.
35pub use alert::AlertDescription;
36pub use keylog::{KeyLog, KeyLogPolicy};
37pub use os_rng::OsRng;
38pub use stream_info::connection_info_from;
39
40// ── TLS Version ──────────────────────────────────────────────────────────────
41
42/// TLS protocol version.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum TlsVersion {
45    /// TLS 1.2 (RFC 5246)
46    Tls12,
47    /// TLS 1.3 (RFC 8446)
48    Tls13,
49}
50
51impl TlsVersion {
52    /// All known TLS versions, in ascending order.
53    pub const ALL: &'static [TlsVersion] = &[TlsVersion::Tls12, TlsVersion::Tls13];
54}
55
56impl fmt::Display for TlsVersion {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            TlsVersion::Tls12 => write!(f, "TLS 1.2"),
60            TlsVersion::Tls13 => write!(f, "TLS 1.3"),
61        }
62    }
63}
64
65impl std::str::FromStr for TlsVersion {
66    type Err = TlsError;
67
68    fn from_str(s: &str) -> Result<Self, Self::Err> {
69        match s {
70            "TLS 1.2" | "tls1.2" | "TLSv1.2" | "1.2" => Ok(TlsVersion::Tls12),
71            "TLS 1.3" | "tls1.3" | "TLSv1.3" | "1.3" => Ok(TlsVersion::Tls13),
72            _ => Err(TlsError::Other(format!("unknown TLS version: {s}"))),
73        }
74    }
75}
76
77// ── Cipher Suite ─────────────────────────────────────────────────────────────
78
79/// TLS cipher suite identifiers covering TLS 1.3 mandatory suites and
80/// commonly-used TLS 1.2 AEAD suites.
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
82#[non_exhaustive]
83pub enum CipherSuite {
84    // ── TLS 1.3 (RFC 8446, Section 9.1) ──
85    /// TLS_AES_128_GCM_SHA256 (0x13,0x01)
86    Tls13Aes128GcmSha256,
87    /// TLS_AES_256_GCM_SHA384 (0x13,0x02)
88    Tls13Aes256GcmSha384,
89    /// TLS_CHACHA20_POLY1305_SHA256 (0x13,0x03)
90    Tls13Chacha20Poly1305Sha256,
91
92    // ── TLS 1.2 AEAD suites ──
93    /// TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 (0xC0,0x2B)
94    Tls12EcdheEcdsaAes128GcmSha256,
95    /// TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 (0xC0,0x2C)
96    Tls12EcdheEcdsaAes256GcmSha384,
97    /// TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 (0xC0,0x2F)
98    Tls12EcdheRsaAes128GcmSha256,
99    /// TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 (0xC0,0x30)
100    Tls12EcdheRsaAes256GcmSha384,
101    /// TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 (0xCC,0xA9)
102    Tls12EcdheEcdsaChacha20Poly1305Sha256,
103    /// TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (0xCC,0xA8)
104    Tls12EcdheRsaChacha20Poly1305Sha256,
105    /// An unrecognised cipher suite not covered by the variants above.
106    Unknown,
107}
108
109impl CipherSuite {
110    /// Returns the IANA two-byte identifier for this cipher suite.
111    ///
112    /// Returns `[0x00, 0x00]` for [`CipherSuite::Unknown`].
113    pub fn iana_value(&self) -> [u8; 2] {
114        match self {
115            CipherSuite::Tls13Aes128GcmSha256 => [0x13, 0x01],
116            CipherSuite::Tls13Aes256GcmSha384 => [0x13, 0x02],
117            CipherSuite::Tls13Chacha20Poly1305Sha256 => [0x13, 0x03],
118            CipherSuite::Tls12EcdheEcdsaAes128GcmSha256 => [0xC0, 0x2B],
119            CipherSuite::Tls12EcdheEcdsaAes256GcmSha384 => [0xC0, 0x2C],
120            CipherSuite::Tls12EcdheRsaAes128GcmSha256 => [0xC0, 0x2F],
121            CipherSuite::Tls12EcdheRsaAes256GcmSha384 => [0xC0, 0x30],
122            CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256 => [0xCC, 0xA9],
123            CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256 => [0xCC, 0xA8],
124            CipherSuite::Unknown => [0xFF, 0xFF],
125        }
126    }
127
128    /// Try to look up a cipher suite from its IANA two-byte identifier.
129    pub fn from_iana(bytes: [u8; 2]) -> Option<Self> {
130        match bytes {
131            [0x13, 0x01] => Some(CipherSuite::Tls13Aes128GcmSha256),
132            [0x13, 0x02] => Some(CipherSuite::Tls13Aes256GcmSha384),
133            [0x13, 0x03] => Some(CipherSuite::Tls13Chacha20Poly1305Sha256),
134            [0xC0, 0x2B] => Some(CipherSuite::Tls12EcdheEcdsaAes128GcmSha256),
135            [0xC0, 0x2C] => Some(CipherSuite::Tls12EcdheEcdsaAes256GcmSha384),
136            [0xC0, 0x2F] => Some(CipherSuite::Tls12EcdheRsaAes128GcmSha256),
137            [0xC0, 0x30] => Some(CipherSuite::Tls12EcdheRsaAes256GcmSha384),
138            [0xCC, 0xA9] => Some(CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256),
139            [0xCC, 0xA8] => Some(CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256),
140            _ => None,
141        }
142    }
143
144    /// Whether this suite belongs to TLS 1.3.
145    pub fn is_tls13(&self) -> bool {
146        matches!(
147            self,
148            CipherSuite::Tls13Aes128GcmSha256
149                | CipherSuite::Tls13Aes256GcmSha384
150                | CipherSuite::Tls13Chacha20Poly1305Sha256
151        )
152    }
153
154    /// Whether this suite belongs to TLS 1.2.
155    pub fn is_tls12(&self) -> bool {
156        matches!(
157            self,
158            CipherSuite::Tls12EcdheEcdsaAes128GcmSha256
159                | CipherSuite::Tls12EcdheEcdsaAes256GcmSha384
160                | CipherSuite::Tls12EcdheRsaAes128GcmSha256
161                | CipherSuite::Tls12EcdheRsaAes256GcmSha384
162                | CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256
163                | CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256
164        )
165    }
166
167    /// Whether this is the `Unknown` catch-all variant.
168    pub fn is_unknown(&self) -> bool {
169        matches!(self, CipherSuite::Unknown)
170    }
171
172    /// All named cipher suites (excluding [`CipherSuite::Unknown`]).
173    pub const ALL: &'static [CipherSuite] = &[
174        CipherSuite::Tls13Aes128GcmSha256,
175        CipherSuite::Tls13Aes256GcmSha384,
176        CipherSuite::Tls13Chacha20Poly1305Sha256,
177        CipherSuite::Tls12EcdheEcdsaAes128GcmSha256,
178        CipherSuite::Tls12EcdheEcdsaAes256GcmSha384,
179        CipherSuite::Tls12EcdheRsaAes128GcmSha256,
180        CipherSuite::Tls12EcdheRsaAes256GcmSha384,
181        CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256,
182        CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256,
183    ];
184}
185
186impl fmt::Display for CipherSuite {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        let name = match self {
189            CipherSuite::Tls13Aes128GcmSha256 => "TLS_AES_128_GCM_SHA256",
190            CipherSuite::Tls13Aes256GcmSha384 => "TLS_AES_256_GCM_SHA384",
191            CipherSuite::Tls13Chacha20Poly1305Sha256 => "TLS_CHACHA20_POLY1305_SHA256",
192            CipherSuite::Tls12EcdheEcdsaAes128GcmSha256 => {
193                "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
194            }
195            CipherSuite::Tls12EcdheEcdsaAes256GcmSha384 => {
196                "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
197            }
198            CipherSuite::Tls12EcdheRsaAes128GcmSha256 => "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
199            CipherSuite::Tls12EcdheRsaAes256GcmSha384 => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
200            CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256 => {
201                "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"
202            }
203            CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256 => {
204                "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256"
205            }
206            CipherSuite::Unknown => "UNKNOWN",
207        };
208        write!(f, "{name}")
209    }
210}
211
212impl std::str::FromStr for CipherSuite {
213    type Err = TlsError;
214
215    fn from_str(s: &str) -> Result<Self, Self::Err> {
216        match s {
217            "TLS_AES_128_GCM_SHA256" => Ok(CipherSuite::Tls13Aes128GcmSha256),
218            "TLS_AES_256_GCM_SHA384" => Ok(CipherSuite::Tls13Aes256GcmSha384),
219            "TLS_CHACHA20_POLY1305_SHA256" => Ok(CipherSuite::Tls13Chacha20Poly1305Sha256),
220            "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => {
221                Ok(CipherSuite::Tls12EcdheEcdsaAes128GcmSha256)
222            }
223            "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => {
224                Ok(CipherSuite::Tls12EcdheEcdsaAes256GcmSha384)
225            }
226            "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => {
227                Ok(CipherSuite::Tls12EcdheRsaAes128GcmSha256)
228            }
229            "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => {
230                Ok(CipherSuite::Tls12EcdheRsaAes256GcmSha384)
231            }
232            "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => {
233                Ok(CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256)
234            }
235            "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => {
236                Ok(CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256)
237            }
238            "UNKNOWN" => Ok(CipherSuite::Unknown),
239            _ => Err(TlsError::Other(format!("unknown cipher suite: {s}"))),
240        }
241    }
242}
243
244// ── Connection Info ──────────────────────────────────────────────────────────
245
246/// Information about a completed TLS connection.
247///
248/// Constructed incrementally by adapter crates after the handshake completes.
249/// The builder pattern allows partial population (e.g. ALPN may be `None` if
250/// not negotiated).
251#[derive(Debug, Clone)]
252pub struct ConnectionInfo {
253    /// The negotiated TLS protocol version.
254    pub version: Option<TlsVersion>,
255    /// The negotiated cipher suite.
256    pub cipher_suite: Option<CipherSuite>,
257    /// The negotiated ALPN protocol (e.g. `b"h2"`, `b"http/1.1"`).
258    pub alpn_protocol: Option<Vec<u8>>,
259    /// The SNI (Server Name Indication) value sent by the client.
260    pub sni: Option<String>,
261    /// DER-encoded peer certificates (leaf first), if provided.
262    pub peer_certificates: Vec<Vec<u8>>,
263}
264
265impl ConnectionInfo {
266    /// Create a new empty `ConnectionInfo`.
267    pub fn new() -> Self {
268        Self {
269            version: None,
270            cipher_suite: None,
271            alpn_protocol: None,
272            sni: None,
273            peer_certificates: Vec::new(),
274        }
275    }
276
277    /// Set the negotiated TLS version.
278    pub fn with_version(mut self, version: TlsVersion) -> Self {
279        self.version = Some(version);
280        self
281    }
282
283    /// Set the negotiated cipher suite.
284    pub fn with_cipher_suite(mut self, suite: CipherSuite) -> Self {
285        self.cipher_suite = Some(suite);
286        self
287    }
288
289    /// Set the negotiated ALPN protocol.
290    pub fn with_alpn_protocol(mut self, proto: Vec<u8>) -> Self {
291        self.alpn_protocol = Some(proto);
292        self
293    }
294
295    /// Set the SNI name.
296    pub fn with_sni(mut self, sni: String) -> Self {
297        self.sni = Some(sni);
298        self
299    }
300
301    /// Set the peer certificate chain (DER-encoded, leaf first).
302    pub fn with_peer_certificates(mut self, certs: Vec<Vec<u8>>) -> Self {
303        self.peer_certificates = certs;
304        self
305    }
306
307    /// The negotiated ALPN protocol as a UTF-8 string, if it is valid UTF-8.
308    pub fn alpn_protocol_str(&self) -> Option<&str> {
309        self.alpn_protocol
310            .as_ref()
311            .and_then(|p| std::str::from_utf8(p).ok())
312    }
313}
314
315impl Default for ConnectionInfo {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321// ── ConnectionInfo Builder ───────────────────────────────────────────────────
322
323/// Fluent builder for [`ConnectionInfo`].
324///
325/// Provides a separate builder type for constructing `ConnectionInfo` using
326/// snake_case setter methods (`version()`, `cipher_suite()`, etc.) rather than
327/// the `with_*` methods on `ConnectionInfo` itself.
328///
329/// # Example
330/// ```
331/// use oxitls_core::{CipherSuite, ConnectionInfoBuilder, TlsVersion};
332///
333/// let info = ConnectionInfoBuilder::new()
334///     .version(TlsVersion::Tls13)
335///     .cipher_suite(CipherSuite::Tls13Aes256GcmSha384)
336///     .alpn_protocol(b"h2".to_vec())
337///     .sni("example.com".to_string())
338///     .build();
339///
340/// assert_eq!(info.version, Some(TlsVersion::Tls13));
341/// ```
342#[derive(Debug, Default)]
343pub struct ConnectionInfoBuilder {
344    inner: ConnectionInfo,
345}
346
347impl ConnectionInfoBuilder {
348    /// Create a new builder with all fields set to `None`.
349    pub fn new() -> Self {
350        Self::default()
351    }
352
353    /// Set the negotiated TLS version.
354    pub fn version(mut self, version: TlsVersion) -> Self {
355        self.inner.version = Some(version);
356        self
357    }
358
359    /// Set the negotiated cipher suite.
360    pub fn cipher_suite(mut self, suite: CipherSuite) -> Self {
361        self.inner.cipher_suite = Some(suite);
362        self
363    }
364
365    /// Set the negotiated ALPN protocol bytes.
366    pub fn alpn_protocol(mut self, proto: Vec<u8>) -> Self {
367        self.inner.alpn_protocol = Some(proto);
368        self
369    }
370
371    /// Set the SNI server name.
372    pub fn sni(mut self, sni: String) -> Self {
373        self.inner.sni = Some(sni);
374        self
375    }
376
377    /// Set the peer certificate chain (DER-encoded, leaf first).
378    pub fn peer_certificates(mut self, certs: Vec<Vec<u8>>) -> Self {
379        self.inner.peer_certificates = certs;
380        self
381    }
382
383    /// Consume the builder and produce a [`ConnectionInfo`].
384    pub fn build(self) -> ConnectionInfo {
385        self.inner
386    }
387}
388
389// ── TLS Error ────────────────────────────────────────────────────────────────
390
391/// Errors that can occur during TLS operations.
392#[derive(Debug, Clone, PartialEq)]
393#[non_exhaustive]
394pub enum TlsError {
395    /// An I/O error occurred, identified by its kind.
396    Io(io::ErrorKind),
397    /// A TLS handshake error.
398    Handshake(String),
399    /// An invalid or unacceptable certificate.
400    BadCert(String),
401    /// The TLS configuration is invalid.
402    InvalidConfig(String),
403    /// A certificate has been revoked (CRL or OCSP).
404    CertRevoked(String),
405    /// A certificate is invalid (e.g. bad signature, malformed DER, expired).
406    CertInvalid(String),
407    /// The remote peer violated the TLS protocol.
408    ProtocolViolation(String),
409    /// A TLS alert was received from the peer.
410    AlertReceived(AlertDescription),
411    /// Any other TLS error.
412    Other(String),
413}
414
415impl TlsError {
416    /// Returns `true` if this is a handshake error.
417    pub fn is_handshake(&self) -> bool {
418        matches!(self, TlsError::Handshake(_))
419    }
420
421    /// Returns `true` if this is an I/O error.
422    pub fn is_io(&self) -> bool {
423        matches!(self, TlsError::Io(_))
424    }
425
426    /// Returns `true` if this is a certificate-related error.
427    pub fn is_cert(&self) -> bool {
428        matches!(
429            self,
430            TlsError::BadCert(_) | TlsError::CertRevoked(_) | TlsError::CertInvalid(_)
431        )
432    }
433
434    /// Returns `true` if this is a configuration error.
435    pub fn is_config(&self) -> bool {
436        matches!(self, TlsError::InvalidConfig(_))
437    }
438
439    /// Returns `true` if this is a protocol-violation error.
440    pub fn is_protocol_violation(&self) -> bool {
441        matches!(self, TlsError::ProtocolViolation(_))
442    }
443}
444
445impl fmt::Display for TlsError {
446    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
447        match self {
448            TlsError::Io(k) => write!(f, "I/O error: {k:?}"),
449            TlsError::Handshake(s) => write!(f, "handshake error: {s}"),
450            TlsError::BadCert(s) => write!(f, "bad certificate: {s}"),
451            TlsError::InvalidConfig(s) => write!(f, "invalid config: {s}"),
452            TlsError::CertRevoked(s) => write!(f, "certificate revoked: {s}"),
453            TlsError::CertInvalid(s) => write!(f, "invalid certificate: {s}"),
454            TlsError::ProtocolViolation(s) => write!(f, "protocol violation: {s}"),
455            TlsError::AlertReceived(d) => write!(f, "TLS alert received: {d}"),
456            TlsError::Other(s) => write!(f, "TLS error: {s}"),
457        }
458    }
459}
460
461impl std::error::Error for TlsError {}
462
463impl From<io::Error> for TlsError {
464    fn from(e: io::Error) -> Self {
465        TlsError::Io(e.kind())
466    }
467}
468
469impl From<TlsError> for io::Error {
470    fn from(e: TlsError) -> Self {
471        match e {
472            TlsError::Io(kind) => io::Error::new(kind, "TLS I/O error"),
473            TlsError::Handshake(s) => io::Error::new(io::ErrorKind::ConnectionAborted, s),
474            TlsError::BadCert(s) => io::Error::new(io::ErrorKind::InvalidData, s),
475            TlsError::InvalidConfig(s) => io::Error::new(io::ErrorKind::InvalidInput, s),
476            TlsError::CertRevoked(s) => io::Error::new(io::ErrorKind::PermissionDenied, s),
477            TlsError::CertInvalid(s) => io::Error::new(io::ErrorKind::InvalidData, s),
478            TlsError::ProtocolViolation(s) => io::Error::new(io::ErrorKind::InvalidData, s),
479            TlsError::AlertReceived(d) => {
480                io::Error::new(io::ErrorKind::ConnectionAborted, format!("TLS alert: {d}"))
481            }
482            TlsError::Other(s) => io::Error::other(s),
483        }
484    }
485}
486
487impl From<rustls::Error> for TlsError {
488    fn from(e: rustls::Error) -> Self {
489        match &e {
490            rustls::Error::NoCertificatesPresented => {
491                TlsError::CertInvalid("no certificates presented".to_string())
492            }
493            rustls::Error::UnsupportedNameType => {
494                TlsError::CertInvalid("unsupported name type".to_string())
495            }
496            rustls::Error::InvalidCertificate(reason) => {
497                TlsError::CertInvalid(format!("{reason:?}"))
498            }
499            rustls::Error::PeerIncompatible(reason) => {
500                TlsError::ProtocolViolation(format!("{reason:?}"))
501            }
502            rustls::Error::PeerMisbehaved(reason) => {
503                TlsError::ProtocolViolation(format!("{reason:?}"))
504            }
505            rustls::Error::AlertReceived(alert) => TlsError::Handshake(format!("alert: {alert:?}")),
506            rustls::Error::BadMaxFragmentSize => {
507                TlsError::InvalidConfig("bad max fragment size".to_string())
508            }
509            rustls::Error::General(s) => TlsError::Other(s.clone()),
510            _ => TlsError::Other(e.to_string()),
511        }
512    }
513}
514
515// ── TLS Stream ───────────────────────────────────────────────────────────────
516
517/// A boxed async stream that can be read from and written to.
518pub type TlsStream = Box<dyn TlsStreamTrait>;
519
520/// Trait alias for an async TLS stream.
521pub trait TlsStreamTrait: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
522impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> TlsStreamTrait for T {}
523
524/// Types that can establish outbound TLS connections.
525///
526/// Implementations wrap a transport-layer stream in a TLS client handshake,
527/// producing a [`TlsStream`] on success.  The trait is object-safe: it can be
528/// used through `Box<dyn TlsConnector>` or `Arc<dyn TlsConnector>`.
529pub trait TlsConnector: Send + Sync + 'static {
530    /// Perform the TLS client handshake over `stream`, using `server_name`
531    /// for SNI and certificate verification.
532    ///
533    /// Returns the wrapped [`TlsStream`] on success, or a [`TlsError`] on
534    /// failure.
535    fn connect(
536        &self,
537        stream: TlsStream,
538        server_name: rustls::pki_types::ServerName<'static>,
539    ) -> Pin<Box<dyn Future<Output = Result<TlsStream, TlsError>> + Send + '_>>;
540}
541
542/// Types that can accept inbound TLS connections.
543///
544/// Implementations wrap a transport-layer stream in a TLS server handshake,
545/// producing a [`TlsStream`] on success.  The trait is object-safe: it can be
546/// used through `Box<dyn TlsAcceptor>` or `Arc<dyn TlsAcceptor>`.
547pub trait TlsAcceptor: Send + Sync + 'static {
548    /// Perform the TLS server handshake over `stream`.
549    ///
550    /// Returns the wrapped [`TlsStream`] on success, or a [`TlsError`] on
551    /// failure.
552    fn accept(
553        &self,
554        stream: TlsStream,
555    ) -> Pin<Box<dyn Future<Output = Result<TlsStream, TlsError>> + Send + '_>>;
556}
557
558/// Trait for TLS streams that can expose post-handshake connection metadata.
559///
560/// Implementors may override [`Self::connection_info`] to return a reference to the
561/// [`ConnectionInfo`] populated after the handshake completes. The default
562/// implementation returns `None`, which is appropriate for stream wrappers that
563/// do not have access to connection metadata (e.g. transparent proxies).
564///
565/// # Example
566/// ```
567/// use oxitls_core::{ConnectionInfo, TlsStreamInfo};
568///
569/// struct MyStream {
570///     info: ConnectionInfo,
571/// }
572///
573/// impl TlsStreamInfo for MyStream {
574///     fn connection_info(&self) -> Option<&ConnectionInfo> {
575///         Some(&self.info)
576///     }
577/// }
578/// ```
579pub trait TlsStreamInfo {
580    /// Return the [`ConnectionInfo`] for this stream, if available.
581    ///
582    /// Returns `None` until the TLS handshake has completed, or for streams
583    /// that do not expose connection metadata.
584    fn connection_info(&self) -> Option<&ConnectionInfo> {
585        None
586    }
587}
588
589// ── Generic Transport GAT Traits ─────────────────────────────────────────────
590
591/// Boxed, pinned future returned by [`GenericTlsConnector`] and
592/// [`GenericTlsAcceptor`] method implementations.
593///
594/// The lifetime `'a` is tied to `&'a self` so that implementations may borrow
595/// `self` inside the async block.
596#[cfg(feature = "generic-transport")]
597pub type GenericTlsFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, TlsError>> + Send + 'a>>;
598
599/// Types that can establish outbound TLS connections while preserving the
600/// concrete underlying transport type.
601///
602/// Unlike [`TlsConnector`], which erases the transport to `Box<dyn
603/// TlsStreamTrait>`, this trait uses a generic associated type (GAT)
604/// `Stream<S>` so callers retain the concrete `S` through the TLS layer.
605/// This avoids heap allocation of the transport itself: only the returned
606/// `Future` is boxed.
607///
608/// # Usage
609///
610/// Implementations are used through `<C: GenericTlsConnector>` bounds, not
611/// through `dyn GenericTlsConnector` (the GAT makes the trait non-object-safe).
612///
613/// # Example
614///
615/// ```ignore
616/// async fn connect_plain<C: GenericTlsConnector, S>(
617///     connector: &C,
618///     stream: S,
619///     name: rustls::pki_types::ServerName<'static>,
620/// ) -> Result<C::Stream<S>, TlsError>
621/// where
622///     S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
623/// {
624///     connector.connect(stream, name).await
625/// }
626/// ```
627#[cfg(feature = "generic-transport")]
628pub trait GenericTlsConnector: Send + Sync + 'static {
629    /// The TLS-wrapped stream type.  Preserves the concrete transport `S`.
630    type Stream<S>: AsyncRead + AsyncWrite + Unpin + Send + TlsStreamInfo
631    where
632        S: AsyncRead + AsyncWrite + Unpin + Send + 'static;
633
634    /// Perform the TLS client handshake over `stream`, using `server_name`
635    /// for SNI and certificate verification.
636    ///
637    /// Returns the wrapped `Self::Stream<S>` on success, or a [`TlsError`] on
638    /// failure.  The returned `Future` is boxed so the method is callable from
639    /// `dyn`-erased contexts that don't use the GAT.
640    fn connect<S>(
641        &self,
642        stream: S,
643        server_name: rustls::pki_types::ServerName<'static>,
644    ) -> GenericTlsFuture<'_, Self::Stream<S>>
645    where
646        S: AsyncRead + AsyncWrite + Unpin + Send + 'static;
647}
648
649/// Types that can accept inbound TLS connections while preserving the concrete
650/// underlying transport type.
651///
652/// The mirror of [`GenericTlsConnector`] for the server side.  The GAT
653/// `Stream<S>` avoids erasing the transport to a boxed trait object; only the
654/// returned `Future` is boxed.
655///
656/// # Usage
657///
658/// Implementations are used through `<A: GenericTlsAcceptor>` bounds, not
659/// through `dyn GenericTlsAcceptor` (the GAT makes the trait non-object-safe).
660#[cfg(feature = "generic-transport")]
661pub trait GenericTlsAcceptor: Send + Sync + 'static {
662    /// The TLS-wrapped stream type.  Preserves the concrete transport `S`.
663    type Stream<S>: AsyncRead + AsyncWrite + Unpin + Send + TlsStreamInfo
664    where
665        S: AsyncRead + AsyncWrite + Unpin + Send + 'static;
666
667    /// Perform the TLS server handshake over `stream`.
668    ///
669    /// Returns the wrapped `Self::Stream<S>` on success, or a [`TlsError`] on
670    /// failure.  The returned `Future` is boxed so the method is callable from
671    /// `dyn`-erased contexts that don't use the GAT.
672    fn accept<S>(&self, stream: S) -> GenericTlsFuture<'_, Self::Stream<S>>
673    where
674        S: AsyncRead + AsyncWrite + Unpin + Send + 'static;
675}
676
677// ── Tests ────────────────────────────────────────────────────────────────────
678
679#[cfg(test)]
680mod tests {
681    use super::*;
682
683    #[test]
684    fn tls_version_display_roundtrip() {
685        let v13 = TlsVersion::Tls13;
686        let s = v13.to_string();
687        assert_eq!(s, "TLS 1.3");
688        let parsed: TlsVersion = s.parse().expect("should parse");
689        assert_eq!(parsed, v13);
690
691        let v12 = TlsVersion::Tls12;
692        let s = v12.to_string();
693        assert_eq!(s, "TLS 1.2");
694        let parsed: TlsVersion = s.parse().expect("should parse");
695        assert_eq!(parsed, v12);
696    }
697
698    #[test]
699    fn tls_version_parse_variants() {
700        assert_eq!("tls1.3".parse::<TlsVersion>().ok(), Some(TlsVersion::Tls13));
701        assert_eq!(
702            "TLSv1.2".parse::<TlsVersion>().ok(),
703            Some(TlsVersion::Tls12)
704        );
705        assert_eq!("1.3".parse::<TlsVersion>().ok(), Some(TlsVersion::Tls13));
706        assert!("TLS 1.0".parse::<TlsVersion>().is_err());
707    }
708
709    #[test]
710    fn cipher_suite_display_roundtrip() {
711        let suites = [
712            CipherSuite::Tls13Aes128GcmSha256,
713            CipherSuite::Tls13Aes256GcmSha384,
714            CipherSuite::Tls13Chacha20Poly1305Sha256,
715            CipherSuite::Tls12EcdheEcdsaAes128GcmSha256,
716            CipherSuite::Tls12EcdheEcdsaAes256GcmSha384,
717            CipherSuite::Tls12EcdheRsaAes128GcmSha256,
718            CipherSuite::Tls12EcdheRsaAes256GcmSha384,
719            CipherSuite::Tls12EcdheEcdsaChacha20Poly1305Sha256,
720            CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256,
721        ];
722        for suite in &suites {
723            let s = suite.to_string();
724            let parsed: CipherSuite = s.parse().expect("should parse");
725            assert_eq!(&parsed, suite);
726        }
727    }
728
729    #[test]
730    fn cipher_suite_iana_roundtrip() {
731        let suites = [
732            CipherSuite::Tls13Aes128GcmSha256,
733            CipherSuite::Tls13Aes256GcmSha384,
734            CipherSuite::Tls13Chacha20Poly1305Sha256,
735            CipherSuite::Tls12EcdheEcdsaAes128GcmSha256,
736            CipherSuite::Tls12EcdheRsaChacha20Poly1305Sha256,
737        ];
738        for suite in &suites {
739            let iana = suite.iana_value();
740            let from_iana = CipherSuite::from_iana(iana);
741            assert_eq!(from_iana, Some(*suite));
742        }
743        assert_eq!(CipherSuite::from_iana([0xFF, 0xFF]), None);
744    }
745
746    #[test]
747    fn cipher_suite_version_classification() {
748        assert!(CipherSuite::Tls13Aes128GcmSha256.is_tls13());
749        assert!(!CipherSuite::Tls13Aes128GcmSha256.is_tls12());
750        assert!(CipherSuite::Tls12EcdheRsaAes128GcmSha256.is_tls12());
751        assert!(!CipherSuite::Tls12EcdheRsaAes128GcmSha256.is_tls13());
752    }
753
754    #[test]
755    fn connection_info_builder() {
756        let info = ConnectionInfo::new()
757            .with_version(TlsVersion::Tls13)
758            .with_cipher_suite(CipherSuite::Tls13Aes256GcmSha384)
759            .with_alpn_protocol(b"h2".to_vec())
760            .with_sni("example.com".to_string());
761
762        assert_eq!(info.version, Some(TlsVersion::Tls13));
763        assert_eq!(info.cipher_suite, Some(CipherSuite::Tls13Aes256GcmSha384));
764        assert_eq!(info.alpn_protocol_str(), Some("h2"));
765        assert_eq!(info.sni.as_deref(), Some("example.com"));
766        assert!(info.peer_certificates.is_empty());
767    }
768
769    #[test]
770    fn connection_info_default() {
771        let info = ConnectionInfo::default();
772        assert_eq!(info.version, None);
773        assert_eq!(info.cipher_suite, None);
774        assert_eq!(info.alpn_protocol, None);
775        assert_eq!(info.sni, None);
776        assert!(info.peer_certificates.is_empty());
777    }
778
779    #[test]
780    fn tls_error_display_all_variants() {
781        let cases = [
782            (TlsError::Io(io::ErrorKind::BrokenPipe), "I/O error:"),
783            (TlsError::Handshake("test".into()), "handshake error:"),
784            (TlsError::BadCert("test".into()), "bad certificate:"),
785            (TlsError::InvalidConfig("test".into()), "invalid config:"),
786            (TlsError::CertRevoked("test".into()), "certificate revoked:"),
787            (TlsError::Other("test".into()), "TLS error:"),
788        ];
789        for (err, prefix) in &cases {
790            assert!(
791                err.to_string().starts_with(prefix),
792                "{err} should start with {prefix}"
793            );
794        }
795    }
796
797    #[test]
798    fn tls_error_predicates() {
799        assert!(TlsError::Handshake("x".into()).is_handshake());
800        assert!(!TlsError::Handshake("x".into()).is_io());
801        assert!(TlsError::Io(io::ErrorKind::Other).is_io());
802        assert!(TlsError::BadCert("x".into()).is_cert());
803        assert!(TlsError::CertRevoked("x".into()).is_cert());
804        assert!(!TlsError::Other("x".into()).is_cert());
805        assert!(TlsError::InvalidConfig("x".into()).is_config());
806    }
807
808    #[test]
809    fn tls_error_from_io_error() {
810        let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused");
811        let tls_err = TlsError::from(io_err);
812        assert!(tls_err.is_io());
813    }
814
815    #[test]
816    fn tls_error_into_io_error() {
817        let cases: Vec<TlsError> = vec![
818            TlsError::Io(io::ErrorKind::BrokenPipe),
819            TlsError::Handshake("hs".into()),
820            TlsError::BadCert("bc".into()),
821            TlsError::InvalidConfig("ic".into()),
822            TlsError::CertRevoked("cr".into()),
823            TlsError::Other("ot".into()),
824        ];
825        for tls_err in cases {
826            let io_err: io::Error = tls_err.into();
827            // Just verify conversion works and kind is sensible.
828            let _ = io_err.kind();
829        }
830    }
831}