roameo/
lib.rs

1use clap::{Arg, Command};
2use log::debug;
3use nix::ifaddrs;
4use nix::sys::socket::AddressFamily::Inet;
5use nix::sys::socket::{SockaddrIn, SockaddrLike, SockaddrStorage};
6use std::fmt;
7use std::io::{Error, ErrorKind};
8use std::str::FromStr;
9
10#[cfg(target_os = "linux")]
11mod linux;
12
13#[cfg(not(target_os = "linux"))]
14mod unsupported;
15
16const ANYINTERFACE: &str = "any";
17const EMPTYSTRING: &str = "";
18const MAX_INTERFACE_LENGTH: usize = 16;
19
20/// IPv4NetworkAddress is a struct to hide the munging of addresses
21#[derive(Eq)]
22pub struct IPv4NetworkAddress {
23    addr: u32,
24    mask: u32,
25    network: u32,
26}
27
28impl PartialEq for IPv4NetworkAddress {
29    fn eq(&self, other: &Self) -> bool {
30        self.network == other.network
31    }
32}
33
34impl fmt::Display for IPv4NetworkAddress {
35    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
36        write!(f, "{:08x}/{:08x}", self.addr, self.mask)
37    }
38}
39
40impl IPv4NetworkAddress {
41    pub fn from_sockaddr(
42        addr: &SockaddrStorage,
43        mask: &SockaddrStorage,
44    ) -> Result<IPv4NetworkAddress, Error> {
45        let mut ret_addr = 0;
46        let mut ret_mask = 0;
47        if let Some(in_addr) = addr.as_sockaddr_in() {
48            ret_addr = in_addr.ip();
49        }
50        if let Some(in_mask) = mask.as_sockaddr_in() {
51            ret_mask = in_mask.ip();
52        }
53        debug!("Parsed address to {}", ret_addr);
54        debug!("Parsed netmask to {}", ret_mask);
55        Ok(IPv4NetworkAddress {
56            addr: ret_addr,
57            mask: ret_mask,
58            network: ret_addr & ret_mask,
59        })
60    }
61    pub fn from_cidr(cidr: &str) -> Result<IPv4NetworkAddress, Error> {
62        let mut ret_addr = 0;
63        let mut ret_mask = 0;
64
65        let t = cidr.split_once('/');
66        if let Some((addr, mask)) = t {
67            let octets = addr.splitn(4, '.');
68            for (count, octet) in octets.enumerate() {
69                let bits = 8 * (3 - count);
70                let num: u32 = octet.parse().unwrap();
71                ret_addr += num << bits;
72            }
73
74            debug!("Parsed address {} to {}", addr, ret_addr);
75
76            // This isn't the most efficient, but will work.
77            let mask_bits: u32 = mask.parse().unwrap();
78            // Do 1s first
79            for i in 0..32 {
80                if i >= (32 - mask_bits) {
81                    ret_mask |= 1 << i;
82                }
83            }
84
85            debug!("Parsed netmask {} to {}", mask, ret_mask);
86        }
87        // split cidr once at '/'
88        // split addr by '.' and for each, shift left (ndots * 8) and add
89        Ok(IPv4NetworkAddress {
90            addr: ret_addr,
91            mask: ret_mask,
92            network: ret_addr & ret_mask,
93        })
94    }
95}
96
97// configuration data structure
98pub struct Roameo {
99    interface: String,
100    address: String,
101    essid: String,
102    subnet: String,
103}
104
105// Constructor for configuration struct
106impl Roameo {
107    pub fn new() -> Result<Roameo, &'static str> {
108        let args = Command::new(env!("CARGO_PKG_NAME"))
109            .version(env!("CARGO_PKG_VERSION"))
110            .author(env!("CARGO_PKG_AUTHORS"))
111            .about(env!("CARGO_PKG_DESCRIPTION"))
112            .arg(
113                Arg::new("essid")
114                    .short('e')
115                    .long("essid")
116                    .takes_value(true)
117                    .help("Match wireless ESSID. Eg, 'CorporateWiFi`"),
118            )
119            .arg(
120                Arg::new("address")
121                    .short('a')
122                    .long("address")
123                    .takes_value(true)
124                    .help("Match IP address. Eg, '203.0.113.6'."),
125            )
126            .arg(
127                Arg::new("subnet")
128                    .short('s')
129                    .long("subnet")
130                    .takes_value(true)
131                    .help("Match IP subnet. Eg, '203.0.113.0/24'."),
132            )
133            .arg(
134                Arg::new("interface")
135                    .short('i')
136                    .long("interface")
137                    .takes_value(true)
138                    .help("Network interface to limit match too. Eg, 'eno1'. Defaults to all."),
139            )
140            .get_matches();
141
142        // Return struct
143        Ok(Roameo {
144            interface: args
145                .value_of("interface")
146                .unwrap_or(ANYINTERFACE)
147                .to_string(),
148            address: args.value_of("address").unwrap_or(EMPTYSTRING).to_string(),
149            subnet: args.value_of("subnet").unwrap_or(EMPTYSTRING).to_string(),
150            essid: args.value_of("essid").unwrap_or(EMPTYSTRING).to_string(),
151        })
152    }
153
154    pub fn find_match(&self) -> Result<(), Error> {
155        if self.interface.len() > MAX_INTERFACE_LENGTH {
156            // Bail
157            panic!("Interface name longer than maximum allowed");
158        }
159
160        if self.essid != EMPTYSTRING {
161            return self.match_essid(&self.essid);
162        } else if self.subnet != EMPTYSTRING || self.address != EMPTYSTRING {
163            return self.get_inet_addrs();
164        }
165        Err(Error::new(
166            ErrorKind::InvalidInput,
167            "Invalid command line arguments",
168        ))
169    }
170
171    fn get_inet_addrs(&self) -> Result<(), Error> {
172        /*
173         * Iterate through (select) interface addresses and return inet4/inet6
174         * addresses.
175         */
176        let addrs = ifaddrs::getifaddrs()?;
177        let mut match_addr = self.address.clone();
178
179        if !match_addr.ends_with(":0") {
180            match_addr += ":0";
181        }
182
183        // Convert string IPv4 address if possible. TODO: make sure that the port
184        // is added to the provided address, or from_str() won't parse it.
185        let match_ip4 = SockaddrIn::from_str(&match_addr).unwrap_or_else(|_| {
186            // match against 0.0.0.0:0 instead
187            SockaddrIn::new(0, 0, 0, 0, 0)
188        });
189        // Try IPv6 to in case that's what we were given.
190        /*
191            let match_ip6 = SockaddrIn6::from_str(&match_addr).unwrap_or_else(|_| {
192                SockaddrIn6::from_str(&"::/0:0").unwrap()
193            });
194        */
195
196        // Loop through interface addresses
197        for ifaddr in addrs {
198            // Check to see if we're supposed to limit this to a particular interface
199            if ifaddr.interface_name != self.interface && self.interface != ANYINTERFACE {
200                continue;
201            }
202
203            if let Some(addr) = ifaddr.address {
204                debug!(
205                    "{}: {:?} {:?}",
206                    ifaddr.interface_name,
207                    addr.family(),
208                    addr.to_string()
209                );
210                if addr.family() == Some(Inet) {
211                    let is4 = addr.as_sockaddr_in();
212                    if let Some(a4) = is4 {
213                        debug!(
214                            "{}: IPv4 => {}: Match => {}",
215                            ifaddr.interface_name,
216                            a4.ip(),
217                            match_ip4.ip()
218                        );
219                        // Try to match only if an address was provided.
220                        if self.address != EMPTYSTRING && a4.ip() == match_ip4.ip() {
221                            debug!(
222                                "Found a match for {} on {}",
223                                self.address, ifaddr.interface_name
224                            );
225                            return Ok(());
226                        } else if self.subnet != EMPTYSTRING {
227                            debug!(
228                                "Trying to find a match for subnet {} on {}",
229                                self.subnet, ifaddr.interface_name,
230                            );
231                            if let Some(s) = ifaddr.netmask {
232                                let dest = IPv4NetworkAddress::from_sockaddr(&addr, &s)?;
233                                let src = IPv4NetworkAddress::from_cidr(&self.subnet)?;
234                                if src == dest {
235                                    debug!("Found subnet: {}", dest);
236                                    return Ok(());
237                                }
238                            }
239                        }
240                    }
241                    continue;
242                /*
243                                } else if addr.family() == Some(Inet6) {
244                                    let is6 = addr.as_sockaddr_in6();
245                                    if let Some(a6) = is6 {
246                                        debug!("{}: IPv6 => {}", ifaddr.interface_name, a6.ip());
247                                        if self.address != EMPTYSTRING {
248                                            if a6.ip() == match_ip6.ip() {
249                                                debug!("Found a match for {} on {}", config.address, ifaddr.interface_name);
250                                            }
251                                        }
252                                    }
253                                    continue;
254                */
255                } else {
256                    // Skip all other address types
257                    continue;
258                }
259            }
260
261            if let Some(addr) = ifaddr.netmask {
262                debug!(
263                    "{}: {:?} {:?}",
264                    ifaddr.interface_name,
265                    addr.family(),
266                    addr.to_string()
267                );
268            }
269        }
270
271        Err(Error::new(ErrorKind::Other, "Not found"))
272    }
273}