nomad_protocol/transport/
migration.rs1use std::collections::HashMap;
6use std::net::{IpAddr, SocketAddr};
7use std::time::{Duration, Instant};
8
9pub mod constants {
11 use std::time::Duration;
12
13 pub const MIN_MIGRATION_INTERVAL: Duration = Duration::from_secs(1);
15
16 pub const AMPLIFICATION_FACTOR: usize = 3;
18}
19
20#[derive(Debug, Clone)]
22struct AddressState {
23 first_seen: Instant,
25 bytes_received: usize,
27 bytes_sent: usize,
29 validated: bool,
31}
32
33impl AddressState {
34 fn new() -> Self {
35 Self {
36 first_seen: Instant::now(),
37 bytes_received: 0,
38 bytes_sent: 0,
39 validated: false,
40 }
41 }
42}
43
44fn subnet_key(addr: &IpAddr) -> Vec<u8> {
47 match addr {
48 IpAddr::V4(v4) => {
49 let octets = v4.octets();
50 vec![octets[0], octets[1], octets[2]] }
52 IpAddr::V6(v6) => {
53 let segments = v6.segments();
54 let mut key = Vec::with_capacity(6);
56 for seg in &segments[0..3] {
57 key.extend_from_slice(&seg.to_be_bytes());
58 }
59 key
60 }
61 }
62}
63
64#[derive(Debug)]
71pub struct MigrationState {
72 current_address: SocketAddr,
74 addresses: HashMap<SocketAddr, AddressState>,
76 subnet_last_migration: HashMap<Vec<u8>, Instant>,
78}
79
80impl MigrationState {
81 pub fn new(initial_address: SocketAddr) -> Self {
83 let mut addresses = HashMap::new();
84 let mut state = AddressState::new();
85 state.validated = true; addresses.insert(initial_address, state);
87
88 Self {
89 current_address: initial_address,
90 addresses,
91 subnet_last_migration: HashMap::new(),
92 }
93 }
94
95 pub fn current_address(&self) -> SocketAddr {
97 self.current_address
98 }
99
100 pub fn on_receive(&mut self, from: SocketAddr, bytes: usize) {
102 let state = self.addresses.entry(from).or_insert_with(AddressState::new);
103 state.bytes_received = state.bytes_received.saturating_add(bytes);
104 }
105
106 pub fn on_send(&mut self, to: SocketAddr, bytes: usize) {
108 if let Some(state) = self.addresses.get_mut(&to) {
109 state.bytes_sent = state.bytes_sent.saturating_add(bytes);
110 }
111 }
112
113 pub fn can_send(&self, to: SocketAddr, bytes: usize) -> bool {
115 if let Some(state) = self.addresses.get(&to) {
116 if state.validated {
117 return true;
118 }
119 let allowed = state.bytes_received.saturating_mul(constants::AMPLIFICATION_FACTOR);
121 state.bytes_sent.saturating_add(bytes) <= allowed
122 } else {
123 false
125 }
126 }
127
128 pub fn validate_address(&mut self, addr: SocketAddr) -> bool {
132 if addr != self.current_address {
134 let current_subnet = subnet_key(&self.current_address.ip());
135 let new_subnet = subnet_key(&addr.ip());
136
137 if current_subnet != new_subnet {
138 let now = Instant::now();
139 if let Some(&last) = self.subnet_last_migration.get(&new_subnet)
140 && now.duration_since(last) < constants::MIN_MIGRATION_INTERVAL
141 {
142 return false; }
144 self.subnet_last_migration.insert(new_subnet, now);
145 }
146 }
147
148 let state = self.addresses.entry(addr).or_insert_with(AddressState::new);
150 state.validated = true;
151
152 self.current_address = addr;
154
155 true
156 }
157
158 pub fn is_validated(&self, addr: SocketAddr) -> bool {
160 self.addresses
161 .get(&addr)
162 .is_some_and(|state| state.validated)
163 }
164
165 pub fn cleanup(&mut self, max_age: Duration) {
167 let now = Instant::now();
168 self.addresses.retain(|addr, state| {
169 *addr == self.current_address || now.duration_since(state.first_seen) < max_age
171 });
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
179
180 fn addr_v4(a: u8, b: u8, c: u8, d: u8, port: u16) -> SocketAddr {
181 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(a, b, c, d)), port)
182 }
183
184 #[allow(dead_code)]
185 fn addr_v6(segments: [u16; 8], port: u16) -> SocketAddr {
186 SocketAddr::new(
187 IpAddr::V6(Ipv6Addr::new(
188 segments[0],
189 segments[1],
190 segments[2],
191 segments[3],
192 segments[4],
193 segments[5],
194 segments[6],
195 segments[7],
196 )),
197 port,
198 )
199 }
200
201 #[test]
202 fn test_migration_state_initial() {
203 let addr = addr_v4(192, 168, 1, 100, 8080);
204 let state = MigrationState::new(addr);
205
206 assert_eq!(state.current_address(), addr);
207 assert!(state.is_validated(addr));
208 }
209
210 #[test]
211 fn test_anti_amplification() {
212 let initial = addr_v4(192, 168, 1, 100, 8080);
213 let mut state = MigrationState::new(initial);
214
215 let new_addr = addr_v4(10, 0, 0, 50, 9090);
216
217 assert!(!state.can_send(new_addr, 100));
219
220 state.on_receive(new_addr, 100);
222
223 assert!(state.can_send(new_addr, 300));
225 assert!(!state.can_send(new_addr, 301));
226
227 state.validate_address(new_addr);
229 assert!(state.can_send(new_addr, 10000));
230 }
231
232 #[test]
233 fn test_migration_same_subnet() {
234 let initial = addr_v4(192, 168, 1, 100, 8080);
235 let mut state = MigrationState::new(initial);
236
237 let new_addr = addr_v4(192, 168, 1, 200, 9090);
239 assert!(state.validate_address(new_addr));
240 assert_eq!(state.current_address(), new_addr);
241 }
242
243 #[test]
244 fn test_migration_different_subnet_rate_limit() {
245 let initial = addr_v4(192, 168, 1, 100, 8080);
246 let mut state = MigrationState::new(initial);
247
248 let new_addr = addr_v4(10, 0, 0, 50, 9090);
250 assert!(state.validate_address(new_addr));
251
252 let another_addr = addr_v4(172, 16, 0, 1, 7070);
254 assert!(state.validate_address(another_addr));
256 }
257
258 #[test]
259 fn test_subnet_key_v4() {
260 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
261 let key = subnet_key(&ip);
262 assert_eq!(key, vec![192, 168, 1]);
263 }
264
265 #[test]
266 fn test_subnet_key_v6() {
267 let ip = IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0, 0, 0, 0, 1));
268 let key = subnet_key(&ip);
269 assert_eq!(key.len(), 6);
271 assert_eq!(&key[0..2], &[0x20, 0x01]);
272 assert_eq!(&key[2..4], &[0x0d, 0xb8]);
273 assert_eq!(&key[4..6], &[0x85, 0xa3]);
274 }
275
276 #[test]
277 fn test_cleanup() {
278 let initial = addr_v4(192, 168, 1, 100, 8080);
279 let mut state = MigrationState::new(initial);
280
281 let other = addr_v4(10, 0, 0, 50, 9090);
283 state.on_receive(other, 100);
284
285 state.cleanup(Duration::from_nanos(1));
287 assert!(state.is_validated(initial));
288 }
289}