wifi-manager 0.1.4

A cross-platform Wi-Fi management library for Rust, supporting Linux and Windows.
Documentation
use std::{
    collections::HashMap,
    hash::{BuildHasherDefault, DefaultHasher},
    num::{ParseFloatError, ParseIntError},
    ops::Deref,
    sync::Mutex,
    time::Instant,
};

use log::*;
use thiserror::Error;
use tokio::process::Command;

use crate::os::{ID, OS};

mod info;
mod utils;

#[cfg_attr(windows, path = "os/windows/mod.rs")]
#[cfg_attr(target_os = "linux", path = "os/linux/mod.rs")]
mod os;

pub use info::*;

static MAX_BANDWIDTH: Mutex<
    HashMap<ID, HashMap<usize, BandWidth>, BuildHasherDefault<DefaultHasher>>,
> = Mutex::new(HashMap::with_hasher(BuildHasherDefault::new()));

#[derive(Error, Debug)]
pub enum WiFiError {
    #[error("System error: {0}")]
    System(String),
    #[error("Not support: {0}")]
    NotSupport(String),
}

impl WiFiError {
    fn new_system<E: Deref<Target = str>>(e: E) -> Self {
        WiFiError::System(e.to_string())
    }
}

impl From<std::io::Error> for WiFiError {
    fn from(e: std::io::Error) -> Self {
        WiFiError::new_system(e.to_string())
    }
}

impl From<ParseIntError> for WiFiError {
    fn from(value: ParseIntError) -> Self {
        WiFiError::new_system(format!("ParseIntError: {}", value))
    }
}

impl From<ParseFloatError> for WiFiError {
    fn from(value: ParseFloatError) -> Self {
        WiFiError::new_system(format!("ParseFloatError: {}", value))
    }
}

pub type WiFiResult<T = ()> = std::result::Result<T, WiFiError>;

#[derive(Debug, Clone)]
pub struct Interface {
    pub id: ID,
    pub support_mode: Vec<Mode>,
}

impl Interface {
    pub async fn set_mode(&self, mode: Mode) -> WiFiResult {
        let start = Instant::now();
        OS::set_mode(&self.id, mode).await?;
        debug!(
            "Set mode for interface [{}] to {:?} took {:?}",
            self.id,
            mode,
            start.elapsed()
        );
        Ok(())
    }

    pub async fn set_channel(
        &self,
        channel: usize,
        band_width: Option<BandWidth>,
        second: Option<SecondChannel>,
    ) -> WiFiResult {
        if let Err(e) = self.try_set_chennel(channel, band_width, second).await {
            warn!(
                "interface `{}` set channel {channel} {band_width:?} fail, try downcast, err: {}",
                self.id, e
            );
            downcast_channel_max_bandwidth(&self.id, channel);
            self.try_set_chennel(channel, None, None).await
        } else {
            Ok(())
        }
    }
    pub async fn set_frequency(
        &self,
        freq_mhz: usize,
        band_width: Option<BandWidth>,
        second: Option<SecondChannel>,
    ) -> WiFiResult {
        let channel = freq_mhz_to_channel(freq_mhz);
        self.set_channel(channel, band_width, second).await
    }

    async fn try_set_chennel(
        &self,
        channel: usize,
        chennel: Option<BandWidth>,
        second: Option<SecondChannel>,
    ) -> WiFiResult {
        let start = Instant::now();
        let band_width = adapt_channel_max_bandwidth(&self.id, channel, chennel, second);
        OS::set_channel(&self.id, channel, band_width).await?;
        let band_width_str = band_width
            .map(|bw| format!(" bandwidth {}", bw))
            .unwrap_or_default();

        debug!(
            "Set interface [{}] to channel {channel} {band_width_str} took {:?}",
            self.id,
            start.elapsed()
        );
        Ok(())
    }

    pub async fn ifup(&self) -> WiFiResult {
        OS::ifup(&self.id).await
    }
    pub async fn ifdown(&self) -> WiFiResult {
        OS::ifdown(&self.id).await
    }

