#![cfg(target_os = "linux")]
use std::fs;
use std::fs::Permissions;
use std::net::IpAddr;
use std::os::fd::{AsFd, AsRawFd, FromRawFd};
use std::os::unix::fs::PermissionsExt;
use std::str::FromStr;
use cidr::IpCidr;
use crate::{run_command, TproxyArgs, TproxyState, ETC_RESOLV_CONF_FILE};
fn bytes_to_string(bytes: Vec<u8>) -> std::io::Result<String> {
match String::from_utf8(bytes) {
Ok(content) => Ok(content),
Err(e) => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("error converting bytes to string: {}", e),
)),
}
}
fn bytes_to_lines(bytes: Vec<u8>) -> std::io::Result<Vec<String>> {
let content = bytes_to_string(bytes)?;
let lines: Vec<String> = content.lines().map(|s| s.to_string()).collect();
Ok(lines)
}
fn route_exists(route: &str) -> std::io::Result<bool> {
Ok(!bytes_to_string(run_command("ip", &["route", "show", route])?)?.trim().is_empty())
}
fn create_cidr(addr: IpAddr, len: u8) -> std::io::Result<IpCidr> {
match IpCidr::new(addr, len) {
Ok(cidr) => Ok(cidr),
Err(_) => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("failed to convert {}/{} to CIDR", addr, len),
)),
}
}
fn write_buffer_to_fd(fd: std::os::fd::BorrowedFd<'_>, data: &[u8]) -> std::io::Result<()> {
let mut written = 0;
loop {
if written >= data.len() {
break;
}
written += nix::unistd::write(fd, &data[written..])?;
}
Ok(())
}
fn write_nameserver(fd: std::os::fd::BorrowedFd<'_>, tun_gateway: Option<IpAddr>) -> std::io::Result<()> {
let tun_gateway = match tun_gateway {
Some(gw) => gw,
None => "198.18.0.1".parse().unwrap(),
};
let data = format!("nameserver {}\n", tun_gateway);
nix::sys::stat::fchmod(fd.as_raw_fd(), nix::sys::stat::Mode::from_bits(0o444).unwrap())?;
write_buffer_to_fd(fd, data.as_bytes())?;
Ok(())
}
fn setup_resolv_conf(restore: &mut TproxyState) -> std::io::Result<()> {
let tun_gateway = restore.tproxy_args.as_ref().map(|args| args.tun_gateway);
let file = tempfile::Builder::new()
.permissions(Permissions::from_mode(0o644))
.rand_bytes(32)
.tempfile()?;
write_nameserver(file.as_fd(), tun_gateway)?;
let source = format!("/proc/self/fd/{}", file.as_raw_fd());
let flags = nix::mount::MsFlags::MS_BIND;
let mount1 = nix::mount::mount(source.as_str().into(), ETC_RESOLV_CONF_FILE, "".into(), flags, "".into());
if mount1.is_ok() {
restore.umount_resolvconf = true;
let flags = nix::mount::MsFlags::MS_REMOUNT | nix::mount::MsFlags::MS_RDONLY | nix::mount::MsFlags::MS_BIND;
if nix::mount::mount("".into(), ETC_RESOLV_CONF_FILE, "".into(), flags, "".into()).is_err() {
#[cfg(feature = "log")]
log::warn!("failed to remount /etc/resolv.conf as readonly");
}
}
drop(file);
if mount1.is_err() {
#[cfg(feature = "log")]
log::warn!("failed to bind mount custom resolv.conf onto /etc/resolv.conf, resorting to direct write");
restore.restore_resolvconf_content = Some(fs::read(ETC_RESOLV_CONF_FILE)?);
let flags = nix::fcntl::OFlag::O_WRONLY | nix::fcntl::OFlag::O_CLOEXEC | nix::fcntl::OFlag::O_TRUNC;
let fd = nix::fcntl::open(ETC_RESOLV_CONF_FILE, flags, nix::sys::stat::Mode::from_bits(0o644).unwrap())?;
let fd = unsafe { std::os::unix::io::OwnedFd::from_raw_fd(fd) };
write_nameserver(fd.as_fd(), tun_gateway)?;
}
Ok(())
}
fn route_show(is_ipv6: bool) -> std::io::Result<Vec<(IpCidr, Vec<String>)>> {
let route_show_args = if is_ipv6 {
["-6", "route", "show"]
} else {
["-4", "route", "show"]
};
let routes = bytes_to_lines(run_command("ip", &route_show_args)?)?;
let mut route_info = Vec::<(IpCidr, Vec<String>)>::new();
for line in routes {
if line.starts_with([' ', '\t']) {
continue;
}
let mut split = line.split_whitespace();
let mut dst_str = split.next().unwrap();
if dst_str == "default" {
dst_str = if is_ipv6 { "::/0" } else { "0.0.0.0/0" }
}
let (addr_str, prefix_len_str) = match dst_str.split_once(['/']) {
None => (dst_str, if is_ipv6 { "128" } else { "32" }),
Some((addr_str, prefix_len_str)) => (addr_str, prefix_len_str),
};
let cidr: IpCidr = create_cidr(IpAddr::from_str(addr_str).unwrap(), u8::from_str(prefix_len_str).unwrap())?;
let route_components: Vec<String> = split.map(String::from).collect();
route_info.push((cidr, route_components))
}
Ok(route_info)
}
fn bypass_ip(ip: &IpAddr) -> std::io::Result<bool> {
let mut route_info = route_show(ip.is_ipv6())?;
let cidr = create_cidr(*ip, if ip.is_ipv6() { 128 } else { 32 })?;
route_info.sort_by(|entry1, entry2| entry2.0.network_length().cmp(&entry1.0.network_length()));
for (route_cidr, route_components) in route_info {
if !route_cidr.contains(&cidr.first_address()) || !route_cidr.contains(&cidr.last_address()) {
continue;
}
if route_cidr.network_length() != 0 {
break;
}
let mut proxy_route = vec!["route".into(), "add".into()];
proxy_route.push(cidr.to_string());
proxy_route.extend(route_components.into_iter());
run_command("ip", &proxy_route.iter().map(|s| s.as_str()).collect::<Vec<&str>>())?;
return Ok(true);
}
Ok(false)
}
pub fn get_restore_components(route: &str) -> std::io::Result<Option<Vec<String>>> {
let cidr_all = IpCidr::from_str(route).unwrap();
let routes = route_show(cidr_all.is_ipv6())?;
let default_route = routes.iter().find(|(cidr, _)| cidr == &cidr_all);
match default_route {
None => Ok(None),
Some((_, components)) => {
let mut vec = Vec::new();
vec.push(String::from(route));
vec.extend(components.clone());
Ok(Some(vec))
}
}
}
pub fn restore_route(route_components: &[String]) -> std::io::Result<()> {
let mut args = Vec::new();
args.push("route");
args.push("add");
args.extend(route_components.iter().map(|x| x.as_str()));
run_command("ip", args.as_slice())?;
Ok(())
}
pub fn tproxy_setup(tproxy_args: &TproxyArgs) -> std::io::Result<TproxyState> {
let tun_name = &tproxy_args.tun_name;
let targs = Some(tproxy_args.clone());
let mut state = TproxyState {
tproxy_args: targs,
original_dns_servers: None,
gateway: None,
gw_scope: None,
umount_resolvconf: false,
restore_resolvconf_content: None,
tproxy_removed_done: false,
restore_ipv4_route: None,
restore_ipv6_route: None,
};
let args = &["link", "set", tun_name, "up"];
run_command("ip", args)?;
for ip in tproxy_args.bypass_ips.iter() {
bypass_ip(ip)?;
}
if tproxy_args.bypass_ips.is_empty() && !crate::is_private_ip(tproxy_args.proxy_addr.ip()) {
bypass_ip(&tproxy_args.proxy_addr.ip())?;
}
if tproxy_args.ipv4_default_route {
if !route_exists("0.0.0.0/0")? {
let args = &["route", "add", "0.0.0.0/0", "dev", tun_name];
run_command("ip", args)?;
} else {
let args = &["route", "add", "128.0.0.0/1", "dev", tun_name];
run_command("ip", args)?;
let args = &["route", "add", "0.0.0.0/1", "dev", tun_name];
run_command("ip", args)?;
}
} else {
state.restore_ipv4_route = get_restore_components("0.0.0.0/0")?;
#[cfg(feature = "log")]
log::debug!("restore ipv4 route: {:?}", state.restore_ipv4_route);
if let Err(_err) = run_command("ip", &["route", "del", "0.0.0.0/0"]) {
#[cfg(feature = "log")]
log::debug!("command \"ip route del 0.0.0.0/0\" error: {}", _err);
}
}
if tproxy_args.ipv6_default_route {
if !route_exists("::/0")? {
let args = &["route", "add", "::/0", "dev", tun_name];
run_command("ip", args)?;
} else {
let args = &["route", "add", "::/1", "dev", tun_name];
run_command("ip", args)?;
let args = &["route", "add", "8000::/1", "dev", tun_name];
run_command("ip", args)?;
}
} else {
state.restore_ipv6_route = get_restore_components("::/0")?;
#[cfg(feature = "log")]
log::debug!("restore ipv6 route: {:?}", state.restore_ipv6_route);
if let Err(_err) = run_command("ip", &["route", "del", "::/0"]) {
#[cfg(feature = "log")]
log::debug!("command \"ip route del ::/0\" error: {}", _err);
}
}
setup_resolv_conf(&mut state)?;
crate::store_intermediate_state(&state)?;
Ok(state)
}
impl Drop for TproxyState {
fn drop(&mut self) {
#[cfg(feature = "log")]
log::debug!("restoring network settings");
if let Err(_e) = _tproxy_remove(self) {
#[cfg(feature = "log")]
log::error!("failed to restore network settings: {}", _e);
}
}
}
pub fn tproxy_remove(state: Option<TproxyState>) -> std::io::Result<()> {
let mut state = match state {
Some(state) => state,
None => crate::retrieve_intermediate_state()?,
};
_tproxy_remove(&mut state)
}
pub(crate) fn _tproxy_remove(state: &mut TproxyState) -> std::io::Result<()> {
if state.tproxy_removed_done {
return Ok(());
}
state.tproxy_removed_done = true;
let err = std::io::Error::new(std::io::ErrorKind::InvalidData, "tproxy_args is None");
let tproxy_args = state.tproxy_args.as_ref().ok_or(err)?;
for bypass_ip in tproxy_args.bypass_ips.iter() {
let args = &["route", "del", &bypass_ip.to_string()];
if let Err(_err) = run_command("ip", args) {
#[cfg(feature = "log")]
log::debug!("command \"ip route del {}\" error: {}", bypass_ip, _err);
}
}
if tproxy_args.bypass_ips.is_empty() && !crate::is_private_ip(tproxy_args.proxy_addr.ip()) {
let bypass_ip = tproxy_args.proxy_addr.ip();
let args = &["route", "del", &bypass_ip.to_string()];
if let Err(_err) = run_command("ip", args) {
#[cfg(feature = "log")]
log::debug!("command \"ip route del {}\" error: {}", bypass_ip, _err);
}
}
if let Some(components) = &state.restore_ipv4_route {
#[cfg(feature = "log")]
log::debug!("restore route: {:?}", components);
restore_route(components.as_slice())?;
}
if let Some(components) = &state.restore_ipv6_route {
#[cfg(feature = "log")]
log::debug!("restore route: {:?}", components);
restore_route(components.as_slice())?;
}
let args = &["link", "del", &tproxy_args.tun_name];
if let Err(_err) = run_command("ip", args) {
#[cfg(feature = "log")]
log::debug!("command \"ip {:?}\" error: {}", args, _err);
}
if state.umount_resolvconf {
nix::mount::umount(ETC_RESOLV_CONF_FILE)?;
}
if let Some(data) = &state.restore_resolvconf_content {
fs::write(ETC_RESOLV_CONF_FILE, data)?;
}
Ok(())
}
#[allow(dead_code)]
pub(crate) fn get_default_gateway() -> std::io::Result<(IpAddr, String)> {
let cmd = "ip route | grep default | awk '{print $3}'";
let out = run_command("sh", &["-c", cmd])?;
let stdout = String::from_utf8_lossy(&out).into_owned();
let addr = IpAddr::from_str(stdout.trim()).map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
let cmd = "ip route | grep default | awk '{print $5}'";
let out = run_command("sh", &["-c", cmd])?;
let stdout = String::from_utf8_lossy(&out).into_owned();
let iface = stdout.trim().to_string();
Ok((addr, iface))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_default_gateway() {
let (addr, iface) = get_default_gateway().unwrap();
println!("addr: {:?}, iface: {}", addr, iface);
}
#[test]
fn test_bypass_ip() {
let ip = "123.45.67.89".parse().unwrap();
let res = bypass_ip(&ip);
println!("bypass_ip: {:?}", res);
}
}