macos_routing_table/
routing_table.rs

1use crate::{Entity, Protocol, RouteEntry};
2use std::{collections::HashMap, net::IpAddr, process::ExitStatus, string::FromUtf8Error};
3use tokio::process::Command;
4
5const NETSTAT_PATH: &str = "/usr/sbin/netstat";
6
7/// A snapshot of the routing table
8#[derive(Debug)]
9pub struct RoutingTable {
10    routes: Vec<RouteEntry>,
11    /// Map of interfaces to their default routers
12    if_router: HashMap<String, Vec<IpAddr>>,
13}
14
15/// Various errors
16#[derive(Debug, thiserror::Error)]
17pub enum Error {
18    #[error("failed to execute {NETSTAT_PATH}: {0}")]
19    NetstatExec(std::io::Error),
20    #[error("failed to get routing table: {0}")]
21    NetstatFail(ExitStatus),
22    #[error("netstat output not non-UTF-8")]
23    NetstatUtf8(FromUtf8Error),
24    #[error("no headers follow {0:?} section marker")]
25    NetstatParseNoHeaders(String),
26    #[error("parsing route entry: {0}")]
27    RouteEntryParse(#[from] crate::route_entry::Error),
28    #[error("route entry found before protocol (Internet/Internet6) found.")]
29    EntryBeforeProto,
30}
31
32impl RoutingTable {
33    /// Query the routing table using the `netstat` command.
34    ///
35    /// # Errors
36    ///
37    /// Returns an error if the `netstat` command fails to execute, or returns
38    /// unparseable output.
39    pub async fn load_from_netstat() -> Result<Self, Error> {
40        let output = execute_netstat().await?;
41        Self::from_netstat_output(&output)
42    }
43
44    /// Generate a `RoutingTable` from complete netstat output.  The output should
45    /// conform to what would be returned from `netstat -rn` on macOS/Darwin.
46    ///
47    /// # Errors
48    ///
49    /// Returns an error
50    pub fn from_netstat_output(output: &str) -> Result<RoutingTable, Error> {
51        let mut lines = output.lines();
52        let mut headers = vec![];
53        let mut routes = vec![];
54        let mut proto = None;
55        let mut if_router = HashMap::new();
56
57        while let Some(line) = lines.next() {
58            if line.is_empty() || line.starts_with("Routing table") {
59                continue;
60            }
61            match line {
62                section @ ("Internet:" | "Internet6:") => {
63                    proto = match section {
64                        "Internet:" => Some(Protocol::V4),
65                        "Internet6:" => Some(Protocol::V6),
66                        _ => unreachable!(),
67                    };
68                    // Next line will contain the column headers
69                    if let Some(line) = lines.next() {
70                        headers = line.split_ascii_whitespace().collect();
71                    } else {
72                        return Err(Error::NetstatParseNoHeaders(section.into()));
73                    }
74                    continue;
75                }
76                entry => {
77                    if let Some(proto) = proto {
78                        let route = RouteEntry::parse(proto, entry, &headers)?;
79                        if let (Entity::Default, Entity::Cidr(cidr)) =
80                            (&route.dest.entity, &route.gateway.entity)
81                        {
82                            if cidr.is_host_address() {
83                                let route = route.clone();
84                                let gws = if_router.entry(route.net_if).or_insert_with(Vec::new);
85                                // The route parser doesn't produce `Any` CIDRs,
86                                // so there's always a first address.
87                                gws.push(cidr.first_address().unwrap_or_else(|| unreachable!()));
88                            }
89                        }
90                        routes.push(route);
91                    } else {
92                        return Err(Error::EntryBeforeProto);
93                    }
94                }
95            };
96        }
97        Ok(RoutingTable { routes, if_router })
98    }
99
100    /// Find the routing table entry that most-precisely matches the provided
101    /// address.
102    #[must_use]
103    pub fn find_route_entry(&self, addr: IpAddr) -> Option<&RouteEntry> {
104        // TODO: implement a proper lookup table and/or short-circuit on an
105        // exact match
106        self.routes
107            .iter()
108            .filter(|route| route.contains(addr))
109            .fold(None, |old, new| match old {
110                None => Some(new),
111                Some(old) => Some(old.most_precise(new)),
112            })
113    }
114
115    #[must_use]
116    pub fn default_gateways_for_netif(&self, net_if: &str) -> Option<&Vec<IpAddr>> {
117        self.if_router.get(net_if)
118    }
119}
120
121/// Execute `netstat -rn` and return the output
122///
123/// # Errors
124///
125/// Returns an error if command execution fails, or the output is not UTF-8
126pub async fn execute_netstat() -> Result<String, Error> {
127    let output = Command::new(NETSTAT_PATH)
128        .arg("-rn")
129        .stdin(std::process::Stdio::null())
130        .output()
131        .await
132        .map_err(Error::NetstatExec)?;
133    if !output.status.success() {
134        return Err(Error::NetstatFail(output.status));
135    }
136    String::from_utf8(output.stdout).map_err(Error::NetstatUtf8)
137}
138
139#[cfg(test)]
140mod tests {
141    use super::Error;
142    use crate::{Destination, Entity, RoutingTable};
143    use std::{process::ExitStatus, string::FromUtf8Error};
144
145    include!(concat!(env!("OUT_DIR"), "/sample_table.rs"));
146
147    #[tokio::test]
148    async fn coverage() {
149        let rt = RoutingTable::from_netstat_output(SAMPLE_TABLE).expect("parse routing table");
150        let _ = format!("{rt:?}");
151        let _ = format!(
152            "{:?}",
153            Error::NetstatExec(std::io::Error::from_raw_os_error(1))
154        );
155        let _ = format!("{:?}", Error::NetstatFail(ExitStatus::default()));
156        // This error is reachable only if the netstat command outputs invalid
157        // UTF-8.
158        let from_utf8err = String::from_utf8([0xa0, 0xa1].to_vec()).unwrap_err();
159        let _ = format!("{:?}", Error::NetstatUtf8(from_utf8err));
160    }
161
162    #[tokio::test]
163    #[cfg(target_os = "macos")]
164    async fn live_test() {
165        let _routing_table = RoutingTable::load_from_netstat()
166            .await
167            .expect("parse live routing table");
168    }
169
170    #[test]
171    fn good_table() {
172        let rt = RoutingTable::from_netstat_output(SAMPLE_TABLE).expect("parse routing table");
173        let entry = rt.find_route_entry("1.1.1.1".parse().unwrap());
174        dbg!(&entry);
175        assert!(entry.is_some());
176        let entry = entry.unwrap();
177        assert!(matches!(
178            entry.dest,
179            Destination {
180                entity: Entity::Default,
181                zone: None
182            }
183        ));
184        // Coverage of debug formatting
185        let _ = format!("{rt:?}");
186    }
187
188    #[test]
189    fn missing_headers() {
190        for section in ["", "6"] {
191            let input = format!("{SAMPLE_TABLE}Internet{section}:\n");
192            let result = RoutingTable::from_netstat_output(&input);
193            assert!(matches!(result, Err(Error::NetstatParseNoHeaders(_))));
194            // Coverage of debug formatting
195            let _ = format!("{:?}", result.unwrap_err());
196        }
197    }
198
199    #[test]
200    fn stray_entry() {
201        let input = format!("extra stuff\n{SAMPLE_TABLE}");
202        let result = RoutingTable::from_netstat_output(&input);
203        assert!(matches!(result, Err(Error::EntryBeforeProto)));
204        // Coverage of debug formatting
205        let _ = format!("{:?}", result.unwrap_err());
206    }
207
208    #[test]
209    fn bad_entry() {
210        let input = format!("{SAMPLE_TABLE}How now brown cow.\n");
211        let result = RoutingTable::from_netstat_output(&input);
212        dbg!(&result);
213        assert!(matches!(
214            result,
215            Err(Error::RouteEntryParse(
216                crate::route_entry::Error::ParseIPv4AddrBadInt {
217                    addr: _,
218                    err: std::num::ParseIntError { .. },
219                }
220            ))
221        ));
222        // Coverage of debug formatting
223        let _ = format!("{:?}", result.unwrap_err());
224    }
225}