use std::borrow::Cow;
use tosca::parameters::ParametersValues;
use tokio::sync::mpsc::{self, Receiver};
use tracing::{error, warn};
use crate::device::{Device, Devices};
use crate::discovery::Discovery;
use crate::error::{Error, ErrorKind};
use crate::events::{EventPayload, EventsRunner};
use crate::policy::Policy;
use crate::request::Request;
use crate::response::Response;
fn sender_error(error: impl Into<Cow<'static, str>>) -> Error {
Error::new(ErrorKind::Sender, error)
}
#[derive(Debug, PartialEq)]
pub struct RequestSender<'controller> {
controller: &'controller Controller,
request: &'controller Request,
skip: bool,
}
impl RequestSender<'_> {
pub async fn send(&self) -> Result<Response, Error> {
self.request
.retrieve_response(self.skip, || async { self.request.plain_send().await })
.await
}
pub async fn send_with_parameters(
&self,
parameters: &ParametersValues<'_>,
) -> Result<Response, Error> {
if self.request.parameters_data.is_empty() {
warn!("The request does not have input parameters.");
return self.send().await;
}
self.request
.retrieve_response(self.skip, || async {
self.request.create_response(parameters).await
})
.await
}
}
#[derive(Debug, PartialEq)]
pub struct DeviceSender<'controller> {
controller: &'controller Controller,
device: &'controller Device,
id: usize,
}
impl DeviceSender<'_> {
pub fn request(&self, route: &str) -> Result<RequestSender<'_>, Error> {
let request = self.device.request(route).ok_or_else(|| {
sender_error(format!(
"Error in retrieving the request with route `{route}`."
))
})?;
let skip = if request.hazards.is_empty() {
false
} else {
self.evaluate_privacy_policy(request, route)
};
Ok(RequestSender {
controller: self.controller,
request,
skip,
})
}
fn evaluate_privacy_policy(&self, request: &Request, route: &str) -> bool {
let mut skip = false;
let global_blocked_hazards = self
.controller
.privacy_policy
.global_blocked_hazards(&request.hazards);
let local_blocked_hazards = self
.controller
.privacy_policy
.local_blocked_hazards(self.id, &request.hazards);
if !global_blocked_hazards.is_empty() {
warn!(
"The {route} is skipped because it contains the global blocked hazards: {:?}",
global_blocked_hazards
);
skip = true;
}
if !local_blocked_hazards.is_empty() {
warn!(
"The {route} is skipped because the device contains the local blocked hazards: {:?}",
local_blocked_hazards
);
skip = true;
}
skip
}
}
#[derive(Debug, PartialEq)]
pub struct Controller {
discovery: Discovery,
devices: Devices,
privacy_policy: Policy,
}
impl Controller {
#[must_use]
#[inline]
pub fn new(discovery: Discovery) -> Self {
Self {
discovery,
devices: Devices::new(),
privacy_policy: Policy::init(),
}
}
#[must_use]
#[inline]
pub fn from_devices(discovery: Discovery, devices: Devices) -> Self {
Self {
discovery,
devices,
privacy_policy: Policy::init(),
}
}
#[must_use]
#[inline]
pub fn policy(mut self, privacy_policy: Policy) -> Self {
self.privacy_policy = privacy_policy;
self
}
#[inline]
pub fn change_policy(&mut self, privacy_policy: Policy) {
self.privacy_policy = privacy_policy;
}
#[inline]
pub async fn discover(&mut self) -> Result<(), Error> {
self.devices = self.discovery.discover().await?;
Ok(())
}
pub async fn start_event_receivers(
&mut self,
buffer_size: usize,
) -> Result<Receiver<EventPayload>, Error> {
let (tx, rx) = mpsc::channel(buffer_size);
let mut started_count = 0;
for (id, device) in self.devices.iter_mut().enumerate() {
if device.event_handle.is_some() {
warn!("Skip device with id `{id}`: event receiver already started");
continue;
}
let Some(ref events) = device.events else {
warn!("Skip device with id `{id}`: it does not support events");
continue;
};
EventsRunner::run_global_subscriber(events, id, tx.clone()).await?;
started_count += 1;
}
if started_count == 0 {
return Err(Error::new(
ErrorKind::Events,
"No event receiver tasks has started",
));
}
Ok(rx)
}
#[must_use]
pub const fn devices(&self) -> &Devices {
&self.devices
}
#[must_use]
pub const fn devices_mut(&mut self) -> &mut Devices {
&mut self.devices
}
pub fn device(&self, id: usize) -> Result<DeviceSender<'_>, Error> {
if self.devices.is_empty() {
return Err(sender_error("No devices found."));
}
let device = self.devices.get(id).ok_or_else(|| {
sender_error(format!(
"Error in retrieving the device with identifier {id}."
))
})?;
Ok(DeviceSender {
controller: self,
device,
id,
})
}
pub async fn shutdown(self) {
for device in self.devices {
if let Some(events) = device.events {
events.cancellation_token.cancel();
}
if let Some(event_handle) = device.event_handle {
if let Err(e) = event_handle.await {
error!("Failed to await the event task: {e}");
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use tracing::warn;
use tosca::hazards::{Hazard, Hazards};
use tosca::parameters::ParametersValues;
use tosca::response::{OkResponse, SerialResponse};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::json;
use serial_test::serial;
use crate::device::Devices;
use crate::error::Error;
use crate::policy::Policy;
use crate::response::Response;
use crate::device::tests::{create_light, create_unknown};
use crate::discovery::tests::configure_discovery;
use crate::tests::{Brightness, check_function_with_device};
use super::{Controller, DeviceSender, RequestSender, sender_error};
#[test]
fn empty_controller() {
let controller = Controller::new(configure_discovery());
assert_eq!(
controller,
Controller {
discovery: configure_discovery(),
devices: Devices::new(),
privacy_policy: Policy::init(),
}
);
assert_eq!(controller.device(0), Err(sender_error("No devices found.")));
}
#[test]
fn controller_from_devices() {
let devices = Devices::from_devices(vec![create_light(), create_unknown()]);
let controller = Controller::from_devices(configure_discovery(), devices);
assert_eq!(
controller,
Controller {
discovery: configure_discovery(),
devices: Devices::from_devices(vec![create_light(), create_unknown()]),
privacy_policy: Policy::init(),
}
);
}
async fn check_ok_response_plain(device_sender: &DeviceSender<'_>, route: &str) {
check_ok_response(device_sender, route, async move |request_sender| {
request_sender.send().await
})
.await;
}
async fn check_ok_response_with_parameters(
device_sender: &DeviceSender<'_>,
route: &str,
parameters: &ParametersValues<'_>,
) {
check_ok_response(device_sender, route, async move |request_sender| {
request_sender.send_with_parameters(parameters).await
})
.await;
}
async fn check_ok_response<'controller, 'a, F>(
device_sender: &'a DeviceSender<'controller>,
route: &'a str,
get_response: F,
) where
F: AsyncFnOnce(RequestSender<'controller>) -> Result<Response, Error>,
'a: 'controller,
{
let request = device_sender.request(route).unwrap();
let response = get_response(request).await.unwrap();
if let Response::OkBody(response) = response {
let ok_response = response.parse_body().await.unwrap();
assert_eq!(ok_response, OkResponse::ok());
} else {
assert!(
matches!(response, Response::Skipped),
"Should be a blocked global `LogEnergyConsumption` for `/off` request"
);
}
}
async fn check_serial_response_plain<T: Serialize + DeserializeOwned + Debug + PartialEq>(
device_sender: &DeviceSender<'_>,
route: &str,
value: T,
) {
check_serial_response(
device_sender,
route,
async move |request_sender| request_sender.send().await,
value,
)
.await;
}
async fn check_serial_response_with_parameters<
T: Serialize + DeserializeOwned + Debug + PartialEq,
>(
device_sender: &DeviceSender<'_>,
route: &str,
parameters: &ParametersValues<'_>,
value: T,
) {
check_serial_response(
device_sender,
route,
async move |request| request.send_with_parameters(parameters).await,
value,
)
.await;
}
async fn check_serial_response<'controller, 'a, F, T>(
device: &'a DeviceSender<'controller>,
route: &'a str,
get_response: F,
value: T,
) where
F: AsyncFnOnce(RequestSender<'controller>) -> Result<Response, Error>,
T: Serialize + DeserializeOwned + Debug + PartialEq,
'a: 'controller,
{
let request = device.request(route).unwrap();
let response = get_response(request).await.unwrap();
if let Response::SerialBody(response) = response {
let serial_response = response.parse_body::<T>().await.unwrap();
assert_eq!(serial_response, SerialResponse::new(value));
} else {
assert!(
matches!(response, Response::Skipped),
"Should be a blocked local `FireHazard` for `/toggle` request"
);
}
}
async fn controller_checks(controller: Controller) {
assert_eq!(
controller.device(1),
Err(sender_error(
"Error in retrieving the device with identifier 1."
))
);
let device_sender = controller.device(0).unwrap();
assert_eq!(
device_sender.request("/wrong"),
Err(sender_error(
"Error in retrieving the request with route `/wrong`."
))
);
check_ok_response_plain(&device_sender, "/on").await;
check_ok_response_plain(&device_sender, "/off").await;
check_serial_response_plain(
&device_sender,
"/toggle",
json!({
"brightness": 0,
}),
)
.await;
let mut parameters = ParametersValues::new();
parameters.u64("brightness", 5);
check_ok_response_with_parameters(&device_sender, "/on", ¶meters).await;
check_ok_response_with_parameters(&device_sender, "/off", ¶meters).await;
check_serial_response_with_parameters(
&device_sender,
"/toggle",
¶meters,
Brightness { brightness: 5 },
)
.await;
}
#[inline]
async fn controller_without_policy() {
let mut controller = Controller::new(configure_discovery());
controller.discover().await.unwrap();
controller_checks(controller).await;
}
#[inline]
async fn controller_with_policy() {
let global_hazards = Hazards::new().insert(Hazard::LogEnergyConsumption);
let local_hazards = Hazards::new().insert(Hazard::FireHazard);
let policy = Policy::new(global_hazards).block_device_on_hazards(0, local_hazards);
let mut controller = Controller::new(configure_discovery()).policy(policy);
controller.discover().await.unwrap();
controller_checks(controller).await;
}
#[inline]
async fn run_controller_function<F, Fut>(name: &str, function: F)
where
F: FnOnce() -> Fut,
Fut: Future<Output = ()>,
{
if option_env!("CI").is_some() {
warn!(
"Skipping test on CI: {} can run only on systems that expose physical MAC addresses.",
name
);
} else {
check_function_with_device(|| async {
function().await;
})
.await;
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn test_without_policy_controller() {
run_controller_function("controller_without_policy", || async {
controller_without_policy().await;
})
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn test_with_policy_controller() {
run_controller_function("controller_with_policy", || async {
controller_with_policy().await;
})
.await;
}
}