lava_api/
device.rs

1//! Retrieve devices
2
3use futures::future::BoxFuture;
4use futures::FutureExt;
5use futures::{stream, stream::Stream, stream::StreamExt};
6use serde::Deserialize;
7use serde_with::DeserializeFromStr;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use strum::{Display, EnumString};
11
12use crate::paginator::{PaginationError, Paginator};
13use crate::tag::Tag;
14use crate::Lava;
15
16/// The current status of a [`Device`]
17#[derive(Clone, Copy, Debug, DeserializeFromStr, Display, EnumString, Eq, PartialEq)]
18pub enum Health {
19    Unknown,
20    Maintenance,
21    Good,
22    Bad,
23    Looping,
24    Retired,
25}
26
27#[derive(Clone, Deserialize, Debug)]
28struct LavaDevice {
29    hostname: String,
30    worker_host: String,
31    device_type: String,
32    description: Option<String>,
33    health: Health,
34    pub tags: Vec<u32>,
35}
36
37/// A subset of the data available for a device from the LAVA API.
38///
39/// Note that [`tags`](Device::tags) have been resolved into [`Tag`]
40/// objects, rather than tag ids.
41#[derive(Clone, Debug, PartialEq, Eq)]
42pub struct Device {
43    pub hostname: String,
44    pub worker_host: String,
45    pub device_type: String,
46    pub description: Option<String>,
47    pub health: Health,
48    pub tags: Vec<Tag>,
49}
50
51enum State<'a> {
52    Paging,
53    Transforming(BoxFuture<'a, Device>),
54}
55
56/// A [`Stream`] that yields all the [`Device`] instances on a LAVA
57/// server.
58pub struct Devices<'a> {
59    lava: &'a Lava,
60    paginator: Paginator<LavaDevice>,
61    state: State<'a>,
62}
63
64impl<'a> Devices<'a> {
65    /// Create a new stream, using the given [`Lava`] proxy.
66    ///
67    /// Note that due to pagination, the dataset returned is not
68    /// guaranteed to be self-consistent, and the odds of
69    /// self-consistency decrease the longer it takes to iterate over
70    /// the stream. It is therefore advisable to extract whatever data
71    /// is required immediately after the creation of this object.
72    pub fn new(lava: &'a Lava) -> Self {
73        let url = lava
74            .base
75            .join("devices/?ordering=hostname")
76            .expect("Failed to append to base url");
77        let paginator = Paginator::new(lava.client.clone(), url);
78        Self {
79            lava,
80            paginator,
81            state: State::Paging,
82        }
83    }
84}
85
86async fn transform_device(device: LavaDevice, lava: &Lava) -> Device {
87    let t = stream::iter(device.tags.iter());
88    let tags = t
89        .filter_map(|i| async move { lava.tag(*i).await })
90        .collect()
91        .await;
92
93    Device {
94        hostname: device.hostname,
95        worker_host: device.worker_host,
96        device_type: device.device_type,
97        description: device.description,
98        health: device.health,
99        tags,
100    }
101}
102
103impl<'a> Stream for Devices<'a> {
104    type Item = Result<Device, PaginationError>;
105
106    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
107        let me = self.get_mut();
108
109        loop {
110            return match &mut me.state {
111                State::Paging => {
112                    let p = Pin::new(&mut me.paginator);
113                    match p.poll_next(cx) {
114                        Poll::Ready(None) => Poll::Ready(None),
115                        Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
116                        Poll::Ready(Some(Ok(d))) => {
117                            me.state = State::Transforming(transform_device(d, me.lava).boxed());
118                            continue;
119                        }
120                        Poll::Pending => Poll::Pending,
121                    }
122                }
123                State::Transforming(fut) => match fut.as_mut().poll(cx) {
124                    Poll::Ready(d) => {
125                        me.state = State::Paging;
126                        Poll::Ready(Some(Ok(d)))
127                    }
128                    Poll::Pending => Poll::Pending,
129                },
130            };
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::{Device, Health, Tag};
138    use crate::Lava;
139
140    use boulder::{Buildable, Builder};
141    use futures::TryStreamExt;
142    use lava_api_mock::{
143        Device as MockDevice, DeviceHealth as MockDeviceHealth, DeviceType as MockDeviceType,
144        LavaMock, PaginationLimits, PopulationParams, SharedState, State, Tag as MockTag,
145        Worker as MockWorker,
146    };
147    use persian_rug::{Accessor, Context};
148    use std::collections::BTreeMap;
149    use std::convert::{Infallible, TryFrom, TryInto};
150    use test_log::test;
151
152    impl TryFrom<MockDeviceHealth> for Health {
153        type Error = Infallible;
154        fn try_from(dev: MockDeviceHealth) -> Result<Health, Self::Error> {
155            use Health::*;
156            match dev {
157                MockDeviceHealth::Unknown => Ok(Unknown),
158                MockDeviceHealth::Maintenance => Ok(Maintenance),
159                MockDeviceHealth::Good => Ok(Good),
160                MockDeviceHealth::Bad => Ok(Bad),
161                MockDeviceHealth::Looping => Ok(Looping),
162                MockDeviceHealth::Retired => Ok(Retired),
163            }
164        }
165    }
166
167    impl Device {
168        #[persian_rug::constraints(context = C, access(MockTag<C>, MockDeviceType<C>, MockWorker<C>))]
169        pub fn from_mock<'b, B, C>(dev: &MockDevice<C>, context: B) -> Device
170        where
171            B: 'b + Accessor<Context = C>,
172            C: Context + 'static,
173        {
174            Self {
175                hostname: dev.hostname.clone(),
176                worker_host: context.get(&dev.worker_host).hostname.clone(),
177                device_type: context.get(&dev.device_type).name.clone(),
178                description: dev.description.clone(),
179                health: dev.health.clone().try_into().unwrap(),
180                tags: dev
181                    .tags
182                    .iter()
183                    .map(|t| Tag::from_mock(context.get(t), context.clone()))
184                    .collect::<Vec<_>>(),
185            }
186        }
187    }
188
189    /// Stream 50 devices with a page limit of 5 from the server
190    /// checking that we correctly reconstruct their tags and that
191    /// they are all accounted for (that pagination is handled
192    /// properly)
193    #[test(tokio::test)]
194    async fn test_basic() {
195        let state =
196            SharedState::new_populated(PopulationParams::builder().devices(50usize).build());
197        let server = LavaMock::new(
198            state.clone(),
199            PaginationLimits::builder().devices(Some(5)).build(),
200        )
201        .await;
202
203        let mut map = BTreeMap::new();
204        let start = state.access();
205        for device in start.get_iter::<lava_api_mock::Device<State>>() {
206            map.insert(device.hostname.clone(), device);
207        }
208
209        let lava = Lava::new(&server.uri(), None).expect("failed to make lava server");
210
211        let mut ld = lava.devices();
212
213        let mut seen = BTreeMap::new();
214        while let Some(device) = ld.try_next().await.expect("failed to get device") {
215            assert!(!seen.contains_key(&device.hostname));
216            assert!(map.contains_key(&device.hostname));
217            let dev = map.get(&device.hostname).unwrap();
218            assert_eq!(device.hostname, dev.hostname);
219            assert_eq!(device.worker_host, start.get(&dev.worker_host).hostname);
220            assert_eq!(device.device_type, start.get(&dev.device_type).name);
221            assert_eq!(device.description, dev.description);
222            assert_eq!(device.health.to_string(), dev.health.to_string());
223
224            assert_eq!(device.tags.len(), dev.tags.len());
225            for i in 0..device.tags.len() {
226                assert_eq!(device.tags[i].id, start.get(&dev.tags[i]).id);
227                assert_eq!(device.tags[i].name, start.get(&dev.tags[i]).name);
228                assert_eq!(
229                    device.tags[i].description,
230                    start.get(&dev.tags[i]).description
231                );
232            }
233
234            seen.insert(device.hostname.clone(), device.clone());
235        }
236        assert_eq!(seen.len(), 50);
237    }
238}