1use crate::raw_socket::{RawSocketHandle, RawSocketProtocol};
2use anyhow::{anyhow, Result};
3use etherparse::{
4 IcmpEchoHeader, Icmpv4Header, Icmpv4Slice, Icmpv4Type, Icmpv6Header, Icmpv6Slice, Icmpv6Type,
5 IpNumber, NetSlice, SlicedPacket,
6};
7use log::warn;
8use std::{
9 collections::HashMap,
10 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
11 os::fd::{FromRawFd, IntoRawFd},
12 sync::Arc,
13 time::Duration,
14};
15use tokio::{
16 net::UdpSocket,
17 sync::{oneshot, Mutex},
18 task::JoinHandle,
19 time::timeout,
20};
21
22#[derive(Debug)]
23pub enum IcmpProtocol {
24 Icmpv4,
25 Icmpv6,
26}
27
28impl IcmpProtocol {
29 pub fn to_socket_protocol(&self) -> RawSocketProtocol {
30 match self {
31 IcmpProtocol::Icmpv4 => RawSocketProtocol::Icmpv4,
32 IcmpProtocol::Icmpv6 => RawSocketProtocol::Icmpv6,
33 }
34 }
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38struct IcmpHandlerToken(IpAddr, Option<u16>, u16);
39
40#[derive(Debug)]
41pub enum IcmpReply {
42 Icmpv4 {
43 header: Icmpv4Header,
44 echo: IcmpEchoHeader,
45 payload: Vec<u8>,
46 },
47
48 Icmpv6 {
49 header: Icmpv6Header,
50 echo: IcmpEchoHeader,
51 payload: Vec<u8>,
52 },
53}
54
55type IcmpHandlerMap = Arc<Mutex<HashMap<IcmpHandlerToken, oneshot::Sender<IcmpReply>>>>;
56
57#[derive(Clone)]
58pub struct IcmpClient {
59 socket: Arc<UdpSocket>,
60 handlers: IcmpHandlerMap,
61 task: Arc<JoinHandle<Result<()>>>,
62}
63
64impl IcmpClient {
65 pub fn new(protocol: IcmpProtocol) -> Result<IcmpClient> {
66 let handle = RawSocketHandle::new(protocol.to_socket_protocol())?;
67 let socket = unsafe { std::net::UdpSocket::from_raw_fd(handle.into_raw_fd()) };
68 let socket: Arc<UdpSocket> = Arc::new(socket.try_into()?);
69 let handlers = Arc::new(Mutex::new(HashMap::new()));
70 let task = Arc::new(tokio::task::spawn(IcmpClient::process(
71 protocol,
72 socket.clone(),
73 handlers.clone(),
74 )));
75 Ok(IcmpClient {
76 socket,
77 handlers,
78 task,
79 })
80 }
81
82 async fn process(
83 protocol: IcmpProtocol,
84 socket: Arc<UdpSocket>,
85 handlers: IcmpHandlerMap,
86 ) -> Result<()> {
87 let mut buffer = vec![0u8; 2048];
88 loop {
89 let (size, addr) = socket.recv_from(&mut buffer).await?;
90 let packet = &buffer[0..size];
91
92 let (token, reply) = match protocol {
93 IcmpProtocol::Icmpv4 => {
94 let sliced = match SlicedPacket::from_ip(packet) {
95 Ok(sliced) => sliced,
96 Err(error) => {
97 warn!("received icmp packet but failed to parse it: {}", error);
98 continue;
99 }
100 };
101
102 let Some(NetSlice::Ipv4(ipv4)) = sliced.net else {
103 continue;
104 };
105
106 if ipv4.header().protocol() != IpNumber::ICMP {
107 continue;
108 }
109
110 let Ok(icmpv4) = Icmpv4Slice::from_slice(ipv4.payload().payload) else {
111 continue;
112 };
113
114 let Icmpv4Type::EchoReply(echo) = icmpv4.header().icmp_type else {
115 continue;
116 };
117
118 let token = IcmpHandlerToken(
119 IpAddr::V4(ipv4.header().source_addr()),
120 Some(echo.id),
121 echo.seq,
122 );
123 let reply = IcmpReply::Icmpv4 {
124 header: icmpv4.header(),
125 echo,
126 payload: icmpv4.payload().to_vec(),
127 };
128 (token, reply)
129 }
130
131 IcmpProtocol::Icmpv6 => {
132 let Ok(icmpv6) = Icmpv6Slice::from_slice(packet) else {
133 continue;
134 };
135
136 let Icmpv6Type::EchoReply(echo) = icmpv6.header().icmp_type else {
137 continue;
138 };
139
140 let SocketAddr::V6(addr) = addr else {
141 continue;
142 };
143
144 let token = IcmpHandlerToken(IpAddr::V6(*addr.ip()), Some(echo.id), echo.seq);
145
146 let reply = IcmpReply::Icmpv6 {
147 header: icmpv6.header(),
148 echo,
149 payload: icmpv6.payload().to_vec(),
150 };
151 (token, reply)
152 }
153 };
154
155 if let Some(sender) = handlers.lock().await.remove(&token) {
156 let _ = sender.send(reply);
157 }
158 }
159 }
160
161 async fn add_handler(&self, token: IcmpHandlerToken) -> Result<oneshot::Receiver<IcmpReply>> {
162 let (tx, rx) = oneshot::channel();
163 if self
164 .handlers
165 .lock()
166 .await
167 .insert(token.clone(), tx)
168 .is_some()
169 {
170 return Err(anyhow!("duplicate icmp request: {:?}", token));
171 }
172 Ok(rx)
173 }
174
175 async fn remove_handler(&self, token: IcmpHandlerToken) -> Result<()> {
176 self.handlers.lock().await.remove(&token);
177 Ok(())
178 }
179
180 pub async fn ping4(
181 &self,
182 addr: Ipv4Addr,
183 id: u16,
184 seq: u16,
185 payload: &[u8],
186 deadline: Duration,
187 ) -> Result<Option<IcmpReply>> {
188 let token = IcmpHandlerToken(IpAddr::V4(addr), Some(id), seq);
189 let rx = self.add_handler(token.clone()).await?;
190
191 let echo = IcmpEchoHeader { id, seq };
192 let mut header = Icmpv4Header::new(Icmpv4Type::EchoRequest(echo));
193 header.update_checksum(payload);
194 let mut buffer: Vec<u8> = Vec::new();
195 header.write(&mut buffer)?;
196 buffer.extend_from_slice(payload);
197
198 self.socket
199 .send_to(&buffer, SocketAddr::V4(SocketAddrV4::new(addr, 0)))
200 .await?;
201
202 let result = timeout(deadline, rx).await;
203 self.remove_handler(token).await?;
204 let reply = match result {
205 Ok(Ok(packet)) => Some(packet),
206 Ok(Err(err)) => return Err(anyhow!("failed to wait for icmp packet: {}", err)),
207 Err(_) => None,
208 };
209 Ok(reply)
210 }
211
212 pub async fn ping6(
213 &self,
214 addr: Ipv6Addr,
215 id: u16,
216 seq: u16,
217 payload: &[u8],
218 deadline: Duration,
219 ) -> Result<Option<IcmpReply>> {
220 let token = IcmpHandlerToken(IpAddr::V6(addr), Some(id), seq);
221 let rx = self.add_handler(token.clone()).await?;
222
223 let echo = IcmpEchoHeader { id, seq };
224 let header = Icmpv6Header::new(Icmpv6Type::EchoRequest(echo));
225 let mut buffer: Vec<u8> = Vec::new();
226 header.write(&mut buffer)?;
227 buffer.extend_from_slice(payload);
228
229 self.socket
230 .send_to(&buffer, SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)))
231 .await?;
232
233 let result = timeout(deadline, rx).await;
234 self.remove_handler(token).await?;
235 let reply = match result {
236 Ok(Ok(packet)) => Some(packet),
237 Ok(Err(err)) => return Err(anyhow!("failed to wait for icmp packet: {}", err)),
238 Err(_) => None,
239 };
240 Ok(reply)
241 }
242}
243
244impl Drop for IcmpClient {
245 fn drop(&mut self) {
246 if Arc::strong_count(&self.task) <= 1 {
247 self.task.abort();
248 }
249 }
250}