Skip to main content

trojan_client/
connector.rs

1//! TLS connection establishment to the remote trojan server.
2
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use rustls::pki_types::ServerName;
8use tokio::net::TcpStream;
9use tokio_rustls::TlsConnector;
10use tokio_rustls::client::TlsStream;
11use tracing::debug;
12use trojan_config::TcpConfig;
13
14use crate::error::ClientError;
15
16/// Shared client state for establishing outbound connections.
17pub struct ClientState {
18    /// SHA-224 hex hash of the password (56 bytes).
19    pub hash_hex: String,
20    /// Remote trojan server address string (host:port).
21    pub remote_addr: String,
22    /// TLS connector.
23    pub tls_connector: TlsConnector,
24    /// TLS SNI server name.
25    pub sni: ServerName<'static>,
26    /// TCP socket options.
27    pub tcp_config: TcpConfig,
28    /// TLS handshake timeout.
29    pub tls_handshake_timeout: Duration,
30}
31
32impl ClientState {
33    /// Establish a TLS connection to the remote trojan server.
34    pub async fn connect(&self) -> Result<TlsStream<TcpStream>, ClientError> {
35        // DNS resolve
36        let addr: SocketAddr = tokio::net::lookup_host(&self.remote_addr)
37            .await?
38            .next()
39            .ok_or_else(|| ClientError::Resolve(self.remote_addr.clone()))?;
40
41        debug!(remote = %addr, "connecting to trojan server");
42
43        // TCP connect
44        let tcp = TcpStream::connect(addr).await?;
45        apply_tcp_options(&tcp, &self.tcp_config)?;
46
47        // TLS handshake with timeout
48        let tls = tokio::time::timeout(
49            self.tls_handshake_timeout,
50            self.tls_connector.connect(self.sni.clone(), tcp),
51        )
52        .await
53        .map_err(|_| {
54            std::io::Error::new(std::io::ErrorKind::TimedOut, "TLS handshake timed out")
55        })??;
56
57        Ok(tls)
58    }
59}
60
61/// Build TLS client config from client TLS settings.
62pub fn build_tls_config(
63    tls: &crate::config::ClientTlsConfig,
64) -> Result<rustls::ClientConfig, ClientError> {
65    let mut root_store = rustls::RootCertStore::empty();
66
67    if let Some(ca_path) = &tls.ca {
68        let ca_data = std::fs::read(ca_path)
69            .map_err(|e| ClientError::Config(format!("failed to read CA cert: {e}")))?;
70
71        let certs = rustls_pemfile::certs(&mut std::io::Cursor::new(&ca_data))
72            .collect::<Result<Vec<_>, _>>()
73            .map_err(|e| ClientError::Config(format!("failed to parse CA cert: {e}")))?;
74
75        for cert in certs {
76            root_store
77                .add(cert)
78                .map_err(|e| ClientError::Config(format!("failed to add CA cert: {e}")))?;
79        }
80    } else {
81        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
82    }
83
84    let mut config = if tls.skip_verify {
85        rustls::ClientConfig::builder()
86            .dangerous()
87            .with_custom_certificate_verifier(Arc::new(NoVerifier))
88            .with_no_client_auth()
89    } else {
90        rustls::ClientConfig::builder()
91            .with_root_certificates(root_store)
92            .with_no_client_auth()
93    };
94
95    config.alpn_protocols = tls.alpn.iter().map(|s| s.as_bytes().to_vec()).collect();
96
97    Ok(config)
98}
99
100/// Extract the SNI hostname from config or the remote address.
101pub fn resolve_sni(
102    tls: &crate::config::ClientTlsConfig,
103    remote: &str,
104) -> Result<ServerName<'static>, ClientError> {
105    let host = if let Some(sni) = &tls.sni {
106        sni.clone()
107    } else {
108        extract_host(remote)
109    };
110
111    ServerName::try_from(host)
112        .map_err(|e| ClientError::Config(format!("invalid SNI hostname: {e}")))
113}
114
115fn extract_host(remote: &str) -> String {
116    if let Some(stripped) = remote.strip_prefix('[') {
117        if let Some(end) = stripped.find(']') {
118            return stripped[..end].to_string();
119        }
120    }
121
122    if remote.chars().filter(|&c| c == ':').count() == 1 {
123        return remote
124            .rsplit_once(':')
125            .map(|(h, _)| h.to_string())
126            .unwrap_or_else(|| remote.to_string());
127    }
128
129    remote.to_string()
130}
131
132#[cfg(test)]
133mod tests {
134    use super::{extract_host, resolve_sni};
135
136    #[test]
137    fn extract_host_parses_bracketed_ipv6() {
138        assert_eq!(extract_host("[::1]:443"), "::1");
139        assert_eq!(extract_host("[2001:db8::1]:8443"), "2001:db8::1");
140    }
141
142    #[test]
143    fn extract_host_parses_hostname_and_port() {
144        assert_eq!(extract_host("example.com:443"), "example.com");
145        assert_eq!(extract_host("example.com"), "example.com");
146    }
147
148    #[test]
149    fn resolve_sni_accepts_ipv6_literal() {
150        let tls = crate::config::ClientTlsConfig::default();
151        let sni = resolve_sni(&tls, "[::1]:443");
152        assert!(sni.is_ok());
153    }
154}
155
156/// Apply TCP socket options.
157fn apply_tcp_options(stream: &TcpStream, config: &TcpConfig) -> Result<(), ClientError> {
158    stream.set_nodelay(config.no_delay)?;
159
160    if config.keepalive_secs > 0 {
161        let sock = socket2::SockRef::from(stream);
162        let keepalive =
163            socket2::TcpKeepalive::new().with_time(Duration::from_secs(config.keepalive_secs));
164        sock.set_tcp_keepalive(&keepalive)?;
165    }
166
167    Ok(())
168}
169
170/// Certificate verifier that accepts any certificate (for skip_verify mode).
171#[derive(Debug)]
172struct NoVerifier;
173
174impl rustls::client::danger::ServerCertVerifier for NoVerifier {
175    fn verify_server_cert(
176        &self,
177        _end_entity: &rustls::pki_types::CertificateDer<'_>,
178        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
179        _server_name: &ServerName<'_>,
180        _ocsp_response: &[u8],
181        _now: rustls::pki_types::UnixTime,
182    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
183        Ok(rustls::client::danger::ServerCertVerified::assertion())
184    }
185
186    fn verify_tls12_signature(
187        &self,
188        _message: &[u8],
189        _cert: &rustls::pki_types::CertificateDer<'_>,
190        _dss: &rustls::DigitallySignedStruct,
191    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
192        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
193    }
194
195    fn verify_tls13_signature(
196        &self,
197        _message: &[u8],
198        _cert: &rustls::pki_types::CertificateDer<'_>,
199        _dss: &rustls::DigitallySignedStruct,
200    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
201        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
202    }
203
204    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
205        rustls::crypto::CryptoProvider::get_default()
206            .map(|provider| provider.signature_verification_algorithms.supported_schemes())
207            .unwrap_or_default()
208    }
209}