async_arp/
client.rs

1use afpacket::tokio::RawPacketStream;
2use pnet::{
3    packet::{
4        arp::{Arp, ArpHardwareTypes, ArpOperations, MutableArpPacket},
5        ethernet::{EtherTypes, MutableEthernetPacket},
6        Packet,
7    },
8    util::MacAddr,
9};
10
11use std::{future::Future, net::Ipv4Addr, sync::Arc, time::Duration};
12use tokio::task::JoinHandle;
13use tokio::{
14    io::AsyncWriteExt,
15    sync::{Mutex, Notify},
16};
17
18use tokio_util::sync::CancellationToken;
19
20use crate::{
21    caching::ArpCache, error::ResponseTimeout, probe::ProbeInput, request::RequestOutcome,
22};
23use crate::{constants::IP_V4_LEN, notification::NotificationHandler};
24use crate::{
25    constants::{ARP_PACK_LEN, ETH_PACK_LEN, MAC_ADDR_LEN},
26    request::RequestInput,
27};
28use crate::{
29    error::{Error, Result},
30    probe::ProbeOutcome,
31};
32use crate::{probe::ProbeStatus, response::Listener};
33
34/// A struct responsible for performing batch requests and probes with retry logic.
35///
36/// This struct abstracts the client and provides methods to perform multiple
37/// probes or requests with retry capabilities. The number of retries can be
38/// configured during initialization.
39#[derive(Debug)]
40pub struct ClientSpinner {
41    client: Client,
42    n_retries: usize,
43}
44
45impl ClientSpinner {
46    /// Creates a new instance of `ClientSpinner` with the given `Client`.
47    ///
48    /// This constructor initializes a `ClientSpinner` with a client and sets the
49    /// number of retries to `0` (no retries).
50    pub fn new(client: Client) -> Self {
51        Self {
52            client,
53            n_retries: 0,
54        }
55    }
56
57    /// Sets the number of retries for subsequent probes and requests.
58    pub fn with_retries(mut self, n_retires: usize) -> Self {
59        self.n_retries = n_retires;
60        self
61    }
62
63    /// Performs a batch of probes asynchronously with retries.
64    ///
65    /// This method takes an array of `ProbeInput` and attempts to probe each one.
66    /// If a probe fails, it will retry up to `n_retries` times before returning the
67    /// results.
68    pub async fn probe_batch(&self, inputs: &[ProbeInput]) -> Result<Vec<ProbeOutcome>> {
69        let futures_producer = || {
70            inputs
71                .iter()
72                .map(|input| async { self.client.probe(*input).await })
73        };
74        Self::handle_retries(self.n_retries, futures_producer).await
75    }
76
77    /// Performs a batch of requests asynchronously with retries.
78    ///
79    /// This method takes an array of `RequestInput` and attempts to request each one.
80    /// If a request fails, it will retry up to `n_retries` times before returning the
81    /// results.
82    pub async fn request_batch(&self, inputs: &[RequestInput]) -> Result<Vec<RequestOutcome>> {
83        let futures_producer = || {
84            inputs
85                .iter()
86                .map(|input| async { self.client.request(*input).await })
87        };
88        Self::handle_retries(self.n_retries, futures_producer).await
89    }
90
91    async fn handle_retries<F, I, Fut, T>(n_retries: usize, futures_producer: F) -> Result<Vec<T>>
92    where
93        F: Fn() -> I,
94        Fut: Future<Output = Result<T>>,
95        I: Iterator<Item = Fut>,
96    {
97        for _ in 0..n_retries {
98            futures::future::try_join_all(futures_producer()).await?;
99        }
100        futures::future::try_join_all(futures_producer()).await
101    }
102}
103
104#[derive(Debug, Clone)]
105pub struct ClientConfig {
106    pub interface_name: String,
107    pub response_timeout: Duration,
108    pub cache_timeout: Duration,
109}
110
111#[derive(Debug, Clone)]
112pub struct ClientConfigBuilder {
113    interface_name: String,
114    response_timeout: Option<Duration>,
115    cache_timeout: Option<Duration>,
116}
117
118impl ClientConfigBuilder {
119    pub fn new(interface_name: &str) -> Self {
120        Self {
121            interface_name: interface_name.into(),
122            response_timeout: Some(Duration::from_secs(1)),
123            cache_timeout: Some(Duration::from_secs(60)),
124        }
125    }
126
127    pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
128        self.response_timeout = Some(timeout);
129        self
130    }
131
132    pub fn with_cache_timeout(mut self, timeout: Duration) -> Self {
133        self.cache_timeout = Some(timeout);
134        self
135    }
136
137    pub fn build(self) -> ClientConfig {
138        ClientConfig {
139            interface_name: self.interface_name,
140            cache_timeout: self.cache_timeout.unwrap(),
141            response_timeout: self.response_timeout.unwrap(),
142        }
143    }
144}
145
146/// A client for handling ARP (Address Resolution Protocol) requests and probes.
147///
148/// The `Client` is responsible for sending ARP requests, caching responses,
149/// and handling notifications. It uses a raw packet stream for network communication.
150///
151/// # Example
152/// ```no_run
153/// use async_arp::{Client, ClientConfig};
154/// use std::time::Duration;
155///
156/// let config = ClientConfig {
157///     interface_name: "eth0".to_string(),
158///     response_timeout: Duration::from_secs(2),
159///     cache_timeout: Duration::from_secs(60),
160/// };
161///
162/// let client = Client::new(config).expect("Failed to create ARP client");
163/// ```
164#[derive(Debug)]
165pub struct Client {
166    response_timeout: Duration,
167    stream: Mutex<RawPacketStream>,
168    cache: Arc<ArpCache>,
169
170    notification_handler: Arc<NotificationHandler>,
171    _task_spawner: BackgroundTaskSpawner,
172}
173
174impl Client {
175    /// Creates a new `Client` with the given configuration.
176    ///
177    /// This function initializes a raw packet stream, binds it to the specified
178    /// network interface, and sets up caching and background tasks for listening
179    /// to ARP responses.
180    ///
181    /// # Errors
182    /// Returns an error if the packet stream cannot be created or if binding to
183    /// the specified network interface fails.
184    pub fn new(config: ClientConfig) -> Result<Self> {
185        let mut stream = RawPacketStream::new().map_err(|err| {
186            Error::Opaque(format!("failed to create packet stream, reason: {}", err).into())
187        })?;
188        stream.bind(&config.interface_name).map_err(|err| {
189            Error::Opaque(format!("failed to bind interface to stream, reason {}", err).into())
190        })?;
191
192        let notification_handler = Arc::new(NotificationHandler::new());
193        let cache = Arc::new(ArpCache::new(
194            config.cache_timeout,
195            Arc::clone(&notification_handler),
196        ));
197
198        let mut task_spawner = BackgroundTaskSpawner::new();
199        task_spawner.spawn(Listener::new(stream.clone(), Arc::clone(&cache)));
200
201        Ok(Self {
202            response_timeout: config.response_timeout,
203            stream: Mutex::new(stream),
204            cache,
205            notification_handler,
206            _task_spawner: task_spawner,
207        })
208    }
209
210    /// Probes for the presence of a device at the given IP address.
211    ///
212    /// This function sends an ARP request to determine whether an IP address
213    /// is occupied. It returns a `ProbeOutcome`, indicating whether the address
214    /// is in use.
215    ///
216    /// # Example
217    /// ```no_run
218    /// use async_arp::{Client, ClientConfigBuilder, ProbeStatus, ProbeInputBuilder};
219    /// use pnet::util::MacAddr;
220    /// use std::net::Ipv4Addr;
221    ///
222    /// let probe_input = ProbeInputBuilder::new()
223    ///     .with_sender_mac(MacAddr::new(0x00, 0x1A, 0x2B, 0x3C, 0x4D, 0x5E))
224    ///     .with_target_ip(Ipv4Addr::new(192, 168, 1, 1))
225    ///     .build()
226    ///     .expect("Failed to build probe input");
227    /// tokio_test::block_on(async {
228    ///     let client = Client::new(ClientConfigBuilder::new("eth0").build()).unwrap();
229    ///     let outcome = client.probe(probe_input).await.unwrap();
230    ///     match outcome.status {
231    ///         ProbeStatus::Occupied => println!("IP is in use"),
232    ///         ProbeStatus::Free => println!("IP is available"),
233    /// }
234    /// })
235    /// ```
236    ///
237    /// # Errors
238    /// Returns an error if sending the ARP request fails.
239    pub async fn probe(&self, input: ProbeInput) -> Result<ProbeOutcome> {
240        let input = RequestInput {
241            sender_ip: Ipv4Addr::UNSPECIFIED,
242            sender_mac: input.sender_mac,
243            target_ip: input.target_ip,
244            target_mac: MacAddr::zero(),
245        };
246
247        match self.request(input).await {
248            Ok(response) => {
249                let status = match response.response_result {
250                    Ok(_) => ProbeStatus::Occupied,
251                    Err(_) => ProbeStatus::Free,
252                };
253                Ok(ProbeOutcome::new(status, input.target_ip))
254            }
255            Err(err) => Err(err),
256        }
257    }
258
259    /// Sends an ARP request and waits for a response.
260    ///
261    /// If the requested IP is already cached, the cached response is returned immediately.
262    /// Otherwise, a new ARP request is sent, and the client waits for a response within
263    /// the configured timeout period.
264    ///
265    /// # Example
266    /// ```no_run
267    /// use pnet::util::MacAddr;
268    /// use std::net::Ipv4Addr;
269    /// use async_arp::{Client, ClientConfigBuilder, RequestInputBuilder};
270    ///
271    /// let request_input = RequestInputBuilder::new()
272    ///     .with_sender_ip(Ipv4Addr::new(192, 168, 1, 100))
273    ///     .with_sender_mac(MacAddr::new(0x00, 0x1A, 0x2B, 0x3C, 0x4D, 0x5E))
274    ///     .with_target_ip(Ipv4Addr::new(192, 168, 1, 1))
275    ///     .with_target_mac(MacAddr::zero())
276    ///     .build()
277    ///     .expect("Failed to build request input");
278    /// tokio_test::block_on(async {
279    ///     let client = Client::new(ClientConfigBuilder::new("eth0").build()).unwrap();
280    ///     let outcome = client.request(request_input).await.unwrap();
281    ///
282    ///     println!("Received response: {:?}", outcome);
283    /// })
284    /// ```
285    ///
286    /// # Errors
287    /// Returns an error if sending the request fails or if no response is received
288    /// within the timeout period.
289    pub async fn request(&self, input: RequestInput) -> Result<RequestOutcome> {
290        if let Some(cached) = self.cache.get(&input.target_ip) {
291            return Ok(RequestOutcome::new(input, Ok(cached)));
292        }
293        let mut eth_buf = [0; ETH_PACK_LEN];
294        Self::fill_packet_buf(&mut eth_buf, &input);
295        let notifier = self
296            .notification_handler
297            .register_notifier(input.target_ip)
298            .await;
299        self.stream
300            .lock()
301            .await
302            .write_all(&eth_buf)
303            .await
304            .map_err(|err| {
305                Error::Opaque(format!("failed to send request, reason: {}", err).into())
306            })?;
307
308        let response_result = tokio::time::timeout(
309            self.response_timeout,
310            self.await_response(notifier, &input.target_ip),
311        )
312        .await
313        .map_err(|_| ResponseTimeout {});
314        Ok(RequestOutcome::new(input, response_result))
315    }
316
317    fn fill_packet_buf(eth_buf: &mut [u8], input: &RequestInput) {
318        let mut eth_packet = MutableEthernetPacket::new(eth_buf).unwrap();
319        eth_packet.set_destination(MacAddr::broadcast());
320        eth_packet.set_source(input.sender_mac);
321        eth_packet.set_ethertype(EtherTypes::Arp);
322
323        let mut arp_buf = [0; ARP_PACK_LEN];
324        let mut arp_packet = MutableArpPacket::new(&mut arp_buf).unwrap();
325        arp_packet.set_hardware_type(ArpHardwareTypes::Ethernet);
326        arp_packet.set_protocol_type(EtherTypes::Ipv4);
327        arp_packet.set_hw_addr_len(MAC_ADDR_LEN);
328        arp_packet.set_proto_addr_len(IP_V4_LEN);
329        arp_packet.set_operation(ArpOperations::Request);
330        arp_packet.set_sender_hw_addr(input.sender_mac);
331        arp_packet.set_sender_proto_addr(input.sender_ip);
332        arp_packet.set_target_hw_addr(input.target_mac);
333        arp_packet.set_target_proto_addr(input.target_ip);
334
335        eth_packet.set_payload(arp_packet.packet());
336    }
337
338    async fn await_response(&self, notifier: Arc<Notify>, target_ip: &Ipv4Addr) -> Arp {
339        loop {
340            notifier.notified().await;
341            {
342                if let Some(packet) = self.cache.get(target_ip) {
343                    return packet;
344                }
345            }
346        }
347    }
348}
349
350#[derive(Debug)]
351struct BackgroundTaskSpawner {
352    token: CancellationToken,
353    handle: Option<JoinHandle<()>>,
354}
355
356impl BackgroundTaskSpawner {
357    fn new() -> Self {
358        Self {
359            token: CancellationToken::new(),
360            handle: None,
361        }
362    }
363
364    fn spawn(&mut self, mut listener: Listener) {
365        let token = self.token.clone();
366        let handle = tokio::task::spawn(async move {
367            tokio::select! {
368                _ = listener.listen() => {
369
370                },
371                _ = token.cancelled() => {
372                }
373            }
374        });
375        self.handle = Some(handle);
376    }
377}
378
379impl Drop for BackgroundTaskSpawner {
380    fn drop(&mut self) {
381        if self.handle.is_some() {
382            self.token.cancel();
383        }
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use std::{net::Ipv4Addr, sync::Arc};
390
391    use crate::{
392        client::{Client, ClientConfigBuilder, ProbeStatus},
393        constants::{ARP_PACK_LEN, ETH_PACK_LEN, IP_V4_LEN, MAC_ADDR_LEN},
394        probe::ProbeInputBuilder,
395        response::parse_arp_packet,
396        ClientSpinner,
397    };
398    use afpacket::tokio::RawPacketStream;
399    use ipnet::Ipv4Net;
400    use pnet::{
401        datalink,
402        packet::{
403            arp::{ArpHardwareTypes, ArpOperations, MutableArpPacket},
404            ethernet::{EtherTypes, MutableEthernetPacket},
405            Packet,
406        },
407        util::MacAddr,
408    };
409    use tokio::io::{AsyncReadExt, AsyncWriteExt};
410
411    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
412    type Result<T> = std::result::Result<T, Error>;
413
414    struct Server {
415        mac: MacAddr,
416        stream: RawPacketStream,
417        net: Ipv4Net,
418    }
419
420    impl Server {
421        fn new(interface_name: &str, net: Ipv4Net) -> Result<Self> {
422            let interfaces = datalink::interfaces();
423            let interface = interfaces
424                .into_iter()
425                .find(|iface| iface.name == interface_name)
426                .ok_or_else(|| format!("interface {} not found", interface_name))?;
427            let mut stream = RawPacketStream::new()?;
428            stream.bind(interface_name)?;
429            Ok(Self {
430                mac: interface.mac.unwrap(),
431                stream,
432                net,
433            })
434        }
435
436        async fn serve(&mut self) -> Result<()> {
437            let mut request_buf = [0; ETH_PACK_LEN];
438            let mut arp_buf = [0; ARP_PACK_LEN];
439            let mut response_buf = [0; ETH_PACK_LEN];
440            while let Ok(read_bytes) = self.stream.read(&mut request_buf).await {
441                if let Ok(request) = parse_arp_packet(&request_buf[..read_bytes]) {
442                    if self.net.contains(&request.target_proto_addr) {
443                        let mut arp_response = MutableArpPacket::new(&mut arp_buf).unwrap();
444                        arp_response.set_hardware_type(ArpHardwareTypes::Ethernet);
445                        arp_response.set_protocol_type(EtherTypes::Ipv4);
446                        arp_response.set_hw_addr_len(MAC_ADDR_LEN);
447                        arp_response.set_proto_addr_len(IP_V4_LEN);
448                        arp_response.set_operation(ArpOperations::Reply);
449
450                        arp_response.set_sender_proto_addr(request.target_proto_addr);
451                        arp_response.set_sender_hw_addr(self.mac);
452                        arp_response.set_target_proto_addr(request.sender_proto_addr);
453                        arp_response.set_target_hw_addr(request.sender_hw_addr);
454
455                        let mut eth_response = MutableEthernetPacket::new(&mut response_buf)
456                            .ok_or("failed to create Ethernet frame")?;
457                        eth_response.set_ethertype(EtherTypes::Arp);
458                        eth_response.set_destination(request.sender_hw_addr);
459                        eth_response.set_source(self.mac);
460                        eth_response.set_payload(arp_response.packet());
461
462                        self.stream.write_all(eth_response.packet()).await?;
463                    }
464                }
465            }
466            Ok(())
467        }
468    }
469
470    #[tokio::test]
471    async fn test_spinner_down_interface() {
472        const INTERFACE_NAME: &str = "down_dummy";
473        let client = Client::new(ClientConfigBuilder::new(INTERFACE_NAME).build()).unwrap();
474        let spinner = ClientSpinner::new(client);
475        let result = spinner
476            .probe_batch(&[ProbeInputBuilder::new()
477                .with_sender_mac(MacAddr::broadcast())
478                .with_target_ip(Ipv4Addr::new(10, 1, 1, 1))
479                .build()
480                .unwrap()])
481            .await;
482        assert!(result.is_err())
483    }
484
485    // Even though no async functions called directly, tokio runtime must be running to rely on AsyncFd (which is used by dependency)
486    #[tokio::test]
487    async fn test_invalid_interface() {
488        const INTERFACE_NAME: &str = "invalid_dummy";
489        assert!(Client::new(ClientConfigBuilder::new(INTERFACE_NAME).build()).is_err());
490    }
491
492    #[tokio::test]
493    async fn test_client_detection() {
494        const INTERFACE_NAME: &str = "dummy0";
495        tokio::spawn(async move {
496            let net = Ipv4Net::new(Ipv4Addr::new(10, 1, 1, 0), 25).unwrap();
497            let mut server = Server::new(INTERFACE_NAME, net).unwrap();
498            server.serve().await.unwrap();
499        });
500        {
501            let client =
502                Arc::new(Client::new(ClientConfigBuilder::new(INTERFACE_NAME).build()).unwrap());
503
504            let sender_mac = datalink::interfaces()
505                .into_iter()
506                .find(|iface| iface.name == INTERFACE_NAME)
507                .ok_or_else(|| format!("interface {} not found", INTERFACE_NAME))
508                .unwrap()
509                .mac
510                .ok_or("interface does not have mac address")
511                .unwrap();
512
513            let future_probes: Vec<_> = (0..128)
514                .map(|ip_d| {
515                    let client_clone = client.clone();
516                    async move {
517                        let builder = ProbeInputBuilder::new()
518                            .with_sender_mac(sender_mac)
519                            .with_target_ip(Ipv4Addr::new(10, 1, 1, ip_d as u8));
520                        client_clone.probe(builder.build().unwrap()).await.unwrap()
521                    }
522                })
523                .collect();
524
525            let outcomes = futures::future::join_all(future_probes).await;
526            for outcome in outcomes {
527                assert_eq!(outcome.status, ProbeStatus::Occupied);
528            }
529
530            let future_probes: Vec<_> = (128..=255)
531                .map(|ip_d| {
532                    let client_clone = client.clone();
533                    async move {
534                        let builder = ProbeInputBuilder::new()
535                            .with_sender_mac(sender_mac)
536                            .with_target_ip(Ipv4Addr::new(10, 1, 1, ip_d as u8));
537                        client_clone.probe(builder.build().unwrap()).await.unwrap()
538                    }
539                })
540                .collect();
541
542            let outcomes = futures::future::join_all(future_probes).await;
543            for outcome in outcomes {
544                assert_eq!(outcome.status, ProbeStatus::Free);
545            }
546        }
547    }
548}