1use afpacket::tokio::RawPacketStream;
2use pnet::{
3 packet::{
4 arp::{Arp, ArpHardwareTypes, ArpOperations, MutableArpPacket},
5 ethernet::{EtherTypes, MutableEthernetPacket},
6 Packet,
7 },
8 util::MacAddr,
9};
10
11use std::{future::Future, net::Ipv4Addr, sync::Arc, time::Duration};
12use tokio::task::JoinHandle;
13use tokio::{
14 io::AsyncWriteExt,
15 sync::{Mutex, Notify},
16};
17
18use tokio_util::sync::CancellationToken;
19
20use crate::{caching::ArpCache, probe::ProbeInput, request::RequestOutcome};
21use crate::{constants::IP_V4_LEN, notification::NotificationHandler};
22use crate::{
23 constants::{ARP_PACK_LEN, ETH_PACK_LEN, MAC_ADDR_LEN},
24 request::RequestInput,
25};
26use crate::{
27 error::{Error, Result},
28 probe::ProbeOutcome,
29};
30use crate::{probe::ProbeStatus, response::Listener};
31
32pub struct ClientSpinner {
38 client: Client,
39 n_retries: usize,
40}
41
42impl ClientSpinner {
43 pub fn new(client: Client) -> Self {
48 Self {
49 client,
50 n_retries: 0,
51 }
52 }
53
54 pub fn with_retries(mut self, n_retires: usize) -> Self {
56 self.n_retries = n_retires;
57 self
58 }
59
60 pub async fn probe_batch(&self, inputs: &[ProbeInput]) -> Vec<Result<ProbeOutcome>> {
66 let futures_producer = || {
67 inputs
68 .iter()
69 .map(|input| async { self.client.probe(*input).await })
70 };
71 Self::handle_retries(self.n_retries, futures_producer).await
72 }
73
74 pub async fn request_batch(&self, inputs: &[RequestInput]) -> Vec<Result<RequestOutcome>> {
80 let futures_producer = || {
81 inputs
82 .iter()
83 .map(|input| async { self.client.request(*input).await })
84 };
85 Self::handle_retries(self.n_retries, futures_producer).await
86 }
87
88 async fn handle_retries<F, I, Fut, T>(n_retries: usize, futures_producer: F) -> Vec<Result<T>>
89 where
90 F: Fn() -> I,
91 Fut: Future<Output = Result<T>>,
92 I: Iterator<Item = Fut>,
93 {
94 for _ in 0..n_retries {
95 futures::future::join_all(futures_producer()).await;
96 }
97 futures::future::join_all(futures_producer()).await
98 }
99}
100
101pub struct ClientConfig {
102 pub interface_name: String,
103 pub response_timeout: Duration,
104 pub cache_timeout: Duration,
105}
106
107pub struct ClientConfigBuilder {
108 interface_name: String,
109 response_timeout: Option<Duration>,
110 cache_timeout: Option<Duration>,
111}
112
113impl ClientConfigBuilder {
114 pub fn new(interface_name: &str) -> Self {
115 Self {
116 interface_name: interface_name.into(),
117 response_timeout: Some(Duration::from_secs(1)),
118 cache_timeout: Some(Duration::from_secs(60)),
119 }
120 }
121
122 pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
123 self.response_timeout = Some(timeout);
124 self
125 }
126
127 pub fn with_cache_timeout(mut self, timeout: Duration) -> Self {
128 self.cache_timeout = Some(timeout);
129 self
130 }
131
132 pub fn build(self) -> ClientConfig {
133 ClientConfig {
134 interface_name: self.interface_name,
135 cache_timeout: self.cache_timeout.unwrap(),
136 response_timeout: self.response_timeout.unwrap(),
137 }
138 }
139}
140
141pub struct Client {
160 response_timeout: Duration,
161 stream: Mutex<RawPacketStream>,
162 cache: Arc<ArpCache>,
163
164 notification_handler: Arc<NotificationHandler>,
165 _task_spawner: BackgroundTaskSpawner,
166}
167
168impl Client {
169 pub fn new(config: ClientConfig) -> Result<Self> {
179 let mut stream = RawPacketStream::new().map_err(|err| {
180 Error::Opaque(format!("failed to create packet stream, reason: {}", err).into())
181 })?;
182 stream.bind(&config.interface_name).map_err(|err| {
183 Error::Opaque(format!("failed to bind interface to stream, reason {}", err).into())
184 })?;
185
186 let notification_handler = Arc::new(NotificationHandler::new());
187 let cache = Arc::new(ArpCache::new(
188 config.cache_timeout,
189 Arc::clone(¬ification_handler),
190 ));
191
192 let mut task_spawner = BackgroundTaskSpawner::new();
193 task_spawner.spawn(Listener::new(stream.clone(), Arc::clone(&cache)));
194
195 Ok(Self {
196 response_timeout: config.response_timeout,
197 stream: Mutex::new(stream),
198 cache,
199 notification_handler,
200 _task_spawner: task_spawner,
201 })
202 }
203
204 pub async fn probe(&self, input: ProbeInput) -> Result<ProbeOutcome> {
234 let input = RequestInput {
235 sender_ip: Ipv4Addr::UNSPECIFIED,
236 sender_mac: input.sender_mac,
237 target_ip: input.target_ip,
238 target_mac: MacAddr::zero(),
239 };
240
241 match self.request(input).await {
242 Ok(_) => Ok(ProbeOutcome::new(ProbeStatus::Occupied, input.target_ip)),
243 Err(Error::ResponseTimeout) => {
244 Ok(ProbeOutcome::new(ProbeStatus::Free, input.target_ip))
245 }
246 Err(err) => Err(err),
247 }
248 }
249
250 pub async fn request(&self, input: RequestInput) -> Result<RequestOutcome> {
281 if let Some(cached) = self.cache.get(&input.target_ip) {
282 return Ok(RequestOutcome::new(input, cached));
283 }
284 let mut eth_buf = [0; ETH_PACK_LEN];
285 Self::fill_packet_buf(&mut eth_buf, &input);
286 let notifier = self
287 .notification_handler
288 .register_notifier(input.target_ip)
289 .await;
290 self.stream
291 .lock()
292 .await
293 .write_all(ð_buf)
294 .await
295 .map_err(|err| {
296 Error::Opaque(format!("failed to send request, reason: {}", err).into())
297 })?;
298
299 let response = tokio::time::timeout(
300 self.response_timeout,
301 self.await_response(notifier, &input.target_ip),
302 )
303 .await
304 .map_err(|_| Error::ResponseTimeout)?;
305 Ok(RequestOutcome::new(input, response))
306 }
307
308 fn fill_packet_buf(eth_buf: &mut [u8], input: &RequestInput) {
309 let mut eth_packet = MutableEthernetPacket::new(eth_buf).unwrap();
310 eth_packet.set_destination(MacAddr::broadcast());
311 eth_packet.set_source(input.sender_mac);
312 eth_packet.set_ethertype(EtherTypes::Arp);
313
314 let mut arp_buf = [0; ARP_PACK_LEN];
315 let mut arp_packet = MutableArpPacket::new(&mut arp_buf).unwrap();
316 arp_packet.set_hardware_type(ArpHardwareTypes::Ethernet);
317 arp_packet.set_protocol_type(EtherTypes::Ipv4);
318 arp_packet.set_hw_addr_len(MAC_ADDR_LEN);
319 arp_packet.set_proto_addr_len(IP_V4_LEN);
320 arp_packet.set_operation(ArpOperations::Request);
321 arp_packet.set_sender_hw_addr(input.sender_mac);
322 arp_packet.set_sender_proto_addr(input.sender_ip);
323 arp_packet.set_target_hw_addr(input.target_mac);
324 arp_packet.set_target_proto_addr(input.target_ip);
325
326 eth_packet.set_payload(arp_packet.packet());
327 }
328
329 async fn await_response(&self, notifier: Arc<Notify>, target_ip: &Ipv4Addr) -> Arp {
330 loop {
331 notifier.notified().await;
332 {
333 if let Some(packet) = self.cache.get(target_ip) {
334 return packet;
335 }
336 }
337 }
338 }
339}
340
341struct BackgroundTaskSpawner {
342 token: CancellationToken,
343 handle: Option<JoinHandle<()>>,
344}
345
346impl BackgroundTaskSpawner {
347 fn new() -> Self {
348 Self {
349 token: CancellationToken::new(),
350 handle: None,
351 }
352 }
353
354 fn spawn(&mut self, mut listener: Listener) {
355 let token = self.token.clone();
356 let handle = tokio::task::spawn(async move {
357 tokio::select! {
358 _ = listener.listen() => {
359
360 },
361 _ = token.cancelled() => {
362 }
363 }
364 });
365 self.handle = Some(handle);
366 }
367}
368
369impl Drop for BackgroundTaskSpawner {
370 fn drop(&mut self) {
371 if self.handle.is_some() {
372 self.token.cancel();
373 }
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use std::{net::Ipv4Addr, sync::Arc};
380
381 use crate::{
382 client::{Client, ClientConfigBuilder, ProbeStatus},
383 constants::{ARP_PACK_LEN, ETH_PACK_LEN, IP_V4_LEN, MAC_ADDR_LEN},
384 probe::ProbeInputBuilder,
385 response::parse_arp_packet,
386 };
387 use afpacket::tokio::RawPacketStream;
388 use ipnet::Ipv4Net;
389 use pnet::{
390 datalink,
391 packet::{
392 arp::{ArpHardwareTypes, ArpOperations, MutableArpPacket},
393 ethernet::{EtherTypes, MutableEthernetPacket},
394 Packet,
395 },
396 util::MacAddr,
397 };
398 use tokio::io::{AsyncReadExt, AsyncWriteExt};
399
400 type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
401 type Result<T> = std::result::Result<T, Error>;
402
403 struct Server {
404 mac: MacAddr,
405 stream: RawPacketStream,
406 net: Ipv4Net,
407 }
408
409 impl Server {
410 fn new(interface_name: &str, net: Ipv4Net) -> Result<Self> {
411 let interfaces = datalink::interfaces();
412 let interface = interfaces
413 .into_iter()
414 .find(|iface| iface.name == interface_name)
415 .ok_or_else(|| format!("interface {} not found", interface_name))?;
416 let mut stream = RawPacketStream::new()?;
417 stream.bind(interface_name)?;
418 Ok(Self {
419 mac: interface.mac.unwrap(),
420 stream,
421 net,
422 })
423 }
424
425 async fn serve(&mut self) -> Result<()> {
426 let mut request_buf = [0; ETH_PACK_LEN];
427 let mut arp_buf = [0; ARP_PACK_LEN];
428 let mut response_buf = [0; ETH_PACK_LEN];
429 while let Ok(read_bytes) = self.stream.read(&mut request_buf).await {
430 if let Ok(request) = parse_arp_packet(&request_buf[..read_bytes]) {
431 if self.net.contains(&request.target_proto_addr) {
432 let mut arp_response = MutableArpPacket::new(&mut arp_buf).unwrap();
433 arp_response.set_hardware_type(ArpHardwareTypes::Ethernet);
434 arp_response.set_protocol_type(EtherTypes::Ipv4);
435 arp_response.set_hw_addr_len(MAC_ADDR_LEN);
436 arp_response.set_proto_addr_len(IP_V4_LEN);
437 arp_response.set_operation(ArpOperations::Reply);
438
439 arp_response.set_sender_proto_addr(request.target_proto_addr);
440 arp_response.set_sender_hw_addr(self.mac);
441 arp_response.set_target_proto_addr(request.sender_proto_addr);
442 arp_response.set_target_hw_addr(request.sender_hw_addr);
443
444 let mut eth_response = MutableEthernetPacket::new(&mut response_buf)
445 .ok_or("failed to create Ethernet frame")?;
446 eth_response.set_ethertype(EtherTypes::Arp);
447 eth_response.set_destination(request.sender_hw_addr);
448 eth_response.set_source(self.mac);
449 eth_response.set_payload(arp_response.packet());
450
451 self.stream.write_all(eth_response.packet()).await?;
452 }
453 }
454 }
455 Ok(())
456 }
457 }
458
459 #[tokio::test]
460 async fn test_detection() {
461 const INTERFACE_NAME: &str = "dummy0";
462 tokio::spawn(async move {
463 let net = Ipv4Net::new(Ipv4Addr::new(10, 1, 1, 0), 25).unwrap();
464 let mut server = Server::new(INTERFACE_NAME, net).unwrap();
465 server.serve().await.unwrap();
466 });
467 {
468 let client =
469 Arc::new(Client::new(ClientConfigBuilder::new(INTERFACE_NAME).build()).unwrap());
470
471 let sender_mac = datalink::interfaces()
472 .into_iter()
473 .find(|iface| iface.name == INTERFACE_NAME)
474 .ok_or_else(|| format!("interface {} not found", INTERFACE_NAME))
475 .unwrap()
476 .mac
477 .ok_or("interface does not have mac address")
478 .unwrap();
479
480 let future_probes: Vec<_> = (0..128)
481 .map(|ip_d| {
482 let client_clone = client.clone();
483 async move {
484 let builder = ProbeInputBuilder::new()
485 .with_sender_mac(sender_mac)
486 .with_target_ip(Ipv4Addr::new(10, 1, 1, ip_d as u8));
487 client_clone.probe(builder.build().unwrap()).await.unwrap()
488 }
489 })
490 .collect();
491
492 let outcomes = futures::future::join_all(future_probes).await;
493 for outcome in outcomes {
494 assert_eq!(outcome.status, ProbeStatus::Occupied);
495 }
496
497 let future_probes: Vec<_> = (128..=255)
498 .map(|ip_d| {
499 let client_clone = client.clone();
500 async move {
501 let builder = ProbeInputBuilder::new()
502 .with_sender_mac(sender_mac)
503 .with_target_ip(Ipv4Addr::new(10, 1, 1, ip_d as u8));
504 client_clone.probe(builder.build().unwrap()).await.unwrap()
505 }
506 })
507 .collect();
508
509 let outcomes = futures::future::join_all(future_probes).await;
510 for outcome in outcomes {
511 assert_eq!(outcome.status, ProbeStatus::Free);
512 }
513 }
514 }
515}