use core::cell::RefCell;
use std::collections::HashMap;
use embassy_futures::select::{select, Either};
use embassy_time::{Duration, Timer};
use futures_lite::StreamExt;
use zbus::zvariant::{ObjectPath, OwnedObjectPath, OwnedValue, Value};
use zbus::Connection;
use crate::dm::clusters::net_comm::{
NetCtl, NetCtlError, NetworkScanInfo, NetworkType, WiFiBandEnum, WiFiSecurityBitmap,
WirelessCreds,
};
use crate::dm::clusters::wifi_diag::{SecurityTypeEnum, WiFiVersionEnum, WifiDiag, WirelessDiag};
use crate::dm::networks::NetChangeNotif;
use crate::error::{Error, ErrorCode};
use crate::tlv::Nullable;
use crate::utils::sync::{blocking, DynBase, IfMutex};
use crate::utils::zbus_proxies::wpa_supp::bss::BSSProxy;
use crate::utils::zbus_proxies::wpa_supp::interface::InterfaceProxy;
use crate::utils::zbus_proxies::wpa_supp::wpa_supplicant::WPASupplicantProxy;
#[cfg(unix)]
pub mod unix;
pub struct WpaSuppCtl<'a, T>
where
T: IpStackCtl,
{
connection: &'a Connection,
ifname: &'a str,
ip_stack_ctl: T,
network: IfMutex<Option<OwnedObjectPath>>,
wifi_conn_info: blocking::Mutex<RefCell<Option<WifiConnInfo>>>,
}
impl<'a, T> WpaSuppCtl<'a, T>
where
T: IpStackCtl,
{
pub const fn new(connection: &'a Connection, ifname: &'a str, ip_stack_ctl: T) -> Self {
Self {
connection,
ifname,
ip_stack_ctl,
network: IfMutex::new(None),
wifi_conn_info: blocking::Mutex::new(RefCell::new(None)),
}
}
pub const fn connection(&self) -> &Connection {
self.connection
}
async fn interface(&self) -> Result<InterfaceProxy<'a>, zbus::Error> {
let wpas = WPASupplicantProxy::new(self.connection).await?;
let interface_path = wpas.get_interface(self.ifname).await?;
InterfaceProxy::builder(self.connection)
.path(interface_path.clone())?
.build()
.await
}
async fn wait(&self, for_connection: bool) -> Result<(), Error> {
let interface = self.interface().await?;
let mut iface_state_changed = interface.receive_state_changed().await;
loop {
let bss = interface.current_bss().await?;
let (changed, connected) = if bss.len() > 1 {
self.network_scan_info(&bss, |info| {
let info = info.map(WifiConnInfo::new);
Ok(self.update_wifi_conn_info(info))
})
.await?
} else {
self.update_wifi_conn_info(None)
};
if for_connection && connected || !for_connection && changed {
break Ok(());
}
let ip_stack_changed = self.ip_stack_ctl.wait_changed();
select(iface_state_changed.next(), ip_stack_changed).await;
}
}
fn update_wifi_conn_info(&self, new_wifi_conn_info: Option<WifiConnInfo>) -> (bool, bool) {
self.wifi_conn_info.lock(|wifi_conn_info| {
let mut wifi_conn_info = wifi_conn_info.borrow_mut();
let changed = if *wifi_conn_info != new_wifi_conn_info {
*wifi_conn_info = new_wifi_conn_info;
true
} else {
false
};
let connected = Self::connected(wifi_conn_info.as_ref());
(changed, connected)
})
}
fn connected(wifi_conn_info: Option<&WifiConnInfo>) -> bool {
wifi_conn_info.is_some()
}
async fn remove_network(&self, network: &mut Option<OwnedObjectPath>) -> zbus::Result<()> {
let interface = self.interface().await?;
if let Some(network_path) = network.clone() {
if interface.remove_network(&network_path).await.is_ok() {
network.take();
}
}
Ok(())
}
async fn network_scan_info<F, R>(&self, bss: &ObjectPath<'_>, f: F) -> Result<R, Error>
where
F: FnOnce(Option<&NetworkScanInfo>) -> Result<R, Error>,
{
let bss_info = BSSProxy::builder(self.connection)
.path(bss)?
.build()
.await?;
if bss_info.mode().await? == "infrastructure" {
let wpa = bss_info.wpa().await?;
let rsn = bss_info.rsn().await?;
let security = if wpa.is_empty() && rsn.is_empty() {
WiFiSecurityBitmap::UNENCRYPTED
} else {
let str_list_val = |key, map: &HashMap<String, OwnedValue>| {
let str_list: Vec<String> = map
.get(key)
.cloned()
.and_then(|w| w.clone().try_into().ok())
.unwrap_or_default();
str_list
};
let mut security = WiFiSecurityBitmap::empty();
let wpa_key_mgmt = str_list_val("KeyMgmt", &wpa);
if wpa_key_mgmt.contains(&"wpa-none".to_string()) {
security |= WiFiSecurityBitmap::WEP;
}
if wpa_key_mgmt.contains(&"wpa-psk".to_string()) {
security |= WiFiSecurityBitmap::WPA_PERSONAL;
}
let rsn_key_mgmt = str_list_val("KeyMgmt", &rsn);
if rsn_key_mgmt.contains(&"wpa-psk".to_string())
|| rsn_key_mgmt.contains(&"wpa-ft-psk".to_string())
|| rsn_key_mgmt.contains(&"wpa-psk-sha256".to_string())
{
security |= WiFiSecurityBitmap::WPA_2_PERSONAL;
}
if rsn_key_mgmt.contains(&"sae".to_string()) {
security |= WiFiSecurityBitmap::WPA_3_PERSONAL
}
security
};
let (band, channel) = super::band::band_and_channel(bss_info.frequency().await? as u32)
.unwrap_or((WiFiBandEnum::V2G4, 0));
let network_scan_info = NetworkScanInfo::Wifi {
security,
ssid: &bss_info.ssid().await?,
bssid: &bss_info.bssid().await?,
band,
channel,
rssi: bss_info.signal().await?.min(i8::MIN as _).max(i8::MAX as _) as i8,
};
f(Some(&network_scan_info))
} else {
f(None)
}
}
}
impl<T> Drop for WpaSuppCtl<'_, T>
where
T: IpStackCtl,
{
fn drop(&mut self) {
let _ = futures_lite::future::block_on(async {
let mut network = self.network.lock().await;
self.remove_network(&mut network).await
});
}
}
impl<T> NetCtl for WpaSuppCtl<'_, T>
where
T: IpStackCtl,
{
fn net_type(&self) -> NetworkType {
NetworkType::Wifi
}
async fn scan<F>(&self, network: Option<&[u8]>, mut f: F) -> Result<(), NetCtlError>
where
F: FnMut(&NetworkScanInfo) -> Result<(), Error>,
{
const SCAN_RETRIES: usize = 3;
const SCAN_RETRIES_SLEEP_SEC: u64 = 5;
const SCAN_DONE_TIMEOUT_SEC: u64 = 20;
let _guard = self.network.lock().await;
let mut args = HashMap::new();
let active = Value::from("active");
args.insert("Type", &active);
let ssids = network.map(|network| vec![network.to_vec()].into());
if ssids.is_some() {
#[allow(clippy::unnecessary_unwrap)]
args.insert("SSIDs", ssids.as_ref().unwrap());
}
let interface = self.interface().await?;
let mut scan_done = interface.receive_scan_done().await?;
for _ in 0..SCAN_RETRIES {
if interface.scan(args.clone()).await.is_ok() {
let scan_done = async {
loop {
if scan_done.next().await.is_some() {
break;
}
}
};
let timeout = Timer::after(Duration::from_secs(SCAN_DONE_TIMEOUT_SEC));
if let Either::First(_) = select(scan_done, timeout).await {
break;
}
}
Timer::after(Duration::from_secs(SCAN_RETRIES_SLEEP_SEC)).await;
}
let bsss = interface.bsss().await?;
for bss in bsss {
self.network_scan_info(&bss, |info| {
if let Some(info) = info {
f(info)?;
}
Ok(())
})
.await?;
}
Ok(())
}
async fn connect(&self, creds: &WirelessCreds<'_>) -> Result<(), NetCtlError> {
const CONNECT_TIMEOUT_SECS: u64 = 30;
let mut network = self.network.lock().await;
let WirelessCreds::Wifi { ssid, pass } = creds else {
return Err(NetCtlError::Other(ErrorCode::InvalidAction.into()));
};
let interface = self.interface().await?;
self.remove_network(&mut network).await?;
let mut args = HashMap::new();
let utf8_err = |_| NetCtlError::Other(ErrorCode::Invalid.into());
let arg_ssid = core::str::from_utf8(ssid).map_err(utf8_err)?.into();
args.insert("ssid", &arg_ssid);
let arg_pass = core::str::from_utf8(pass).map_err(utf8_err)?.into();
if !pass.is_empty() {
args.insert("psk", &arg_pass);
}
let network_path = interface.add_network(args).await?;
*network = Some(network_path.clone());
interface.select_network(&network_path).await?;
let connected = self.wait(true);
let timeout = Timer::after(Duration::from_secs(CONNECT_TIMEOUT_SECS));
match select(connected, timeout).await {
Either::First(_) => info!("Connected to Wifi network: {}", self.ifname),
Either::Second(_) => {
error!(
"Connection to Wifi network timed out: {}, assuming auth failure",
self.ifname
);
if let Err(e2) = self.remove_network(&mut network).await {
warn!(
"Failed to remove network after connection timeout: {:?}",
e2
);
}
return Err(NetCtlError::AuthFailure);
}
}
match self.ip_stack_ctl.connect().await {
Ok(()) => {
info!("IP stack connected for network: {}", self.ifname);
Ok(())
}
Err(e) => {
error!(
"Failed to connect IP stack for network {}: {:?}",
self.ifname, e
);
if let Err(e2) = self.remove_network(&mut network).await {
warn!(
"Failed to remove network after IP stack connection failure: {:?}",
e2
);
}
Err(e)
}
}
}
}
impl<T> WirelessDiag for WpaSuppCtl<'_, T>
where
T: IpStackCtl,
{
fn connected(&self) -> Result<bool, Error> {
Ok(self.wifi_conn_info.lock(|ssid_info| {
Self::connected(ssid_info.borrow().as_ref())
&& self.ip_stack_ctl.is_connected().unwrap_or(false)
}))
}
}
impl<T> DynBase for WpaSuppCtl<'_, T> where T: IpStackCtl {}
impl<T> WifiDiag for WpaSuppCtl<'_, T>
where
T: IpStackCtl,
{
fn bssid(&self, f: &mut dyn FnMut(Option<&[u8]>) -> Result<(), Error>) -> Result<(), Error> {
self.wifi_conn_info.lock(|wifi_conn_info| {
let wifi_conn_info = wifi_conn_info.borrow();
if let Some(wifi_conn_info) = wifi_conn_info.as_ref() {
f(Some(&wifi_conn_info.bssid))
} else {
f(None)
}
})
}
fn security_type(&self) -> Result<Nullable<SecurityTypeEnum>, Error> {
Ok(Nullable::none())
}
fn wi_fi_version(&self) -> Result<Nullable<WiFiVersionEnum>, Error> {
Ok(Nullable::none())
}
fn channel_number(&self) -> Result<Nullable<u16>, Error> {
Ok(self.wifi_conn_info.lock(|wifi_conn_info| {
let wifi_conn_info = wifi_conn_info.borrow();
if let Some(wifi_conn_info) = wifi_conn_info.as_ref() {
Nullable::some(wifi_conn_info.channel)
} else {
Nullable::none()
}
}))
}
fn rssi(&self) -> Result<Nullable<i8>, Error> {
Ok(self.wifi_conn_info.lock(|wifi_conn_info| {
let wifi_conn_info = wifi_conn_info.borrow();
if let Some(wifi_conn_info) = wifi_conn_info.as_ref() {
Nullable::some(wifi_conn_info.rssi)
} else {
Nullable::none()
}
}))
}
}
impl<T> NetChangeNotif for WpaSuppCtl<'_, T>
where
T: IpStackCtl,
{
async fn wait_changed(&self) {
let wait_wifi = self.wait(false);
let wait_ip = self.ip_stack_ctl.wait_changed();
select(wait_wifi, wait_ip).await;
}
}
pub trait IpStackCtl: DynBase {
async fn connect(&self) -> Result<(), NetCtlError>;
async fn wait_changed(&self);
fn is_connected(&self) -> Result<bool, NetCtlError>;
}
impl<T> IpStackCtl for &T
where
T: IpStackCtl,
{
async fn connect(&self) -> Result<(), NetCtlError> {
T::connect(self).await
}
async fn wait_changed(&self) {
T::wait_changed(self).await;
}
fn is_connected(&self) -> Result<bool, NetCtlError> {
T::is_connected(self)
}
}
#[derive(Debug, Eq, PartialEq, Clone, Hash)]
struct WifiConnInfo {
security: WiFiSecurityBitmap,
ssid: Vec<u8>,
bssid: Vec<u8>,
band: WiFiBandEnum,
channel: u16,
rssi: i8,
}
impl WifiConnInfo {
fn new(scan_info: &NetworkScanInfo) -> Self {
let NetworkScanInfo::Wifi {
security,
ssid,
bssid,
channel,
band,
rssi,
} = scan_info
else {
unreachable!();
};
Self {
security: *security,
ssid: ssid.to_vec(),
bssid: bssid.to_vec(),
band: *band,
channel: *channel,
rssi: *rssi,
}
}
}
impl From<zbus::Error> for NetCtlError {
fn from(value: zbus::Error) -> Self {
NetCtlError::Other(value.into())
}
}