Skip to main content

lava_api/
device.rs

1//! Retrieve devices
2
3use futures::FutureExt;
4use futures::future::BoxFuture;
5use futures::{stream, stream::Stream, stream::StreamExt};
6use serde::Deserialize;
7use serde_with::DeserializeFromStr;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use strum::{Display, EnumString};
11
12use crate::Lava;
13use crate::paginator::{PaginationError, Paginator};
14use crate::tag::Tag;
15
16/// The current status of a [`Device`]
17#[derive(Clone, Copy, Debug, DeserializeFromStr, Display, EnumString, Eq, PartialEq)]
18pub enum Health {
19    Unknown,
20    Maintenance,
21    Good,
22    Bad,
23    Looping,
24    Retired,
25}
26
27#[derive(Clone, Deserialize, Debug)]
28struct LavaDevice {
29    hostname: String,
30    worker_host: String,
31    device_type: String,
32    description: Option<String>,
33    health: Health,
34    pub tags: Vec<u32>,
35}
36
37/// A subset of the data available for a device from the LAVA API.
38///
39/// Note that [`tags`](Device::tags) have been resolved into [`Tag`]
40/// objects, rather than tag ids.
41#[derive(Clone, Debug, PartialEq, Eq)]
42pub struct Device {
43    pub hostname: String,
44    pub worker_host: String,
45    pub device_type: String,
46    pub description: Option<String>,
47    pub health: Health,
48    pub tags: Vec<Tag>,
49}
50
51enum State<'a> {
52    Paging,
53    Transforming(BoxFuture<'a, Device>),
54}
55
56/// A [`Stream`] that yields all the [`Device`] instances on a LAVA
57/// server.
58pub struct Devices<'a> {
59    lava: &'a Lava,
60    paginator: Paginator<LavaDevice>,
61    state: State<'a>,
62}
63
64impl<'a> Devices<'a> {
65    /// Create a new stream, using the given [`Lava`] proxy.
66    ///
67    /// Note that due to pagination, the dataset returned is not
68    /// guaranteed to be self-consistent, and the odds of
69    /// self-consistency decrease the longer it takes to iterate over
70    /// the stream. It is therefore advisable to extract whatever data
71    /// is required immediately after the creation of this object.
72    pub fn new(lava: &'a Lava) -> Self {
73        let url = lava
74            .base
75            .join("devices/?ordering=hostname")
76            .expect("Failed to append to base url");
77        let paginator = Paginator::new(lava.client.clone(), url);
78        Self {
79            lava,
80            paginator,
81            state: State::Paging,
82        }
83    }
84}
85
86async fn transform_device(device: LavaDevice, lava: &Lava) -> Device {
87    let t = stream::iter(device.tags.iter());
88    let tags = t
89        .filter_map(|i| async move { lava.tag(*i).await })
90        .collect()
91        .await;
92
93    Device {
94        hostname: device.hostname,
95        worker_host: device.worker_host,
96        device_type: device.device_type,
97        description: device.description,
98        health: device.health,
99        tags,
100    }
101}
102
103impl Stream for Devices<'_> {
104    type Item = Result<Device, PaginationError>;
105
106    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
107        let me = self.get_mut();
108
109        loop {
110            return match &mut me.state {
111                State::Paging => {
112                    let p = Pin::new(&mut me.paginator);
113                    match p.poll_next(cx) {
114                        Poll::Ready(None) => Poll::Ready(None),
115                        Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
116                        Poll::Ready(Some(Ok(d))) => {
117                            me.state = State::Transforming(transform_device(d, me.lava).boxed());
118                            continue;
119                        }
120                        Poll::Pending => Poll::Pending,
121                    }
122                }
123                State::Transforming(fut) => match fut.as_mut().poll(cx) {
124                    Poll::Ready(d) => {
125                        me.state = State::Paging;
126                        Poll::Ready(Some(Ok(d)))
127                    }
128                    Poll::Pending => Poll::Pending,
129                },
130            };
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::{Device, Health, Tag};
138    use crate::Lava;
139
140    use boulder::{Buildable, Builder};
141    use futures::TryStreamExt;
142    use lava_api_mock::{
143        Device as MockDevice, DeviceHealth as MockDeviceHealth, DeviceType as MockDeviceType,
144        LavaMock, PaginationLimits, PopulationParams, SharedState, State, Tag as MockTag,
145        Worker as MockWorker,
146    };
147    use persian_rug::{Accessor, Context};
148    use std::collections::BTreeMap;
149    use test_log::test;
150
151    impl From<MockDeviceHealth> for Health {
152        fn from(dev: MockDeviceHealth) -> Health {
153            use Health::*;
154            match dev {
155                MockDeviceHealth::Unknown => Unknown,
156                MockDeviceHealth::Maintenance => Maintenance,
157                MockDeviceHealth::Good => Good,
158                MockDeviceHealth::Bad => Bad,
159                MockDeviceHealth::Looping => Looping,
160                MockDeviceHealth::Retired => Retired,
161            }
162        }
163    }
164
165    impl Device {
166        #[persian_rug::constraints(context = C, access(MockTag<C>, MockDeviceType<C>, MockWorker<C>))]
167        pub fn from_mock<'b, B, C>(dev: &MockDevice<C>, context: B) -> Device
168        where
169            B: 'b + Accessor<Context = C>,
170            C: Context + 'static,
171        {
172            Self {
173                hostname: dev.hostname.clone(),
174                worker_host: context.get(&dev.worker_host).hostname.clone(),
175                device_type: context.get(&dev.device_type).name.clone(),
176                description: dev.description.clone(),
177                health: dev.health.clone().into(),
178                tags: dev
179                    .tags
180                    .iter()
181                    .map(|t| Tag::from_mock(context.get(t), context.clone()))
182                    .collect::<Vec<_>>(),
183            }
184        }
185    }
186
187    /// Stream 50 devices with a page limit of 5 from the server
188    /// checking that we correctly reconstruct their tags and that
189    /// they are all accounted for (that pagination is handled
190    /// properly)
191    #[test(tokio::test)]
192    async fn test_basic() {
193        let state =
194            SharedState::new_populated(PopulationParams::builder().devices(50usize).build());
195        let server = LavaMock::new(
196            state.clone(),
197            PaginationLimits::builder().devices(Some(5)).build(),
198        )
199        .await;
200
201        let mut map = BTreeMap::new();
202        let start = state.access();
203        for device in start.get_iter::<lava_api_mock::Device<State>>() {
204            map.insert(device.hostname.clone(), device);
205        }
206
207        let lava = Lava::new(&server.uri(), None).expect("failed to make lava server");
208
209        let mut ld = lava.devices();
210
211        let mut seen = BTreeMap::new();
212        while let Some(device) = ld.try_next().await.expect("failed to get device") {
213            assert!(!seen.contains_key(&device.hostname));
214            assert!(map.contains_key(&device.hostname));
215            let dev = map.get(&device.hostname).unwrap();
216            assert_eq!(device.hostname, dev.hostname);
217            assert_eq!(device.worker_host, start.get(&dev.worker_host).hostname);
218            assert_eq!(device.device_type, start.get(&dev.device_type).name);
219            assert_eq!(device.description, dev.description);
220            assert_eq!(device.health.to_string(), dev.health.to_string());
221
222            assert_eq!(device.tags.len(), dev.tags.len());
223            for i in 0..device.tags.len() {
224                assert_eq!(device.tags[i].id, start.get(&dev.tags[i]).id);
225                assert_eq!(device.tags[i].name, start.get(&dev.tags[i]).name);
226                assert_eq!(
227                    device.tags[i].description,
228                    start.get(&dev.tags[i]).description
229                );
230            }
231
232            seen.insert(device.hostname.clone(), device.clone());
233        }
234        assert_eq!(seen.len(), 50);
235    }
236}