Skip to main content

corevpn_protocol/
control.rs

1//! Control Channel Message Types
2
3use std::net::Ipv4Addr;
4
5use bytes::Bytes;
6use serde::{Deserialize, Serialize};
7
8use crate::{ProtocolError, Result};
9
10/// Validate that a string is a valid IPv4 address
11fn validate_ipv4(s: &str) -> Result<()> {
12    s.parse::<Ipv4Addr>()
13        .map_err(|_| ProtocolError::InvalidPacket(format!("invalid IPv4 address: {}", s)))?;
14    Ok(())
15}
16
17/// Control channel message types
18#[derive(Debug, Clone)]
19pub enum ControlMessage {
20    /// TLS data (wrapped in control channel)
21    TlsData(Bytes),
22    /// Push request from client
23    PushRequest,
24    /// Push reply from server
25    PushReply(PushReply),
26    /// Authentication data
27    Auth(AuthMessage),
28    /// Info message (version, etc.)
29    Info(String),
30    /// Exit/shutdown
31    Exit,
32}
33
34/// Control packet for the reliable transport layer
35#[derive(Debug, Clone)]
36pub struct ControlPacket {
37    /// Packet ID for reliability
38    pub packet_id: u32,
39    /// Message content
40    pub message: ControlMessage,
41}
42
43impl ControlPacket {
44    /// Create a new control packet
45    pub fn new(packet_id: u32, message: ControlMessage) -> Self {
46        Self { packet_id, message }
47    }
48}
49
50/// Push reply containing VPN configuration
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct PushReply {
53    /// Routes to push
54    pub routes: Vec<PushRoute>,
55    /// IPv4 address and netmask
56    pub ifconfig: Option<(String, String)>,
57    /// IPv6 address
58    pub ifconfig_ipv6: Option<String>,
59    /// DNS servers
60    pub dns: Vec<String>,
61    /// Search domains
62    pub dns_search: Vec<String>,
63    /// Redirect gateway (full tunnel)
64    pub redirect_gateway: bool,
65    /// Topology type
66    pub topology: Topology,
67    /// Ping interval
68    pub ping: u32,
69    /// Ping restart timeout
70    pub ping_restart: u32,
71    /// Additional options
72    pub options: Vec<String>,
73}
74
75impl Default for PushReply {
76    fn default() -> Self {
77        Self {
78            routes: vec![],
79            ifconfig: None,
80            ifconfig_ipv6: None,
81            dns: vec![],
82            dns_search: vec![],
83            redirect_gateway: false,
84            topology: Topology::Subnet,
85            ping: 10,
86            ping_restart: 60,
87            options: vec![],
88        }
89    }
90}
91
92impl PushReply {
93    /// Encode as OpenVPN push reply string
94    pub fn encode(&self) -> String {
95        let mut parts = vec!["PUSH_REPLY".to_string()];
96
97        // Topology
98        parts.push(format!("topology {}", self.topology.as_str()));
99
100        // ifconfig
101        if let Some((ip, mask)) = &self.ifconfig {
102            parts.push(format!("ifconfig {} {}", ip, mask));
103        }
104
105        // ifconfig-ipv6
106        if let Some(ipv6) = &self.ifconfig_ipv6 {
107            parts.push(format!("ifconfig-ipv6 {}", ipv6));
108        }
109
110        // Routes
111        for route in &self.routes {
112            parts.push(route.encode());
113        }
114
115        // Redirect gateway
116        if self.redirect_gateway {
117            parts.push("redirect-gateway def1".to_string());
118        }
119
120        // DNS
121        for dns in &self.dns {
122            parts.push(format!("dhcp-option DNS {}", dns));
123        }
124
125        // DNS search domains
126        for domain in &self.dns_search {
127            parts.push(format!("dhcp-option DOMAIN {}", domain));
128        }
129
130        // Ping settings
131        parts.push(format!("ping {}", self.ping));
132        parts.push(format!("ping-restart {}", self.ping_restart));
133
134        // Additional options
135        for opt in &self.options {
136            parts.push(opt.clone());
137        }
138
139        parts.join(",")
140    }
141
142    /// Parse from OpenVPN push reply string
143    pub fn parse(s: &str) -> Result<Self> {
144        let mut reply = Self::default();
145
146        // Remove PUSH_REPLY prefix if present
147        let s = s.strip_prefix("PUSH_REPLY,").unwrap_or(s);
148
149        for part in s.split(',') {
150            let part = part.trim();
151            if part.is_empty() {
152                continue;
153            }
154
155            let mut tokens = part.split_whitespace();
156            match tokens.next() {
157                Some("topology") => {
158                    if let Some(topo) = tokens.next() {
159                        reply.topology = Topology::parse(topo);
160                    }
161                }
162                Some("ifconfig") => {
163                    let ip = tokens.next().unwrap_or("").to_string();
164                    let mask = tokens.next().unwrap_or("").to_string();
165                    reply.ifconfig = Some((ip, mask));
166                }
167                Some("ifconfig-ipv6") => {
168                    if let Some(ipv6) = tokens.next() {
169                        reply.ifconfig_ipv6 = Some(ipv6.to_string());
170                    }
171                }
172                Some("route") => {
173                    if let Ok(route) = PushRoute::parse(part) {
174                        reply.routes.push(route);
175                    }
176                }
177                Some("redirect-gateway") => {
178                    reply.redirect_gateway = true;
179                }
180                Some("dhcp-option") => {
181                    match tokens.next() {
182                        Some("DNS") => {
183                            if let Some(dns) = tokens.next() {
184                                reply.dns.push(dns.to_string());
185                            }
186                        }
187                        Some("DOMAIN") => {
188                            if let Some(domain) = tokens.next() {
189                                reply.dns_search.push(domain.to_string());
190                            }
191                        }
192                        _ => {}
193                    }
194                }
195                Some("ping") => {
196                    if let Some(Ok(p)) = tokens.next().map(|s| s.parse()) {
197                        reply.ping = p;
198                    }
199                }
200                Some("ping-restart") => {
201                    if let Some(Ok(p)) = tokens.next().map(|s| s.parse()) {
202                        reply.ping_restart = p;
203                    }
204                }
205                _ => {
206                    reply.options.push(part.to_string());
207                }
208            }
209        }
210
211        Ok(reply)
212    }
213}
214
215/// Route to push to client
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct PushRoute {
218    /// Network address
219    pub network: String,
220    /// Netmask
221    pub netmask: String,
222    /// Gateway (optional, vpn_gateway used if not set)
223    pub gateway: Option<String>,
224    /// Metric
225    pub metric: Option<u32>,
226}
227
228impl PushRoute {
229    /// Create a new route
230    pub fn new(network: &str, netmask: &str) -> Self {
231        Self {
232            network: network.to_string(),
233            netmask: netmask.to_string(),
234            gateway: None,
235            metric: None,
236        }
237    }
238
239    /// Encode as OpenVPN route directive
240    pub fn encode(&self) -> String {
241        let mut s = format!("route {} {}", self.network, self.netmask);
242        if let Some(gw) = &self.gateway {
243            s.push_str(&format!(" {}", gw));
244        } else {
245            s.push_str(" vpn_gateway");
246        }
247        if let Some(metric) = self.metric {
248            s.push_str(&format!(" {}", metric));
249        }
250        s
251    }
252
253    /// Parse from OpenVPN route directive
254    pub fn parse(s: &str) -> Result<Self> {
255        let mut tokens = s.split_whitespace();
256        tokens.next(); // skip "route"
257
258        let network_str = tokens
259            .next()
260            .ok_or_else(|| ProtocolError::InvalidPacket("missing network in route".into()))?;
261        
262        // Validate network address
263        validate_ipv4(network_str)?;
264        let network = network_str.to_string();
265
266        let netmask_str = tokens
267            .next()
268            .ok_or_else(|| ProtocolError::InvalidPacket("missing netmask in route".into()))?;
269        
270        // Validate netmask
271        validate_ipv4(netmask_str)?;
272        let netmask = netmask_str.to_string();
273
274        let gateway = tokens.next().and_then(|g| {
275            if g == "vpn_gateway" {
276                None
277            } else {
278                // Validate gateway IP address
279                validate_ipv4(g).ok().map(|_| g.to_string())
280            }
281        });
282
283        let metric = tokens.next().and_then(|m| {
284            m.parse::<u32>().ok().filter(|&m| m <= 9999) // Reasonable metric limit
285        });
286
287        Ok(Self {
288            network,
289            netmask,
290            gateway,
291            metric,
292        })
293    }
294}
295
296/// Network topology type
297#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
298pub enum Topology {
299    /// Point-to-point (net30)
300    Net30,
301    /// Point-to-point (p2p)
302    P2P,
303    /// Subnet mode (recommended)
304    #[default]
305    Subnet,
306}
307
308impl Topology {
309    /// Parse from string
310    pub fn parse(s: &str) -> Self {
311        match s.to_lowercase().as_str() {
312            "net30" => Topology::Net30,
313            "p2p" => Topology::P2P,
314            "subnet" => Topology::Subnet,
315            _ => Topology::Subnet,
316        }
317    }
318
319    /// Convert to string
320    pub fn as_str(&self) -> &'static str {
321        match self {
322            Topology::Net30 => "net30",
323            Topology::P2P => "p2p",
324            Topology::Subnet => "subnet",
325        }
326    }
327}
328
329/// Authentication message from client
330#[derive(Debug, Clone)]
331pub struct AuthMessage {
332    /// Username
333    pub username: String,
334    /// Password
335    pub password: String,
336}
337
338impl AuthMessage {
339    /// Maximum username length
340    const MAX_USERNAME_LEN: usize = 256;
341    /// Maximum password length
342    const MAX_PASSWORD_LEN: usize = 1024;
343
344    /// Parse from OpenVPN auth data
345    pub fn parse(data: &[u8]) -> Result<Self> {
346        // Format: username\0password\0
347        // Security: Limit total input size to prevent DoS
348        if data.len() > Self::MAX_USERNAME_LEN + Self::MAX_PASSWORD_LEN + 2 {
349            return Err(ProtocolError::InvalidPacket("auth data too long".into()));
350        }
351
352        let s = std::str::from_utf8(data)
353            .map_err(|_| ProtocolError::InvalidPacket("invalid UTF-8 in auth".into()))?;
354
355        let parts: Vec<&str> = s.split('\0').collect();
356        if parts.len() < 2 {
357            return Err(ProtocolError::InvalidPacket("missing auth fields".into()));
358        }
359
360        let username = parts[0];
361        let password = parts[1];
362
363        // Validate lengths
364        if username.len() > Self::MAX_USERNAME_LEN {
365            return Err(ProtocolError::InvalidPacket(
366                format!("username too long (max {} bytes)", Self::MAX_USERNAME_LEN).into(),
367            ));
368        }
369        if password.len() > Self::MAX_PASSWORD_LEN {
370            return Err(ProtocolError::InvalidPacket(
371                format!("password too long (max {} bytes)", Self::MAX_PASSWORD_LEN).into(),
372            ));
373        }
374
375        Ok(Self {
376            username: username.to_string(),
377            password: password.to_string(),
378        })
379    }
380
381    /// Encode to OpenVPN auth format
382    pub fn encode(&self) -> Vec<u8> {
383        let mut data = Vec::new();
384        data.extend_from_slice(self.username.as_bytes());
385        data.push(0);
386        data.extend_from_slice(self.password.as_bytes());
387        data.push(0);
388        data
389    }
390}
391
392/// Key method v2 data (exchanged during TLS handshake)
393#[derive(Debug, Clone)]
394pub struct KeyMethodV2 {
395    /// Pre-master secret (48 bytes)
396    pub pre_master: [u8; 48],
397    /// Random data (32 bytes)
398    pub random: [u8; 32],
399    /// Options string
400    pub options: String,
401    /// Username (if using auth)
402    pub username: Option<String>,
403    /// Password (if using auth)
404    pub password: Option<String>,
405    /// Peer info
406    pub peer_info: Option<String>,
407}
408
409impl KeyMethodV2 {
410    /// Encode to bytes
411    pub fn encode(&self) -> Vec<u8> {
412        let mut buf = Vec::new();
413
414        // Literal 0
415        buf.extend_from_slice(&[0u8; 4]);
416
417        // Key method (2)
418        buf.push(2);
419
420        // Pre-master secret
421        buf.extend_from_slice(&self.pre_master);
422
423        // Random
424        buf.extend_from_slice(&self.random);
425
426        // Options string length + string
427        let opts_bytes = self.options.as_bytes();
428        buf.extend_from_slice(&(opts_bytes.len() as u16).to_be_bytes());
429        buf.extend_from_slice(opts_bytes);
430
431        // Username (optional)
432        if let Some(username) = &self.username {
433            let username_bytes = username.as_bytes();
434            buf.extend_from_slice(&(username_bytes.len() as u16).to_be_bytes());
435            buf.extend_from_slice(username_bytes);
436        } else {
437            buf.extend_from_slice(&0u16.to_be_bytes());
438        }
439
440        // Password (optional)
441        if let Some(password) = &self.password {
442            let password_bytes = password.as_bytes();
443            buf.extend_from_slice(&(password_bytes.len() as u16).to_be_bytes());
444            buf.extend_from_slice(password_bytes);
445        } else {
446            buf.extend_from_slice(&0u16.to_be_bytes());
447        }
448
449        // Peer info (optional)
450        if let Some(peer_info) = &self.peer_info {
451            let peer_info_bytes = peer_info.as_bytes();
452            buf.extend_from_slice(&(peer_info_bytes.len() as u16).to_be_bytes());
453            buf.extend_from_slice(peer_info_bytes);
454        }
455
456        buf
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    #[test]
465    fn test_push_reply_roundtrip() {
466        let mut reply = PushReply::default();
467        reply.ifconfig = Some(("10.8.0.2".to_string(), "255.255.255.0".to_string()));
468        reply.dns.push("1.1.1.1".to_string());
469        reply.routes.push(PushRoute::new("192.168.1.0", "255.255.255.0"));
470        reply.redirect_gateway = true;
471
472        let encoded = reply.encode();
473        let parsed = PushReply::parse(&encoded).unwrap();
474
475        assert_eq!(parsed.ifconfig, reply.ifconfig);
476        assert_eq!(parsed.dns, reply.dns);
477        assert!(parsed.redirect_gateway);
478    }
479
480    #[test]
481    fn test_auth_message() {
482        let auth = AuthMessage {
483            username: "user".to_string(),
484            password: "pass".to_string(),
485        };
486
487        let encoded = auth.encode();
488        let parsed = AuthMessage::parse(&encoded).unwrap();
489
490        assert_eq!(parsed.username, "user");
491        assert_eq!(parsed.password, "pass");
492    }
493}