tosca_controller/
controller.rs

1use std::borrow::Cow;
2
3use tosca::parameters::ParametersValues;
4
5use tokio::sync::mpsc::{self, Receiver};
6
7use tracing::{error, warn};
8
9use crate::device::{Device, Devices};
10use crate::discovery::Discovery;
11use crate::error::{Error, ErrorKind};
12use crate::events::{EventPayload, EventsRunner};
13use crate::policy::Policy;
14use crate::request::Request;
15use crate::response::Response;
16
17// TODO: Use the MAC address as id.
18
19fn sender_error(error: impl Into<Cow<'static, str>>) -> Error {
20    Error::new(ErrorKind::Sender, error)
21}
22
23/// A request sender.
24#[derive(Debug, PartialEq)]
25pub struct RequestSender<'controller> {
26    controller: &'controller Controller,
27    request: &'controller Request,
28    skip: bool,
29}
30
31impl RequestSender<'_> {
32    /// Sends a request to a device and returns a [`Response`].
33    ///
34    /// # Errors
35    ///
36    /// Network failures or timeouts may prevent the request from being sent
37    /// and affect the returned response as well.
38    pub async fn send(&self) -> Result<Response, Error> {
39        self.request
40            .retrieve_response(self.skip, || async { self.request.plain_send().await })
41            .await
42    }
43
44    /// Sends a request to a device with the given [`ParametersValues`]
45    /// and returns a [`Response`].
46    ///
47    /// # Errors
48    ///
49    /// Network failures or timeouts may prevent the request from being sent
50    /// and affect the returned response as well.
51    pub async fn send_with_parameters(
52        &self,
53        parameters: &ParametersValues<'_>,
54    ) -> Result<Response, Error> {
55        if self.request.parameters_data.is_empty() {
56            warn!("The request does not have input parameters.");
57            return self.send().await;
58        }
59
60        self.request
61            .retrieve_response(self.skip, || async {
62                self.request.create_response(parameters).await
63            })
64            .await
65    }
66}
67
68/// A device sender.
69///
70/// It generates multiple unique request senders for a device.
71#[derive(Debug, PartialEq)]
72pub struct DeviceSender<'controller> {
73    controller: &'controller Controller,
74    device: &'controller Device,
75    id: usize,
76}
77
78impl DeviceSender<'_> {
79    /// Builds a [`RequestSender`] for the given route.
80    ///
81    /// The generated request sender is tightly bound to the device sender and
82    /// cannot function independently.
83    ///
84    /// # Errors
85    ///
86    /// An error is returned if the given route **does** not exist.
87    pub fn request(&self, route: &str) -> Result<RequestSender<'_>, Error> {
88        let request = self.device.request(route).ok_or_else(|| {
89            sender_error(format!(
90                "Error in retrieving the request with route `{route}`."
91            ))
92        })?;
93
94        let skip = if request.hazards.is_empty() {
95            false
96        } else {
97            self.evaluate_privacy_policy(request, route)
98        };
99
100        Ok(RequestSender {
101            controller: self.controller,
102            request,
103            skip,
104        })
105    }
106
107    fn evaluate_privacy_policy(&self, request: &Request, route: &str) -> bool {
108        let mut skip = false;
109
110        let global_blocked_hazards = self
111            .controller
112            .privacy_policy
113            .global_blocked_hazards(&request.hazards);
114
115        let local_blocked_hazards = self
116            .controller
117            .privacy_policy
118            .local_blocked_hazards(self.id, &request.hazards);
119
120        if !global_blocked_hazards.is_empty() {
121            warn!(
122                "The {route} is skipped because it contains the global blocked hazards: {:?}",
123                global_blocked_hazards
124            );
125            skip = true;
126        }
127
128        if !local_blocked_hazards.is_empty() {
129            warn!(
130                "The {route} is skipped because the device contains the local blocked hazards: {:?}",
131                local_blocked_hazards
132            );
133            skip = true;
134        }
135
136        skip
137    }
138}
139
140/// A controller for interacting with `tosca` devices.
141///
142/// The main functionalities include:
143///
144/// - Discovering `tosca` devices on the network and registering them in memory.
145/// - Sending requests to a specific device identified by its ID, awaiting a
146///   response, and forwarding it directly to the caller.
147/// - Controlling request sending by allowing or blocking requests based on the
148///   defined privacy policy.
149#[derive(Debug, PartialEq)]
150pub struct Controller {
151    discovery: Discovery,
152    devices: Devices,
153    privacy_policy: Policy,
154}
155
156impl Controller {
157    /// Creates a [`Controller`] from a [`Discovery`] configuration.
158    #[must_use]
159    #[inline]
160    pub fn new(discovery: Discovery) -> Self {
161        Self {
162            discovery,
163            devices: Devices::new(),
164            privacy_policy: Policy::init(),
165        }
166    }
167
168    /// Creates a [`Controller`] from a [`Discovery`] configuration and
169    /// an initial set of [`Devices`].
170    ///
171    /// This method is useful when [`Devices`] are retrieved from database.
172    #[must_use]
173    #[inline]
174    pub fn from_devices(discovery: Discovery, devices: Devices) -> Self {
175        Self {
176            discovery,
177            devices,
178            privacy_policy: Policy::init(),
179        }
180    }
181
182    /// Defines a [`Policy`] while constructing a [`Controller`].
183    #[must_use]
184    #[inline]
185    pub fn policy(mut self, privacy_policy: Policy) -> Self {
186        self.privacy_policy = privacy_policy;
187        self
188    }
189
190    /// Changes the [`Policy`].
191    #[inline]
192    pub fn change_policy(&mut self, privacy_policy: Policy) {
193        self.privacy_policy = privacy_policy;
194    }
195
196    /// Discovers all available [`Devices`] on the network.
197    ///
198    /// # Errors
199    ///
200    /// ## Discovery Errors
201    ///
202    /// During the discovery process, common errors include:
203    ///
204    /// - Inability to connect to the network
205    /// - Failure to disable a particular network interface
206    /// - Issues terminating the discovery process.
207    ///
208    /// ## Sending Requests Errors
209    ///
210    /// When sending a request to a device to retrieve its structure description
211    /// and routes, network failures or timeouts may prevent the request from
212    /// being sent and affect the returned response as well.
213    #[inline]
214    pub async fn discover(&mut self) -> Result<(), Error> {
215        self.devices = self.discovery.discover().await?;
216        Ok(())
217    }
218
219    /// Starts asynchronous event receiver tasks for all [`Device`]s that
220    /// support events.
221    ///
222    /// An event receiver task connects to the broker of a device
223    /// and subscribes to its topic.
224    /// When a device transmits an event to the broker, the task retrieves the
225    /// event payload from the broker, parses the data, and sends the relevant
226    /// content to the [`Receiver`] returned by this method.
227    ///
228    /// The `buffer_size` parameter specifies how many messages the event
229    /// receiver buffer can hold.
230    /// When the buffer is full, subsequent send attempts will wait until
231    /// a message is consumed from the channel.
232    ///
233    /// When the [`Receiver`] is dropped, all tasks terminate automatically.
234    ///
235    /// # Errors
236    ///
237    /// - No event receiver tasks has started
238    /// - An error occurred while subscribing to the broker topic of a device.
239    pub async fn start_event_receivers(
240        &mut self,
241        buffer_size: usize,
242    ) -> Result<Receiver<EventPayload>, Error> {
243        let (tx, rx) = mpsc::channel(buffer_size);
244
245        let mut started_count = 0;
246        for (id, device) in self.devices.iter_mut().enumerate() {
247            if device.event_handle.is_some() {
248                warn!("Skip device with id `{id}`: event receiver already started");
249                continue;
250            }
251
252            let Some(ref events) = device.events else {
253                warn!("Skip device with id `{id}`: it does not support events");
254                continue;
255            };
256
257            EventsRunner::run_global_subscriber(events, id, tx.clone()).await?;
258
259            started_count += 1;
260        }
261
262        if started_count == 0 {
263            return Err(Error::new(
264                ErrorKind::Events,
265                "No event receiver tasks has started",
266            ));
267        }
268
269        Ok(rx)
270    }
271
272    /// Returns an immutable reference to [`Devices`].
273    #[must_use]
274    pub const fn devices(&self) -> &Devices {
275        &self.devices
276    }
277
278    /// Returns a mutable reference to [`Devices`].
279    #[must_use]
280    pub const fn devices_mut(&mut self) -> &mut Devices {
281        &mut self.devices
282    }
283
284    /// Builds a [`DeviceSender`] for the [`Device`] with the given identifier.
285    ///
286    /// # Errors
287    ///
288    /// An error is returned if no devices are found or if the given index
289    /// **does** not exist.
290    pub fn device(&self, id: usize) -> Result<DeviceSender<'_>, Error> {
291        if self.devices.is_empty() {
292            return Err(sender_error("No devices found."));
293        }
294
295        let device = self.devices.get(id).ok_or_else(|| {
296            sender_error(format!(
297                "Error in retrieving the device with identifier {id}."
298            ))
299        })?;
300        Ok(DeviceSender {
301            controller: self,
302            device,
303            id,
304        })
305    }
306
307    /// Shuts down the [`Controller`], stopping all asynchronous tasks and
308    /// releasing all associated resources.
309    ///
310    /// # Note
311    ///
312    /// For a graceful shutdown, this method must be called before dropping
313    /// the [`Controller`].
314    pub async fn shutdown(self) {
315        // Stop all events tasks.
316        for device in self.devices {
317            if let Some(events) = device.events {
318                // Stop the infinite loop
319                events.cancellation_token.cancel();
320            }
321
322            if let Some(event_handle) = device.event_handle {
323                // Await the task.
324                if let Err(e) = event_handle.await {
325                    error!("Failed to await the event task: {e}");
326                }
327            }
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use std::fmt::Debug;
335
336    use tracing::warn;
337
338    use tosca::hazards::{Hazard, Hazards};
339    use tosca::parameters::ParametersValues;
340    use tosca::response::{OkResponse, SerialResponse};
341
342    use serde::{Serialize, de::DeserializeOwned};
343    use serde_json::json;
344
345    use serial_test::serial;
346
347    use crate::device::Devices;
348    use crate::error::Error;
349    use crate::policy::Policy;
350    use crate::response::Response;
351
352    use crate::device::tests::{create_light, create_unknown};
353    use crate::discovery::tests::configure_discovery;
354    use crate::tests::{Brightness, check_function_with_device};
355
356    use super::{Controller, DeviceSender, RequestSender, sender_error};
357
358    #[test]
359    fn empty_controller() {
360        let controller = Controller::new(configure_discovery());
361
362        assert_eq!(
363            controller,
364            Controller {
365                discovery: configure_discovery(),
366                devices: Devices::new(),
367                privacy_policy: Policy::init(),
368            }
369        );
370
371        // No devices.
372        assert_eq!(controller.device(0), Err(sender_error("No devices found.")));
373    }
374
375    #[test]
376    fn controller_from_devices() {
377        let devices = Devices::from_devices(vec![create_light(), create_unknown()]);
378
379        let controller = Controller::from_devices(configure_discovery(), devices);
380
381        assert_eq!(
382            controller,
383            Controller {
384                discovery: configure_discovery(),
385                devices: Devices::from_devices(vec![create_light(), create_unknown()]),
386                privacy_policy: Policy::init(),
387            }
388        );
389    }
390
391    async fn check_ok_response_plain(device_sender: &DeviceSender<'_>, route: &str) {
392        check_ok_response(device_sender, route, async move |request_sender| {
393            request_sender.send().await
394        })
395        .await;
396    }
397
398    async fn check_ok_response_with_parameters(
399        device_sender: &DeviceSender<'_>,
400        route: &str,
401        parameters: &ParametersValues<'_>,
402    ) {
403        check_ok_response(device_sender, route, async move |request_sender| {
404            request_sender.send_with_parameters(parameters).await
405        })
406        .await;
407    }
408
409    async fn check_ok_response<'controller, 'a, F>(
410        device_sender: &'a DeviceSender<'controller>,
411        route: &'a str,
412        get_response: F,
413    ) where
414        F: AsyncFnOnce(RequestSender<'controller>) -> Result<Response, Error>,
415        'a: 'controller,
416    {
417        let request = device_sender.request(route).unwrap();
418
419        let response = get_response(request).await.unwrap();
420        if let Response::OkBody(response) = response {
421            let ok_response = response.parse_body().await.unwrap();
422            assert_eq!(ok_response, OkResponse::ok());
423        } else {
424            assert!(
425                matches!(response, Response::Skipped),
426                "Should be a blocked global `LogEnergyConsumption` for `/off` request"
427            );
428        }
429    }
430
431    async fn check_serial_response_plain<T: Serialize + DeserializeOwned + Debug + PartialEq>(
432        device_sender: &DeviceSender<'_>,
433        route: &str,
434        value: T,
435    ) {
436        check_serial_response(
437            device_sender,
438            route,
439            async move |request_sender| request_sender.send().await,
440            value,
441        )
442        .await;
443    }
444
445    async fn check_serial_response_with_parameters<
446        T: Serialize + DeserializeOwned + Debug + PartialEq,
447    >(
448        device_sender: &DeviceSender<'_>,
449        route: &str,
450        parameters: &ParametersValues<'_>,
451        value: T,
452    ) {
453        check_serial_response(
454            device_sender,
455            route,
456            async move |request| request.send_with_parameters(parameters).await,
457            value,
458        )
459        .await;
460    }
461
462    async fn check_serial_response<'controller, 'a, F, T>(
463        device: &'a DeviceSender<'controller>,
464        route: &'a str,
465        get_response: F,
466        value: T,
467    ) where
468        F: AsyncFnOnce(RequestSender<'controller>) -> Result<Response, Error>,
469        T: Serialize + DeserializeOwned + Debug + PartialEq,
470        'a: 'controller,
471    {
472        let request = device.request(route).unwrap();
473
474        let response = get_response(request).await.unwrap();
475        if let Response::SerialBody(response) = response {
476            let serial_response = response.parse_body::<T>().await.unwrap();
477            assert_eq!(serial_response, SerialResponse::new(value));
478        } else {
479            assert!(
480                matches!(response, Response::Skipped),
481                "Should be a blocked local `FireHazard` for `/toggle` request"
482            );
483        }
484    }
485
486    async fn controller_checks(controller: Controller) {
487        // Wrong device id.
488        assert_eq!(
489            controller.device(1),
490            Err(sender_error(
491                "Error in retrieving the device with identifier 1."
492            ))
493        );
494
495        // Get device.
496        let device_sender = controller.device(0).unwrap();
497
498        // Wrong request.
499        assert_eq!(
500            device_sender.request("/wrong"),
501            Err(sender_error(
502                "Error in retrieving the request with route `/wrong`."
503            ))
504        );
505
506        // Run "/on" request and get "Ok" response.
507        check_ok_response_plain(&device_sender, "/on").await;
508
509        // Run "/off" request and get "Ok" response.
510        check_ok_response_plain(&device_sender, "/off").await;
511
512        // Run "/toggle" request and get "Ok" response.
513        check_serial_response_plain(
514            &device_sender,
515            "/toggle",
516            json!({
517                "brightness": 0,
518            }),
519        )
520        .await;
521
522        // With parameters
523        let mut parameters = ParametersValues::new();
524        parameters.u64("brightness", 5);
525
526        // Run "/on" request and get an "Ok" response with parameters.
527        check_ok_response_with_parameters(&device_sender, "/on", &parameters).await;
528
529        // Run "/off" request and get an "Ok" response with parameters.
530        check_ok_response_with_parameters(&device_sender, "/off", &parameters).await;
531
532        // Run "/toggle" request and get an "Ok" response with parameters.
533        check_serial_response_with_parameters(
534            &device_sender,
535            "/toggle",
536            &parameters,
537            Brightness { brightness: 5 },
538        )
539        .await;
540    }
541
542    #[inline]
543    async fn controller_without_policy() {
544        // Create a controller.
545        let mut controller = Controller::new(configure_discovery());
546
547        // Run discovery process.
548        controller.discover().await.unwrap();
549
550        // Run controller checks.
551        controller_checks(controller).await;
552    }
553
554    #[inline]
555    async fn controller_with_policy() {
556        // Global blocked hazards.
557        let global_hazards = Hazards::new().insert(Hazard::LogEnergyConsumption);
558
559        // Local blocked hazards for a specific device.
560        let local_hazards = Hazards::new().insert(Hazard::FireHazard);
561
562        // Create both a global policy and a local one.
563        let policy = Policy::new(global_hazards).block_device_on_hazards(0, local_hazards);
564
565        // Create a controller.
566        let mut controller = Controller::new(configure_discovery()).policy(policy);
567
568        // Run discovery process.
569        controller.discover().await.unwrap();
570
571        // Run controller checks.
572        controller_checks(controller).await;
573    }
574
575    #[inline]
576    async fn run_controller_function<F, Fut>(name: &str, function: F)
577    where
578        F: FnOnce() -> Fut,
579        Fut: Future<Output = ()>,
580    {
581        if option_env!("CI").is_some() {
582            warn!(
583                "Skipping test on CI: {} can run only on systems that expose physical MAC addresses.",
584                name
585            );
586        } else {
587            check_function_with_device(|| async {
588                function().await;
589            })
590            .await;
591        }
592    }
593
594    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
595    #[serial]
596    async fn test_without_policy_controller() {
597        run_controller_function("controller_without_policy", || async {
598            controller_without_policy().await;
599        })
600        .await;
601    }
602
603    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
604    #[serial]
605    async fn test_with_policy_controller() {
606        run_controller_function("controller_with_policy", || async {
607            controller_with_policy().await;
608        })
609        .await;
610    }
611}