kratanet/
icmp.rs

1use crate::raw_socket::{RawSocketHandle, RawSocketProtocol};
2use anyhow::{anyhow, Result};
3use etherparse::{
4    IcmpEchoHeader, Icmpv4Header, Icmpv4Slice, Icmpv4Type, Icmpv6Header, Icmpv6Slice, Icmpv6Type,
5    IpNumber, NetSlice, SlicedPacket,
6};
7use log::warn;
8use std::{
9    collections::HashMap,
10    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
11    os::fd::{FromRawFd, IntoRawFd},
12    sync::Arc,
13    time::Duration,
14};
15use tokio::{
16    net::UdpSocket,
17    sync::{oneshot, Mutex},
18    task::JoinHandle,
19    time::timeout,
20};
21
22#[derive(Debug)]
23pub enum IcmpProtocol {
24    Icmpv4,
25    Icmpv6,
26}
27
28impl IcmpProtocol {
29    pub fn to_socket_protocol(&self) -> RawSocketProtocol {
30        match self {
31            IcmpProtocol::Icmpv4 => RawSocketProtocol::Icmpv4,
32            IcmpProtocol::Icmpv6 => RawSocketProtocol::Icmpv6,
33        }
34    }
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38struct IcmpHandlerToken(IpAddr, Option<u16>, u16);
39
40#[derive(Debug)]
41pub enum IcmpReply {
42    Icmpv4 {
43        header: Icmpv4Header,
44        echo: IcmpEchoHeader,
45        payload: Vec<u8>,
46    },
47
48    Icmpv6 {
49        header: Icmpv6Header,
50        echo: IcmpEchoHeader,
51        payload: Vec<u8>,
52    },
53}
54
55type IcmpHandlerMap = Arc<Mutex<HashMap<IcmpHandlerToken, oneshot::Sender<IcmpReply>>>>;
56
57#[derive(Clone)]
58pub struct IcmpClient {
59    socket: Arc<UdpSocket>,
60    handlers: IcmpHandlerMap,
61    task: Arc<JoinHandle<Result<()>>>,
62}
63
64impl IcmpClient {
65    pub fn new(protocol: IcmpProtocol) -> Result<IcmpClient> {
66        let handle = RawSocketHandle::new(protocol.to_socket_protocol())?;
67        let socket = unsafe { std::net::UdpSocket::from_raw_fd(handle.into_raw_fd()) };
68        let socket: Arc<UdpSocket> = Arc::new(socket.try_into()?);
69        let handlers = Arc::new(Mutex::new(HashMap::new()));
70        let task = Arc::new(tokio::task::spawn(IcmpClient::process(
71            protocol,
72            socket.clone(),
73            handlers.clone(),
74        )));
75        Ok(IcmpClient {
76            socket,
77            handlers,
78            task,
79        })
80    }
81
82    async fn process(
83        protocol: IcmpProtocol,
84        socket: Arc<UdpSocket>,
85        handlers: IcmpHandlerMap,
86    ) -> Result<()> {
87        let mut buffer = vec![0u8; 2048];
88        loop {
89            let (size, addr) = socket.recv_from(&mut buffer).await?;
90            let packet = &buffer[0..size];
91
92            let (token, reply) = match protocol {
93                IcmpProtocol::Icmpv4 => {
94                    let sliced = match SlicedPacket::from_ip(packet) {
95                        Ok(sliced) => sliced,
96                        Err(error) => {
97                            warn!("received icmp packet but failed to parse it: {}", error);
98                            continue;
99                        }
100                    };
101
102                    let Some(NetSlice::Ipv4(ipv4)) = sliced.net else {
103                        continue;
104                    };
105
106                    if ipv4.header().protocol() != IpNumber::ICMP {
107                        continue;
108                    }
109
110                    let Ok(icmpv4) = Icmpv4Slice::from_slice(ipv4.payload().payload) else {
111                        continue;
112                    };
113
114                    let Icmpv4Type::EchoReply(echo) = icmpv4.header().icmp_type else {
115                        continue;
116                    };
117
118                    let token = IcmpHandlerToken(
119                        IpAddr::V4(ipv4.header().source_addr()),
120                        Some(echo.id),
121                        echo.seq,
122                    );
123                    let reply = IcmpReply::Icmpv4 {
124                        header: icmpv4.header(),
125                        echo,
126                        payload: icmpv4.payload().to_vec(),
127                    };
128                    (token, reply)
129                }
130
131                IcmpProtocol::Icmpv6 => {
132                    let Ok(icmpv6) = Icmpv6Slice::from_slice(packet) else {
133                        continue;
134                    };
135
136                    let Icmpv6Type::EchoReply(echo) = icmpv6.header().icmp_type else {
137                        continue;
138                    };
139
140                    let SocketAddr::V6(addr) = addr else {
141                        continue;
142                    };
143
144                    let token = IcmpHandlerToken(IpAddr::V6(*addr.ip()), Some(echo.id), echo.seq);
145
146                    let reply = IcmpReply::Icmpv6 {
147                        header: icmpv6.header(),
148                        echo,
149                        payload: icmpv6.payload().to_vec(),
150                    };
151                    (token, reply)
152                }
153            };
154
155            if let Some(sender) = handlers.lock().await.remove(&token) {
156                let _ = sender.send(reply);
157            }
158        }
159    }
160
161    async fn add_handler(&self, token: IcmpHandlerToken) -> Result<oneshot::Receiver<IcmpReply>> {
162        let (tx, rx) = oneshot::channel();
163        if self
164            .handlers
165            .lock()
166            .await
167            .insert(token.clone(), tx)
168            .is_some()
169        {
170            return Err(anyhow!("duplicate icmp request: {:?}", token));
171        }
172        Ok(rx)
173    }
174
175    async fn remove_handler(&self, token: IcmpHandlerToken) -> Result<()> {
176        self.handlers.lock().await.remove(&token);
177        Ok(())
178    }
179
180    pub async fn ping4(
181        &self,
182        addr: Ipv4Addr,
183        id: u16,
184        seq: u16,
185        payload: &[u8],
186        deadline: Duration,
187    ) -> Result<Option<IcmpReply>> {
188        let token = IcmpHandlerToken(IpAddr::V4(addr), Some(id), seq);
189        let rx = self.add_handler(token.clone()).await?;
190
191        let echo = IcmpEchoHeader { id, seq };
192        let mut header = Icmpv4Header::new(Icmpv4Type::EchoRequest(echo));
193        header.update_checksum(payload);
194        let mut buffer: Vec<u8> = Vec::new();
195        header.write(&mut buffer)?;
196        buffer.extend_from_slice(payload);
197
198        self.socket
199            .send_to(&buffer, SocketAddr::V4(SocketAddrV4::new(addr, 0)))
200            .await?;
201
202        let result = timeout(deadline, rx).await;
203        self.remove_handler(token).await?;
204        let reply = match result {
205            Ok(Ok(packet)) => Some(packet),
206            Ok(Err(err)) => return Err(anyhow!("failed to wait for icmp packet: {}", err)),
207            Err(_) => None,
208        };
209        Ok(reply)
210    }
211
212    pub async fn ping6(
213        &self,
214        addr: Ipv6Addr,
215        id: u16,
216        seq: u16,
217        payload: &[u8],
218        deadline: Duration,
219    ) -> Result<Option<IcmpReply>> {
220        let token = IcmpHandlerToken(IpAddr::V6(addr), Some(id), seq);
221        let rx = self.add_handler(token.clone()).await?;
222
223        let echo = IcmpEchoHeader { id, seq };
224        let header = Icmpv6Header::new(Icmpv6Type::EchoRequest(echo));
225        let mut buffer: Vec<u8> = Vec::new();
226        header.write(&mut buffer)?;
227        buffer.extend_from_slice(payload);
228
229        self.socket
230            .send_to(&buffer, SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)))
231            .await?;
232
233        let result = timeout(deadline, rx).await;
234        self.remove_handler(token).await?;
235        let reply = match result {
236            Ok(Ok(packet)) => Some(packet),
237            Ok(Err(err)) => return Err(anyhow!("failed to wait for icmp packet: {}", err)),
238            Err(_) => None,
239        };
240        Ok(reply)
241    }
242}
243
244impl Drop for IcmpClient {
245    fn drop(&mut self) {
246        if Arc::strong_count(&self.task) <= 1 {
247            self.task.abort();
248        }
249    }
250}