Skip to main content

corevpn_cli/
client.rs

1//! VPN Client Connection Logic
2//!
3//! Implements the full OpenVPN-compatible client handshake and data channel,
4//! including tls-auth, TLS 1.3, KeyMethodV2 exchange, and TUN forwarding.
5
6use std::process::Command;
7
8use anyhow::{Context, Result, bail};
9use tokio::net::UdpSocket;
10use tracing::{debug, info, warn, error};
11
12use corevpn_crypto::{CipherSuite, HmacAuth, KeyMaterial};
13use corevpn_protocol::{
14    KeyMethodV2, ProcessedPacket, ProtocolSession, ProtocolState,
15    PushReply, TlsClientHandler, create_client_config, load_certs_from_pem, load_key_from_pem,
16};
17
18use crate::ovpn::OvpnConfig;
19
20/// Events emitted during a VPN connection lifecycle.
21///
22/// Consumers (like the NetworkManager plugin) can receive these events
23/// to learn about push reply data (IP, routes, DNS) before the data plane starts.
24#[derive(Debug, Clone)]
25pub enum ConnectionEvent {
26    /// The server sent a PUSH_REPLY with network configuration.
27    PushReply {
28        /// Assigned IP address and netmask
29        ifconfig: Option<(String, String)>,
30        /// Routes pushed by the server (network, netmask)
31        routes: Vec<(String, String)>,
32        /// DNS servers pushed by the server
33        dns: Vec<String>,
34        /// Gateway address
35        gateway: Option<String>,
36        /// Whether server wants redirect-gateway (full tunnel)
37        redirect_gateway: bool,
38    },
39    /// TUN device was created and data plane is starting.
40    Connected {
41        /// Name of the TUN device
42        tun_name: String,
43    },
44    /// An error occurred during connection.
45    Error(String),
46}
47
48/// VPN client that manages the connection lifecycle
49pub struct VpnClient {
50    config: OvpnConfig,
51}
52
53impl VpnClient {
54    /// Create a new VPN client from parsed .ovpn config
55    pub fn new(config: OvpnConfig) -> Self {
56        Self { config }
57    }
58
59    /// Connect to the VPN server and run the tunnel.
60    pub async fn connect(&self) -> Result<()> {
61        self.connect_with_info(None).await
62    }
63
64    /// Connect to the VPN server, sending lifecycle events through the optional channel.
65    ///
66    /// This is used by the NetworkManager plugin to receive push reply data
67    /// (IP, routes, DNS) before the data plane starts, so it can report
68    /// the configuration back to NetworkManager.
69    pub async fn connect_with_info(
70        &self,
71        event_tx: Option<tokio::sync::mpsc::UnboundedSender<ConnectionEvent>>,
72    ) -> Result<()> {
73        info!("Connecting to {} via {}...", self.config.remote, self.config.protocol);
74
75        // Determine cipher suite
76        let cipher_suite = match self.config.cipher.to_uppercase().as_str() {
77            "CHACHA20-POLY1305" => CipherSuite::ChaCha20Poly1305,
78            "AES-256-GCM" => CipherSuite::Aes256Gcm,
79            other => bail!("Unsupported cipher: {}", other),
80        };
81
82        // Bind UDP socket
83        let socket = UdpSocket::bind("0.0.0.0:0").await
84            .context("Failed to bind UDP socket")?;
85        socket.connect(self.config.remote).await
86            .context("Failed to connect UDP socket")?;
87        info!("Bound local socket to {}", socket.local_addr()?);
88
89        // Create protocol session (client side)
90        let mut session = ProtocolSession::new_client(cipher_suite);
91
92        // Set up tls-auth if configured
93        if let Some(ref ta_key_bytes) = self.config.tls_auth_key {
94            let ta_key: [u8; 256] = ta_key_bytes[..256].try_into()
95                .context("tls-auth key must be 256 bytes")?;
96            let key_dir = self.config.key_direction;
97            let hmac_auth = HmacAuth::from_ta_key(&ta_key, false, key_dir)
98                .map_err(|e| anyhow::anyhow!("Failed to create HMAC auth: {}", e))?;
99            session.set_tls_auth(hmac_auth);
100            info!("tls-auth enabled (key-direction: {:?})", key_dir);
101        }
102
103        // Phase 1: Send hard reset to server
104        info!("Sending HARD_RESET_CLIENT_V2...");
105        let hard_reset = session.create_hard_reset_client()
106            .map_err(|e| anyhow::anyhow!("Failed to create hard reset: {}", e))?;
107        socket.send(&hard_reset).await?;
108
109        // Phase 2: Wait for server's hard reset response
110        let mut buf = vec![0u8; 4096];
111        let _server_response = tokio::time::timeout(
112            std::time::Duration::from_secs(10),
113            self.receive_until(&socket, &mut session, &mut buf, |pkt| {
114                matches!(pkt, ProcessedPacket::HardResetAck)
115            }),
116        )
117        .await
118        .context("Timeout waiting for server hard reset response")??;
119        info!("Received server hard reset response");
120
121        // Send ACK for server's hard reset
122        if let Some(ack) = session.create_ack_packet() {
123            socket.send(&ack).await?;
124            debug!("Sent ACK for server hard reset");
125        }
126
127        // Phase 3: TLS handshake
128        info!("Starting TLS handshake...");
129        let tls = self.setup_tls_client()?;
130        let mut tls = tls;
131
132        // Get initial TLS ClientHello
133        let client_hello = tls.get_outgoing()
134            .map_err(|e| anyhow::anyhow!("Failed to get ClientHello: {}", e))?
135            .context("No ClientHello data")?;
136
137        // Send ClientHello via control channel
138        let ctrl_packets = session.create_control_packets(client_hello)
139            .map_err(|e| anyhow::anyhow!("Failed to create control packets: {}", e))?;
140        for pkt in &ctrl_packets {
141            socket.send(pkt).await?;
142        }
143        debug!("Sent ClientHello ({} control packets)", ctrl_packets.len());
144
145        // Phase 4: Prepare KeyMethodV2 (generated before handshake loop so it
146        // can be sent together with the TLS Finished message)
147        let pre_master: [u8; 48] = corevpn_crypto::random_bytes();
148        let client_random1: [u8; 32] = corevpn_crypto::random_bytes();
149        let client_random2: [u8; 32] = corevpn_crypto::random_bytes();
150
151        let client_km = KeyMethodV2 {
152            pre_master,
153            random1: client_random1,
154            random2: client_random2,
155            options: format!(
156                "V4,dev-type tun,link-mtu 1560,tun-mtu 1500,proto UDPv4,cipher {},auth [null-digest],keysize 256,key-method 2,tls-client",
157                self.config.cipher
158            ),
159            username: None,
160            password: None,
161            peer_info: Some(format!(
162                "IV_VER=corevpn-0.4.0\nIV_PLAT=linux\nIV_NCP=2\nIV_TCPNL=1\nIV_PROTO=30\nIV_CIPHERS=CHACHA20-POLY1305:AES-256-GCM:AES-128-GCM\n"
163            )),
164        };
165
166        let km_bytes = client_km.encode(false); // false = client
167
168        // TLS handshake loop - exchange TLS records until handshake completes
169        let mut handshake_complete = false;
170        let mut handshake_attempts = 0;
171        const MAX_HANDSHAKE_ATTEMPTS: usize = 50;
172
173        while !handshake_complete {
174            handshake_attempts += 1;
175            if handshake_attempts > MAX_HANDSHAKE_ATTEMPTS {
176                bail!("TLS handshake failed: too many iterations");
177            }
178
179            // Receive data from server
180            let n = tokio::time::timeout(
181                std::time::Duration::from_secs(10),
182                socket.recv(&mut buf),
183            )
184            .await
185            .context("Timeout during TLS handshake")?
186            .context("Failed to receive during TLS handshake")?;
187
188            let result = session.process_packet(&buf[..n])
189                .map_err(|e| anyhow::anyhow!("Failed to process packet during handshake: {}", e))?;
190
191            match result {
192                ProcessedPacket::TlsData(records) => {
193                    // Feed TLS records to the TLS handler
194                    tls.process_tls_records(records)
195                        .map_err(|e| anyhow::anyhow!("TLS processing failed: {}", e))?;
196
197                    // Send any pending ACKs
198                    if session.should_send_ack() {
199                        if let Some(ack) = session.create_ack_packet() {
200                            socket.send(&ack).await?;
201                        }
202                    }
203
204                    if tls.is_handshake_complete() {
205                        // Handshake just completed. Write the KM2 plaintext
206                        // BEFORE flushing TLS so the Finished message and KM2
207                        // application data are sent together in the same batch.
208                        // This is critical: the corevpn server reads plaintext
209                        // immediately when the handshake completes, so the KM2
210                        // must be available in the same TLS processing round.
211                        info!("Performing key exchange...");
212                        session.set_state(ProtocolState::KeyExchange);
213                        debug!("Writing {} bytes of key_method_v2 to TLS", km_bytes.len());
214                        tls.write_plaintext(&km_bytes)
215                            .map_err(|e| anyhow::anyhow!("Failed to write key_method_v2: {}", e))?;
216
217                        // Now flush everything: TLS Finished + KM2 together
218                        self.flush_tls_to_socket(&mut tls, &mut session, &socket).await?;
219                        handshake_complete = true;
220                    } else {
221                        // Normal handshake: flush TLS response data
222                        while tls.wants_write() {
223                            if let Some(tls_out) = tls.get_outgoing()
224                                .map_err(|e| anyhow::anyhow!("TLS outgoing failed: {}", e))?
225                            {
226                                let ctrl_packets = session.create_control_packets(tls_out)
227                                    .map_err(|e| anyhow::anyhow!("Failed to create control packets: {}", e))?;
228                                for pkt in &ctrl_packets {
229                                    socket.send(pkt).await?;
230                                }
231                            } else {
232                                break;
233                            }
234                        }
235                    }
236                }
237                ProcessedPacket::None => {
238                    // ACK or other non-data packet, continue
239                }
240                other => {
241                    debug!("Unexpected packet during handshake: {:?}", other);
242                }
243            }
244
245            // Check retransmits
246            let retransmits = session.get_retransmits();
247            for pkt in retransmits {
248                socket.send(&pkt).await?;
249            }
250        }
251
252        info!(
253            "TLS handshake complete (cipher: {:?}, version: {:?})",
254            tls.cipher_suite(),
255            tls.protocol_version()
256        );
257
258        // Receive server's key_method_v2 and PUSH_REPLY from TLS stream
259        let mut plaintext_buf = vec![0u8; 8192];
260        let mut total_plaintext = Vec::new();
261        let mut server_km: Option<KeyMethodV2> = None;
262        let mut push_reply: Option<PushReply> = None;
263        let mut km_attempts = 0;
264        const MAX_KM_ATTEMPTS: usize = 100;
265
266        while push_reply.is_none() {
267            km_attempts += 1;
268            if km_attempts > MAX_KM_ATTEMPTS {
269                bail!("Timeout waiting for server key exchange and push reply");
270            }
271
272            // Try to read plaintext first (data may already be buffered)
273            let n = tls.read_plaintext(&mut plaintext_buf)
274                .map_err(|e| anyhow::anyhow!("TLS read failed: {}", e))?;
275            if n > 0 {
276                total_plaintext.extend_from_slice(&plaintext_buf[..n]);
277                debug!("Read {} bytes of TLS plaintext (total: {})", n, total_plaintext.len());
278
279                // Try to parse messages from plaintext
280                self.try_parse_server_messages(&mut total_plaintext, &mut server_km, &mut push_reply)?;
281
282                if push_reply.is_some() {
283                    break;
284                }
285                continue;
286            }
287
288            // Need more data from network (longer timeout to allow OAuth auth)
289            let recv_result = tokio::time::timeout(
290                std::time::Duration::from_secs(120),
291                socket.recv(&mut buf),
292            )
293            .await
294            .context("Timeout waiting for server key exchange (if OAuth, did you complete auth?)")?
295            .context("Failed to receive during key exchange")?;
296
297            debug!("Received {} bytes from server during KM exchange", recv_result);
298
299            let result = session.process_packet(&buf[..recv_result])
300                .map_err(|e| anyhow::anyhow!("Failed to process packet during key exchange: {}", e))?;
301
302            match result {
303                ProcessedPacket::TlsData(records) => {
304                    debug!("Got {} TLS records during KM exchange", records.len());
305                    tls.process_tls_records(records)
306                        .map_err(|e| anyhow::anyhow!("TLS processing failed: {}", e))?;
307
308                    // Send ACKs
309                    if session.should_send_ack() {
310                        if let Some(ack) = session.create_ack_packet() {
311                            socket.send(&ack).await?;
312                            debug!("Sent ACK during KM exchange");
313                        }
314                    }
315
316                    // Flush any TLS responses
317                    self.flush_tls_to_socket(&mut tls, &mut session, &socket).await?;
318                }
319                ProcessedPacket::None => {
320                    debug!("Got None (ACK) during KM exchange");
321                    // May need to send ACKs back
322                    if session.should_send_ack() {
323                        if let Some(ack) = session.create_ack_packet() {
324                            socket.send(&ack).await?;
325                        }
326                    }
327                }
328                other => {
329                    debug!("Unexpected packet during key exchange: {:?}", other);
330                }
331            }
332
333            // Check for retransmissions
334            let retransmits = session.get_retransmits();
335            for pkt in retransmits {
336                socket.send(&pkt).await?;
337                debug!("Retransmitted control packet during KM exchange");
338            }
339        }
340
341        let server_km = server_km.context("Never received server KeyMethodV2")?;
342        let push_reply = push_reply.unwrap();
343
344        // Notify listeners of push reply data (for NM plugin integration)
345        if let Some(ref tx) = event_tx {
346            let _ = tx.send(ConnectionEvent::PushReply {
347                ifconfig: push_reply.ifconfig.clone(),
348                routes: push_reply.routes.iter()
349                    .map(|r| (r.network.clone(), r.netmask.clone()))
350                    .collect(),
351                dns: push_reply.dns.clone(),
352                gateway: push_reply.route_gateway.clone(),
353                redirect_gateway: push_reply.redirect_gateway,
354            });
355        }
356
357        // Phase 5: Key derivation via OpenVPN PRF
358        info!("Deriving data channel keys...");
359        let client_sid = *session.local_session_id();
360        let server_sid = session.remote_session_id()
361            .copied()
362            .context("Missing remote session ID")?;
363
364        // Step 1: master_secret = PRF(pre_master, "OpenVPN master secret", client_r1 || server_r1, 48)
365        let mut seed1 = Vec::with_capacity(64);
366        seed1.extend_from_slice(&client_random1);
367        seed1.extend_from_slice(&server_km.random1);
368        let master_secret = corevpn_crypto::openvpn_prf(
369            &pre_master,
370            b"OpenVPN master secret",
371            &seed1,
372            48,
373        ).map_err(|e| anyhow::anyhow!("PRF master secret failed: {}", e))?;
374
375        debug!("PRF master_secret[..8]={:02x?}", &master_secret[..8]);
376
377        // Step 2: key_block = PRF(master_secret, "OpenVPN key expansion",
378        //                         client_r2 || server_r2 || client_sid || server_sid, 256)
379        let mut seed2 = Vec::with_capacity(64 + 16);
380        seed2.extend_from_slice(&client_random2);
381        seed2.extend_from_slice(&server_km.random2);
382        seed2.extend_from_slice(&client_sid);
383        seed2.extend_from_slice(&server_sid);
384        let key_block = corevpn_crypto::openvpn_prf(
385            &master_secret,
386            b"OpenVPN key expansion",
387            &seed2,
388            256,
389        ).map_err(|e| anyhow::anyhow!("PRF key expansion failed: {}", e))?;
390
391        let km = KeyMaterial::from_openvpn_key2_block(&key_block);
392
393        // Install keys (false = client side: encrypt with client key, decrypt with server key)
394        session.install_keys(&km, false);
395        info!("Installed data channel keys");
396
397        // Phase 6: Configure TUN device
398        let (ifconfig_ip, ifconfig_mask) = push_reply.ifconfig.as_ref()
399            .context("PUSH_REPLY missing ifconfig")?
400            .clone();
401
402        info!("VPN IP: {} / {}", ifconfig_ip, ifconfig_mask);
403        if let Some(ref gw) = push_reply.route_gateway {
404            info!("Gateway: {}", gw);
405        }
406        for dns in &push_reply.dns {
407            info!("DNS: {}", dns);
408        }
409
410        // Set state to established
411        session.set_state(ProtocolState::Established);
412        info!("VPN session established!");
413
414        // Create TUN device
415        let tun_dev = self.create_tun_device(&ifconfig_ip, &ifconfig_mask, &push_reply)?;
416
417        // Notify listeners that the TUN device is up
418        if let Some(ref tx) = event_tx {
419            let _ = tx.send(ConnectionEvent::Connected {
420                tun_name: "tun0".to_string(), // TUN crate doesn't expose the name easily
421            });
422        }
423
424        // Phase 7: Data plane - forward packets between TUN and UDP
425        info!("Starting data plane forwarding...");
426        self.run_data_plane(socket, session, tun_dev, &push_reply).await
427    }
428
429    /// Set up TLS client handler
430    fn setup_tls_client(&self) -> Result<TlsClientHandler> {
431        // Install ring as the crypto provider for rustls
432        let _ = rustls::crypto::ring::default_provider().install_default();
433
434        // Parse CA certificate
435        let ca_certs = load_certs_from_pem(&self.config.ca_pem)
436            .map_err(|e| anyhow::anyhow!("Failed to load CA cert: {}", e))?;
437
438        // Parse client certificate and key for mTLS
439        let client_certs = load_certs_from_pem(&self.config.cert_pem)
440            .map_err(|e| anyhow::anyhow!("Failed to load client cert: {}", e))?;
441        let client_key = load_key_from_pem(&self.config.key_pem)
442            .map_err(|e| anyhow::anyhow!("Failed to load client key: {}", e))?;
443
444        // Create TLS config
445        let tls_config = create_client_config(
446            ca_certs,
447            Some((client_certs, client_key)),
448        ).map_err(|e| anyhow::anyhow!("Failed to create TLS config: {}", e))?;
449
450        // Create TLS handler - use "corevpn" as server name (will be validated by our custom verifier)
451        let server_name = rustls::pki_types::ServerName::try_from("corevpn")
452            .map_err(|e| anyhow::anyhow!("Invalid server name: {}", e))?
453            .to_owned();
454
455        TlsClientHandler::new(tls_config, server_name)
456            .map_err(|e| anyhow::anyhow!("Failed to create TLS handler: {}", e))
457    }
458
459    /// Create and configure the TUN device
460    fn create_tun_device(
461        &self,
462        ip: &str,
463        mask: &str,
464        push_reply: &PushReply,
465    ) -> Result<tun::AsyncDevice> {
466        let mut config = tun::Configuration::default();
467        config.address(ip.parse::<std::net::Ipv4Addr>()?)
468            .netmask(mask.parse::<std::net::Ipv4Addr>()?)
469            .mtu(1500)
470            .up();
471
472        let dev = tun::create_as_async(&config)
473            .context("Failed to create TUN device. Are you running as root/with CAP_NET_ADMIN?")?;
474
475        info!("TUN device created");
476
477        // Configure routes
478        if push_reply.redirect_gateway {
479            if let Some(ref gw) = push_reply.route_gateway {
480                info!("Setting up full tunnel (redirect-gateway) via {}", gw);
481                // Note: Full gateway redirect requires careful route manipulation
482                // to avoid routing loops. For now, just add the VPN subnet route.
483            }
484        }
485
486        for route in &push_reply.routes {
487            info!("Adding route: {} {} via VPN", route.network, route.netmask);
488            // Routes will be configured via the OS
489            if let Err(e) = add_route(&route.network, &route.netmask, ip) {
490                warn!("Failed to add route {} {}: {}", route.network, route.netmask, e);
491            }
492        }
493
494        Ok(dev)
495    }
496
497    /// Flush TLS outgoing data to the UDP socket via control packets.
498    ///
499    /// Collects ALL pending TLS output into a single buffer before wrapping
500    /// in control packets. This is critical: the TLS Finished message and
501    /// any immediately-following application data (like KeyMethodV2) must
502    /// arrive in the same control packet batch so the server can process
503    /// them together in a single round.
504    async fn flush_tls_to_socket(
505        &self,
506        tls: &mut TlsClientHandler,
507        session: &mut ProtocolSession,
508        socket: &UdpSocket,
509    ) -> Result<()> {
510        // Collect all pending TLS output first
511        let mut all_tls_data = Vec::new();
512        while tls.wants_write() {
513            if let Some(tls_out) = tls.get_outgoing()
514                .map_err(|e| anyhow::anyhow!("TLS outgoing failed: {}", e))?
515            {
516                all_tls_data.extend_from_slice(&tls_out);
517            } else {
518                break;
519            }
520        }
521        if !all_tls_data.is_empty() {
522            debug!("Flushing {} bytes of TLS data to control channel", all_tls_data.len());
523            let tls_data = bytes::Bytes::from(all_tls_data);
524            let ctrl_packets = session.create_control_packets(tls_data)
525                .map_err(|e| anyhow::anyhow!("Failed to create control packets: {}", e))?;
526            for pkt in &ctrl_packets {
527                socket.send(pkt).await?;
528            }
529        }
530        Ok(())
531    }
532
533    /// Try to parse server KeyMethodV2 and PUSH_REPLY from accumulated plaintext
534    fn try_parse_server_messages(
535        &self,
536        total_plaintext: &mut Vec<u8>,
537        server_km: &mut Option<KeyMethodV2>,
538        push_reply: &mut Option<PushReply>,
539    ) -> Result<()> {
540        // Try to parse server's KeyMethodV2 if we haven't yet
541        if server_km.is_none() && total_plaintext.len() >= 71 {
542            match KeyMethodV2::parse_from_server(total_plaintext) {
543                Ok(km) => {
544                    debug!("Parsed server KeyMethodV2 (options: {})", km.options);
545
546                    // Determine how many bytes the KM consumed
547                    let km_size = calculate_km_size(total_plaintext)?;
548                    let remaining = total_plaintext.split_off(km_size);
549                    *total_plaintext = remaining;
550
551                    *server_km = Some(km);
552                    debug!("Remaining plaintext after KM: {} bytes", total_plaintext.len());
553                }
554                Err(e) => {
555                    debug!("Not enough data for KeyMethodV2 yet: {}", e);
556                }
557            }
558        }
559
560        // Try to parse control messages if we have server_km
561        if server_km.is_some() && !total_plaintext.is_empty() {
562            // Strip leading null bytes (separators between messages)
563            while total_plaintext.first() == Some(&0) {
564                total_plaintext.remove(0);
565            }
566            if total_plaintext.is_empty() {
567                return Ok(());
568            }
569
570            let msg = String::from_utf8_lossy(total_plaintext);
571            let msg_str = msg.trim_end_matches('\0');
572            debug!("Checking plaintext for control message: {:?}", msg_str);
573
574            if msg_str.starts_with("PUSH_REPLY") {
575                *push_reply = Some(PushReply::parse(msg_str)
576                    .map_err(|e| anyhow::anyhow!("Failed to parse PUSH_REPLY: {}", e))?);
577                info!("Received PUSH_REPLY");
578            } else if msg_str.contains("AUTH_PENDING") {
579                info!("Server requires authentication (pending)");
580                total_plaintext.clear();
581            } else if msg_str.starts_with("INFO_PRE,WEB_AUTH:") {
582                // WEB_AUTH format: INFO_PRE,WEB_AUTH:flags:url
583                // flags may be empty, so the URL follows the second ':'
584                if let Some(rest) = msg_str.strip_prefix("INFO_PRE,WEB_AUTH:") {
585                    // Skip flags (everything up to the next ':')
586                    let url = rest.split_once(':').map(|(_, u)| u).unwrap_or(rest).trim();
587                    info!("Server requires OAuth authentication.");
588                    info!("Please open this URL in your browser: {}", url);
589                    eprintln!("\n  Open this URL to authenticate:\n  {}\n", url);
590                }
591                total_plaintext.clear();
592            } else if msg_str.starts_with("INFO_PRE,OPEN_URL:") {
593                // Deprecated OPEN_URL format
594                let url = msg_str.strip_prefix("INFO_PRE,OPEN_URL:").unwrap_or("").trim();
595                info!("Server requires OAuth authentication.");
596                info!("Please open this URL in your browser: {}", url);
597                eprintln!("\n  Open this URL to authenticate:\n  {}\n", url);
598                total_plaintext.clear();
599            }
600        }
601
602        Ok(())
603    }
604
605    /// Receive packets until a condition is met
606    async fn receive_until(
607        &self,
608        socket: &UdpSocket,
609        session: &mut ProtocolSession,
610        buf: &mut [u8],
611        condition: impl Fn(&ProcessedPacket) -> bool,
612    ) -> Result<ProcessedPacket> {
613        loop {
614            let n = socket.recv(buf).await
615                .context("Failed to receive packet")?;
616
617            let result = session.process_packet(&buf[..n])
618                .map_err(|e| anyhow::anyhow!("Failed to process packet: {}", e))?;
619
620            if condition(&result) {
621                return Ok(result);
622            }
623        }
624    }
625
626    /// Main data plane forwarding loop
627    async fn run_data_plane(
628        &self,
629        socket: UdpSocket,
630        mut session: ProtocolSession,
631        tun_dev: tun::AsyncDevice,
632        push_reply: &PushReply,
633    ) -> Result<()> {
634        use tokio::io::{AsyncReadExt, AsyncWriteExt};
635
636        let (mut tun_reader, mut tun_writer) = tokio::io::split(tun_dev);
637
638        let mut udp_buf = vec![0u8; 4096];
639        let mut tun_buf = vec![0u8; 2048];
640
641        // Ping timer
642        let ping_interval = std::time::Duration::from_secs(push_reply.ping as u64);
643        let ping_restart = std::time::Duration::from_secs(push_reply.ping_restart as u64);
644        let mut last_recv = tokio::time::Instant::now();
645        let mut ping_timer = tokio::time::interval(ping_interval);
646        ping_timer.tick().await; // First tick is immediate
647
648        info!("Data plane active (ping: {}s, restart: {}s)", push_reply.ping, push_reply.ping_restart);
649
650        loop {
651            tokio::select! {
652                // Read from TUN -> encrypt -> send to server
653                result = tun_reader.read(&mut tun_buf) => {
654                    match result {
655                        Ok(0) => {
656                            info!("TUN device closed");
657                            break;
658                        }
659                        Ok(n) => {
660                            match session.encrypt_data(&tun_buf[..n]) {
661                                Ok(encrypted) => {
662                                    if let Err(e) = socket.send(&encrypted).await {
663                                        warn!("Failed to send encrypted data: {}", e);
664                                    }
665                                }
666                                Err(e) => {
667                                    warn!("Failed to encrypt data: {}", e);
668                                }
669                            }
670                        }
671                        Err(e) => {
672                            error!("TUN read error: {}", e);
673                            break;
674                        }
675                    }
676                }
677
678                // Read from server -> decrypt -> write to TUN
679                result = socket.recv(&mut udp_buf) => {
680                    match result {
681                        Ok(n) => {
682                            last_recv = tokio::time::Instant::now();
683
684                            match session.process_packet(&udp_buf[..n]) {
685                                Ok(ProcessedPacket::Data(decrypted)) => {
686                                    if let Err(e) = tun_writer.write_all(&decrypted).await {
687                                        warn!("Failed to write to TUN: {}", e);
688                                    }
689                                }
690                                Ok(ProcessedPacket::None) => {
691                                    // ACK or keepalive - send ACK if needed
692                                    if session.should_send_ack() {
693                                        if let Some(ack) = session.create_ack_packet() {
694                                            let _ = socket.send(&ack).await;
695                                        }
696                                    }
697                                }
698                                Ok(ProcessedPacket::TlsData(_)) => {
699                                    debug!("Received post-handshake TLS data (ignored)");
700                                    if session.should_send_ack() {
701                                        if let Some(ack) = session.create_ack_packet() {
702                                            let _ = socket.send(&ack).await;
703                                        }
704                                    }
705                                }
706                                Ok(other) => {
707                                    debug!("Unexpected packet in data plane: {:?}", other);
708                                }
709                                Err(e) => {
710                                    debug!("Failed to process packet: {}", e);
711                                }
712                            }
713                        }
714                        Err(e) => {
715                            error!("UDP recv error: {}", e);
716                            break;
717                        }
718                    }
719                }
720
721                // Send OpenVPN ping (keepalive)
722                _ = ping_timer.tick() => {
723                    // Check if we've timed out
724                    if last_recv.elapsed() > ping_restart {
725                        error!("Connection timed out (no data for {}s)", ping_restart.as_secs());
726                        break;
727                    }
728
729                    // Send OpenVPN ping packet (0x2a 0x18 0x7b 0xf3 0x64 0x1e 0xb4 0xcb 0x07 0xed 0x2d 0x0a 0x98 0x1f 0xc7 0x48)
730                    // This is the well-known OpenVPN ping payload
731                    let ping_payload: [u8; 16] = [
732                        0x2a, 0x18, 0x7b, 0xf3, 0x64, 0x1e, 0xb4, 0xcb,
733                        0x07, 0xed, 0x2d, 0x0a, 0x98, 0x1f, 0xc7, 0x48,
734                    ];
735                    match session.encrypt_data(&ping_payload) {
736                        Ok(encrypted) => {
737                            if let Err(e) = socket.send(&encrypted).await {
738                                warn!("Failed to send ping: {}", e);
739                            }
740                        }
741                        Err(e) => {
742                            debug!("Failed to encrypt ping: {}", e);
743                        }
744                    }
745                }
746            }
747        }
748
749        info!("VPN connection closed");
750        Ok(())
751    }
752}
753
754/// Calculate the byte size of a server KeyMethodV2 message
755fn calculate_km_size(data: &[u8]) -> Result<usize> {
756    // Server KM2 format (no pre_master):
757    // 4 (literal zero) + 1 (key method) + 32 (random1) + 32 (random2) = 69
758    // Then length-prefixed strings: options, username, password
759    if data.len() < 71 {
760        bail!("KeyMethodV2 data too short");
761    }
762
763    let mut offset = 69; // After fixed header
764
765    // Skip length-prefixed strings (options, then optional username, password)
766    // The server may send empty username/password as length-prefixed empty strings
767    for _ in 0..3 {
768        if offset + 2 > data.len() {
769            break;
770        }
771        let str_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
772        if offset + 2 + str_len > data.len() {
773            break;
774        }
775        offset += 2 + str_len;
776    }
777
778    Ok(offset)
779}
780
781/// Add a route via the OS
782fn add_route(network: &str, netmask: &str, gateway: &str) -> Result<()> {
783    #[cfg(target_os = "linux")]
784    {
785        // Convert netmask to prefix length
786        let mask: std::net::Ipv4Addr = netmask.parse()?;
787        let mask_bits: u32 = u32::from(mask);
788        let prefix_len = mask_bits.count_ones();
789
790        let output = Command::new("ip")
791            .args(["route", "add", &format!("{}/{}", network, prefix_len), "via", gateway])
792            .output()
793            .context("Failed to execute ip route add")?;
794
795        if !output.status.success() {
796            let stderr = String::from_utf8_lossy(&output.stderr);
797            // Ignore "File exists" errors (route already exists)
798            if !stderr.contains("File exists") {
799                warn!("ip route add failed: {}", stderr.trim());
800            }
801        }
802    }
803
804    #[cfg(not(target_os = "linux"))]
805    {
806        warn!("Route configuration not implemented for this platform");
807    }
808
809    Ok(())
810}