1use anti_common::{calculate_checksum, icmp, PingConfig, PingError, PingReply, PingResult};
7use bytes::{BufMut, BytesMut};
8use socket2::{Domain, Protocol, Socket, Type};
9use std::io::Read;
10use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11use std::os::unix::io::{AsRawFd, RawFd};
12use std::time::{Duration, Instant};
13
14#[derive(Debug, Clone)]
16pub struct IcmpPacket {
17 pub icmp_type: u8,
19 pub code: u8,
21 pub checksum: u16,
23 pub identifier: u16,
25 pub sequence: u16,
27 pub data: Vec<u8>,
29}
30
31impl IcmpPacket {
32 pub fn new_echo_request(identifier: u16, sequence: u16, data_size: usize) -> Self {
34 let data = vec![0x08; data_size.max(8).min(1024)]; Self {
37 icmp_type: icmp::ECHO_REQUEST,
38 code: 0,
39 checksum: 0,
40 identifier,
41 sequence,
42 data,
43 }
44 }
45
46 pub fn from_bytes(data: &[u8]) -> PingResult<Self> {
48 if data.len() < 8 {
49 return Err(PingError::InvalidResponse {
50 reason: "ICMP packet too short".to_string(),
51 });
52 }
53
54 Ok(Self {
55 icmp_type: data[0],
56 code: data[1],
57 checksum: u16::from_be_bytes([data[2], data[3]]),
58 identifier: u16::from_be_bytes([data[4], data[5]]),
59 sequence: u16::from_be_bytes([data[6], data[7]]),
60 data: data[8..].to_vec(),
61 })
62 }
63
64 pub fn to_bytes(&self) -> Vec<u8> {
66 let mut buf = BytesMut::new();
67 buf.put_u8(self.icmp_type);
68 buf.put_u8(self.code);
69 buf.put_u16(self.checksum);
70 buf.put_u16(self.identifier);
71 buf.put_u16(self.sequence);
72 buf.extend_from_slice(&self.data);
73 buf.to_vec()
74 }
75
76 pub fn calculate_checksum(&mut self) {
78 self.checksum = 0;
79 let bytes = self.to_bytes();
80 self.checksum = calculate_checksum(&bytes);
81 }
82
83 pub fn is_echo_reply(&self) -> bool {
85 self.icmp_type == icmp::ECHO_REPLY
86 }
87
88 pub fn matches(&self, identifier: u16, sequence: u16) -> bool {
90 self.identifier == identifier && self.sequence == sequence
91 }
92}
93
94pub struct IcmpSocket {
96 socket: Socket,
97 is_raw: bool,
98}
99
100impl IcmpSocket {
101 pub fn new() -> PingResult<Self> {
106 match Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4)) {
108 Ok(socket) => {
109 socket
110 .set_nonblocking(false)
111 .map_err(|e| PingError::SocketCreation(e.to_string()))?;
112 socket
113 .set_read_timeout(Some(Duration::from_secs(5)))
114 .map_err(|e| PingError::SocketCreation(e.to_string()))?;
115 socket.set_broadcast(true).ok(); Ok(Self {
118 socket,
119 is_raw: false,
120 })
121 }
122 Err(_) => {
123 let socket = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))
125 .map_err(|e| {
126 if e.kind() == std::io::ErrorKind::PermissionDenied {
127 PingError::PermissionDenied {
128 context: "ICMP ping requires root privileges. Try running with sudo or use UDP/TCP ping instead.".to_string(),
129 }
130 } else {
131 PingError::SocketCreation(e.to_string())
132 }
133 })?;
134
135 socket
136 .set_nonblocking(false)
137 .map_err(|e| PingError::SocketCreation(e.to_string()))?;
138 socket
139 .set_read_timeout(Some(Duration::from_secs(5)))
140 .map_err(|e| PingError::SocketCreation(e.to_string()))?;
141
142 Ok(Self {
143 socket,
144 is_raw: true,
145 })
146 }
147 }
148 }
149
150 pub fn connect(&self, target: Ipv4Addr) -> PingResult<()> {
152 if self.is_raw {
153 let addr = SocketAddr::new(IpAddr::V4(target), 0);
154 self.socket
155 .connect(&addr.into())
156 .map_err(|e| PingError::SocketCreation(e.to_string()))
157 } else {
158 Ok(())
160 }
161 }
162
163 pub fn send(&self, packet: &IcmpPacket, target: Option<Ipv4Addr>) -> PingResult<usize> {
165 let mut packet = packet.clone();
166 packet.calculate_checksum();
167 let bytes = packet.to_bytes();
168
169 let result = if self.is_raw {
170 self.socket.send(&bytes)
171 } else {
172 let target_addr = target.unwrap_or(Ipv4Addr::new(127, 0, 0, 1));
174 let addr = SocketAddr::new(IpAddr::V4(target_addr), 0);
175 self.socket.send_to(&bytes, &addr.into())
176 };
177
178 result.map_err(|e| PingError::SocketCreation(e.to_string()))
179 }
180
181 pub fn recv(&self, timeout: Duration) -> PingResult<(IcmpPacket, Ipv4Addr, Option<u8>)> {
183 self.socket
184 .set_read_timeout(Some(timeout))
185 .map_err(|e| PingError::SocketCreation(e.to_string()))?;
186
187 let mut buf = [0u8; 1024];
188 let start = Instant::now();
189
190 loop {
191 if start.elapsed() >= timeout {
192 return Err(PingError::Timeout { duration: timeout });
193 }
194
195 let size = if self.is_raw {
196 match (&self.socket).read(&mut buf) {
197 Ok(n) => n,
198 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
199 std::thread::sleep(Duration::from_millis(1));
200 continue;
201 }
202 Err(e) => return Err(PingError::SocketCreation(e.to_string())),
203 }
204 } else {
205 let mut uninit_buffer = [std::mem::MaybeUninit::<u8>::uninit(); 1024];
207 match self.socket.recv_from(&mut uninit_buffer) {
208 Ok((n, _from_addr)) => {
209 for i in 0..n {
211 buf[i] = unsafe { uninit_buffer[i].assume_init() };
212 }
213 n
214 }
215 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
216 std::thread::sleep(Duration::from_millis(1));
217 continue;
218 }
219 Err(e) => return Err(PingError::SocketCreation(e.to_string())),
220 }
221 };
222
223 let icmp_data = if size > 20 {
225 if (buf[0] >> 4) == 4 {
227 let ip_header_len = ((buf[0] & 0x0F) * 4) as usize;
228 if size > ip_header_len {
229 &buf[ip_header_len..size]
230 } else {
231 continue;
232 }
233 } else if size >= 8 {
234 &buf[..size]
236 } else {
237 continue;
238 }
239 } else if size >= 8 {
240 &buf[..size]
242 } else {
243 continue;
244 };
245
246 match IcmpPacket::from_bytes(icmp_data) {
247 Ok(packet) => {
248 let (source_ip, ttl) = if size >= 20 && (buf[0] >> 4) == 4 {
250 (
251 Ipv4Addr::new(buf[12], buf[13], buf[14], buf[15]),
252 Some(buf[8]),
253 )
254 } else {
255 (Ipv4Addr::new(0, 0, 0, 0), None)
256 };
257 return Ok((packet, source_ip, ttl));
258 }
259 Err(_) => continue, }
261 }
262 }
263
264 pub fn is_raw(&self) -> bool {
266 self.is_raw
267 }
268}
269
270impl AsRawFd for IcmpSocket {
271 fn as_raw_fd(&self) -> RawFd {
272 self.socket.as_raw_fd()
273 }
274}
275
276pub struct IcmpPinger {
278 socket: IcmpSocket,
279 config: PingConfig,
280 identifier: u16,
281}
282
283impl IcmpPinger {
284 pub fn new(config: PingConfig) -> PingResult<Self> {
286 let socket = IcmpSocket::new()?;
287 socket.connect(config.target)?;
288
289 let identifier = config.identifier.unwrap_or_else(|| rand::random::<u16>());
290
291 Ok(Self {
292 socket,
293 config,
294 identifier,
295 })
296 }
297
298 pub fn ping(&self, sequence: u16) -> PingResult<PingReply> {
300 let packet = IcmpPacket::new_echo_request(
301 self.identifier,
302 sequence,
303 self.config.packet_size.saturating_sub(8), );
305
306 let start = Instant::now();
307
308 self.socket.send(&packet, Some(self.config.target))?;
310
311 loop {
313 let elapsed = start.elapsed();
314 if elapsed >= self.config.timeout {
315 return Err(PingError::Timeout {
316 duration: self.config.timeout,
317 });
318 }
319
320 let remaining = self.config.timeout - elapsed;
321 match self.socket.recv(remaining) {
322 Ok((reply_packet, source, ttl)) => {
323 if reply_packet.is_echo_reply()
324 && reply_packet.matches(self.identifier, sequence)
325 {
326 let rtt = start.elapsed();
327 return Ok(PingReply {
328 sequence,
329 rtt,
330 bytes_received: reply_packet.to_bytes().len(),
331 from: if source.is_unspecified() {
332 self.config.target
333 } else {
334 source
335 },
336 ttl,
337 });
338 }
339 }
341 Err(PingError::Timeout { .. }) => {
342 return Err(PingError::Timeout {
343 duration: self.config.timeout,
344 });
345 }
346 Err(e) => return Err(e),
347 }
348 }
349 }
350
351 pub fn is_raw(&self) -> bool {
353 self.socket.is_raw()
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_icmp_packet_creation() {
363 let packet = IcmpPacket::new_echo_request(12345, 1, 56);
364 assert_eq!(packet.icmp_type, icmp::ECHO_REQUEST);
365 assert_eq!(packet.code, 0);
366 assert_eq!(packet.identifier, 12345);
367 assert_eq!(packet.sequence, 1);
368 assert_eq!(packet.data.len(), 56);
369 }
370
371 #[test]
372 fn test_icmp_packet_serialization() {
373 let mut packet = IcmpPacket::new_echo_request(12345, 1, 8);
374 packet.calculate_checksum();
375
376 let bytes = packet.to_bytes();
377 assert!(bytes.len() >= 16); let parsed = IcmpPacket::from_bytes(&bytes).unwrap();
380 assert_eq!(parsed.icmp_type, packet.icmp_type);
381 assert_eq!(parsed.identifier, packet.identifier);
382 assert_eq!(parsed.sequence, packet.sequence);
383 }
384
385 #[test]
386 fn test_packet_matching() {
387 let packet = IcmpPacket {
388 icmp_type: icmp::ECHO_REPLY,
389 code: 0,
390 checksum: 0,
391 identifier: 12345,
392 sequence: 42,
393 data: vec![],
394 };
395
396 assert!(packet.is_echo_reply());
397 assert!(packet.matches(12345, 42));
398 assert!(!packet.matches(12345, 41));
399 assert!(!packet.matches(12344, 42));
400 }
401
402 #[test]
403 fn test_checksum_calculation() {
404 let mut packet = IcmpPacket::new_echo_request(1, 1, 8);
405 packet.calculate_checksum();
406 assert_ne!(packet.checksum, 0);
407 }
408}