Skip to main content

corevpn_protocol/
control.rs

1//! Control Channel Message Types
2
3use std::net::Ipv4Addr;
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7
8use crate::{ProtocolError, Result};
9
10/// Validate that a string is a valid IPv4 address
11fn validate_ipv4(s: &str) -> Result<()> {
12    s.parse::<Ipv4Addr>()
13        .map_err(|_| ProtocolError::InvalidPacket(format!("invalid IPv4 address: {}", s)))?;
14    Ok(())
15}
16
17/// Control channel message types
18#[derive(Debug, Clone)]
19pub enum ControlMessage {
20    /// TLS data (wrapped in control channel)
21    TlsData(Bytes),
22    /// Push request from client
23    PushRequest,
24    /// Push reply from server
25    PushReply(PushReply),
26    /// Authentication data
27    Auth(AuthMessage),
28    /// Info message (version, etc.)
29    Info(String),
30    /// Exit/shutdown
31    Exit,
32}
33
34/// Control packet for the reliable transport layer
35#[derive(Debug, Clone)]
36pub struct ControlPacket {
37    /// Packet ID for reliability
38    pub packet_id: u32,
39    /// Message content
40    pub message: ControlMessage,
41}
42
43impl ControlPacket {
44    /// Create a new control packet
45    pub fn new(packet_id: u32, message: ControlMessage) -> Self {
46        Self { packet_id, message }
47    }
48}
49
50/// Push reply containing VPN configuration
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct PushReply {
53    /// Routes to push
54    pub routes: Vec<PushRoute>,
55    /// IPv4 address and netmask
56    pub ifconfig: Option<(String, String)>,
57    /// IPv6 address
58    pub ifconfig_ipv6: Option<String>,
59    /// DNS servers
60    pub dns: Vec<String>,
61    /// Search domains
62    pub dns_search: Vec<String>,
63    /// Redirect gateway (full tunnel)
64    pub redirect_gateway: bool,
65    /// Route gateway (VPN gateway IP for redirect-gateway)
66    pub route_gateway: Option<String>,
67    /// Topology type
68    pub topology: Topology,
69    /// Ping interval
70    pub ping: u32,
71    /// Ping restart timeout
72    pub ping_restart: u32,
73    /// Additional options
74    pub options: Vec<String>,
75}
76
77impl Default for PushReply {
78    fn default() -> Self {
79        Self {
80            routes: vec![],
81            ifconfig: None,
82            ifconfig_ipv6: None,
83            dns: vec![],
84            dns_search: vec![],
85            redirect_gateway: false,
86            route_gateway: None,
87            topology: Topology::Subnet,
88            ping: 10,
89            ping_restart: 60,
90            options: vec![],
91        }
92    }
93}
94
95impl PushReply {
96    /// Encode as OpenVPN push reply string
97    pub fn encode(&self) -> String {
98        let mut parts = vec!["PUSH_REPLY".to_string()];
99
100        // Topology
101        parts.push(format!("topology {}", self.topology.as_str()));
102
103        // ifconfig
104        if let Some((ip, mask)) = &self.ifconfig {
105            parts.push(format!("ifconfig {} {}", ip, mask));
106        }
107
108        // ifconfig-ipv6
109        if let Some(ipv6) = &self.ifconfig_ipv6 {
110            parts.push(format!("ifconfig-ipv6 {}", ipv6));
111        }
112
113        // Routes
114        for route in &self.routes {
115            parts.push(route.encode());
116        }
117
118        // Route gateway (must come before redirect-gateway)
119        if let Some(gw) = &self.route_gateway {
120            parts.push(format!("route-gateway {}", gw));
121        }
122
123        // Redirect gateway
124        if self.redirect_gateway {
125            parts.push("redirect-gateway def1".to_string());
126        }
127
128        // DNS
129        for dns in &self.dns {
130            parts.push(format!("dhcp-option DNS {}", dns));
131        }
132
133        // DNS search domains
134        for domain in &self.dns_search {
135            parts.push(format!("dhcp-option DOMAIN {}", domain));
136        }
137
138        // Ping settings
139        parts.push(format!("ping {}", self.ping));
140        parts.push(format!("ping-restart {}", self.ping_restart));
141
142        // Additional options
143        for opt in &self.options {
144            parts.push(opt.clone());
145        }
146
147        parts.join(",")
148    }
149
150    /// Parse from OpenVPN push reply string
151    pub fn parse(s: &str) -> Result<Self> {
152        let mut reply = Self::default();
153
154        // Remove PUSH_REPLY prefix if present
155        let s = s.strip_prefix("PUSH_REPLY,").unwrap_or(s);
156
157        for part in s.split(',') {
158            let part = part.trim();
159            if part.is_empty() {
160                continue;
161            }
162
163            let mut tokens = part.split_whitespace();
164            match tokens.next() {
165                Some("topology") => {
166                    if let Some(topo) = tokens.next() {
167                        reply.topology = Topology::parse(topo);
168                    }
169                }
170                Some("ifconfig") => {
171                    let ip = tokens.next().unwrap_or("").to_string();
172                    let mask = tokens.next().unwrap_or("").to_string();
173                    reply.ifconfig = Some((ip, mask));
174                }
175                Some("ifconfig-ipv6") => {
176                    if let Some(ipv6) = tokens.next() {
177                        reply.ifconfig_ipv6 = Some(ipv6.to_string());
178                    }
179                }
180                Some("route") => {
181                    if let Ok(route) = PushRoute::parse(part) {
182                        reply.routes.push(route);
183                    }
184                }
185                Some("route-gateway") => {
186                    if let Some(gw) = tokens.next() {
187                        reply.route_gateway = Some(gw.to_string());
188                    }
189                }
190                Some("redirect-gateway") => {
191                    reply.redirect_gateway = true;
192                }
193                Some("dhcp-option") => {
194                    match tokens.next() {
195                        Some("DNS") => {
196                            if let Some(dns) = tokens.next() {
197                                reply.dns.push(dns.to_string());
198                            }
199                        }
200                        Some("DOMAIN") => {
201                            if let Some(domain) = tokens.next() {
202                                reply.dns_search.push(domain.to_string());
203                            }
204                        }
205                        _ => {}
206                    }
207                }
208                Some("ping") => {
209                    if let Some(Ok(p)) = tokens.next().map(|s| s.parse()) {
210                        reply.ping = p;
211                    }
212                }
213                Some("ping-restart") => {
214                    if let Some(Ok(p)) = tokens.next().map(|s| s.parse()) {
215                        reply.ping_restart = p;
216                    }
217                }
218                _ => {
219                    reply.options.push(part.to_string());
220                }
221            }
222        }
223
224        Ok(reply)
225    }
226}
227
228/// Route to push to client
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct PushRoute {
231    /// Network address
232    pub network: String,
233    /// Netmask
234    pub netmask: String,
235    /// Gateway (optional, vpn_gateway used if not set)
236    pub gateway: Option<String>,
237    /// Metric
238    pub metric: Option<u32>,
239}
240
241impl PushRoute {
242    /// Create a new route
243    pub fn new(network: &str, netmask: &str) -> Self {
244        Self {
245            network: network.to_string(),
246            netmask: netmask.to_string(),
247            gateway: None,
248            metric: None,
249        }
250    }
251
252    /// Encode as OpenVPN route directive
253    pub fn encode(&self) -> String {
254        let mut s = format!("route {} {}", self.network, self.netmask);
255        if let Some(gw) = &self.gateway {
256            s.push_str(&format!(" {}", gw));
257        } else {
258            s.push_str(" vpn_gateway");
259        }
260        if let Some(metric) = self.metric {
261            s.push_str(&format!(" {}", metric));
262        }
263        s
264    }
265
266    /// Parse from OpenVPN route directive
267    pub fn parse(s: &str) -> Result<Self> {
268        let mut tokens = s.split_whitespace();
269        tokens.next(); // skip "route"
270
271        let network_str = tokens
272            .next()
273            .ok_or_else(|| ProtocolError::InvalidPacket("missing network in route".into()))?;
274        
275        // Validate network address
276        validate_ipv4(network_str)?;
277        let network = network_str.to_string();
278
279        let netmask_str = tokens
280            .next()
281            .ok_or_else(|| ProtocolError::InvalidPacket("missing netmask in route".into()))?;
282        
283        // Validate netmask
284        validate_ipv4(netmask_str)?;
285        let netmask = netmask_str.to_string();
286
287        let gateway = tokens.next().and_then(|g| {
288            if g == "vpn_gateway" {
289                None
290            } else {
291                // Validate gateway IP address
292                validate_ipv4(g).ok().map(|_| g.to_string())
293            }
294        });
295
296        let metric = tokens.next().and_then(|m| {
297            m.parse::<u32>().ok().filter(|&m| m <= 9999) // Reasonable metric limit
298        });
299
300        Ok(Self {
301            network,
302            netmask,
303            gateway,
304            metric,
305        })
306    }
307}
308
309/// Network topology type
310#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
311pub enum Topology {
312    /// Point-to-point (net30)
313    Net30,
314    /// Point-to-point (p2p)
315    P2P,
316    /// Subnet mode (recommended)
317    #[default]
318    Subnet,
319}
320
321impl Topology {
322    /// Parse from string
323    pub fn parse(s: &str) -> Self {
324        match s.to_lowercase().as_str() {
325            "net30" => Topology::Net30,
326            "p2p" => Topology::P2P,
327            "subnet" => Topology::Subnet,
328            _ => Topology::Subnet,
329        }
330    }
331
332    /// Convert to string
333    pub fn as_str(&self) -> &'static str {
334        match self {
335            Topology::Net30 => "net30",
336            Topology::P2P => "p2p",
337            Topology::Subnet => "subnet",
338        }
339    }
340}
341
342/// Authentication message from client
343#[derive(Debug, Clone)]
344pub struct AuthMessage {
345    /// Username
346    pub username: String,
347    /// Password
348    pub password: String,
349}
350
351impl AuthMessage {
352    /// Maximum username length
353    const MAX_USERNAME_LEN: usize = 256;
354    /// Maximum password length
355    const MAX_PASSWORD_LEN: usize = 1024;
356
357    /// Parse from OpenVPN auth data
358    pub fn parse(data: &[u8]) -> Result<Self> {
359        // Format: username\0password\0
360        // Security: Limit total input size to prevent DoS
361        if data.len() > Self::MAX_USERNAME_LEN + Self::MAX_PASSWORD_LEN + 2 {
362            return Err(ProtocolError::InvalidPacket("auth data too long".into()));
363        }
364
365        let s = std::str::from_utf8(data)
366            .map_err(|_| ProtocolError::InvalidPacket("invalid UTF-8 in auth".into()))?;
367
368        let parts: Vec<&str> = s.split('\0').collect();
369        if parts.len() < 2 {
370            return Err(ProtocolError::InvalidPacket("missing auth fields".into()));
371        }
372
373        let username = parts[0];
374        let password = parts[1];
375
376        // Validate lengths
377        if username.len() > Self::MAX_USERNAME_LEN {
378            return Err(ProtocolError::InvalidPacket(
379                format!("username too long (max {} bytes)", Self::MAX_USERNAME_LEN).into(),
380            ));
381        }
382        if password.len() > Self::MAX_PASSWORD_LEN {
383            return Err(ProtocolError::InvalidPacket(
384                format!("password too long (max {} bytes)", Self::MAX_PASSWORD_LEN).into(),
385            ));
386        }
387
388        Ok(Self {
389            username: username.to_string(),
390            password: password.to_string(),
391        })
392    }
393
394    /// Encode to OpenVPN auth format
395    pub fn encode(&self) -> Vec<u8> {
396        let mut data = Vec::new();
397        data.extend_from_slice(self.username.as_bytes());
398        data.push(0);
399        data.extend_from_slice(self.password.as_bytes());
400        data.push(0);
401        data
402    }
403}
404
405/// Key method v2 data (exchanged during TLS handshake)
406#[derive(Debug, Clone)]
407pub struct KeyMethodV2 {
408    /// Pre-master secret (48 bytes)
409    pub pre_master: [u8; 48],
410    /// Random data 1 (32 bytes) - used as EKM context and PRF seed
411    pub random1: [u8; 32],
412    /// Random data 2 (32 bytes) - used as additional PRF seed
413    pub random2: [u8; 32],
414    /// Options string
415    pub options: String,
416    /// Username (if using auth)
417    pub username: Option<String>,
418    /// Password (if using auth)
419    pub password: Option<String>,
420    /// Peer info
421    pub peer_info: Option<String>,
422}
423
424impl KeyMethodV2 {
425    /// Parse key method v2 data from bytes (received from TLS plaintext)
426    ///
427    /// Format (OpenVPN key_source + metadata):
428    /// - 4 bytes: literal 0
429    /// - 1 byte: key method (must be 2)
430    /// - 48 bytes: pre-master secret
431    /// - 32 bytes: random1
432    /// - 32 bytes: random2
433    /// - 2 bytes + N bytes: options string (length-prefixed, null-terminated)
434    /// - 2 bytes + N bytes: username (length-prefixed, optional)
435    /// - 2 bytes + N bytes: password (length-prefixed, optional)
436    /// - 2 bytes + N bytes: peer_info (length-prefixed, optional)
437    pub fn parse(data: &[u8]) -> Result<Self> {
438        // Minimum: 4 + 1 + 48 + 32 + 32 + 2 = 119 bytes
439        if data.len() < 119 {
440            return Err(ProtocolError::PacketTooShort {
441                expected: 119,
442                got: data.len(),
443            });
444        }
445
446        let mut offset = 0;
447
448        // Skip 4 bytes literal zero
449        offset += 4;
450
451        // Key method byte (must be 2)
452        let key_method = data[offset];
453        offset += 1;
454        if key_method != 2 {
455            return Err(ProtocolError::InvalidPacket(
456                format!("unsupported key method: {}", key_method),
457            ));
458        }
459
460        // Pre-master secret (48 bytes)
461        let mut pre_master = [0u8; 48];
462        pre_master.copy_from_slice(&data[offset..offset + 48]);
463        offset += 48;
464
465        // Random1 (32 bytes)
466        let mut random1 = [0u8; 32];
467        random1.copy_from_slice(&data[offset..offset + 32]);
468        offset += 32;
469
470        // Random2 (32 bytes)
471        let mut random2 = [0u8; 32];
472        random2.copy_from_slice(&data[offset..offset + 32]);
473        offset += 32;
474
475        // Options string (length-prefixed)
476        let options = Self::read_length_prefixed_string(data, &mut offset)?;
477
478        // Username (optional, length-prefixed)
479        let username = if offset + 2 <= data.len() {
480            let s = Self::read_length_prefixed_string(data, &mut offset)?;
481            if s.is_empty() { None } else { Some(s) }
482        } else {
483            None
484        };
485
486        // Password (optional, length-prefixed)
487        let password = if offset + 2 <= data.len() {
488            let s = Self::read_length_prefixed_string(data, &mut offset)?;
489            if s.is_empty() { None } else { Some(s) }
490        } else {
491            None
492        };
493
494        // Peer info (optional, length-prefixed)
495        let peer_info = if offset + 2 <= data.len() {
496            let s = Self::read_length_prefixed_string(data, &mut offset)?;
497            if s.is_empty() { None } else { Some(s) }
498        } else {
499            None
500        };
501
502        Ok(Self {
503            pre_master,
504            random1,
505            random2,
506            options,
507            username,
508            password,
509            peer_info,
510        })
511    }
512
513    /// Parse key method v2 data from server (received from TLS plaintext)
514    ///
515    /// The server format omits the pre_master secret, only sending:
516    /// - 4 bytes: literal 0
517    /// - 1 byte: key method (must be 2)
518    /// - 32 bytes: random1
519    /// - 32 bytes: random2
520    /// - 2 bytes + N bytes: options string (length-prefixed, null-terminated)
521    pub fn parse_from_server(data: &[u8]) -> Result<Self> {
522        // Minimum: 4 + 1 + 32 + 32 + 2 = 71 bytes
523        if data.len() < 71 {
524            return Err(ProtocolError::PacketTooShort {
525                expected: 71,
526                got: data.len(),
527            });
528        }
529
530        let mut offset = 0;
531
532        // Skip 4 bytes literal zero
533        offset += 4;
534
535        // Key method byte (must be 2)
536        let key_method = data[offset];
537        offset += 1;
538        if key_method != 2 {
539            return Err(ProtocolError::InvalidPacket(
540                format!("unsupported key method: {}", key_method),
541            ));
542        }
543
544        // Server does NOT send pre_master - only random1 and random2
545
546        // Random1 (32 bytes)
547        let mut random1 = [0u8; 32];
548        random1.copy_from_slice(&data[offset..offset + 32]);
549        offset += 32;
550
551        // Random2 (32 bytes)
552        let mut random2 = [0u8; 32];
553        random2.copy_from_slice(&data[offset..offset + 32]);
554        offset += 32;
555
556        // Options string (length-prefixed)
557        let options = Self::read_length_prefixed_string(data, &mut offset)?;
558
559        Ok(Self {
560            pre_master: [0u8; 48], // Not sent by server
561            random1,
562            random2,
563            options,
564            username: None,
565            password: None,
566            peer_info: None,
567        })
568    }
569
570    /// Read a length-prefixed string from the buffer
571    fn read_length_prefixed_string(data: &[u8], offset: &mut usize) -> Result<String> {
572        if *offset + 2 > data.len() {
573            return Err(ProtocolError::PacketTooShort {
574                expected: *offset + 2,
575                got: data.len(),
576            });
577        }
578        let len = u16::from_be_bytes([data[*offset], data[*offset + 1]]) as usize;
579        *offset += 2;
580        if *offset + len > data.len() {
581            return Err(ProtocolError::PacketTooShort {
582                expected: *offset + len,
583                got: data.len(),
584            });
585        }
586        let s = std::str::from_utf8(&data[*offset..*offset + len])
587            .map_err(|_| ProtocolError::InvalidPacket("invalid UTF-8 in key method v2".into()))?;
588        *offset += len;
589        // Trim trailing null bytes (OpenVPN often null-terminates strings)
590        Ok(s.trim_end_matches('\0').to_string())
591    }
592
593    /// Write a null-terminated string in OpenVPN's wire format:
594    /// u16 length (including null terminator) + string bytes + null byte
595    fn write_string(buf: &mut Vec<u8>, s: &str) {
596        let len = s.len() + 1; // include null terminator
597        buf.extend_from_slice(&(len as u16).to_be_bytes());
598        buf.extend_from_slice(s.as_bytes());
599        buf.push(0); // null terminator
600    }
601
602    /// Encode to bytes (OpenVPN key_method_v2 wire format)
603    ///
604    /// When `is_server` is true (server writing its response), pre_master is
605    /// NOT included in the key source material -- only random1 and random2.
606    /// When `is_server` is false (client writing), pre_master IS included.
607    /// This matches the OpenVPN key_source2_randomize_write asymmetry.
608    pub fn encode(&self, is_server: bool) -> Vec<u8> {
609        let mut buf = Vec::new();
610
611        // Literal 0
612        buf.extend_from_slice(&[0u8; 4]);
613
614        // Key method (2)
615        buf.push(2);
616
617        // Key source material:
618        // Client writes: pre_master(48) + random1(32) + random2(32) = 112 bytes
619        // Server writes: random1(32) + random2(32) = 64 bytes
620        if !is_server {
621            buf.extend_from_slice(&self.pre_master);
622        }
623
624        // Random1 (32 bytes)
625        buf.extend_from_slice(&self.random1);
626
627        // Random2 (32 bytes)
628        buf.extend_from_slice(&self.random2);
629
630        // Options string (null-terminated, length includes null)
631        Self::write_string(&mut buf, &self.options);
632
633        // Username (optional, null-terminated)
634        if let Some(username) = &self.username {
635            Self::write_string(&mut buf, username);
636        } else {
637            Self::write_string(&mut buf, "");
638        }
639
640        // Password (optional, null-terminated)
641        if let Some(password) = &self.password {
642            Self::write_string(&mut buf, password);
643        } else {
644            Self::write_string(&mut buf, "");
645        }
646
647        // Peer info (optional, null-terminated)
648        if let Some(peer_info) = &self.peer_info {
649            Self::write_string(&mut buf, peer_info);
650        }
651
652        buf
653    }
654}
655
656#[cfg(test)]
657mod tests {
658    use super::*;
659
660    #[test]
661    fn test_push_reply_roundtrip() {
662        let mut reply = PushReply::default();
663        reply.ifconfig = Some(("10.8.0.2".to_string(), "255.255.255.0".to_string()));
664        reply.dns.push("1.1.1.1".to_string());
665        reply.routes.push(PushRoute::new("192.168.1.0", "255.255.255.0"));
666        reply.redirect_gateway = true;
667
668        let encoded = reply.encode();
669        let parsed = PushReply::parse(&encoded).unwrap();
670
671        assert_eq!(parsed.ifconfig, reply.ifconfig);
672        assert_eq!(parsed.dns, reply.dns);
673        assert!(parsed.redirect_gateway);
674    }
675
676    #[test]
677    fn test_auth_message() {
678        let auth = AuthMessage {
679            username: "user".to_string(),
680            password: "pass".to_string(),
681        };
682
683        let encoded = auth.encode();
684        let parsed = AuthMessage::parse(&encoded).unwrap();
685
686        assert_eq!(parsed.username, "user");
687        assert_eq!(parsed.password, "pass");
688    }
689}