wifi_manager/
lib.rs

1use std::{
2    collections::HashMap,
3    hash::{BuildHasherDefault, DefaultHasher},
4    num::{ParseFloatError, ParseIntError},
5    ops::Deref,
6    sync::Mutex,
7    time::Instant,
8};
9
10use log::*;
11use thiserror::Error;
12use tokio::process::Command;
13
14use crate::os::{ID, OS};
15
16mod info;
17mod utils;
18
19#[cfg_attr(windows, path = "os/windows/mod.rs")]
20#[cfg_attr(target_os = "linux", path = "os/linux/mod.rs")]
21mod os;
22
23pub use info::*;
24
25static MAX_BANDWIDTH: Mutex<
26    HashMap<ID, HashMap<usize, BandWidth>, BuildHasherDefault<DefaultHasher>>,
27> = Mutex::new(HashMap::with_hasher(BuildHasherDefault::new()));
28
29#[derive(Error, Debug)]
30pub enum WiFiError {
31    #[error("System error: {0}")]
32    System(String),
33    #[error("Not support: {0}")]
34    NotSupport(String),
35}
36
37impl WiFiError {
38    fn new_system<E: Deref<Target = str>>(e: E) -> Self {
39        WiFiError::System(e.to_string())
40    }
41}
42
43impl From<std::io::Error> for WiFiError {
44    fn from(e: std::io::Error) -> Self {
45        WiFiError::new_system(e.to_string())
46    }
47}
48
49impl From<ParseIntError> for WiFiError {
50    fn from(value: ParseIntError) -> Self {
51        WiFiError::new_system(format!("ParseIntError: {}", value))
52    }
53}
54
55impl From<ParseFloatError> for WiFiError {
56    fn from(value: ParseFloatError) -> Self {
57        WiFiError::new_system(format!("ParseFloatError: {}", value))
58    }
59}
60
61pub type WiFiResult<T = ()> = std::result::Result<T, WiFiError>;
62
63#[derive(Debug, Clone)]
64pub struct Interface {
65    pub id: ID,
66    pub support_mode: Vec<Mode>,
67}
68
69impl Interface {
70    pub async fn set_mode(&self, mode: Mode) -> WiFiResult {
71        let start = Instant::now();
72        OS::set_mode(&self.id, mode).await?;
73        debug!(
74            "Set mode for interface [{}] to {:?} took {:?}",
75            self.id,
76            mode,
77            start.elapsed()
78        );
79        Ok(())
80    }
81
82    pub async fn set_channel(
83        &self,
84        channel: usize,
85        band_width: Option<BandWidth>,
86        second: Option<SecondChannel>,
87    ) -> WiFiResult {
88        if let Err(e) = self.try_set_chennel(channel, band_width, second).await {
89            warn!(
90                "interface `{}` set channel {channel} {band_width:?} fail, try downcast, err: {}",
91                self.id, e
92            );
93            downcast_channel_max_bandwidth(&self.id, channel);
94            self.try_set_chennel(channel, None, None).await
95        } else {
96            Ok(())
97        }
98    }
99    pub async fn set_frequency(
100        &self,
101        freq_mhz: usize,
102        band_width: Option<BandWidth>,
103        second: Option<SecondChannel>,
104    ) -> WiFiResult {
105        let channel = freq_mhz_to_channel(freq_mhz);
106        self.set_channel(channel, band_width, second).await
107    }
108
109    async fn try_set_chennel(
110        &self,
111        channel: usize,
112        chennel: Option<BandWidth>,
113        second: Option<SecondChannel>,
114    ) -> WiFiResult {
115        let start = Instant::now();
116        let band_width = adapt_channel_max_bandwidth(&self.id, channel, chennel, second);
117        OS::set_channel(&self.id, channel, band_width).await?;
118        let band_width_str = band_width
119            .map(|bw| format!(" bandwidth {}", bw))
120            .unwrap_or_default();
121
122        debug!(
123            "Set interface [{}] to channel {channel} {band_width_str} took {:?}",
124            self.id,
125            start.elapsed()
126        );
127        Ok(())
128    }
129
130    pub async fn ifup(&self) -> WiFiResult {
131        OS::ifup(&self.id).await
132    }
133    pub async fn ifdown(&self) -> WiFiResult {
134        OS::ifdown(&self.id).await
135    }
136
137    pub async fn get_mode(&self) -> WiFiResult<Mode> {
138        OS::get_mode(&self.id).await
139    }
140
141    pub async fn is_ifup(&self) -> WiFiResult<bool> {
142        OS::is_ifup(&self.id).await
143    }
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum Mode {
148    Managed,
149    Monitor,
150}
151impl Mode {
152    fn cmd(&self) -> &str {
153        match self {
154            Mode::Monitor => "monitor",
155            Mode::Managed => "managed",
156        }
157    }
158}
159
160impl TryFrom<&str> for Mode {
161    type Error = ();
162
163    fn try_from(value: &str) -> Result<Self, Self::Error> {
164        match value.trim() {
165            "managed" => Ok(Mode::Managed),
166            "monitor" => Ok(Mode::Monitor),
167            _ => Err(()),
168        }
169    }
170}
171
172impl TryFrom<String> for Mode {
173    type Error = ();
174
175    fn try_from(value: String) -> Result<Self, Self::Error> {
176        value.as_str().try_into()
177    }
178}
179
180trait Impl {
181    async fn check_environment() -> WiFiResult;
182    async fn interface_list() -> Result<Vec<Interface>, WiFiError>;
183    async fn set_mode(id: &ID, mode: Mode) -> WiFiResult;
184    async fn get_mode(id: &ID) -> WiFiResult<Mode>;
185    async fn set_channel(id: &ID, channel: usize, band_width: Option<BandWidthArg>) -> WiFiResult;
186    async fn ifup(id: &ID) -> WiFiResult;
187    async fn ifdown(id: &ID) -> WiFiResult;
188    async fn is_ifup(id: &ID) -> WiFiResult<bool>;
189    async fn freq_max_bandwidth(id: &ID) -> WiFiResult<HashMap<usize, BandWidth>>;
190}
191
192pub async fn check_environment() -> WiFiResult {
193    OS::check_environment().await
194}
195
196pub async fn interface_list() -> Result<Vec<Interface>, WiFiError> {
197    let mut out = vec![];
198    for one in OS::interface_list().await? {
199        let id = one.id.clone();
200        out.push(one);
201        #[allow(clippy::map_entry)]
202        if !MAX_BANDWIDTH.lock().unwrap().contains_key(&id) {
203            let mut map = HashMap::new();
204            let max_bandwidth = OS::freq_max_bandwidth(&id).await?;
205            for (freq, bandwidth) in max_bandwidth {
206                let channel = freq_mhz_to_channel(freq);
207                map.insert(channel, bandwidth);
208            }
209            MAX_BANDWIDTH.lock().unwrap().insert(id, map);
210        }
211    }
212
213    Ok(out)
214}
215
216#[allow(unused)]
217async fn check_command(cmd: &str) -> WiFiResult {
218    Command::new(cmd)
219        .arg("--help")
220        .output()
221        .await
222        .map_err(|e| WiFiError::NotSupport(format!("command [{}] fail: {:?}", cmd, e)))?;
223    Ok(())
224}
225
226#[allow(unused)]
227trait CommandExt {
228    async fn execute<T: AsRef<str>>(&mut self, expect: T) -> WiFiResult;
229}
230
231impl CommandExt for Command {
232    async fn execute<T: AsRef<str>>(&mut self, expect: T) -> WiFiResult {
233        let program = self.as_std().get_program().to_os_string();
234        let program = program.to_string_lossy();
235        let expect = expect.as_ref();
236
237        let status = self.status().await.map_err(|e| {
238            WiFiError::new_system(format!("{expect} failed, program `{program}`: {e}"))
239        })?;
240        if !status.success() {
241            return Err(WiFiError::new_system(format!(
242                "{expect} failed, program `{program}`"
243            )));
244        }
245        Ok(())
246    }
247}
248
249pub fn channel_to_freq_mhz(channel: usize) -> usize {
250    if channel < 14 {
251        2407 + channel * 5
252    } else {
253        5000 + channel * 5
254    }
255}
256
257pub fn freq_mhz_to_channel(freq_mhz: usize) -> usize {
258    if freq_mhz > 5000 {
259        return (freq_mhz - 5000) / 5;
260    }
261    (freq_mhz - 2407) / 5
262}
263
264fn adapt_channel_max_bandwidth(
265    id: &ID,
266    channel: usize,
267    bandwidth: Option<BandWidth>,
268    second: Option<SecondChannel>,
269) -> Option<BandWidthArg> {
270    let mut bandwidth = bandwidth?;
271    if let Some(max_bandwidth) = channel_max_bandwidth(id, channel) {
272        if bandwidth > max_bandwidth {
273            debug!(
274                "Channel {} supports max bandwidth: {:?}, using it",
275                channel, max_bandwidth
276            );
277            bandwidth = max_bandwidth;
278        }
279    } else {
280        debug!("channel {} not found in max bandwidth map", channel);
281    }
282
283    let out = match bandwidth {
284        BandWidth::HT40 => {
285            if let Some(second) = second {
286                match second {
287                    SecondChannel::Above => BandWidthArg::HT40Above,
288                    SecondChannel::Below => BandWidthArg::HT40Below,
289                }
290            } else {
291                match channel {
292                    1..=6 => BandWidthArg::HT40Above,
293                    7..=13 => BandWidthArg::HT40Below,
294                    _ => {
295                        warn!(
296                            "Channel {} is not in the range of 1-13, defaulting to HT40Above",
297                            channel
298                        );
299                        BandWidthArg::HT40Above
300                    }
301                }
302            }
303        }
304        BandWidth::HT20 => BandWidthArg::HT20,
305        BandWidth::MHz80 => BandWidthArg::MHz80,
306        BandWidth::MHz160 => BandWidthArg::MHz160,
307    };
308
309    Some(out)
310}
311
312fn downcast_channel_max_bandwidth(id: &ID, freq: usize) -> Option<()> {
313    let mut max_bandwidth = MAX_BANDWIDTH.lock().unwrap();
314    let map = max_bandwidth.get_mut(id)?;
315    map.insert(freq, BandWidth::HT20);
316    Some(())
317}
318
319fn channel_max_bandwidth(id: &ID, channel: usize) -> Option<BandWidth> {
320    let max_bandwidth = MAX_BANDWIDTH.lock().unwrap();
321    max_bandwidth.get(id).and_then(|m| m.get(&channel)).cloned()
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[tokio::test]
329    async fn test_get_wifi_adapter_names() {
330        for one in interface_list().await.unwrap() {
331            println!("{one:?}");
332        }
333    }
334
335    #[tokio::test]
336    async fn test_set_mode() {
337        let interface = interface_list().await.unwrap().remove(0);
338        interface.set_mode(Mode::Monitor).await.unwrap();
339
340        let mode = interface.get_mode().await.unwrap();
341
342        assert_eq!(mode, Mode::Monitor);
343
344        let is_up = interface.is_ifup().await.unwrap();
345        println!("is up: {}", is_up);
346        println!("mode: {:?}", mode);
347    }
348
349    #[test]
350    fn test_channel_to_freq_mhz() {
351        assert_eq!(channel_to_freq_mhz(1), 2412);
352        assert_eq!(channel_to_freq_mhz(6), 2437);
353        assert_eq!(channel_to_freq_mhz(13), 2472);
354
355        assert_eq!(channel_to_freq_mhz(36), 5180);
356    }
357
358    #[test]
359    fn test_freq_mhz_to_channel() {
360        assert_eq!(freq_mhz_to_channel(2412), 1);
361        assert_eq!(freq_mhz_to_channel(2437), 6);
362        assert_eq!(freq_mhz_to_channel(2472), 13);
363        assert_eq!(freq_mhz_to_channel(5180), 36);
364    }
365
366    #[tokio::test]
367    async fn test_set_channel() {
368        env_logger::builder()
369            .filter_level(log::LevelFilter::Debug)
370            .is_test(true)
371            .init();
372        let interface = interface_list().await.unwrap().remove(0);
373        interface.set_mode(Mode::Monitor).await.unwrap();
374
375        interface
376            .set_channel(13, Some(BandWidth::MHz160), Some(SecondChannel::Below))
377            .await
378            .unwrap();
379        interface
380            .set_channel(2, Some(BandWidth::MHz160), None)
381            .await
382            .unwrap();
383        interface
384            .set_channel(2, Some(BandWidth::MHz160), Some(SecondChannel::Above))
385            .await
386            .unwrap();
387
388        // interface
389        //     .set_frequency(5180, Some(BandWidth::MHz160), None)
390        //     .await
391        //     .unwrap();
392    }
393}