1use crate::NodeAddr;
8use std::collections::{HashMap, VecDeque};
9use std::net::Ipv6Addr;
10use std::time::Instant;
11use tracing::{debug, info};
12
13#[derive(Debug, thiserror::Error)]
15pub enum PoolError {
16 #[error("invalid CIDR: {0}")]
17 InvalidCidr(String),
18 #[error("pool exhausted ({0} addresses in use)")]
19 Exhausted(usize),
20 #[error("prefix length must be between 1 and 128")]
21 InvalidPrefix,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum MappingState {
27 Allocated,
29 Active,
31 Draining,
33}
34
35#[derive(Debug, Clone)]
37pub struct VirtualIpMapping {
38 pub node_addr: NodeAddr,
40 pub virtual_ip: Ipv6Addr,
42 pub mesh_addr: Ipv6Addr,
44 pub dns_name: String,
46 pub state: MappingState,
48 pub created: Instant,
50 pub last_referenced: Instant,
52 pub drain_start: Option<Instant>,
54 pub session_count: u32,
56}
57
58#[derive(Debug)]
60pub enum PoolEvent {
61 MappingCreated {
63 virtual_ip: Ipv6Addr,
64 mesh_addr: Ipv6Addr,
65 },
66 MappingRemoved {
68 virtual_ip: Ipv6Addr,
69 mesh_addr: Ipv6Addr,
70 },
71}
72
73#[derive(Debug, Clone)]
75pub struct PoolStatus {
76 pub total: usize,
77 pub allocated: usize,
78 pub active: usize,
79 pub draining: usize,
80 pub free: usize,
81}
82
83#[derive(Debug, Clone)]
85pub struct MappingInfo {
86 pub virtual_ip: Ipv6Addr,
87 pub mesh_addr: Ipv6Addr,
88 pub node_addr: NodeAddr,
89 pub dns_name: String,
90 pub state: MappingState,
91 pub session_count: u32,
92 pub age_secs: u64,
93 pub last_ref_secs: u64,
94}
95
96pub trait ConntrackQuerier: Send + Sync {
98 fn active_sessions(&self, virtual_ip: Ipv6Addr) -> Result<u32, std::io::Error>;
101}
102
103pub struct ProcConntrack;
105
106impl ConntrackQuerier for ProcConntrack {
107 fn active_sessions(&self, virtual_ip: Ipv6Addr) -> Result<u32, std::io::Error> {
108 let content = std::fs::read_to_string("/proc/net/nf_conntrack")?;
109 let target = virtual_ip.to_string();
110 let count = content
111 .lines()
112 .filter(|line| line.contains(&format!("dst={target}")))
113 .count();
114 Ok(count as u32)
115 }
116}
117
118pub struct VirtualIpPool {
120 available: VecDeque<Ipv6Addr>,
122 mappings: HashMap<NodeAddr, VirtualIpMapping>,
124 reverse: HashMap<Ipv6Addr, NodeAddr>,
126 ttl_secs: u64,
128 grace_secs: u64,
130 total: usize,
132}
133
134impl VirtualIpPool {
135 pub fn new(cidr: &str, ttl_secs: u64, grace_secs: u64) -> Result<Self, PoolError> {
137 let (base, prefix_len) = parse_ipv6_cidr(cidr)?;
138 if prefix_len == 0 || prefix_len > 128 {
139 return Err(PoolError::InvalidPrefix);
140 }
141
142 let mut available = VecDeque::new();
143 let host_bits = 128 - prefix_len;
144
145 let max_addrs: u128 = if host_bits > 16 {
147 1u128 << 16
148 } else {
149 1u128 << host_bits
150 };
151
152 let base_int = u128::from(base);
153 for i in 1..max_addrs {
155 available.push_back(Ipv6Addr::from(base_int + i));
156 }
157
158 let total = available.len();
159 info!(cidr = %cidr, addresses = total, "Virtual IP pool initialized");
160
161 Ok(Self {
162 available,
163 mappings: HashMap::new(),
164 reverse: HashMap::new(),
165 ttl_secs,
166 grace_secs,
167 total,
168 })
169 }
170
171 pub fn allocate(
174 &mut self,
175 node_addr: NodeAddr,
176 mesh_addr: Ipv6Addr,
177 dns_name: &str,
178 ) -> Result<(Ipv6Addr, bool), PoolError> {
179 if let Some(mapping) = self.mappings.get_mut(&node_addr) {
181 mapping.last_referenced = Instant::now();
182 return Ok((mapping.virtual_ip, false));
183 }
184
185 let virtual_ip = self
186 .available
187 .pop_front()
188 .ok_or(PoolError::Exhausted(self.mappings.len()))?;
189
190 let now = Instant::now();
191 let mapping = VirtualIpMapping {
192 node_addr,
193 virtual_ip,
194 mesh_addr,
195 dns_name: dns_name.to_string(),
196 state: MappingState::Allocated,
197 created: now,
198 last_referenced: now,
199 drain_start: None,
200 session_count: 0,
201 };
202
203 self.mappings.insert(node_addr, mapping);
204 self.reverse.insert(virtual_ip, node_addr);
205
206 info!(
207 virtual_ip = %virtual_ip,
208 mesh_addr = %mesh_addr,
209 dns_name = %dns_name,
210 "Allocated virtual IP"
211 );
212
213 Ok((virtual_ip, true))
214 }
215
216 pub fn tick(&mut self, now: Instant, conntrack: &dyn ConntrackQuerier) -> Vec<PoolEvent> {
219 let mut events = Vec::new();
220 let mut to_free = Vec::new();
221 let ttl = std::time::Duration::from_secs(self.ttl_secs);
222 let grace = std::time::Duration::from_secs(self.grace_secs);
223
224 for (node_addr, mapping) in &mut self.mappings {
225 let sessions = conntrack.active_sessions(mapping.virtual_ip).unwrap_or(0);
227 mapping.session_count = sessions;
228
229 match mapping.state {
230 MappingState::Allocated => {
231 if sessions > 0 {
232 mapping.state = MappingState::Active;
233 debug!(
234 virtual_ip = %mapping.virtual_ip,
235 sessions,
236 "Mapping activated"
237 );
238 } else if now.duration_since(mapping.last_referenced) > ttl {
239 mapping.state = MappingState::Draining;
244 mapping.drain_start = Some(now);
245 debug!(
246 virtual_ip = %mapping.virtual_ip,
247 "Allocated mapping TTL expired, draining"
248 );
249 }
250 }
251 MappingState::Active => {
252 if now.duration_since(mapping.last_referenced) > ttl {
253 if sessions > 0 {
254 mapping.state = MappingState::Draining;
255 mapping.drain_start = Some(now);
256 debug!(
257 virtual_ip = %mapping.virtual_ip,
258 sessions,
259 "Mapping draining (TTL expired, sessions active)"
260 );
261 } else {
262 mapping.state = MappingState::Draining;
264 mapping.drain_start = Some(now);
265 mapping.session_count = 0;
266 }
267 }
268 }
269 MappingState::Draining => {
270 if sessions == 0
271 && let Some(drain_start) = mapping.drain_start
272 && now.duration_since(drain_start) > grace
273 {
274 to_free.push(*node_addr);
275 }
276 }
277 }
278 }
279
280 for node_addr in to_free {
282 if let Some(mapping) = self.mappings.remove(&node_addr) {
283 self.reverse.remove(&mapping.virtual_ip);
284 self.available.push_back(mapping.virtual_ip);
285 info!(
286 virtual_ip = %mapping.virtual_ip,
287 mesh_addr = %mapping.mesh_addr,
288 "Reclaimed virtual IP"
289 );
290 events.push(PoolEvent::MappingRemoved {
291 virtual_ip: mapping.virtual_ip,
292 mesh_addr: mapping.mesh_addr,
293 });
294 }
295 }
296
297 events
298 }
299
300 pub fn status(&self) -> PoolStatus {
302 let mut allocated = 0;
303 let mut active = 0;
304 let mut draining = 0;
305 for mapping in self.mappings.values() {
306 match mapping.state {
307 MappingState::Allocated => allocated += 1,
308 MappingState::Active => active += 1,
309 MappingState::Draining => draining += 1,
310 }
311 }
312 PoolStatus {
313 total: self.total,
314 allocated,
315 active,
316 draining,
317 free: self.available.len(),
318 }
319 }
320
321 pub fn mapping_info(&self, now: Instant) -> Vec<MappingInfo> {
323 self.mappings
324 .values()
325 .map(|m| MappingInfo {
326 virtual_ip: m.virtual_ip,
327 mesh_addr: m.mesh_addr,
328 node_addr: m.node_addr,
329 dns_name: m.dns_name.clone(),
330 state: m.state,
331 session_count: m.session_count,
332 age_secs: now.duration_since(m.created).as_secs(),
333 last_ref_secs: now.duration_since(m.last_referenced).as_secs(),
334 })
335 .collect()
336 }
337
338 pub fn lookup_virtual_ip(&self, virtual_ip: &Ipv6Addr) -> Option<&VirtualIpMapping> {
340 self.reverse
341 .get(virtual_ip)
342 .and_then(|addr| self.mappings.get(addr))
343 }
344}
345
346fn parse_ipv6_cidr(cidr: &str) -> Result<(Ipv6Addr, u32), PoolError> {
348 let parts: Vec<&str> = cidr.split('/').collect();
349 if parts.len() != 2 {
350 return Err(PoolError::InvalidCidr(cidr.to_string()));
351 }
352 let addr: Ipv6Addr = parts[0]
353 .parse()
354 .map_err(|_| PoolError::InvalidCidr(cidr.to_string()))?;
355 let prefix: u32 = parts[1]
356 .parse()
357 .map_err(|_| PoolError::InvalidCidr(cidr.to_string()))?;
358 Ok((addr, prefix))
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 struct MockConntrack {
367 counts: HashMap<Ipv6Addr, u32>,
368 }
369
370 impl MockConntrack {
371 fn new() -> Self {
372 Self {
373 counts: HashMap::new(),
374 }
375 }
376
377 fn set(&mut self, addr: Ipv6Addr, count: u32) {
378 self.counts.insert(addr, count);
379 }
380 }
381
382 impl ConntrackQuerier for MockConntrack {
383 fn active_sessions(&self, virtual_ip: Ipv6Addr) -> Result<u32, std::io::Error> {
384 Ok(*self.counts.get(&virtual_ip).unwrap_or(&0))
385 }
386 }
387
388 fn make_node_addr(byte: u8) -> NodeAddr {
389 let mut bytes = [0u8; 16];
390 bytes[0] = byte;
391 NodeAddr::from_bytes(bytes)
392 }
393
394 fn make_mesh_addr(byte: u8) -> Ipv6Addr {
395 let mut bytes = [0u8; 16];
396 bytes[0] = 0xfd;
397 bytes[15] = byte;
398 Ipv6Addr::from(bytes)
399 }
400
401 #[test]
402 fn test_parse_cidr() {
403 let (addr, prefix) = parse_ipv6_cidr("fd01::/112").unwrap();
404 assert_eq!(addr, "fd01::".parse::<Ipv6Addr>().unwrap());
405 assert_eq!(prefix, 112);
406 }
407
408 #[test]
409 fn test_parse_cidr_invalid() {
410 assert!(parse_ipv6_cidr("not-a-cidr").is_err());
411 assert!(parse_ipv6_cidr("fd01::").is_err());
412 assert!(parse_ipv6_cidr("fd01::/abc").is_err());
413 }
414
415 #[test]
416 fn test_pool_creation() {
417 let pool = VirtualIpPool::new("fd01::/120", 60, 60).unwrap();
418 assert_eq!(pool.total, 255);
420 assert_eq!(pool.available.len(), 255);
421 }
422
423 #[test]
424 fn test_pool_allocation() {
425 let mut pool = VirtualIpPool::new("fd01::/120", 60, 60).unwrap();
426 let node = make_node_addr(1);
427 let mesh = make_mesh_addr(1);
428
429 let (vip, is_new) = pool.allocate(node, mesh, "test.fips").unwrap();
430 assert!(is_new);
431 assert_eq!(vip, "fd01::1".parse::<Ipv6Addr>().unwrap());
432 assert_eq!(pool.available.len(), 254);
433 }
434
435 #[test]
436 fn test_pool_idempotent() {
437 let mut pool = VirtualIpPool::new("fd01::/120", 60, 60).unwrap();
438 let node = make_node_addr(1);
439 let mesh = make_mesh_addr(1);
440
441 let (vip1, new1) = pool.allocate(node, mesh, "test.fips").unwrap();
442 let (vip2, new2) = pool.allocate(node, mesh, "test.fips").unwrap();
443 assert!(new1);
444 assert!(!new2);
445 assert_eq!(vip1, vip2);
446 assert_eq!(pool.available.len(), 254);
447 }
448
449 #[test]
450 fn test_pool_exhaustion() {
451 let mut pool = VirtualIpPool::new("fd01::/126", 60, 60).unwrap();
453 assert_eq!(pool.total, 3);
454
455 for i in 1..=3u8 {
456 pool.allocate(make_node_addr(i), make_mesh_addr(i), "test.fips")
457 .unwrap();
458 }
459 assert!(
460 pool.allocate(make_node_addr(4), make_mesh_addr(4), "test.fips")
461 .is_err()
462 );
463 }
464
465 #[test]
466 fn test_mapping_lifecycle_allocated_to_free() {
467 let mut pool = VirtualIpPool::new("fd01::/120", 1, 1).unwrap();
468 let ct = MockConntrack::new();
469 let node = make_node_addr(1);
470 let mesh = make_mesh_addr(1);
471
472 pool.allocate(node, mesh, "test.fips").unwrap();
473
474 let now = Instant::now();
476 let events = pool.tick(now, &ct);
477 assert!(events.is_empty());
478 assert_eq!(pool.mappings.len(), 1);
479
480 let later = now + std::time::Duration::from_secs(2);
482 let events = pool.tick(later, &ct);
483 assert!(events.is_empty());
484 assert_eq!(pool.mappings.len(), 1);
485 assert_eq!(
486 pool.mappings.values().next().unwrap().state,
487 MappingState::Draining
488 );
489
490 let after_grace = later + std::time::Duration::from_secs(2);
492 let events = pool.tick(after_grace, &ct);
493 assert_eq!(events.len(), 1);
494 assert!(matches!(events[0], PoolEvent::MappingRemoved { .. }));
495 assert_eq!(pool.mappings.len(), 0);
496 assert_eq!(pool.available.len(), 255); }
498
499 #[test]
500 fn test_mapping_lifecycle_active_draining_free() {
501 let mut pool = VirtualIpPool::new("fd01::/120", 1, 1).unwrap();
502 let mut ct = MockConntrack::new();
503 let node = make_node_addr(1);
504 let mesh = make_mesh_addr(1);
505
506 let (vip, _) = pool.allocate(node, mesh, "test.fips").unwrap();
507
508 ct.set(vip, 3);
510 let now = Instant::now();
511 let events = pool.tick(now, &ct);
512 assert!(events.is_empty());
513 assert_eq!(pool.mappings[&node].state, MappingState::Active);
514
515 let later = now + std::time::Duration::from_secs(2);
517 ct.set(vip, 1);
518 let events = pool.tick(later, &ct);
519 assert!(events.is_empty());
520 assert_eq!(pool.mappings[&node].state, MappingState::Draining);
521
522 ct.set(vip, 0);
524 let events = pool.tick(later, &ct);
525 assert!(events.is_empty());
526 assert_eq!(pool.mappings[&node].state, MappingState::Draining);
527
528 let much_later = later + std::time::Duration::from_secs(2);
530 let events = pool.tick(much_later, &ct);
531 assert_eq!(events.len(), 1);
532 assert!(matches!(events[0], PoolEvent::MappingRemoved { .. }));
533 assert_eq!(pool.mappings.len(), 0);
534 }
535
536 #[test]
537 fn test_pool_status() {
538 let mut pool = VirtualIpPool::new("fd01::/120", 60, 60).unwrap();
539 let status = pool.status();
540 assert_eq!(status.total, 255);
541 assert_eq!(status.free, 255);
542 assert_eq!(status.allocated, 0);
543
544 pool.allocate(make_node_addr(1), make_mesh_addr(1), "test.fips")
545 .unwrap();
546 let status = pool.status();
547 assert_eq!(status.allocated, 1);
548 assert_eq!(status.free, 254);
549 }
550
551 #[test]
552 fn test_lookup_virtual_ip() {
553 let mut pool = VirtualIpPool::new("fd01::/120", 60, 60).unwrap();
554 let node = make_node_addr(1);
555 let mesh = make_mesh_addr(1);
556
557 let (vip, _) = pool.allocate(node, mesh, "test.fips").unwrap();
558 let mapping = pool.lookup_virtual_ip(&vip).unwrap();
559 assert_eq!(mapping.node_addr, node);
560 assert_eq!(mapping.mesh_addr, mesh);
561
562 let unknown: Ipv6Addr = "fd01::ff".parse().unwrap();
563 assert!(pool.lookup_virtual_ip(&unknown).is_none());
564 }
565
566 #[test]
567 fn test_large_prefix_capped() {
568 let pool = VirtualIpPool::new("fd01::/96", 60, 60).unwrap();
570 assert_eq!(pool.total, 65535); }
572}