1use afpacket::tokio::RawPacketStream;
2use pnet::{
3 packet::{
4 arp::{Arp, ArpHardwareTypes, ArpOperations, MutableArpPacket},
5 ethernet::{EtherTypes, MutableEthernetPacket},
6 Packet,
7 },
8 util::MacAddr,
9};
10
11use std::{future::Future, net::Ipv4Addr, sync::Arc, time::Duration};
12use tokio::task::JoinHandle;
13use tokio::{
14 io::AsyncWriteExt,
15 sync::{Mutex, Notify},
16};
17
18use tokio_util::sync::CancellationToken;
19
20use crate::{
21 caching::ArpCache, error::ResponseTimeout, probe::ProbeInput, request::RequestOutcome,
22};
23use crate::{constants::IP_V4_LEN, notification::NotificationHandler};
24use crate::{
25 constants::{ARP_PACK_LEN, ETH_PACK_LEN, MAC_ADDR_LEN},
26 request::RequestInput,
27};
28use crate::{
29 error::{Error, Result},
30 probe::ProbeOutcome,
31};
32use crate::{probe::ProbeStatus, response::Listener};
33
34#[derive(Debug)]
40pub struct ClientSpinner {
41 client: Client,
42 n_retries: usize,
43}
44
45impl ClientSpinner {
46 pub fn new(client: Client) -> Self {
51 Self {
52 client,
53 n_retries: 0,
54 }
55 }
56
57 pub fn with_retries(mut self, n_retires: usize) -> Self {
59 self.n_retries = n_retires;
60 self
61 }
62
63 pub async fn probe_batch(&self, inputs: &[ProbeInput]) -> Result<Vec<ProbeOutcome>> {
69 let futures_producer = || {
70 inputs
71 .iter()
72 .map(|input| async { self.client.probe(*input).await })
73 };
74 Self::handle_retries(self.n_retries, futures_producer).await
75 }
76
77 pub async fn request_batch(&self, inputs: &[RequestInput]) -> Result<Vec<RequestOutcome>> {
83 let futures_producer = || {
84 inputs
85 .iter()
86 .map(|input| async { self.client.request(*input).await })
87 };
88 Self::handle_retries(self.n_retries, futures_producer).await
89 }
90
91 async fn handle_retries<F, I, Fut, T>(n_retries: usize, futures_producer: F) -> Result<Vec<T>>
92 where
93 F: Fn() -> I,
94 Fut: Future<Output = Result<T>>,
95 I: Iterator<Item = Fut>,
96 {
97 for _ in 0..n_retries {
98 futures::future::try_join_all(futures_producer()).await?;
99 }
100 futures::future::try_join_all(futures_producer()).await
101 }
102}
103
104#[derive(Debug, Clone)]
105pub struct ClientConfig {
106 pub interface_name: String,
107 pub response_timeout: Duration,
108 pub cache_timeout: Duration,
109}
110
111#[derive(Debug, Clone)]
112pub struct ClientConfigBuilder {
113 interface_name: String,
114 response_timeout: Option<Duration>,
115 cache_timeout: Option<Duration>,
116}
117
118impl ClientConfigBuilder {
119 pub fn new(interface_name: &str) -> Self {
120 Self {
121 interface_name: interface_name.into(),
122 response_timeout: Some(Duration::from_secs(1)),
123 cache_timeout: Some(Duration::from_secs(60)),
124 }
125 }
126
127 pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
128 self.response_timeout = Some(timeout);
129 self
130 }
131
132 pub fn with_cache_timeout(mut self, timeout: Duration) -> Self {
133 self.cache_timeout = Some(timeout);
134 self
135 }
136
137 pub fn build(self) -> ClientConfig {
138 ClientConfig {
139 interface_name: self.interface_name,
140 cache_timeout: self.cache_timeout.unwrap(),
141 response_timeout: self.response_timeout.unwrap(),
142 }
143 }
144}
145
146#[derive(Debug)]
165pub struct Client {
166 response_timeout: Duration,
167 stream: Mutex<RawPacketStream>,
168 cache: Arc<ArpCache>,
169
170 notification_handler: Arc<NotificationHandler>,
171 _task_spawner: BackgroundTaskSpawner,
172}
173
174impl Client {
175 pub fn new(config: ClientConfig) -> Result<Self> {
185 let mut stream = RawPacketStream::new().map_err(|err| {
186 Error::Opaque(format!("failed to create packet stream, reason: {}", err).into())
187 })?;
188 stream.bind(&config.interface_name).map_err(|err| {
189 Error::Opaque(format!("failed to bind interface to stream, reason {}", err).into())
190 })?;
191
192 let notification_handler = Arc::new(NotificationHandler::new());
193 let cache = Arc::new(ArpCache::new(
194 config.cache_timeout,
195 Arc::clone(¬ification_handler),
196 ));
197
198 let mut task_spawner = BackgroundTaskSpawner::new();
199 task_spawner.spawn(Listener::new(stream.clone(), Arc::clone(&cache)));
200
201 Ok(Self {
202 response_timeout: config.response_timeout,
203 stream: Mutex::new(stream),
204 cache,
205 notification_handler,
206 _task_spawner: task_spawner,
207 })
208 }
209
210 pub async fn probe(&self, input: ProbeInput) -> Result<ProbeOutcome> {
240 let input = RequestInput {
241 sender_ip: Ipv4Addr::UNSPECIFIED,
242 sender_mac: input.sender_mac,
243 target_ip: input.target_ip,
244 target_mac: MacAddr::zero(),
245 };
246
247 match self.request(input).await {
248 Ok(response) => {
249 let status = match response.response_result {
250 Ok(_) => ProbeStatus::Occupied,
251 Err(_) => ProbeStatus::Free,
252 };
253 Ok(ProbeOutcome::new(status, input.target_ip))
254 }
255 Err(err) => Err(err),
256 }
257 }
258
259 pub async fn request(&self, input: RequestInput) -> Result<RequestOutcome> {
290 if let Some(cached) = self.cache.get(&input.target_ip) {
291 return Ok(RequestOutcome::new(input, Ok(cached)));
292 }
293 let mut eth_buf = [0; ETH_PACK_LEN];
294 Self::fill_packet_buf(&mut eth_buf, &input);
295 let notifier = self
296 .notification_handler
297 .register_notifier(input.target_ip)
298 .await;
299 self.stream
300 .lock()
301 .await
302 .write_all(ð_buf)
303 .await
304 .map_err(|err| {
305 Error::Opaque(format!("failed to send request, reason: {}", err).into())
306 })?;
307
308 let response_result = tokio::time::timeout(
309 self.response_timeout,
310 self.await_response(notifier, &input.target_ip),
311 )
312 .await
313 .map_err(|_| ResponseTimeout {});
314 Ok(RequestOutcome::new(input, response_result))
315 }
316
317 fn fill_packet_buf(eth_buf: &mut [u8], input: &RequestInput) {
318 let mut eth_packet = MutableEthernetPacket::new(eth_buf).unwrap();
319 eth_packet.set_destination(MacAddr::broadcast());
320 eth_packet.set_source(input.sender_mac);
321 eth_packet.set_ethertype(EtherTypes::Arp);
322
323 let mut arp_buf = [0; ARP_PACK_LEN];
324 let mut arp_packet = MutableArpPacket::new(&mut arp_buf).unwrap();
325 arp_packet.set_hardware_type(ArpHardwareTypes::Ethernet);
326 arp_packet.set_protocol_type(EtherTypes::Ipv4);
327 arp_packet.set_hw_addr_len(MAC_ADDR_LEN);
328 arp_packet.set_proto_addr_len(IP_V4_LEN);
329 arp_packet.set_operation(ArpOperations::Request);
330 arp_packet.set_sender_hw_addr(input.sender_mac);
331 arp_packet.set_sender_proto_addr(input.sender_ip);
332 arp_packet.set_target_hw_addr(input.target_mac);
333 arp_packet.set_target_proto_addr(input.target_ip);
334
335 eth_packet.set_payload(arp_packet.packet());
336 }
337
338 async fn await_response(&self, notifier: Arc<Notify>, target_ip: &Ipv4Addr) -> Arp {
339 loop {
340 notifier.notified().await;
341 {
342 if let Some(packet) = self.cache.get(target_ip) {
343 return packet;
344 }
345 }
346 }
347 }
348}
349
350#[derive(Debug)]
351struct BackgroundTaskSpawner {
352 token: CancellationToken,
353 handle: Option<JoinHandle<()>>,
354}
355
356impl BackgroundTaskSpawner {
357 fn new() -> Self {
358 Self {
359 token: CancellationToken::new(),
360 handle: None,
361 }
362 }
363
364 fn spawn(&mut self, mut listener: Listener) {
365 let token = self.token.clone();
366 let handle = tokio::task::spawn(async move {
367 tokio::select! {
368 _ = listener.listen() => {
369
370 },
371 _ = token.cancelled() => {
372 }
373 }
374 });
375 self.handle = Some(handle);
376 }
377}
378
379impl Drop for BackgroundTaskSpawner {
380 fn drop(&mut self) {
381 if self.handle.is_some() {
382 self.token.cancel();
383 }
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use std::{net::Ipv4Addr, sync::Arc};
390
391 use crate::{
392 client::{Client, ClientConfigBuilder, ProbeStatus},
393 constants::{ARP_PACK_LEN, ETH_PACK_LEN, IP_V4_LEN, MAC_ADDR_LEN},
394 probe::ProbeInputBuilder,
395 response::parse_arp_packet,
396 ClientSpinner,
397 };
398 use afpacket::tokio::RawPacketStream;
399 use ipnet::Ipv4Net;
400 use pnet::{
401 datalink,
402 packet::{
403 arp::{ArpHardwareTypes, ArpOperations, MutableArpPacket},
404 ethernet::{EtherTypes, MutableEthernetPacket},
405 Packet,
406 },
407 util::MacAddr,
408 };
409 use tokio::io::{AsyncReadExt, AsyncWriteExt};
410
411 type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
412 type Result<T> = std::result::Result<T, Error>;
413
414 struct Server {
415 mac: MacAddr,
416 stream: RawPacketStream,
417 net: Ipv4Net,
418 }
419
420 impl Server {
421 fn new(interface_name: &str, net: Ipv4Net) -> Result<Self> {
422 let interfaces = datalink::interfaces();
423 let interface = interfaces
424 .into_iter()
425 .find(|iface| iface.name == interface_name)
426 .ok_or_else(|| format!("interface {} not found", interface_name))?;
427 let mut stream = RawPacketStream::new()?;
428 stream.bind(interface_name)?;
429 Ok(Self {
430 mac: interface.mac.unwrap(),
431 stream,
432 net,
433 })
434 }
435
436 async fn serve(&mut self) -> Result<()> {
437 let mut request_buf = [0; ETH_PACK_LEN];
438 let mut arp_buf = [0; ARP_PACK_LEN];
439 let mut response_buf = [0; ETH_PACK_LEN];
440 while let Ok(read_bytes) = self.stream.read(&mut request_buf).await {
441 if let Ok(request) = parse_arp_packet(&request_buf[..read_bytes]) {
442 if self.net.contains(&request.target_proto_addr) {
443 let mut arp_response = MutableArpPacket::new(&mut arp_buf).unwrap();
444 arp_response.set_hardware_type(ArpHardwareTypes::Ethernet);
445 arp_response.set_protocol_type(EtherTypes::Ipv4);
446 arp_response.set_hw_addr_len(MAC_ADDR_LEN);
447 arp_response.set_proto_addr_len(IP_V4_LEN);
448 arp_response.set_operation(ArpOperations::Reply);
449
450 arp_response.set_sender_proto_addr(request.target_proto_addr);
451 arp_response.set_sender_hw_addr(self.mac);
452 arp_response.set_target_proto_addr(request.sender_proto_addr);
453 arp_response.set_target_hw_addr(request.sender_hw_addr);
454
455 let mut eth_response = MutableEthernetPacket::new(&mut response_buf)
456 .ok_or("failed to create Ethernet frame")?;
457 eth_response.set_ethertype(EtherTypes::Arp);
458 eth_response.set_destination(request.sender_hw_addr);
459 eth_response.set_source(self.mac);
460 eth_response.set_payload(arp_response.packet());
461
462 self.stream.write_all(eth_response.packet()).await?;
463 }
464 }
465 }
466 Ok(())
467 }
468 }
469
470 #[tokio::test]
471 async fn test_spinner_down_interface() {
472 const INTERFACE_NAME: &str = "down_dummy";
473 let client = Client::new(ClientConfigBuilder::new(INTERFACE_NAME).build()).unwrap();
474 let spinner = ClientSpinner::new(client);
475 let result = spinner
476 .probe_batch(&[ProbeInputBuilder::new()
477 .with_sender_mac(MacAddr::broadcast())
478 .with_target_ip(Ipv4Addr::new(10, 1, 1, 1))
479 .build()
480 .unwrap()])
481 .await;
482 assert!(result.is_err())
483 }
484
485 #[tokio::test]
487 async fn test_invalid_interface() {
488 const INTERFACE_NAME: &str = "invalid_dummy";
489 assert!(Client::new(ClientConfigBuilder::new(INTERFACE_NAME).build()).is_err());
490 }
491
492 #[tokio::test]
493 async fn test_client_detection() {
494 const INTERFACE_NAME: &str = "dummy0";
495 tokio::spawn(async move {
496 let net = Ipv4Net::new(Ipv4Addr::new(10, 1, 1, 0), 25).unwrap();
497 let mut server = Server::new(INTERFACE_NAME, net).unwrap();
498 server.serve().await.unwrap();
499 });
500 {
501 let client =
502 Arc::new(Client::new(ClientConfigBuilder::new(INTERFACE_NAME).build()).unwrap());
503
504 let sender_mac = datalink::interfaces()
505 .into_iter()
506 .find(|iface| iface.name == INTERFACE_NAME)
507 .ok_or_else(|| format!("interface {} not found", INTERFACE_NAME))
508 .unwrap()
509 .mac
510 .ok_or("interface does not have mac address")
511 .unwrap();
512
513 let future_probes: Vec<_> = (0..128)
514 .map(|ip_d| {
515 let client_clone = client.clone();
516 async move {
517 let builder = ProbeInputBuilder::new()
518 .with_sender_mac(sender_mac)
519 .with_target_ip(Ipv4Addr::new(10, 1, 1, ip_d as u8));
520 client_clone.probe(builder.build().unwrap()).await.unwrap()
521 }
522 })
523 .collect();
524
525 let outcomes = futures::future::join_all(future_probes).await;
526 for outcome in outcomes {
527 assert_eq!(outcome.status, ProbeStatus::Occupied);
528 }
529
530 let future_probes: Vec<_> = (128..=255)
531 .map(|ip_d| {
532 let client_clone = client.clone();
533 async move {
534 let builder = ProbeInputBuilder::new()
535 .with_sender_mac(sender_mac)
536 .with_target_ip(Ipv4Addr::new(10, 1, 1, ip_d as u8));
537 client_clone.probe(builder.build().unwrap()).await.unwrap()
538 }
539 })
540 .collect();
541
542 let outcomes = futures::future::join_all(future_probes).await;
543 for outcome in outcomes {
544 assert_eq!(outcome.status, ProbeStatus::Free);
545 }
546 }
547 }
548}