async_arp/
client.rs

1use afpacket::tokio::RawPacketStream;
2use pnet::{
3    packet::{
4        arp::{Arp, ArpHardwareTypes, ArpOperations, MutableArpPacket},
5        ethernet::{EtherTypes, MutableEthernetPacket},
6        Packet,
7    },
8    util::MacAddr,
9};
10
11use std::{future::Future, net::Ipv4Addr, sync::Arc, time::Duration};
12use tokio::task::JoinHandle;
13use tokio::{
14    io::AsyncWriteExt,
15    sync::{Mutex, Notify},
16};
17
18use tokio_util::sync::CancellationToken;
19
20use crate::{caching::ArpCache, probe::ProbeInput, request::RequestOutcome};
21use crate::{constants::IP_V4_LEN, notification::NotificationHandler};
22use crate::{
23    constants::{ARP_PACK_LEN, ETH_PACK_LEN, MAC_ADDR_LEN},
24    request::RequestInput,
25};
26use crate::{
27    error::{Error, Result},
28    probe::ProbeOutcome,
29};
30use crate::{probe::ProbeStatus, response::Listener};
31
32/// A struct responsible for performing batch requests and probes with retry logic.
33///
34/// This struct abstracts the client and provides methods to perform multiple
35/// probes or requests with retry capabilities. The number of retries can be
36/// configured during initialization.
37#[derive(Debug)]
38pub struct ClientSpinner {
39    client: Client,
40    n_retries: usize,
41}
42
43impl ClientSpinner {
44    /// Creates a new instance of `ClientSpinner` with the given `Client`.
45    ///
46    /// This constructor initializes a `ClientSpinner` with a client and sets the
47    /// number of retries to `0` (no retries).
48    pub fn new(client: Client) -> Self {
49        Self {
50            client,
51            n_retries: 0,
52        }
53    }
54
55    /// Sets the number of retries for subsequent probes and requests.
56    pub fn with_retries(mut self, n_retires: usize) -> Self {
57        self.n_retries = n_retires;
58        self
59    }
60
61    /// Performs a batch of probes asynchronously with retries.
62    ///
63    /// This method takes an array of `ProbeInput` and attempts to probe each one.
64    /// If a probe fails, it will retry up to `n_retries` times before returning the
65    /// results.
66    pub async fn probe_batch(&self, inputs: &[ProbeInput]) -> Vec<Result<ProbeOutcome>> {
67        let futures_producer = || {
68            inputs
69                .iter()
70                .map(|input| async { self.client.probe(*input).await })
71        };
72        Self::handle_retries(self.n_retries, futures_producer).await
73    }
74
75    /// Performs a batch of requests asynchronously with retries.
76    ///
77    /// This method takes an array of `RequestInput` and attempts to request each one.
78    /// If a request fails, it will retry up to `n_retries` times before returning the
79    /// results.
80    pub async fn request_batch(&self, inputs: &[RequestInput]) -> Vec<Result<RequestOutcome>> {
81        let futures_producer = || {
82            inputs
83                .iter()
84                .map(|input| async { self.client.request(*input).await })
85        };
86        Self::handle_retries(self.n_retries, futures_producer).await
87    }
88
89    async fn handle_retries<F, I, Fut, T>(n_retries: usize, futures_producer: F) -> Vec<Result<T>>
90    where
91        F: Fn() -> I,
92        Fut: Future<Output = Result<T>>,
93        I: Iterator<Item = Fut>,
94    {
95        for _ in 0..n_retries {
96            futures::future::join_all(futures_producer()).await;
97        }
98        futures::future::join_all(futures_producer()).await
99    }
100}
101
102#[derive(Debug, Clone)]
103pub struct ClientConfig {
104    pub interface_name: String,
105    pub response_timeout: Duration,
106    pub cache_timeout: Duration,
107}
108
109#[derive(Debug, Clone)]
110pub struct ClientConfigBuilder {
111    interface_name: String,
112    response_timeout: Option<Duration>,
113    cache_timeout: Option<Duration>,
114}
115
116impl ClientConfigBuilder {
117    pub fn new(interface_name: &str) -> Self {
118        Self {
119            interface_name: interface_name.into(),
120            response_timeout: Some(Duration::from_secs(1)),
121            cache_timeout: Some(Duration::from_secs(60)),
122        }
123    }
124
125    pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
126        self.response_timeout = Some(timeout);
127        self
128    }
129
130    pub fn with_cache_timeout(mut self, timeout: Duration) -> Self {
131        self.cache_timeout = Some(timeout);
132        self
133    }
134
135    pub fn build(self) -> ClientConfig {
136        ClientConfig {
137            interface_name: self.interface_name,
138            cache_timeout: self.cache_timeout.unwrap(),
139            response_timeout: self.response_timeout.unwrap(),
140        }
141    }
142}
143
144/// A client for handling ARP (Address Resolution Protocol) requests and probes.
145///
146/// The `Client` is responsible for sending ARP requests, caching responses,
147/// and handling notifications. It uses a raw packet stream for network communication.
148///
149/// # Example
150/// ```no_run
151/// use async_arp::{Client, ClientConfig};
152/// use std::time::Duration;
153///
154/// let config = ClientConfig {
155///     interface_name: "eth0".to_string(),
156///     response_timeout: Duration::from_secs(2),
157///     cache_timeout: Duration::from_secs(60),
158/// };
159///
160/// let client = Client::new(config).expect("Failed to create ARP client");
161/// ```
162#[derive(Debug)]
163pub struct Client {
164    response_timeout: Duration,
165    stream: Mutex<RawPacketStream>,
166    cache: Arc<ArpCache>,
167
168    notification_handler: Arc<NotificationHandler>,
169    _task_spawner: BackgroundTaskSpawner,
170}
171
172impl Client {
173    /// Creates a new `Client` with the given configuration.
174    ///
175    /// This function initializes a raw packet stream, binds it to the specified
176    /// network interface, and sets up caching and background tasks for listening
177    /// to ARP responses.
178    ///
179    /// # Errors
180    /// Returns an error if the packet stream cannot be created or if binding to
181    /// the specified network interface fails.
182    pub fn new(config: ClientConfig) -> Result<Self> {
183        let mut stream = RawPacketStream::new().map_err(|err| {
184            Error::Opaque(format!("failed to create packet stream, reason: {}", err).into())
185        })?;
186        stream.bind(&config.interface_name).map_err(|err| {
187            Error::Opaque(format!("failed to bind interface to stream, reason {}", err).into())
188        })?;
189
190        let notification_handler = Arc::new(NotificationHandler::new());
191        let cache = Arc::new(ArpCache::new(
192            config.cache_timeout,
193            Arc::clone(&notification_handler),
194        ));
195
196        let mut task_spawner = BackgroundTaskSpawner::new();
197        task_spawner.spawn(Listener::new(stream.clone(), Arc::clone(&cache)));
198
199        Ok(Self {
200            response_timeout: config.response_timeout,
201            stream: Mutex::new(stream),
202            cache,
203            notification_handler,
204            _task_spawner: task_spawner,
205        })
206    }
207
208    /// Probes for the presence of a device at the given IP address.
209    ///
210    /// This function sends an ARP request to determine whether an IP address
211    /// is occupied. It returns a `ProbeOutcome`, indicating whether the address
212    /// is in use.
213    ///
214    /// # Example
215    /// ```no_run
216    /// use async_arp::{Client, ClientConfigBuilder, ProbeStatus, ProbeInputBuilder};
217    /// use pnet::util::MacAddr;
218    /// use std::net::Ipv4Addr;
219    ///
220    /// let probe_input = ProbeInputBuilder::new()
221    ///     .with_sender_mac(MacAddr::new(0x00, 0x1A, 0x2B, 0x3C, 0x4D, 0x5E))
222    ///     .with_target_ip(Ipv4Addr::new(192, 168, 1, 1))
223    ///     .build()
224    ///     .expect("Failed to build probe input");
225    /// tokio_test::block_on(async {
226    ///     let client = Client::new(ClientConfigBuilder::new("eth0").build()).unwrap();
227    ///     let outcome = client.probe(probe_input).await.unwrap();
228    ///     match outcome.status {
229    ///         ProbeStatus::Occupied => println!("IP is in use"),
230    ///         ProbeStatus::Free => println!("IP is available"),
231    /// }
232    /// })
233    /// ```
234    ///
235    /// # Errors
236    /// Returns an error if sending the ARP request fails.
237    pub async fn probe(&self, input: ProbeInput) -> Result<ProbeOutcome> {
238        let input = RequestInput {
239            sender_ip: Ipv4Addr::UNSPECIFIED,
240            sender_mac: input.sender_mac,
241            target_ip: input.target_ip,
242            target_mac: MacAddr::zero(),
243        };
244
245        match self.request(input).await {
246            Ok(_) => Ok(ProbeOutcome::new(ProbeStatus::Occupied, input.target_ip)),
247            Err(Error::ResponseTimeout) => {
248                Ok(ProbeOutcome::new(ProbeStatus::Free, input.target_ip))
249            }
250            Err(err) => Err(err),
251        }
252    }
253
254    /// Sends an ARP request and waits for a response.
255    ///
256    /// If the requested IP is already cached, the cached response is returned immediately.
257    /// Otherwise, a new ARP request is sent, and the client waits for a response within
258    /// the configured timeout period.
259    ///
260    /// # Example
261    /// ```no_run
262    /// use pnet::util::MacAddr;
263    /// use std::net::Ipv4Addr;
264    /// use async_arp::{Client, ClientConfigBuilder, RequestInputBuilder};
265    ///
266    /// let request_input = RequestInputBuilder::new()
267    ///     .with_sender_ip(Ipv4Addr::new(192, 168, 1, 100))
268    ///     .with_sender_mac(MacAddr::new(0x00, 0x1A, 0x2B, 0x3C, 0x4D, 0x5E))
269    ///     .with_target_ip(Ipv4Addr::new(192, 168, 1, 1))
270    ///     .with_target_mac(MacAddr::zero())
271    ///     .build()
272    ///     .expect("Failed to build request input");
273    /// tokio_test::block_on(async {
274    ///     let client = Client::new(ClientConfigBuilder::new("eth0").build()).unwrap();
275    ///     let outcome = client.request(request_input).await.unwrap();
276    ///
277    ///     println!("Received response: {:?}", outcome);
278    /// })
279    /// ```
280    ///
281    /// # Errors
282    /// Returns an error if sending the request fails or if no response is received
283    /// within the timeout period.
284    pub async fn request(&self, input: RequestInput) -> Result<RequestOutcome> {
285        if let Some(cached) = self.cache.get(&input.target_ip) {
286            return Ok(RequestOutcome::new(input, cached));
287        }
288        let mut eth_buf = [0; ETH_PACK_LEN];
289        Self::fill_packet_buf(&mut eth_buf, &input);
290        let notifier = self
291            .notification_handler
292            .register_notifier(input.target_ip)
293            .await;
294        self.stream
295            .lock()
296            .await
297            .write_all(&eth_buf)
298            .await
299            .map_err(|err| {
300                Error::Opaque(format!("failed to send request, reason: {}", err).into())
301            })?;
302
303        let response = tokio::time::timeout(
304            self.response_timeout,
305            self.await_response(notifier, &input.target_ip),
306        )
307        .await
308        .map_err(|_| Error::ResponseTimeout)?;
309        Ok(RequestOutcome::new(input, response))
310    }
311
312    fn fill_packet_buf(eth_buf: &mut [u8], input: &RequestInput) {
313        let mut eth_packet = MutableEthernetPacket::new(eth_buf).unwrap();
314        eth_packet.set_destination(MacAddr::broadcast());
315        eth_packet.set_source(input.sender_mac);
316        eth_packet.set_ethertype(EtherTypes::Arp);
317
318        let mut arp_buf = [0; ARP_PACK_LEN];
319        let mut arp_packet = MutableArpPacket::new(&mut arp_buf).unwrap();
320        arp_packet.set_hardware_type(ArpHardwareTypes::Ethernet);
321        arp_packet.set_protocol_type(EtherTypes::Ipv4);
322        arp_packet.set_hw_addr_len(MAC_ADDR_LEN);
323        arp_packet.set_proto_addr_len(IP_V4_LEN);
324        arp_packet.set_operation(ArpOperations::Request);
325        arp_packet.set_sender_hw_addr(input.sender_mac);
326        arp_packet.set_sender_proto_addr(input.sender_ip);
327        arp_packet.set_target_hw_addr(input.target_mac);
328        arp_packet.set_target_proto_addr(input.target_ip);
329
330        eth_packet.set_payload(arp_packet.packet());
331    }
332
333    async fn await_response(&self, notifier: Arc<Notify>, target_ip: &Ipv4Addr) -> Arp {
334        loop {
335            notifier.notified().await;
336            {
337                if let Some(packet) = self.cache.get(target_ip) {
338                    return packet;
339                }
340            }
341        }
342    }
343}
344
345#[derive(Debug)]
346struct BackgroundTaskSpawner {
347    token: CancellationToken,
348    handle: Option<JoinHandle<()>>,
349}
350
351impl BackgroundTaskSpawner {
352    fn new() -> Self {
353        Self {
354            token: CancellationToken::new(),
355            handle: None,
356        }
357    }
358
359    fn spawn(&mut self, mut listener: Listener) {
360        let token = self.token.clone();
361        let handle = tokio::task::spawn(async move {
362            tokio::select! {
363                _ = listener.listen() => {
364
365                },
366                _ = token.cancelled() => {
367                }
368            }
369        });
370        self.handle = Some(handle);
371    }
372}
373
374impl Drop for BackgroundTaskSpawner {
375    fn drop(&mut self) {
376        if self.handle.is_some() {
377            self.token.cancel();
378        }
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use std::{net::Ipv4Addr, sync::Arc};
385
386    use crate::{
387        client::{Client, ClientConfigBuilder, ProbeStatus},
388        constants::{ARP_PACK_LEN, ETH_PACK_LEN, IP_V4_LEN, MAC_ADDR_LEN},
389        probe::ProbeInputBuilder,
390        response::parse_arp_packet,
391    };
392    use afpacket::tokio::RawPacketStream;
393    use ipnet::Ipv4Net;
394    use pnet::{
395        datalink,
396        packet::{
397            arp::{ArpHardwareTypes, ArpOperations, MutableArpPacket},
398            ethernet::{EtherTypes, MutableEthernetPacket},
399            Packet,
400        },
401        util::MacAddr,
402    };
403    use tokio::io::{AsyncReadExt, AsyncWriteExt};
404
405    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
406    type Result<T> = std::result::Result<T, Error>;
407
408    struct Server {
409        mac: MacAddr,
410        stream: RawPacketStream,
411        net: Ipv4Net,
412    }
413
414    impl Server {
415        fn new(interface_name: &str, net: Ipv4Net) -> Result<Self> {
416            let interfaces = datalink::interfaces();
417            let interface = interfaces
418                .into_iter()
419                .find(|iface| iface.name == interface_name)
420                .ok_or_else(|| format!("interface {} not found", interface_name))?;
421            let mut stream = RawPacketStream::new()?;
422            stream.bind(interface_name)?;
423            Ok(Self {
424                mac: interface.mac.unwrap(),
425                stream,
426                net,
427            })
428        }
429
430        async fn serve(&mut self) -> Result<()> {
431            let mut request_buf = [0; ETH_PACK_LEN];
432            let mut arp_buf = [0; ARP_PACK_LEN];
433            let mut response_buf = [0; ETH_PACK_LEN];
434            while let Ok(read_bytes) = self.stream.read(&mut request_buf).await {
435                if let Ok(request) = parse_arp_packet(&request_buf[..read_bytes]) {
436                    if self.net.contains(&request.target_proto_addr) {
437                        let mut arp_response = MutableArpPacket::new(&mut arp_buf).unwrap();
438                        arp_response.set_hardware_type(ArpHardwareTypes::Ethernet);
439                        arp_response.set_protocol_type(EtherTypes::Ipv4);
440                        arp_response.set_hw_addr_len(MAC_ADDR_LEN);
441                        arp_response.set_proto_addr_len(IP_V4_LEN);
442                        arp_response.set_operation(ArpOperations::Reply);
443
444                        arp_response.set_sender_proto_addr(request.target_proto_addr);
445                        arp_response.set_sender_hw_addr(self.mac);
446                        arp_response.set_target_proto_addr(request.sender_proto_addr);
447                        arp_response.set_target_hw_addr(request.sender_hw_addr);
448
449                        let mut eth_response = MutableEthernetPacket::new(&mut response_buf)
450                            .ok_or("failed to create Ethernet frame")?;
451                        eth_response.set_ethertype(EtherTypes::Arp);
452                        eth_response.set_destination(request.sender_hw_addr);
453                        eth_response.set_source(self.mac);
454                        eth_response.set_payload(arp_response.packet());
455
456                        self.stream.write_all(eth_response.packet()).await?;
457                    }
458                }
459            }
460            Ok(())
461        }
462    }
463
464    #[tokio::test]
465    async fn test_detection() {
466        const INTERFACE_NAME: &str = "dummy0";
467        tokio::spawn(async move {
468            let net = Ipv4Net::new(Ipv4Addr::new(10, 1, 1, 0), 25).unwrap();
469            let mut server = Server::new(INTERFACE_NAME, net).unwrap();
470            server.serve().await.unwrap();
471        });
472        {
473            let client =
474                Arc::new(Client::new(ClientConfigBuilder::new(INTERFACE_NAME).build()).unwrap());
475
476            let sender_mac = datalink::interfaces()
477                .into_iter()
478                .find(|iface| iface.name == INTERFACE_NAME)
479                .ok_or_else(|| format!("interface {} not found", INTERFACE_NAME))
480                .unwrap()
481                .mac
482                .ok_or("interface does not have mac address")
483                .unwrap();
484
485            let future_probes: Vec<_> = (0..128)
486                .map(|ip_d| {
487                    let client_clone = client.clone();
488                    async move {
489                        let builder = ProbeInputBuilder::new()
490                            .with_sender_mac(sender_mac)
491                            .with_target_ip(Ipv4Addr::new(10, 1, 1, ip_d as u8));
492                        client_clone.probe(builder.build().unwrap()).await.unwrap()
493                    }
494                })
495                .collect();
496
497            let outcomes = futures::future::join_all(future_probes).await;
498            for outcome in outcomes {
499                assert_eq!(outcome.status, ProbeStatus::Occupied);
500            }
501
502            let future_probes: Vec<_> = (128..=255)
503                .map(|ip_d| {
504                    let client_clone = client.clone();
505                    async move {
506                        let builder = ProbeInputBuilder::new()
507                            .with_sender_mac(sender_mac)
508                            .with_target_ip(Ipv4Addr::new(10, 1, 1, ip_d as u8));
509                        client_clone.probe(builder.build().unwrap()).await.unwrap()
510                    }
511                })
512                .collect();
513
514            let outcomes = futures::future::join_all(future_probes).await;
515            for outcome in outcomes {
516                assert_eq!(outcome.status, ProbeStatus::Free);
517            }
518        }
519    }
520}