Skip to main content

corevpn_core/
network.rs

1//! Network types and IP address management
2
3use ipnet::{IpNet, Ipv4Net, Ipv6Net};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
7
8use crate::{CoreError, Result};
9
10/// VPN IP address assigned to a client
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct VpnAddress {
13    /// IPv4 address (if assigned)
14    pub ipv4: Option<Ipv4Addr>,
15    /// IPv6 address (if assigned)
16    pub ipv6: Option<Ipv6Addr>,
17}
18
19impl VpnAddress {
20    /// Create with only IPv4
21    pub fn v4(addr: Ipv4Addr) -> Self {
22        Self {
23            ipv4: Some(addr),
24            ipv6: None,
25        }
26    }
27
28    /// Create with only IPv6
29    pub fn v6(addr: Ipv6Addr) -> Self {
30        Self {
31            ipv4: None,
32            ipv6: Some(addr),
33        }
34    }
35
36    /// Create with both IPv4 and IPv6
37    pub fn dual(ipv4: Ipv4Addr, ipv6: Ipv6Addr) -> Self {
38        Self {
39            ipv4: Some(ipv4),
40            ipv6: Some(ipv6),
41        }
42    }
43
44    /// Get primary address (prefers IPv4)
45    pub fn primary(&self) -> Option<IpAddr> {
46        self.ipv4
47            .map(IpAddr::V4)
48            .or_else(|| self.ipv6.map(IpAddr::V6))
49    }
50}
51
52/// Route to be pushed to VPN clients
53#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
54pub struct Route {
55    /// Network/prefix to route
56    pub network: IpNet,
57    /// Gateway (None = use VPN gateway)
58    pub gateway: Option<IpAddr>,
59    /// Metric/priority
60    pub metric: u32,
61}
62
63impl Route {
64    /// Create a new route
65    pub fn new(network: IpNet) -> Result<Self> {
66        // Validate network
67        Self::validate_network(&network)?;
68        
69        Ok(Self {
70            network,
71            gateway: None,
72            metric: 0,
73        })
74    }
75
76    /// Validate network configuration
77    fn validate_network(network: &IpNet) -> Result<()> {
78        match network {
79            IpNet::V4(v4_net) => {
80                // Validate IPv4 network prefix length
81                let prefix_len = v4_net.prefix_len();
82                if prefix_len > 32 {
83                    return Err(CoreError::ConfigError(
84                        "Invalid IPv4 network prefix length".into(),
85                    ));
86                }
87            }
88            IpNet::V6(v6_net) => {
89                // Validate IPv6 network prefix length
90                let prefix_len = v6_net.prefix_len();
91                if prefix_len > 128 {
92                    return Err(CoreError::ConfigError(
93                        "Invalid IPv6 network prefix length".into(),
94                    ));
95                }
96            }
97        }
98        Ok(())
99    }
100
101    /// Create default route (0.0.0.0/0)
102    pub fn default_v4() -> Self {
103        Self {
104            network: "0.0.0.0/0".parse().unwrap(),
105            gateway: None,
106            metric: 0,
107        }
108    }
109
110    /// Create default IPv6 route (::/0)
111    pub fn default_v6() -> Self {
112        Self {
113            network: "::/0".parse().unwrap(),
114            gateway: None,
115            metric: 0,
116        }
117    }
118
119    /// Set gateway
120    pub fn with_gateway(mut self, gateway: IpAddr) -> Self {
121        self.gateway = Some(gateway);
122        self
123    }
124
125    /// Set metric
126    pub fn with_metric(mut self, metric: u32) -> Self {
127        self.metric = metric;
128        self
129    }
130}
131
132/// IP address pool for assigning addresses to VPN clients
133pub struct AddressPool {
134    /// IPv4 network range
135    ipv4_net: Option<Ipv4Net>,
136    /// IPv6 network range
137    ipv6_net: Option<Ipv6Net>,
138    /// Allocated IPv4 addresses
139    allocated_v4: parking_lot::RwLock<HashSet<Ipv4Addr>>,
140    /// Allocated IPv6 addresses
141    allocated_v6: parking_lot::RwLock<HashSet<Ipv6Addr>>,
142    /// Reserved addresses (e.g., gateway, broadcast)
143    reserved_v4: HashSet<Ipv4Addr>,
144    /// Reserved IPv6 addresses
145    reserved_v6: HashSet<Ipv6Addr>,
146}
147
148impl AddressPool {
149    /// Create a new address pool
150    ///
151    /// # Arguments
152    /// * `ipv4_net` - IPv4 network (e.g., "10.8.0.0/24")
153    /// * `ipv6_net` - IPv6 network (e.g., "fd00::/64")
154    ///
155    /// # Errors
156    /// Returns an error if the network configuration is invalid
157    pub fn new(ipv4_net: Option<Ipv4Net>, ipv6_net: Option<Ipv6Net>) -> Result<Self> {
158        // Validate IPv4 network if provided
159        if let Some(ref net) = ipv4_net {
160            Self::validate_ipv4_net(net)?;
161        }
162
163        // Validate IPv6 network if provided
164        if let Some(ref net) = ipv6_net {
165            Self::validate_ipv6_net(net)?;
166        }
167
168        // Ensure at least one network is provided
169        if ipv4_net.is_none() && ipv6_net.is_none() {
170            return Err(CoreError::ConfigError(
171                "At least one network (IPv4 or IPv6) must be provided".into(),
172            ));
173        }
174
175        Ok(Self::new_unchecked(ipv4_net, ipv6_net))
176    }
177
178    /// Create a new address pool without validation (internal use)
179    fn new_unchecked(ipv4_net: Option<Ipv4Net>, ipv6_net: Option<Ipv6Net>) -> Self {
180        let mut reserved_v4 = HashSet::new();
181        let mut reserved_v6 = HashSet::new();
182
183        // Reserve network and broadcast addresses for IPv4
184        if let Some(net) = &ipv4_net {
185            reserved_v4.insert(net.network());
186            reserved_v4.insert(net.broadcast());
187            // Reserve .1 for gateway
188            let gateway = Ipv4Addr::from(u32::from(net.network()) + 1);
189            reserved_v4.insert(gateway);
190        }
191
192        // Reserve first address for IPv6 gateway
193        if let Some(net) = &ipv6_net {
194            let gateway = Ipv6Addr::from(u128::from(net.network()) + 1);
195            reserved_v6.insert(gateway);
196        }
197
198        Self {
199            ipv4_net,
200            ipv6_net,
201            allocated_v4: parking_lot::RwLock::new(HashSet::new()),
202            allocated_v6: parking_lot::RwLock::new(HashSet::new()),
203            reserved_v4,
204            reserved_v6,
205        }
206    }
207
208    /// Validate IPv4 network configuration
209    fn validate_ipv4_net(net: &Ipv4Net) -> Result<()> {
210        // Check prefix length is reasonable (not too small, not too large)
211        let prefix_len = net.prefix_len();
212        if prefix_len < 8 {
213            return Err(CoreError::ConfigError(format!(
214                "IPv4 network prefix length {} is too small (minimum 8)",
215                prefix_len
216            )));
217        }
218        if prefix_len > 30 {
219            return Err(CoreError::ConfigError(format!(
220                "IPv4 network prefix length {} is too large (maximum 30)",
221                prefix_len
222            )));
223        }
224
225        // Ensure network is not a loopback or multicast address
226        let network_addr = net.network();
227        if network_addr.is_loopback() {
228            return Err(CoreError::ConfigError(
229                "IPv4 network cannot be a loopback address".into(),
230            ));
231        }
232        if network_addr.is_multicast() {
233            return Err(CoreError::ConfigError(
234                "IPv4 network cannot be a multicast address".into(),
235            ));
236        }
237
238        // Ensure network is not in the link-local range (169.254.0.0/16)
239        if network_addr.octets()[0] == 169 && network_addr.octets()[1] == 254 {
240            return Err(CoreError::ConfigError(
241                "IPv4 network cannot be in the link-local range (169.254.0.0/16)".into(),
242            ));
243        }
244
245        Ok(())
246    }
247
248    /// Validate IPv6 network configuration
249    fn validate_ipv6_net(net: &Ipv6Net) -> Result<()> {
250        // Check prefix length is reasonable
251        let prefix_len = net.prefix_len();
252        if prefix_len < 48 {
253            return Err(CoreError::ConfigError(format!(
254                "IPv6 network prefix length {} is too small (minimum 48)",
255                prefix_len
256            )));
257        }
258        if prefix_len > 120 {
259            return Err(CoreError::ConfigError(format!(
260                "IPv6 network prefix length {} is too large (maximum 120)",
261                prefix_len
262            )));
263        }
264
265        // Ensure network is not a loopback or multicast address
266        let network_addr = net.network();
267        if network_addr.is_loopback() {
268            return Err(CoreError::ConfigError(
269                "IPv6 network cannot be a loopback address".into(),
270            ));
271        }
272        if network_addr.is_multicast() {
273            return Err(CoreError::ConfigError(
274                "IPv6 network cannot be a multicast address".into(),
275            ));
276        }
277
278        // Ensure network is not in the link-local range (fe80::/10)
279        let segments = network_addr.segments();
280        if segments[0] & 0xffc0 == 0xfe80 {
281            return Err(CoreError::ConfigError(
282                "IPv6 network cannot be in the link-local range (fe80::/10)".into(),
283            ));
284        }
285
286        Ok(())
287    }
288
289    /// Get the gateway IPv4 address
290    pub fn gateway_v4(&self) -> Option<Ipv4Addr> {
291        self.ipv4_net.map(|net| {
292            Ipv4Addr::from(u32::from(net.network()) + 1)
293        })
294    }
295
296    /// Get the gateway IPv6 address
297    pub fn gateway_v6(&self) -> Option<Ipv6Addr> {
298        self.ipv6_net.map(|net| {
299            Ipv6Addr::from(u128::from(net.network()) + 1)
300        })
301    }
302
303    /// Allocate an address from the pool
304    pub fn allocate(&self) -> Result<VpnAddress> {
305        let ipv4 = if let Some(net) = &self.ipv4_net {
306            Some(self.allocate_v4(net)?)
307        } else {
308            None
309        };
310
311        let ipv6 = if let Some(net) = &self.ipv6_net {
312            Some(self.allocate_v6(net)?)
313        } else {
314            None
315        };
316
317        if ipv4.is_none() && ipv6.is_none() {
318            return Err(CoreError::ConfigError("No address pools configured".into()));
319        }
320
321        Ok(VpnAddress { ipv4, ipv6 })
322    }
323
324    fn allocate_v4(&self, net: &Ipv4Net) -> Result<Ipv4Addr> {
325        let mut allocated = self.allocated_v4.write();
326
327        // Start from .2 (after gateway)
328        let start = u32::from(net.network()) + 2;
329        let end = u32::from(net.broadcast());
330
331        for addr_u32 in start..end {
332            let addr = Ipv4Addr::from(addr_u32);
333            if !self.reserved_v4.contains(&addr) && !allocated.contains(&addr) {
334                allocated.insert(addr);
335                return Ok(addr);
336            }
337        }
338
339        Err(CoreError::AddressPoolExhausted)
340    }
341
342    fn allocate_v6(&self, net: &Ipv6Net) -> Result<Ipv6Addr> {
343        let mut allocated = self.allocated_v6.write();
344
345        // Start from ::2 (after gateway)
346        let start = u128::from(net.network()) + 2;
347        // Limit search to reasonable range
348        let end = start + 65534;
349
350        for addr_u128 in start..end {
351            let addr = Ipv6Addr::from(addr_u128);
352            if !self.reserved_v6.contains(&addr) && !allocated.contains(&addr) {
353                allocated.insert(addr);
354                return Ok(addr);
355            }
356        }
357
358        Err(CoreError::AddressPoolExhausted)
359    }
360
361    /// Allocate a specific address (for static assignment)
362    pub fn allocate_specific(&self, addr: VpnAddress) -> Result<VpnAddress> {
363        if let Some(v4) = addr.ipv4 {
364            let mut allocated = self.allocated_v4.write();
365            if let Some(net) = &self.ipv4_net {
366                if !net.contains(&v4) {
367                    return Err(CoreError::InvalidAddress(format!(
368                        "{} not in pool {}",
369                        v4, net
370                    )));
371                }
372            }
373            if self.reserved_v4.contains(&v4) || allocated.contains(&v4) {
374                return Err(CoreError::InvalidAddress(format!(
375                    "{} is reserved or already allocated",
376                    v4
377                )));
378            }
379            allocated.insert(v4);
380        }
381
382        if let Some(v6) = addr.ipv6 {
383            let mut allocated = self.allocated_v6.write();
384            if let Some(net) = &self.ipv6_net {
385                if !net.contains(&v6) {
386                    return Err(CoreError::InvalidAddress(format!(
387                        "{} not in pool {}",
388                        v6, net
389                    )));
390                }
391            }
392            if self.reserved_v6.contains(&v6) || allocated.contains(&v6) {
393                return Err(CoreError::InvalidAddress(format!(
394                    "{} is reserved or already allocated",
395                    v6
396                )));
397            }
398            allocated.insert(v6);
399        }
400
401        Ok(addr)
402    }
403
404    /// Release an address back to the pool
405    pub fn release(&self, addr: &VpnAddress) {
406        if let Some(v4) = addr.ipv4 {
407            self.allocated_v4.write().remove(&v4);
408        }
409        if let Some(v6) = addr.ipv6 {
410            self.allocated_v6.write().remove(&v6);
411        }
412    }
413
414    /// Get the number of available IPv4 addresses
415    pub fn available_v4(&self) -> usize {
416        if let Some(net) = &self.ipv4_net {
417            let total = net.hosts().count();
418            let reserved = self.reserved_v4.len();
419            let allocated = self.allocated_v4.read().len();
420            total.saturating_sub(reserved).saturating_sub(allocated)
421        } else {
422            0
423        }
424    }
425
426    /// Get the number of available IPv6 addresses
427    pub fn available_v6(&self) -> usize {
428        if self.ipv6_net.is_some() {
429            // IPv6 pools are effectively unlimited, return large number
430            let allocated = self.allocated_v6.read().len();
431            65534usize.saturating_sub(allocated)
432        } else {
433            0
434        }
435    }
436
437    /// Get pool statistics
438    pub fn stats(&self) -> PoolStats {
439        PoolStats {
440            ipv4_total: self.ipv4_net.map(|n| n.hosts().count()).unwrap_or(0),
441            ipv4_allocated: self.allocated_v4.read().len(),
442            ipv4_available: self.available_v4(),
443            ipv6_allocated: self.allocated_v6.read().len(),
444            ipv6_available: self.available_v6(),
445        }
446    }
447}
448
449/// Address pool statistics
450#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct PoolStats {
452    /// Total IPv4 addresses in pool
453    pub ipv4_total: usize,
454    /// Allocated IPv4 addresses
455    pub ipv4_allocated: usize,
456    /// Available IPv4 addresses
457    pub ipv4_available: usize,
458    /// Allocated IPv6 addresses
459    pub ipv6_allocated: usize,
460    /// Available IPv6 addresses
461    pub ipv6_available: usize,
462}
463
464/// DNS configuration to push to clients
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct DnsConfig {
467    /// DNS servers
468    pub servers: Vec<IpAddr>,
469    /// Search domains
470    pub search_domains: Vec<String>,
471}
472
473impl Default for DnsConfig {
474    fn default() -> Self {
475        Self {
476            servers: vec![
477                "1.1.1.1".parse().unwrap(),  // Cloudflare
478                "1.0.0.1".parse().unwrap(),
479            ],
480            search_domains: vec![],
481        }
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_address_pool() {
491        let pool = AddressPool::new(
492            Some("10.8.0.0/24".parse().unwrap()),
493            None,
494        ).unwrap();
495
496        // First allocation should be .2
497        let addr1 = pool.allocate().unwrap();
498        assert_eq!(addr1.ipv4, Some("10.8.0.2".parse().unwrap()));
499
500        // Second should be .3
501        let addr2 = pool.allocate().unwrap();
502        assert_eq!(addr2.ipv4, Some("10.8.0.3".parse().unwrap()));
503
504        // Release first
505        pool.release(&addr1);
506
507        // Next allocation should reuse .2
508        let addr3 = pool.allocate().unwrap();
509        assert_eq!(addr3.ipv4, Some("10.8.0.2".parse().unwrap()));
510    }
511
512    #[test]
513    fn test_gateway_address() {
514        let pool = AddressPool::new(
515            Some("10.8.0.0/24".parse().unwrap()),
516            Some("fd00::/64".parse().unwrap()),
517        ).unwrap();
518
519        assert_eq!(pool.gateway_v4(), Some("10.8.0.1".parse().unwrap()));
520        assert_eq!(pool.gateway_v6(), Some("fd00::1".parse().unwrap()));
521    }
522
523    #[test]
524    fn test_route() {
525        let route = Route::new("192.168.1.0/24".parse().unwrap()).unwrap()
526            .with_gateway("10.8.0.1".parse().unwrap())
527            .with_metric(100);
528
529        assert_eq!(route.metric, 100);
530        assert_eq!(route.gateway, Some("10.8.0.1".parse().unwrap()));
531    }
532
533    #[test]
534    fn test_address_pool_validation() {
535        // Test invalid IPv4 prefix length
536        assert!(AddressPool::new(Some("10.8.0.0/7".parse().unwrap()), None).is_err());
537        assert!(AddressPool::new(Some("10.8.0.0/31".parse().unwrap()), None).is_err());
538        
539        // Test invalid IPv6 prefix length
540        assert!(AddressPool::new(None, Some("fd00::/47".parse().unwrap())).is_err());
541        assert!(AddressPool::new(None, Some("fd00::/121".parse().unwrap())).is_err());
542        
543        // Test loopback addresses
544        assert!(AddressPool::new(Some("127.0.0.0/24".parse().unwrap()), None).is_err());
545        assert!(AddressPool::new(None, Some("::1/64".parse().unwrap())).is_err());
546        
547        // Test valid addresses
548        assert!(AddressPool::new(Some("10.8.0.0/24".parse().unwrap()), None).is_ok());
549        assert!(AddressPool::new(None, Some("fd00::/64".parse().unwrap())).is_ok());
550    }
551}