cross_socket/socket/
mod.rs

1mod shared;
2pub(crate) use shared::*;
3
4#[cfg(not(target_os = "windows"))]
5mod unix;
6#[cfg(not(target_os = "windows"))]
7pub use unix::*;
8
9#[cfg(target_os = "windows")]
10mod windows;
11#[cfg(target_os = "windows")]
12pub use windows::*;
13
14use async_io::Async;
15use socket2::{Domain, SockAddr, Socket as SystemSocket, Type};
16use std::io;
17use std::mem::MaybeUninit;
18use std::net::{SocketAddr, Shutdown};
19use std::sync::Arc;
20use std::time::Duration;
21
22use crate::packet::builder::PacketBuildOption;
23use crate::packet::ip::IpNextLevelProtocol;
24
25/// IP version. IPv4 or IPv6
26#[derive(Clone, Debug)]
27pub enum IpVersion {
28    V4,
29    V6,
30}
31
32impl IpVersion {
33    /// IP Version number as u8
34    pub fn version_u8(&self) -> u8 {
35        match self {
36            IpVersion::V4 => 4,
37            IpVersion::V6 => 6,
38        }
39    }
40    /// Return true if IP version is IPv4
41    pub fn is_ipv4(&self) -> bool {
42        match self {
43            IpVersion::V4 => true,
44            IpVersion::V6 => false,
45        }
46    }
47    /// Return true if IP version is IPv6
48    pub fn is_ipv6(&self) -> bool {
49        match self {
50            IpVersion::V4 => false,
51            IpVersion::V6 => true,
52        }
53    }
54    pub(crate) fn to_domain(&self) -> Domain {
55        match self {
56            IpVersion::V4 => Domain::IPV4,
57            IpVersion::V6 => Domain::IPV6,
58        }
59    }
60}
61
62/// Socket type
63#[derive(Clone, Debug)]
64pub enum SocketType {
65    /// Raw socket
66    Raw,
67    /// Datagram socket. Usualy used for UDP.
68    Dgram,
69    /// Stream socket. Used for TCP.
70    Stream,
71}
72
73impl SocketType {
74    pub(crate) fn to_type(&self) -> Type {
75        match self {
76            SocketType::Raw => Type::RAW,
77            SocketType::Dgram => Type::DGRAM,
78            SocketType::Stream => Type::STREAM,
79        }
80    }
81}
82
83/// Socket option
84#[derive(Clone, Debug)]
85pub struct SocketOption {
86    /// IP version
87    pub ip_version: IpVersion,
88    /// Socket type
89    pub socket_type: SocketType,
90    /// Protocol. TCP, UDP, ICMP, etc.
91    pub protocol: Option<IpNextLevelProtocol>,
92    /// Timeout
93    pub timeout: Option<u64>,
94    /// TTL or Hop Limit
95    pub ttl: Option<u32>,
96    /// Non-blocking mode
97    pub non_blocking: bool,
98}
99
100impl SocketOption {
101    /// Constructs a new SocketOption
102    pub fn new(
103        ip_version: IpVersion,
104        socket_type: SocketType,
105        protocol: Option<IpNextLevelProtocol>,
106    ) -> SocketOption {
107        SocketOption {
108            ip_version,
109            socket_type,
110            protocol,
111            timeout: None,
112            ttl: None,
113            non_blocking: false,
114        }
115    }
116}
117
118/// Async socket. Provides cross-platform async adapter for system’s socket.
119#[derive(Clone, Debug)]
120pub struct AsyncSocket {
121    inner: Arc<Async<SystemSocket>>,
122}
123
124impl AsyncSocket {
125    /// Constructs a new AsyncSocket
126    pub fn new(socket_option: SocketOption) -> io::Result<AsyncSocket> {
127        match check_socket_option(socket_option.clone()) {
128            Ok(_) => (),
129            Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
130        }
131        let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
132            SystemSocket::new(
133                socket_option.ip_version.to_domain(),
134                socket_option.socket_type.to_type(),
135                Some(protocol.to_socket_protocol()),
136            )?
137        } else {
138            SystemSocket::new(
139                socket_option.ip_version.to_domain(),
140                socket_option.socket_type.to_type(),
141                None,
142            )?
143        };
144        socket.set_nonblocking(true)?;
145        Ok(AsyncSocket {
146            inner: Arc::new(Async::new(socket)?),
147        })
148    }
149    /// Send packet
150    pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
151        loop {
152            self.inner.writable().await?;
153            match self.inner.write_with(|inner| inner.send(buf)).await {
154                Ok(n) => return Ok(n),
155                Err(_) => continue,
156            }
157        }
158    }
159    /// Send packet to target
160    pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
161        let target: SockAddr = SockAddr::from(target);
162        loop {
163            self.inner.writable().await?;
164            match self
165                .inner
166                .write_with(|inner| inner.send_to(buf, &target))
167                .await
168            {
169                Ok(n) => return Ok(n),
170                Err(_) => continue,
171            }
172        }
173    }
174    /// Receive packet
175    pub async fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
176        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
177        loop {
178            self.inner.readable().await?;
179            match self.inner.read_with(|inner| inner.recv(recv_buf)).await {
180                Ok(result) => return Ok(result),
181                Err(_) => continue,
182            }
183        }
184    }
185    /// Receive packet with sender address
186    pub async fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
187        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
188        loop {
189            self.inner.readable().await?;
190            match self
191                .inner
192                .read_with(|inner| inner.recv_from(recv_buf))
193                .await
194            {
195                Ok(result) => {
196                    let (n, addr) = result;
197                    match addr.as_socket() {
198                        Some(addr) => return Ok((n, addr)),
199                        None => continue,
200                    }
201                }
202                Err(_) => continue,
203            }
204        }
205    }
206    /// Bind socket to address
207    pub async fn bind(&self, addr: SocketAddr) -> io::Result<()> {
208        let addr: SockAddr = SockAddr::from(addr);
209        self.inner.writable().await?;
210        self.inner.write_with(|inner| inner.bind(&addr)).await
211    }
212    /// Set receive timeout
213    pub async fn set_receive_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
214        self.inner.writable().await?;
215        self.inner
216            .write_with(|inner| inner.set_read_timeout(timeout))
217            .await
218    }
219    /// Set TTL or Hop Limit
220    pub async fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
221        self.inner.writable().await?;
222        match ip_version {
223            IpVersion::V4 => self.inner.write_with(|inner| inner.set_ttl(ttl)).await,
224            IpVersion::V6 => {
225                self.inner
226                    .write_with(|inner| inner.set_unicast_hops_v6(ttl))
227                    .await
228            }
229        }
230    }
231    /// Initiate TCP connection
232    pub async fn connect(&self, addr: SocketAddr) -> io::Result<()> {
233        let addr: SockAddr = SockAddr::from(addr);
234        self.inner.writable().await?;
235        self.inner.write_with(|inner| inner.connect(&addr)).await
236    }
237    /// Shutdown TCP connection
238    pub async fn shutdown(&self, how: Shutdown) -> io::Result<()> {
239        self.inner.writable().await?;
240        self.inner.write_with(|inner| inner.shutdown(how)).await
241    }
242    /// Listen TCP connection
243    pub async fn listen(&self, backlog: i32) -> io::Result<()> {
244        self.inner.writable().await?;
245        self.inner.write_with(|inner| inner.listen(backlog)).await
246    }
247    /// Accept TCP connection
248    pub async fn accept(&self) -> io::Result<(AsyncSocket, SocketAddr)> {
249        self.inner.readable().await?;
250        match self.inner.read_with(|inner| inner.accept()).await {
251            Ok((socket, addr)) => {
252                let socket = AsyncSocket {
253                    inner: Arc::new(Async::new(socket)?),
254                };
255                Ok((socket, addr.as_socket().unwrap()))
256            }
257            Err(e) => Err(e),
258        }
259    }
260    /// Get peer address
261    pub async fn peer_addr(&self) -> io::Result<SocketAddr> {
262        self.inner.writable().await?;
263        match self.inner.read_with(|inner| inner.peer_addr()).await {
264            Ok(addr) => Ok(addr.as_socket().unwrap()),
265            Err(e) => Err(e),
266        }
267    }
268    /// Get local address
269    pub async fn local_addr(&self) -> io::Result<SocketAddr> {
270        self.inner.writable().await?;
271        match self.inner.read_with(|inner| inner.local_addr()).await {
272            Ok(addr) => Ok(addr.as_socket().unwrap()),
273            Err(e) => Err(e),
274        }
275    }
276}
277
278/// Socket. Provides cross-platform adapter for system’s socket.
279#[derive(Clone, Debug)]
280pub struct Socket {
281    inner: Arc<SystemSocket>,
282}
283
284impl Socket {
285    /// Constructs a new Socket
286    pub fn new(socket_option: SocketOption) -> io::Result<Socket> {
287        match check_socket_option(socket_option.clone()) {
288            Ok(_) => (),
289            Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
290        }
291        let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
292            SystemSocket::new(
293                socket_option.ip_version.to_domain(),
294                socket_option.socket_type.to_type(),
295                Some(protocol.to_socket_protocol()),
296            )?
297        } else {
298            SystemSocket::new(
299                socket_option.ip_version.to_domain(),
300                socket_option.socket_type.to_type(),
301                None,
302            )?
303        };
304        if socket_option.non_blocking {
305            socket.set_nonblocking(true)?;
306        }
307        Ok(Socket {
308            inner: Arc::new(socket),
309        })
310    }
311    /// Send packet to target
312    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
313        let target: SockAddr = SockAddr::from(target);
314        match self.inner.send_to(buf, &target) {
315            Ok(n) => Ok(n),
316            Err(e) => Err(e),
317        }
318    }
319    /// Receive packet
320    pub fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
321        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
322        match self.inner.recv(recv_buf) {
323            Ok(result) => Ok(result),
324            Err(e) => Err(e),
325        }
326    }
327    /// Receive packet with sender address
328    pub fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
329        let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
330        match self.inner.recv_from(recv_buf) {
331            Ok(result) => {
332                let (n, addr) = result;
333                match addr.as_socket() {
334                    Some(addr) => return Ok((n, addr)),
335                    None => {
336                        return Err(io::Error::new(
337                            io::ErrorKind::Other,
338                            "Invalid socket address",
339                        ))
340                    }
341                }
342            }
343            Err(e) => Err(e),
344        }
345    }
346    /// Bind socket to address
347    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
348        let addr: SockAddr = SockAddr::from(addr);
349        self.inner.bind(&addr)
350    }
351    /// Set receive timeout
352    pub fn set_receive_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
353        self.inner.set_read_timeout(timeout)
354    }
355    /// Set TTL or Hop Limit
356    pub fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
357        match ip_version {
358            IpVersion::V4 => self.inner.set_ttl(ttl),
359            IpVersion::V6 => self.inner.set_unicast_hops_v6(ttl),
360        }
361    }
362    /// Initiate TCP connection
363    pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
364        let addr: SockAddr = SockAddr::from(addr);
365        self.inner.connect(&addr)
366    }
367    /// Shutdown TCP connection
368    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
369        self.inner.shutdown(how)
370    }
371    /// Listen TCP connection
372    pub fn listen(&self, backlog: i32) -> io::Result<()> {
373        self.inner.listen(backlog)
374    }
375    /// Accept TCP connection
376    pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
377        match self.inner.accept() {
378            Ok((socket, addr)) => Ok((Socket { inner: Arc::new(socket) }, addr.as_socket().unwrap())),
379            Err(e) => Err(e),
380        }
381    }
382    /// Get peer address
383    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
384        match self.inner.peer_addr() {
385            Ok(addr) => Ok(addr.as_socket().unwrap()),
386            Err(e) => Err(e),
387        }
388    }
389    /// Get local address
390    pub fn local_addr(&self) -> io::Result<SocketAddr> {
391        match self.inner.local_addr() {
392            Ok(addr) => Ok(addr.as_socket().unwrap()),
393            Err(e) => Err(e),
394        }
395    }
396}
397
398/// Cross-platform raw socket.
399/// Enables to send and receive packets with custom headers.
400pub struct DataLinkSocket {
401    pub interface: crate::datalink::interface::Interface,
402    sender: Box<dyn pnet::datalink::DataLinkSender>,
403    receiver: Box<dyn pnet::datalink::DataLinkReceiver>,
404}
405
406impl DataLinkSocket {
407    /// Constructs a new DataLinkSocket
408    pub fn new(
409        interface: crate::datalink::interface::Interface,
410        promiscuous: bool,
411    ) -> io::Result<DataLinkSocket> {
412        let interfaces = pnet::datalink::interfaces();
413        let network_interface = match interfaces
414            .into_iter()
415            .filter(|network_interface: &pnet::datalink::NetworkInterface| {
416                network_interface.index == interface.index
417            })
418            .next()
419        {
420            Some(network_interface) => network_interface,
421            None => {
422                return Err(io::Error::new(
423                    io::ErrorKind::Other,
424                    "Network Interface not found",
425                ))
426            }
427        };
428        let config = pnet::datalink::Config {
429            write_buffer_size: 4096,
430            read_buffer_size: 4096,
431            read_timeout: None,
432            write_timeout: None,
433            channel_type: pnet::datalink::ChannelType::Layer2,
434            bpf_fd_attempts: 1000,
435            linux_fanout: None,
436            promiscuous: promiscuous,
437        };
438        let (tx, rx) = match pnet::datalink::channel(&network_interface, config) {
439            Ok(pnet::datalink::Channel::Ethernet(sender, receiver)) => (sender, receiver),
440            Ok(_) => {
441                return Err(io::Error::new(
442                    io::ErrorKind::Other,
443                    "Not an Ethernet interface",
444                ))
445            }
446            Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
447        };
448        Ok(DataLinkSocket {
449            interface: interface,
450            sender: tx,
451            receiver: rx,
452        })
453    }
454    /// Build packet from PacketBuildOption and send it
455    pub fn send(&mut self, packet_builder: PacketBuildOption) -> io::Result<usize> {
456        build_and_send_packet(&mut self.sender, packet_builder)
457    }
458    /// Send packet
459    pub fn send_to(&mut self, buf: &[u8]) -> io::Result<usize> {
460        match self.sender.send_to(buf, None) {
461            Some(res) => match res {
462                Ok(_) => return Ok(buf.len()),
463                Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
464            },
465            None => Err(io::Error::new(
466                io::ErrorKind::Other,
467                "Failed to send packet",
468            )),
469        }
470    }
471    /// Build and send packet. This is useful when you want to send packet with custom build function.
472    pub fn build_and_send(
473        &mut self,
474        num_packets: usize,
475        packet_size: usize,
476        func: &mut dyn FnMut(&mut [u8]),
477    ) -> io::Result<()> {
478        match self.sender.build_and_send(num_packets, packet_size, func) {
479            Some(res) => match res {
480                Ok(_) => return Ok(()),
481                Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
482            },
483            None => Err(io::Error::new(
484                io::ErrorKind::Other,
485                "Failed to send packet",
486            )),
487        }
488    }
489    /// Receive packet
490    pub fn receive(&mut self) -> io::Result<&[u8]> {
491        match self.receiver.next() {
492            Ok(packet) => Ok(packet),
493            Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
494        }
495    }
496}