use core::cmp::{max, min};
use core::iter::zip;
use core::sync::atomic::AtomicBool;
use core::sync::atomic::Ordering::Relaxed;
use embassy_net_driver_channel as ch;
use embassy_net_driver_channel::driver::HardwareAddress;
use embassy_time::{Duration, Timer};
use crate::consts::*;
use crate::events::{Event, EventSubscriber, Events};
use crate::fmt::Bytes;
use crate::ioctl::{IoctlState, IoctlType};
use crate::structs::*;
use crate::{PowerManagementMode, countries, events};
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum JoinError {
NetworkNotFound,
JoinFailure(u8),
AuthenticationFailure,
}
#[derive(Debug)]
pub enum AddMulticastAddressError {
NotMulticast,
NoFreeSlots,
}
pub struct Control<'a> {
state_ch: ch::StateRunner<'a>,
events: &'a Events,
ioctl_state: &'a IoctlState,
secure_network: &'a AtomicBool,
}
#[derive(Copy, Clone, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ScanType {
Active,
Passive,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[non_exhaustive]
pub struct ScanOptions {
pub ssid: Option<heapless::String<32>>,
pub bssid: Option<[u8; 6]>,
pub nprobes: Option<u16>,
pub home_time: Option<Duration>,
pub scan_type: ScanType,
pub dwell_time: Option<Duration>,
}
impl Default for ScanOptions {
fn default() -> Self {
Self {
ssid: None,
bssid: None,
nprobes: None,
home_time: None,
scan_type: ScanType::Passive,
dwell_time: None,
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum JoinAuth {
Open,
Wpa,
Wpa2,
Wpa3,
Wpa2Wpa3,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[non_exhaustive]
pub struct JoinOptions<'a> {
pub auth: JoinAuth,
pub cipher_tkip: bool,
pub cipher_aes: bool,
pub passphrase: &'a [u8],
pub passphrase_is_prehashed: bool,
}
impl<'a> JoinOptions<'a> {
pub fn new_open() -> Self {
Self {
auth: JoinAuth::Open,
cipher_tkip: false,
cipher_aes: false,
passphrase: &[],
passphrase_is_prehashed: false,
}
}
pub fn new(passphrase: &'a [u8]) -> Self {
let mut this = Self::default();
this.passphrase = passphrase;
this
}
}
impl<'a> Default for JoinOptions<'a> {
fn default() -> Self {
Self {
auth: JoinAuth::Wpa2Wpa3,
cipher_tkip: false,
cipher_aes: true,
passphrase: &[],
passphrase_is_prehashed: false,
}
}
}
impl<'a> Control<'a> {
pub(crate) fn new(
state_ch: ch::StateRunner<'a>,
event_sub: &'a Events,
ioctl_state: &'a IoctlState,
secure_network: &'a AtomicBool,
) -> Self {
Self {
state_ch,
events: event_sub,
ioctl_state,
secure_network,
}
}
async fn load_clm(&mut self, clm: &[u8]) {
const CHUNK_SIZE: usize = 1024;
debug!("Downloading CLM...");
let mut offs = 0;
for chunk in clm.chunks(CHUNK_SIZE) {
let mut flag = DOWNLOAD_FLAG_HANDLER_VER;
if offs == 0 {
flag |= DOWNLOAD_FLAG_BEGIN;
}
offs += chunk.len();
if offs == clm.len() {
flag |= DOWNLOAD_FLAG_END;
}
let header = DownloadHeader {
flag,
dload_type: DOWNLOAD_TYPE_CLM,
len: chunk.len() as _,
crc: 0,
};
let mut buf = [0; 8 + 12 + CHUNK_SIZE];
buf[0..8].copy_from_slice(b"clmload\x00");
buf[8..20].copy_from_slice(&header.to_bytes());
buf[20..][..chunk.len()].copy_from_slice(&chunk);
self.ioctl(IoctlType::Set, Ioctl::SetVar, 0, &mut buf[..8 + 12 + chunk.len()])
.await;
}
assert_eq!(self.get_iovar_u32("clmload_status").await, 0);
}
pub async fn init(&mut self, clm: &[u8]) {
self.load_clm(&clm).await;
debug!("Configuring misc stuff...");
self.set_iovar_u32("bus:txglom", 0).await;
self.set_iovar_u32("apsta", 1).await;
let mac_addr = self.address().await;
debug!("mac addr: {:02x}", Bytes(&mac_addr));
let country = countries::WORLD_WIDE_XX;
let country_info = CountryInfo {
country_abbrev: [country.code[0], country.code[1], 0, 0],
country_code: [country.code[0], country.code[1], 0, 0],
rev: if country.rev == 0 { -1 } else { country.rev as _ },
};
self.set_iovar("country", &country_info.to_bytes()).await;
Timer::after_millis(100).await;
self.ioctl_set_u32(Ioctl::SetAntdiv, 0, 0).await;
self.set_iovar_u32("bus:txglom", 0).await;
Timer::after_millis(100).await;
self.set_iovar_u32("ampdu_ba_wsize", 8).await;
Timer::after_millis(100).await;
self.set_iovar_u32("ampdu_mpdu", 4).await;
Timer::after_millis(100).await;
let mut evts = EventMask {
iface: 0,
events: [0xFF; 24],
};
evts.unset(Event::RADIO);
evts.unset(Event::IF);
evts.unset(Event::PROBREQ_MSG);
evts.unset(Event::PROBREQ_MSG_RX);
evts.unset(Event::PROBRESP_MSG);
evts.unset(Event::PROBRESP_MSG);
evts.unset(Event::ROAM);
self.set_iovar("bsscfg:event_msgs", &evts.to_bytes()).await;
Timer::after_millis(100).await;
self.up().await;
Timer::after_millis(100).await;
self.ioctl_set_u32(Ioctl::SetGmode, 0, 1).await; self.ioctl_set_u32(Ioctl::SetBand, 0, 0).await;
Timer::after_millis(100).await;
self.state_ch.set_hardware_address(HardwareAddress::Ethernet(mac_addr));
debug!("cyw43 control init done");
}
async fn up(&mut self) {
self.ioctl(IoctlType::Set, Ioctl::Up, 0, &mut []).await;
}
async fn down(&mut self) {
self.ioctl(IoctlType::Set, Ioctl::Down, 0, &mut []).await;
}
pub async fn set_power_management(&mut self, mode: PowerManagementMode) {
let mode_num = mode.mode();
if mode_num == 2 {
self.set_iovar_u32("pm2_sleep_ret", mode.sleep_ret_ms() as u32).await;
self.set_iovar_u32("bcn_li_bcn", mode.beacon_period() as u32).await;
self.set_iovar_u32("bcn_li_dtim", mode.dtim_period() as u32).await;
self.set_iovar_u32("assoc_listen", mode.assoc() as u32).await;
}
self.ioctl_set_u32(Ioctl::SetPm, 0, mode_num).await;
}
pub async fn join(&mut self, ssid: &str, options: JoinOptions<'_>) -> Result<(), JoinError> {
self.set_iovar_u32("ampdu_ba_wsize", 8).await;
if options.auth == JoinAuth::Open {
self.ioctl_set_u32(Ioctl::SetWsec, 0, 0).await;
self.set_iovar_u32x2("bsscfg:sup_wpa", 0, 0).await;
self.ioctl_set_u32(Ioctl::SetInfra, 0, 1).await;
self.ioctl_set_u32(Ioctl::SetAuth, 0, 0).await;
self.ioctl_set_u32(Ioctl::SetWpaAuth, 0, WPA_AUTH_DISABLED).await;
} else {
let mut wsec = 0;
if options.cipher_aes {
wsec |= WSEC_AES;
}
if options.cipher_tkip {
wsec |= WSEC_TKIP;
}
self.ioctl_set_u32(Ioctl::SetWsec, 0, wsec).await;
self.set_iovar_u32x2("bsscfg:sup_wpa", 0, 1).await;
self.set_iovar_u32x2("bsscfg:sup_wpa2_eapver", 0, 0xFFFF_FFFF).await;
self.set_iovar_u32x2("bsscfg:sup_wpa_tmo", 0, 2500).await;
Timer::after_millis(100).await;
let (wpa12, wpa3, auth, mfp, wpa_auth) = match options.auth {
JoinAuth::Open => unreachable!(),
JoinAuth::Wpa => (true, false, AUTH_OPEN, MFP_NONE, WPA_AUTH_WPA_PSK),
JoinAuth::Wpa2 => (true, false, AUTH_OPEN, MFP_CAPABLE, WPA_AUTH_WPA2_PSK),
JoinAuth::Wpa3 => (false, true, AUTH_SAE, MFP_REQUIRED, WPA_AUTH_WPA3_SAE_PSK),
JoinAuth::Wpa2Wpa3 => (true, true, AUTH_SAE, MFP_CAPABLE, WPA_AUTH_WPA3_SAE_PSK),
};
if wpa12 {
let mut flags = 0;
if !options.passphrase_is_prehashed {
flags |= 1;
}
let mut pfi = PassphraseInfo {
len: options.passphrase.len() as _,
flags,
passphrase: [0; 64],
};
pfi.passphrase[..options.passphrase.len()].copy_from_slice(options.passphrase);
Timer::after_millis(3).await;
self.ioctl(IoctlType::Set, Ioctl::SetWsecPmk, 0, &mut pfi.to_bytes())
.await;
}
if wpa3 {
let mut pfi = SaePassphraseInfo {
len: options.passphrase.len() as _,
passphrase: [0; 128],
};
pfi.passphrase[..options.passphrase.len()].copy_from_slice(options.passphrase);
Timer::after_millis(3).await;
self.set_iovar("sae_password", &pfi.to_bytes()).await;
}
self.ioctl_set_u32(Ioctl::SetInfra, 0, 1).await;
self.ioctl_set_u32(Ioctl::SetAuth, 0, auth).await;
self.set_iovar_u32("mfp", mfp).await;
self.ioctl_set_u32(Ioctl::SetWpaAuth, 0, wpa_auth).await;
}
let mut i = SsidInfo {
len: ssid.len() as _,
ssid: [0; 32],
};
i.ssid[..ssid.len()].copy_from_slice(ssid.as_bytes());
let secure_network = options.auth != JoinAuth::Open;
self.secure_network.store(secure_network, Relaxed);
self.wait_for_join(i, secure_network).await
}
async fn wait_for_join(&mut self, i: SsidInfo, secure_network: bool) -> Result<(), JoinError> {
struct UnsubscribeOnDrop<'a>(&'a Events);
impl Drop for UnsubscribeOnDrop<'_> {
fn drop(&mut self) {
self.0.mask.disable_all();
}
}
let _uod = UnsubscribeOnDrop(&self.events);
self.events.mask.enable(&[Event::SET_SSID, Event::AUTH, Event::PSK_SUP]);
let mut subscriber = self.events.queue.subscriber().unwrap();
self.ioctl(IoctlType::Set, Ioctl::SetSsid, 0, &mut i.to_bytes()).await;
let result = loop {
let msg = subscriber.next_message_pure().await;
let status = EStatus::from(msg.header.status as u8);
match (msg.header.event_type, status, secure_network) {
(Event::SET_SSID, EStatus::SUCCESS, false) => break Ok(()),
(Event::SET_SSID, EStatus::NO_NETWORKS, _) => break Err(JoinError::NetworkNotFound),
(Event::SET_SSID, status, _) if status != EStatus::SUCCESS => {
break Err(JoinError::JoinFailure(status as u8));
}
(Event::PSK_SUP, EStatus::ABORT, true) => {}
(Event::PSK_SUP, EStatus::UNSOLICITED, true) => break Ok(()),
(Event::PSK_SUP, _, true) | (Event::AUTH, EStatus::FAIL, true) => {
break Err(JoinError::AuthenticationFailure);
}
_ => {}
};
};
match result {
Ok(()) => debug!("JOINED"),
Err(JoinError::JoinFailure(status)) => debug!("JOIN failed: status={}", status),
Err(JoinError::NetworkNotFound) => debug!("JOIN failed: network not found"),
Err(JoinError::AuthenticationFailure) => debug!("JOIN failed: authentication failure"),
};
result
}
pub async fn gpio_set(&mut self, gpio_n: u8, gpio_en: bool) {
assert!(gpio_n < 3);
self.set_iovar_u32x2("gpioout", 1 << gpio_n, if gpio_en { 1 << gpio_n } else { 0 })
.await
}
pub async fn start_ap_open(&mut self, ssid: &str, channel: u8) {
self.start_ap(ssid, "", Security::OPEN, channel).await;
}
pub async fn start_ap_wpa2(&mut self, ssid: &str, passphrase: &str, channel: u8) {
self.start_ap(ssid, passphrase, Security::WPA2_AES_PSK, channel).await;
}
async fn start_ap(&mut self, ssid: &str, passphrase: &str, security: Security, channel: u8) {
if security != Security::OPEN
&& (passphrase.as_bytes().len() < MIN_PSK_LEN || passphrase.as_bytes().len() > MAX_PSK_LEN)
{
panic!("Passphrase is too short or too long");
}
self.down().await;
self.set_iovar_u32("apsta", 0).await;
self.up().await;
self.ioctl_set_u32(Ioctl::SetAuth, 0, AUTH_OPEN).await;
self.ioctl_set_u32(Ioctl::SetAp, 0, 1).await;
let mut i = SsidInfoWithIndex {
index: 0,
ssid_info: SsidInfo {
len: ssid.as_bytes().len() as _,
ssid: [0; 32],
},
};
i.ssid_info.ssid[..ssid.as_bytes().len()].copy_from_slice(ssid.as_bytes());
self.set_iovar("bsscfg:ssid", &i.to_bytes()).await;
self.ioctl_set_u32(Ioctl::SetChannel, 0, channel as u32).await;
self.set_iovar_u32x2("bsscfg:wsec", 0, (security as u32) & 0xFF).await;
if security != Security::OPEN {
self.set_iovar_u32x2("bsscfg:wpa_auth", 0, 0x0084).await;
Timer::after_millis(100).await;
let mut pfi = PassphraseInfo {
len: passphrase.as_bytes().len() as _,
flags: 1, passphrase: [0; 64],
};
pfi.passphrase[..passphrase.as_bytes().len()].copy_from_slice(passphrase.as_bytes());
self.ioctl(IoctlType::Set, Ioctl::SetWsecPmk, 0, &mut pfi.to_bytes())
.await;
}
self.set_iovar_u32("2g_mrate", 11000000 / 500000).await;
self.set_iovar_u32x2("bss", 0, 1).await; }
pub async fn close_ap(&mut self) {
self.set_iovar_u32x2("bss", 0, 0).await;
self.ioctl_set_u32(Ioctl::SetAp, 0, 0).await;
self.down().await;
self.set_iovar_u32("apsta", 1).await;
self.up().await;
}
pub async fn add_multicast_address(&mut self, address: [u8; 6]) -> Result<usize, AddMulticastAddressError> {
if address[0] & 0x01 != 1 {
return Err(AddMulticastAddressError::NotMulticast);
}
let mut buf = [0; 64];
self.get_iovar("mcast_list", &mut buf).await;
let n = u32::from_le_bytes(buf[..4].try_into().unwrap()) as usize;
let (used, free) = buf[4..].split_at_mut(n * 6);
if used.chunks(6).any(|a| a == address) {
return Ok(n);
}
if free.len() < 6 {
return Err(AddMulticastAddressError::NoFreeSlots);
}
free[..6].copy_from_slice(&address);
let n = n + 1;
buf[..4].copy_from_slice(&(n as u32).to_le_bytes());
self.set_iovar_v::<80>("mcast_list", &buf).await;
Ok(n)
}
pub async fn list_multicast_addresses(&mut self, result: &mut [[u8; 6]; 10]) -> usize {
let mut buf = [0; 64];
self.get_iovar("mcast_list", &mut buf).await;
let n = u32::from_le_bytes(buf[..4].try_into().unwrap()) as usize;
let used = &buf[4..][..n * 6];
for (addr, output) in zip(used.chunks(6), result.iter_mut()) {
output.copy_from_slice(addr)
}
n
}
pub async fn get_rssi(&mut self) -> i32 {
let mut rssi_buf = [0u8; 4];
let n = self.ioctl(IoctlType::Get, Ioctl::GetRssi, 0, &mut rssi_buf).await;
assert_eq!(n, 4);
i32::from_ne_bytes(rssi_buf)
}
async fn set_iovar_u32x2(&mut self, name: &str, val1: u32, val2: u32) {
let mut buf = [0; 8];
buf[0..4].copy_from_slice(&val1.to_le_bytes());
buf[4..8].copy_from_slice(&val2.to_le_bytes());
self.set_iovar(name, &buf).await
}
async fn set_iovar_u32(&mut self, name: &str, val: u32) {
self.set_iovar(name, &val.to_le_bytes()).await
}
async fn get_iovar_u32(&mut self, name: &str) -> u32 {
let mut buf = [0; 4];
let len = self.get_iovar(name, &mut buf).await;
assert_eq!(len, 4);
u32::from_le_bytes(buf)
}
async fn set_iovar(&mut self, name: &str, val: &[u8]) {
self.set_iovar_v::<196>(name, val).await
}
async fn set_iovar_v<const BUFSIZE: usize>(&mut self, name: &str, val: &[u8]) {
debug!("iovar set {} = {:02x}", name, Bytes(val));
let mut buf = [0; BUFSIZE];
buf[..name.len()].copy_from_slice(name.as_bytes());
buf[name.len()] = 0;
buf[name.len() + 1..][..val.len()].copy_from_slice(val);
let total_len = name.len() + 1 + val.len();
self.ioctl_inner(IoctlType::Set, Ioctl::SetVar, 0, &mut buf[..total_len])
.await;
}
async fn get_iovar(&mut self, name: &str, res: &mut [u8]) -> usize {
debug!("iovar get {}", name);
let mut buf = [0; 64];
buf[..name.len()].copy_from_slice(name.as_bytes());
buf[name.len()] = 0;
let total_len = max(name.len() + 1, res.len());
let res_len = self
.ioctl_inner(IoctlType::Get, Ioctl::GetVar, 0, &mut buf[..total_len])
.await;
let out_len = min(res.len(), res_len);
res[..out_len].copy_from_slice(&buf[..out_len]);
out_len
}
async fn ioctl_set_u32(&mut self, cmd: Ioctl, iface: u32, val: u32) {
let mut buf = val.to_le_bytes();
self.ioctl(IoctlType::Set, cmd, iface, &mut buf).await;
}
async fn ioctl(&mut self, kind: IoctlType, cmd: Ioctl, iface: u32, buf: &mut [u8]) -> usize {
if kind == IoctlType::Set {
debug!("ioctl set {:?} iface {} = {:02x}", cmd, iface, Bytes(buf));
}
let n = self.ioctl_inner(kind, cmd, iface, buf).await;
n
}
async fn ioctl_inner(&mut self, kind: IoctlType, cmd: Ioctl, iface: u32, buf: &mut [u8]) -> usize {
struct CancelOnDrop<'a>(&'a IoctlState);
impl CancelOnDrop<'_> {
fn defuse(self) {
core::mem::forget(self);
}
}
impl Drop for CancelOnDrop<'_> {
fn drop(&mut self) {
self.0.cancel_ioctl();
}
}
let ioctl = CancelOnDrop(self.ioctl_state);
let resp_len = ioctl.0.do_ioctl(kind, cmd, iface, buf).await;
ioctl.defuse();
resp_len
}
pub async fn scan(&mut self, scan_opts: ScanOptions) -> Scanner<'_> {
const SCANTYPE_ACTIVE: u8 = 0;
const SCANTYPE_PASSIVE: u8 = 1;
let dwell_time = match scan_opts.dwell_time {
None => !0,
Some(t) => {
let mut t = t.as_millis() as u32;
if t == !0 {
t = !0 - 1;
}
t
}
};
let mut active_time = !0;
let mut passive_time = !0;
let scan_type = match scan_opts.scan_type {
ScanType::Active => {
active_time = dwell_time;
SCANTYPE_ACTIVE
}
ScanType::Passive => {
passive_time = dwell_time;
SCANTYPE_PASSIVE
}
};
let scan_params = ScanParams {
version: 1,
action: 1,
sync_id: 1,
ssid_len: scan_opts.ssid.as_ref().map(|e| e.as_bytes().len() as u32).unwrap_or(0),
ssid: scan_opts
.ssid
.map(|e| {
let mut ssid = [0; 32];
ssid[..e.as_bytes().len()].copy_from_slice(e.as_bytes());
ssid
})
.unwrap_or([0; 32]),
bssid: scan_opts.bssid.unwrap_or([0xff; 6]),
bss_type: 2,
scan_type,
nprobes: scan_opts.nprobes.unwrap_or(!0).into(),
active_time,
passive_time,
home_time: scan_opts.home_time.map(|e| e.as_millis() as u32).unwrap_or(!0),
channel_num: 0,
channel_list: [0; 1],
};
self.events.mask.enable(&[Event::ESCAN_RESULT]);
let subscriber = self.events.queue.subscriber().unwrap();
self.set_iovar_v::<256>("escan", &scan_params.to_bytes()).await;
Scanner {
subscriber,
events: &self.events,
}
}
pub async fn leave(&mut self) {
self.ioctl(IoctlType::Set, Ioctl::Disassoc, 0, &mut []).await;
info!("Disassociated")
}
pub async fn address(&mut self) -> [u8; 6] {
let mut mac_addr = [0; 6];
assert_eq!(self.get_iovar("cur_etheraddr", &mut mac_addr).await, 6);
mac_addr
}
}
pub struct Scanner<'a> {
subscriber: EventSubscriber<'a>,
events: &'a Events,
}
impl Scanner<'_> {
pub async fn next(&mut self) -> Option<BssInfo> {
let event = self.subscriber.next_message_pure().await;
if event.header.status != EStatus::PARTIAL {
self.events.mask.disable_all();
return None;
}
if let events::Payload::BssInfo(bss) = event.payload {
Some(bss)
} else {
None
}
}
}
impl Drop for Scanner<'_> {
fn drop(&mut self) {
self.events.mask.disable_all();
}
}