use std::{sync::Arc, time::Duration};
use tokio::sync::{Mutex, broadcast, mpsc, watch::Sender as WatchSender};
use tracing::{debug, instrument, warn};
use zbus::{Connection, zvariant::OwnedObjectPath};
use crate::{
Capabilities, Error, Status, WifiConfigurator,
bluez::{self, AppHandles},
rpc::{Command, Reassembler, Yielded, encode_response},
};
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum AuthorizeMode {
Required,
#[default]
NotRequired,
}
#[derive(Debug, Default, Clone)]
pub struct ImprovWifiConfig {
pub authorize: AuthorizeMode,
pub auth_timeout: Option<Duration>,
pub local_name: Option<String>,
}
#[derive(Debug)]
pub(crate) struct InnerState {
pub(crate) status: Status,
pub(crate) last_error: u8, pub(crate) rpc_result: Vec<u8>,
}
pub(crate) struct State<T> {
pub(crate) inner: Mutex<InnerState>,
pub(crate) capabilities: Capabilities,
pub(crate) configurator: T,
pub(crate) reassembler: Mutex<Reassembler>,
pub(crate) status_tx: broadcast::Sender<Status>,
pub(crate) error_tx: broadcast::Sender<u8>,
pub(crate) rpc_result_tx: broadcast::Sender<Vec<u8>>,
pub(crate) auth_reset_tx: WatchSender<()>,
pub(crate) provisioned_tx: WatchSender<bool>,
pub(crate) auth_required: bool,
}
impl<T> State<T>
where
T: WifiConfigurator,
{
pub(crate) async fn current_state_byte(&self) -> u8 {
self.inner.lock().await.status.as_byte()
}
pub(crate) async fn error_byte(&self) -> u8 {
self.inner.lock().await.last_error
}
pub(crate) async fn rpc_result_bytes(&self) -> Vec<u8> {
self.inner.lock().await.rpc_result.clone()
}
pub(crate) async fn set_status(&self, new: Status) {
{
let mut inner = self.inner.lock().await;
if inner.status == new {
return;
}
inner.status = new;
}
let _ = self.status_tx.send(new);
if new == Status::Authorized {
let _ = self.auth_reset_tx.send(());
}
if new == Status::Provisioned {
let _ = self.provisioned_tx.send(true);
}
}
async fn set_error(&self, err: Option<Error>) {
let byte = err.map_or(0, Error::as_byte);
{
let mut inner = self.inner.lock().await;
if inner.last_error == byte {
return;
}
inner.last_error = byte;
}
let _ = self.error_tx.send(byte);
}
async fn set_rpc_result(&self, bytes: Vec<u8>) {
{
let mut inner = self.inner.lock().await;
inner.rpc_result = bytes.clone();
}
let _ = self.rpc_result_tx.send(bytes);
}
#[instrument(level = "debug", skip(self, write))]
pub(crate) async fn handle_write(&self, write: Vec<u8>) {
let yielded = self.reassembler.lock().await.feed(&write);
for item in yielded {
match item {
Yielded::Command(cmd) => self.dispatch(cmd).await,
Yielded::Error(parse_err) => {
warn!(?parse_err, "RPC parse error");
let mapped = match parse_err {
crate::rpc::ParseError::UnknownCommand(_) => Error::UnknownRPC,
_ => Error::InvalidRPC,
};
self.set_error(Some(mapped)).await;
}
}
}
}
async fn dispatch(&self, cmd: Command) {
debug!(?cmd, "RPC command");
match cmd {
Command::Identify => {
self.set_error(None).await;
if let Err(err) = self.configurator.identify().await {
self.set_error(Some(err)).await;
}
}
Command::DeviceInfo => {
self.respond(
0x03,
self.configurator
.device_info()
.await
.map(|i| i.into_strings()),
)
.await
}
Command::Scan => {
let res = self.configurator.scan().await.map(|nets| {
let mut out = Vec::with_capacity(nets.len() * 3);
for n in nets {
out.push(n.ssid);
out.push(n.rssi.to_string());
out.push(n.auth);
}
out
});
self.respond(0x04, res).await;
}
Command::GetHostname => {
let res = self.configurator.get_hostname().await.map(|h| vec![h]);
self.respond(0x05, res).await;
}
Command::SetHostname(name) => {
if !self.is_authorized().await {
self.set_error(Some(Error::NotAuthorized)).await;
return;
}
let _ = self.auth_reset_tx.send(());
match self.configurator.set_hostname(name.clone()).await {
Ok(()) => {
self.set_error(None).await;
self.set_rpc_result(encode_response(0x05, &[name])).await;
}
Err(err) => self.set_error(Some(err)).await,
}
}
Command::GetDeviceName => {
let res = self.configurator.get_device_name().await.map(|n| vec![n]);
self.respond(0x06, res).await;
}
Command::SetDeviceName(name) => {
if !self.is_authorized().await {
self.set_error(Some(Error::NotAuthorized)).await;
return;
}
let _ = self.auth_reset_tx.send(());
match self.configurator.set_device_name(name.clone()).await {
Ok(()) => {
self.set_error(None).await;
self.set_rpc_result(encode_response(0x06, &[name])).await;
}
Err(err) => self.set_error(Some(err)).await,
}
}
Command::SendWifiSettings { ssid, password } => {
if !self.is_authorized().await {
self.set_error(Some(Error::NotAuthorized)).await;
return;
}
self.set_error(None).await;
self.set_status(Status::Provisioning).await;
match self.configurator.provision(ssid, password).await {
Ok(strings) => {
self.set_rpc_result(encode_response(0x01, &strings)).await;
self.set_status(Status::Provisioned).await;
}
Err(err) => {
self.set_error(Some(err)).await;
self.set_status(Status::Authorized).await;
}
}
}
}
}
async fn respond(&self, command_id: u8, res: Result<Vec<String>, Error>) {
match res {
Ok(strings) => {
self.set_error(None).await;
self.set_rpc_result(encode_response(command_id, &strings))
.await;
}
Err(err) => self.set_error(Some(err)).await,
}
}
async fn is_authorized(&self) -> bool {
matches!(self.inner.lock().await.status, Status::Authorized)
}
}
pub struct ImprovWifi<T: WifiConfigurator + 'static> {
state: Arc<State<T>>,
handles: AppHandles<T>,
}
impl<T> ImprovWifi<T>
where
T: WifiConfigurator + 'static,
{
pub async fn install(
connection: Connection,
adapter_path: OwnedObjectPath,
configurator: T,
config: ImprovWifiConfig,
) -> Result<Self, Error> {
let handles = bluez::install(connection, adapter_path, configurator, config).await?;
let state = handles.state.clone();
Ok(Self { state, handles })
}
pub async fn authorize(&self) {
self.state.set_status(Status::Authorized).await;
}
pub fn auth_handle(&self) -> AuthHandle {
AuthHandle {
tx: self.handles.auth_tx.clone(),
}
}
pub async fn run(self) -> Result<(), Error> {
bluez::run(self.handles).await
}
}
#[derive(Clone, Debug)]
pub struct AuthHandle {
tx: mpsc::UnboundedSender<()>,
}
impl AuthHandle {
pub fn authorize(&self) {
let _ = self.tx.send(());
}
}
pub async fn find_adapter(
connection: &Connection,
name: Option<&str>,
) -> Result<OwnedObjectPath, Error> {
bluez::find_adapter(connection, name).await
}
pub async fn power_on_adapter(
connection: &Connection,
adapter_path: &OwnedObjectPath,
) -> Result<(), Error> {
bluez::power_on_adapter(connection, adapter_path).await
}