1use std::borrow::Cow;
2
3use tosca::parameters::ParametersValues;
4
5use tokio::sync::mpsc::{self, Receiver};
6
7use tracing::{error, warn};
8
9use crate::device::{Device, Devices};
10use crate::discovery::Discovery;
11use crate::error::{Error, ErrorKind};
12use crate::events::{EventPayload, EventsRunner};
13use crate::policy::Policy;
14use crate::request::Request;
15use crate::response::Response;
16
17fn sender_error(error: impl Into<Cow<'static, str>>) -> Error {
20 Error::new(ErrorKind::Sender, error)
21}
22
23#[derive(Debug, PartialEq)]
25pub struct RequestSender<'controller> {
26 controller: &'controller Controller,
27 request: &'controller Request,
28 skip: bool,
29}
30
31impl RequestSender<'_> {
32 pub async fn send(&self) -> Result<Response, Error> {
39 self.request
40 .retrieve_response(self.skip, || async { self.request.plain_send().await })
41 .await
42 }
43
44 pub async fn send_with_parameters(
52 &self,
53 parameters: &ParametersValues<'_>,
54 ) -> Result<Response, Error> {
55 if self.request.parameters_data.is_empty() {
56 warn!("The request does not have input parameters.");
57 return self.send().await;
58 }
59
60 self.request
61 .retrieve_response(self.skip, || async {
62 self.request.create_response(parameters).await
63 })
64 .await
65 }
66}
67
68#[derive(Debug, PartialEq)]
72pub struct DeviceSender<'controller> {
73 controller: &'controller Controller,
74 device: &'controller Device,
75 id: usize,
76}
77
78impl DeviceSender<'_> {
79 pub fn request(&self, route: &str) -> Result<RequestSender<'_>, Error> {
88 let request = self.device.request(route).ok_or_else(|| {
89 sender_error(format!(
90 "Error in retrieving the request with route `{route}`."
91 ))
92 })?;
93
94 let skip = if request.hazards.is_empty() {
95 false
96 } else {
97 self.evaluate_privacy_policy(request, route)
98 };
99
100 Ok(RequestSender {
101 controller: self.controller,
102 request,
103 skip,
104 })
105 }
106
107 fn evaluate_privacy_policy(&self, request: &Request, route: &str) -> bool {
108 let mut skip = false;
109
110 let global_blocked_hazards = self
111 .controller
112 .privacy_policy
113 .global_blocked_hazards(&request.hazards);
114
115 let local_blocked_hazards = self
116 .controller
117 .privacy_policy
118 .local_blocked_hazards(self.id, &request.hazards);
119
120 if !global_blocked_hazards.is_empty() {
121 warn!(
122 "The {route} is skipped because it contains the global blocked hazards: {:?}",
123 global_blocked_hazards
124 );
125 skip = true;
126 }
127
128 if !local_blocked_hazards.is_empty() {
129 warn!(
130 "The {route} is skipped because the device contains the local blocked hazards: {:?}",
131 local_blocked_hazards
132 );
133 skip = true;
134 }
135
136 skip
137 }
138}
139
140#[derive(Debug, PartialEq)]
150pub struct Controller {
151 discovery: Discovery,
152 devices: Devices,
153 privacy_policy: Policy,
154}
155
156impl Controller {
157 #[must_use]
159 #[inline]
160 pub fn new(discovery: Discovery) -> Self {
161 Self {
162 discovery,
163 devices: Devices::new(),
164 privacy_policy: Policy::init(),
165 }
166 }
167
168 #[must_use]
173 #[inline]
174 pub fn from_devices(discovery: Discovery, devices: Devices) -> Self {
175 Self {
176 discovery,
177 devices,
178 privacy_policy: Policy::init(),
179 }
180 }
181
182 #[must_use]
184 #[inline]
185 pub fn policy(mut self, privacy_policy: Policy) -> Self {
186 self.privacy_policy = privacy_policy;
187 self
188 }
189
190 #[inline]
192 pub fn change_policy(&mut self, privacy_policy: Policy) {
193 self.privacy_policy = privacy_policy;
194 }
195
196 #[inline]
214 pub async fn discover(&mut self) -> Result<(), Error> {
215 self.devices = self.discovery.discover().await?;
216 Ok(())
217 }
218
219 pub async fn start_event_receivers(
240 &mut self,
241 buffer_size: usize,
242 ) -> Result<Receiver<EventPayload>, Error> {
243 let (tx, rx) = mpsc::channel(buffer_size);
244
245 let mut started_count = 0;
246 for (id, device) in self.devices.iter_mut().enumerate() {
247 if device.event_handle.is_some() {
248 warn!("Skip device with id `{id}`: event receiver already started");
249 continue;
250 }
251
252 let Some(ref events) = device.events else {
253 warn!("Skip device with id `{id}`: it does not support events");
254 continue;
255 };
256
257 EventsRunner::run_global_subscriber(events, id, tx.clone()).await?;
258
259 started_count += 1;
260 }
261
262 if started_count == 0 {
263 return Err(Error::new(
264 ErrorKind::Events,
265 "No event receiver tasks has started",
266 ));
267 }
268
269 Ok(rx)
270 }
271
272 #[must_use]
274 pub const fn devices(&self) -> &Devices {
275 &self.devices
276 }
277
278 #[must_use]
280 pub const fn devices_mut(&mut self) -> &mut Devices {
281 &mut self.devices
282 }
283
284 pub fn device(&self, id: usize) -> Result<DeviceSender<'_>, Error> {
291 if self.devices.is_empty() {
292 return Err(sender_error("No devices found."));
293 }
294
295 let device = self.devices.get(id).ok_or_else(|| {
296 sender_error(format!(
297 "Error in retrieving the device with identifier {id}."
298 ))
299 })?;
300 Ok(DeviceSender {
301 controller: self,
302 device,
303 id,
304 })
305 }
306
307 pub async fn shutdown(self) {
315 for device in self.devices {
317 if let Some(events) = device.events {
318 events.cancellation_token.cancel();
320 }
321
322 if let Some(event_handle) = device.event_handle {
323 if let Err(e) = event_handle.await {
325 error!("Failed to await the event task: {e}");
326 }
327 }
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use std::fmt::Debug;
335
336 use tracing::warn;
337
338 use tosca::hazards::{Hazard, Hazards};
339 use tosca::parameters::ParametersValues;
340 use tosca::response::{OkResponse, SerialResponse};
341
342 use serde::{Serialize, de::DeserializeOwned};
343 use serde_json::json;
344
345 use serial_test::serial;
346
347 use crate::device::Devices;
348 use crate::error::Error;
349 use crate::policy::Policy;
350 use crate::response::Response;
351
352 use crate::device::tests::{create_light, create_unknown};
353 use crate::discovery::tests::configure_discovery;
354 use crate::tests::{Brightness, check_function_with_device};
355
356 use super::{Controller, DeviceSender, RequestSender, sender_error};
357
358 #[test]
359 fn empty_controller() {
360 let controller = Controller::new(configure_discovery());
361
362 assert_eq!(
363 controller,
364 Controller {
365 discovery: configure_discovery(),
366 devices: Devices::new(),
367 privacy_policy: Policy::init(),
368 }
369 );
370
371 assert_eq!(controller.device(0), Err(sender_error("No devices found.")));
373 }
374
375 #[test]
376 fn controller_from_devices() {
377 let devices = Devices::from_devices(vec![create_light(), create_unknown()]);
378
379 let controller = Controller::from_devices(configure_discovery(), devices);
380
381 assert_eq!(
382 controller,
383 Controller {
384 discovery: configure_discovery(),
385 devices: Devices::from_devices(vec![create_light(), create_unknown()]),
386 privacy_policy: Policy::init(),
387 }
388 );
389 }
390
391 async fn check_ok_response_plain(device_sender: &DeviceSender<'_>, route: &str) {
392 check_ok_response(device_sender, route, async move |request_sender| {
393 request_sender.send().await
394 })
395 .await;
396 }
397
398 async fn check_ok_response_with_parameters(
399 device_sender: &DeviceSender<'_>,
400 route: &str,
401 parameters: &ParametersValues<'_>,
402 ) {
403 check_ok_response(device_sender, route, async move |request_sender| {
404 request_sender.send_with_parameters(parameters).await
405 })
406 .await;
407 }
408
409 async fn check_ok_response<'controller, 'a, F>(
410 device_sender: &'a DeviceSender<'controller>,
411 route: &'a str,
412 get_response: F,
413 ) where
414 F: AsyncFnOnce(RequestSender<'controller>) -> Result<Response, Error>,
415 'a: 'controller,
416 {
417 let request = device_sender.request(route).unwrap();
418
419 let response = get_response(request).await.unwrap();
420 if let Response::OkBody(response) = response {
421 let ok_response = response.parse_body().await.unwrap();
422 assert_eq!(ok_response, OkResponse::ok());
423 } else {
424 assert!(
425 matches!(response, Response::Skipped),
426 "Should be a blocked global `LogEnergyConsumption` for `/off` request"
427 );
428 }
429 }
430
431 async fn check_serial_response_plain<T: Serialize + DeserializeOwned + Debug + PartialEq>(
432 device_sender: &DeviceSender<'_>,
433 route: &str,
434 value: T,
435 ) {
436 check_serial_response(
437 device_sender,
438 route,
439 async move |request_sender| request_sender.send().await,
440 value,
441 )
442 .await;
443 }
444
445 async fn check_serial_response_with_parameters<
446 T: Serialize + DeserializeOwned + Debug + PartialEq,
447 >(
448 device_sender: &DeviceSender<'_>,
449 route: &str,
450 parameters: &ParametersValues<'_>,
451 value: T,
452 ) {
453 check_serial_response(
454 device_sender,
455 route,
456 async move |request| request.send_with_parameters(parameters).await,
457 value,
458 )
459 .await;
460 }
461
462 async fn check_serial_response<'controller, 'a, F, T>(
463 device: &'a DeviceSender<'controller>,
464 route: &'a str,
465 get_response: F,
466 value: T,
467 ) where
468 F: AsyncFnOnce(RequestSender<'controller>) -> Result<Response, Error>,
469 T: Serialize + DeserializeOwned + Debug + PartialEq,
470 'a: 'controller,
471 {
472 let request = device.request(route).unwrap();
473
474 let response = get_response(request).await.unwrap();
475 if let Response::SerialBody(response) = response {
476 let serial_response = response.parse_body::<T>().await.unwrap();
477 assert_eq!(serial_response, SerialResponse::new(value));
478 } else {
479 assert!(
480 matches!(response, Response::Skipped),
481 "Should be a blocked local `FireHazard` for `/toggle` request"
482 );
483 }
484 }
485
486 async fn controller_checks(controller: Controller) {
487 assert_eq!(
489 controller.device(1),
490 Err(sender_error(
491 "Error in retrieving the device with identifier 1."
492 ))
493 );
494
495 let device_sender = controller.device(0).unwrap();
497
498 assert_eq!(
500 device_sender.request("/wrong"),
501 Err(sender_error(
502 "Error in retrieving the request with route `/wrong`."
503 ))
504 );
505
506 check_ok_response_plain(&device_sender, "/on").await;
508
509 check_ok_response_plain(&device_sender, "/off").await;
511
512 check_serial_response_plain(
514 &device_sender,
515 "/toggle",
516 json!({
517 "brightness": 0,
518 }),
519 )
520 .await;
521
522 let mut parameters = ParametersValues::new();
524 parameters.u64("brightness", 5);
525
526 check_ok_response_with_parameters(&device_sender, "/on", ¶meters).await;
528
529 check_ok_response_with_parameters(&device_sender, "/off", ¶meters).await;
531
532 check_serial_response_with_parameters(
534 &device_sender,
535 "/toggle",
536 ¶meters,
537 Brightness { brightness: 5 },
538 )
539 .await;
540 }
541
542 #[inline]
543 async fn controller_without_policy() {
544 let mut controller = Controller::new(configure_discovery());
546
547 controller.discover().await.unwrap();
549
550 controller_checks(controller).await;
552 }
553
554 #[inline]
555 async fn controller_with_policy() {
556 let global_hazards = Hazards::new().insert(Hazard::LogEnergyConsumption);
558
559 let local_hazards = Hazards::new().insert(Hazard::FireHazard);
561
562 let policy = Policy::new(global_hazards).block_device_on_hazards(0, local_hazards);
564
565 let mut controller = Controller::new(configure_discovery()).policy(policy);
567
568 controller.discover().await.unwrap();
570
571 controller_checks(controller).await;
573 }
574
575 #[inline]
576 async fn run_controller_function<F, Fut>(name: &str, function: F)
577 where
578 F: FnOnce() -> Fut,
579 Fut: Future<Output = ()>,
580 {
581 if option_env!("CI").is_some() {
582 warn!(
583 "Skipping test on CI: {} can run only on systems that expose physical MAC addresses.",
584 name
585 );
586 } else {
587 check_function_with_device(|| async {
588 function().await;
589 })
590 .await;
591 }
592 }
593
594 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
595 #[serial]
596 async fn test_without_policy_controller() {
597 run_controller_function("controller_without_policy", || async {
598 controller_without_policy().await;
599 })
600 .await;
601 }
602
603 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
604 #[serial]
605 async fn test_with_policy_controller() {
606 run_controller_function("controller_with_policy", || async {
607 controller_with_policy().await;
608 })
609 .await;
610 }
611}