    pub async fn get_mode(&self) -> WiFiResult<Mode> {
        OS::get_mode(&self.id).await
    }

    pub async fn is_ifup(&self) -> WiFiResult<bool> {
        OS::is_ifup(&self.id).await
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Mode {
    Managed,
    Monitor,
}
impl Mode {
    fn cmd(&self) -> &str {
        match self {
            Mode::Monitor => "monitor",
            Mode::Managed => "managed",
        }
    }
}

impl TryFrom<&str> for Mode {
    type Error = ();

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        match value.trim() {
            "managed" => Ok(Mode::Managed),
            "monitor" => Ok(Mode::Monitor),
            _ => Err(()),
        }
    }
}

impl TryFrom<String> for Mode {
    type Error = ();

    fn try_from(value: String) -> Result<Self, Self::Error> {
        value.as_str().try_into()
    }
}

trait Impl {
    async fn check_environment() -> WiFiResult;
    async fn interface_list() -> Result<Vec<Interface>, WiFiError>;
    async fn set_mode(id: &ID, mode: Mode) -> WiFiResult;
    async fn get_mode(id: &ID) -> WiFiResult<Mode>;
    async fn set_channel(id: &ID, channel: usize, band_width: Option<BandWidthArg>) -> WiFiResult;
    async fn ifup(id: &ID) -> WiFiResult;
    async fn ifdown(id: &ID) -> WiFiResult;
    async fn is_ifup(id: &ID) -> WiFiResult<bool>;
    async fn freq_max_bandwidth(id: &ID) -> WiFiResult<HashMap<usize, BandWidth>>;
}

pub async fn check_environment() -> WiFiResult {
    OS::check_environment().await
}

pub async fn interface_list() -> Result<Vec<Interface>, WiFiError> {
    let mut out = vec![];
    for one in OS::interface_list().await? {
        let id = one.id.clone();
        out.push(one);
        #[allow(clippy::map_entry)]
        if !MAX_BANDWIDTH.lock().unwrap().contains_key(&id) {
            let mut map = HashMap::new();
            let max_bandwidth = OS::freq_max_bandwidth(&id).await?;
            for (freq, bandwidth) in max_bandwidth {
                let channel = freq_mhz_to_channel(freq);
                map.insert(channel, bandwidth);
            }
            MAX_BANDWIDTH.lock().unwrap().insert(id, map);
        }
    }

    Ok(out)
}

#[allow(unused)]
async fn check_command(cmd: &str) -> WiFiResult {
    Command::new(cmd)
        .arg("--help")
        .output()
        .await
        .map_err(|e| WiFiError::NotSupport(format!("command [{}] fail: {:?}", cmd, e)))?;
    Ok(())
}

#[allow(unused)]
trait CommandExt {
    async fn execute<T: AsRef<str>>(&mut self, expect: T) -> WiFiResult;
}

impl CommandExt for Command {
    async fn execute<T: AsRef<str>>(&mut self, expect: T) -> WiFiResult {
        let program = self.as_std().get_program().to_os_string();
        let program = program.to_string_lossy();
        let expect = expect.as_ref();

        let status = self.status().await.map_err(|e| {
            WiFiError::new_system(format!("{expect} failed, program `{program}`: {e}"))
        })?;
        if !status.success() {
            return Err(WiFiError::new_system(format!(
                "{expect} failed, program `{program}`"
            )));
        }
        Ok(())
    }
}

pub fn channel_to_freq_mhz(channel: usize) -> usize {
    if channel < 14 {
        2407 + channel * 5
    } else {
        5000 + channel * 5
    }
}

pub fn freq_mhz_to_channel(freq_mhz: usize) -> usize {
    if freq_mhz > 5000 {
        return (freq_mhz - 5000) / 5;
    }
    (freq_mhz - 2407) / 5
}

fn adapt_channel_max_bandwidth(
    id: &ID,
    channel: usize,
    bandwidth: Option<BandWidth>,
    second: Option<SecondChannel>,
) -> Option<BandWidthArg> {
    let mut bandwidth = bandwidth?;
    if let Some(max_bandwidth) = channel_max_bandwidth(id, channel) {
        if bandwidth > max_bandwidth {
            debug!(
                "Channel {} supports max bandwidth: {:?}, using it",
                channel, max_bandwidth
            );
            bandwidth = max_bandwidth;
        }
    } else {
        debug!("channel {} not found in max bandwidth map", channel);
    }

    let out = match bandwidth {
        BandWidth::HT40 => {
            if let Some(second) = second {
                match second {
                    SecondChannel::Above => BandWidthArg::HT40Above,
                    SecondChannel::Below => BandWidthArg::HT40Below,
                }
            } else {
                match channel {
                    1..=6 => BandWidthArg::HT40Above,
                    7..=13 => BandWidthArg::HT40Below,
                    _ => {
                        warn!(
                            "Channel {} is not in the range of 1-13, defaulting to HT40Above",
                            channel
                        );
                        BandWidthArg::HT40Above
                    }
                }
            }
        }
        BandWidth::HT20 => BandWidthArg::HT20,
        BandWidth::MHz80 => BandWidthArg::MHz80,
        BandWidth::MHz160 => BandWidthArg::MHz160,
    };

    Some(out)
}

fn downcast_channel_max_bandwidth(id: &ID, freq: usize) -> Option<()> {
    let mut max_bandwidth = MAX_BANDWIDTH.lock().unwrap();
    let map = max_bandwidth.get_mut(id)?;
    map.insert(freq, BandWidth::HT20);
    Some(())
}

fn channel_max_bandwidth(id: &ID, channel: usize) -> Option<BandWidth> {
    let max_bandwidth = MAX_BANDWIDTH.lock().unwrap();
    max_bandwidth.get(id).and_then(|m| m.get(&channel)).cloned()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_get_wifi_adapter_names() {
        for one in interface_list().await.unwrap() {
            println!("{one:?}");
        }
    }

    #[tokio::test]
    async fn test_set_mode() {
        let interface = interface_list().await.unwrap().remove(0);
        interface.set_mode(Mode::Monitor).await.unwrap();

        let mode = interface.get_mode().await.unwrap();

        assert_eq!(mode, Mode::Monitor);

        let is_up = interface.is_ifup().await.unwrap();
        println!("is up: {}", is_up);
        println!("mode: {:?}", mode);
    }

    #[test]
    fn test_channel_to_freq_mhz() {
        assert_eq!(channel_to_freq_mhz(1), 2412);
        assert_eq!(channel_to_freq_mhz(6), 2437);
        assert_eq!(channel_to_freq_mhz(13), 2472);

        assert_eq!(channel_to_freq_mhz(36), 5180);
    }

    #[test]
    fn test_freq_mhz_to_channel() {
        assert_eq!(freq_mhz_to_channel(2412), 1);
        assert_eq!(freq_mhz_to_channel(2437), 6);
        assert_eq!(freq_mhz_to_channel(2472), 13);
        assert_eq!(freq_mhz_to_channel(5180), 36);
    }

    #[tokio::test]
    async fn test_set_channel() {
        env_logger::builder()
            .filter_level(log::LevelFilter::Debug)
            .is_test(true)
            .init();
        let interface = interface_list().await.unwrap().remove(0);
        interface.set_mode(Mode::Monitor).await.unwrap();

        interface
            .set_channel(13, Some(BandWidth::MHz160), Some(SecondChannel::Below))
            .await
            .unwrap();
        interface
            .set_channel(2, Some(BandWidth::MHz160), None)
            .await
            .unwrap();
        interface
            .set_channel(2, Some(BandWidth::MHz160), Some(SecondChannel::Above))
            .await
            .unwrap();

        // interface
        //     .set_frequency(5180, Some(BandWidth::MHz160), None)
        //     .await
        //     .unwrap();
    }
}