1use futures::future::BoxFuture;
4use futures::FutureExt;
5use futures::{stream, stream::Stream, stream::StreamExt};
6use serde::Deserialize;
7use serde_with::DeserializeFromStr;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use strum::{Display, EnumString};
11
12use crate::paginator::{PaginationError, Paginator};
13use crate::tag::Tag;
14use crate::Lava;
15
16#[derive(Clone, Copy, Debug, DeserializeFromStr, Display, EnumString, Eq, PartialEq)]
18pub enum Health {
19 Unknown,
20 Maintenance,
21 Good,
22 Bad,
23 Looping,
24 Retired,
25}
26
27#[derive(Clone, Deserialize, Debug)]
28struct LavaDevice {
29 hostname: String,
30 worker_host: String,
31 device_type: String,
32 description: Option<String>,
33 health: Health,
34 pub tags: Vec<u32>,
35}
36
37#[derive(Clone, Debug, PartialEq, Eq)]
42pub struct Device {
43 pub hostname: String,
44 pub worker_host: String,
45 pub device_type: String,
46 pub description: Option<String>,
47 pub health: Health,
48 pub tags: Vec<Tag>,
49}
50
51enum State<'a> {
52 Paging,
53 Transforming(BoxFuture<'a, Device>),
54}
55
56pub struct Devices<'a> {
59 lava: &'a Lava,
60 paginator: Paginator<LavaDevice>,
61 state: State<'a>,
62}
63
64impl<'a> Devices<'a> {
65 pub fn new(lava: &'a Lava) -> Self {
73 let url = lava
74 .base
75 .join("devices/?ordering=hostname")
76 .expect("Failed to append to base url");
77 let paginator = Paginator::new(lava.client.clone(), url);
78 Self {
79 lava,
80 paginator,
81 state: State::Paging,
82 }
83 }
84}
85
86async fn transform_device(device: LavaDevice, lava: &Lava) -> Device {
87 let t = stream::iter(device.tags.iter());
88 let tags = t
89 .filter_map(|i| async move { lava.tag(*i).await })
90 .collect()
91 .await;
92
93 Device {
94 hostname: device.hostname,
95 worker_host: device.worker_host,
96 device_type: device.device_type,
97 description: device.description,
98 health: device.health,
99 tags,
100 }
101}
102
103impl<'a> Stream for Devices<'a> {
104 type Item = Result<Device, PaginationError>;
105
106 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
107 let me = self.get_mut();
108
109 loop {
110 return match &mut me.state {
111 State::Paging => {
112 let p = Pin::new(&mut me.paginator);
113 match p.poll_next(cx) {
114 Poll::Ready(None) => Poll::Ready(None),
115 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
116 Poll::Ready(Some(Ok(d))) => {
117 me.state = State::Transforming(transform_device(d, me.lava).boxed());
118 continue;
119 }
120 Poll::Pending => Poll::Pending,
121 }
122 }
123 State::Transforming(fut) => match fut.as_mut().poll(cx) {
124 Poll::Ready(d) => {
125 me.state = State::Paging;
126 Poll::Ready(Some(Ok(d)))
127 }
128 Poll::Pending => Poll::Pending,
129 },
130 };
131 }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::{Device, Health, Tag};
138 use crate::Lava;
139
140 use boulder::{Buildable, Builder};
141 use futures::TryStreamExt;
142 use lava_api_mock::{
143 Device as MockDevice, DeviceHealth as MockDeviceHealth, DeviceType as MockDeviceType,
144 LavaMock, PaginationLimits, PopulationParams, SharedState, State, Tag as MockTag,
145 Worker as MockWorker,
146 };
147 use persian_rug::{Accessor, Context};
148 use std::collections::BTreeMap;
149 use std::convert::{Infallible, TryFrom, TryInto};
150 use test_log::test;
151
152 impl TryFrom<MockDeviceHealth> for Health {
153 type Error = Infallible;
154 fn try_from(dev: MockDeviceHealth) -> Result<Health, Self::Error> {
155 use Health::*;
156 match dev {
157 MockDeviceHealth::Unknown => Ok(Unknown),
158 MockDeviceHealth::Maintenance => Ok(Maintenance),
159 MockDeviceHealth::Good => Ok(Good),
160 MockDeviceHealth::Bad => Ok(Bad),
161 MockDeviceHealth::Looping => Ok(Looping),
162 MockDeviceHealth::Retired => Ok(Retired),
163 }
164 }
165 }
166
167 impl Device {
168 #[persian_rug::constraints(context = C, access(MockTag<C>, MockDeviceType<C>, MockWorker<C>))]
169 pub fn from_mock<'b, B, C>(dev: &MockDevice<C>, context: B) -> Device
170 where
171 B: 'b + Accessor<Context = C>,
172 C: Context + 'static,
173 {
174 Self {
175 hostname: dev.hostname.clone(),
176 worker_host: context.get(&dev.worker_host).hostname.clone(),
177 device_type: context.get(&dev.device_type).name.clone(),
178 description: dev.description.clone(),
179 health: dev.health.clone().try_into().unwrap(),
180 tags: dev
181 .tags
182 .iter()
183 .map(|t| Tag::from_mock(context.get(t), context.clone()))
184 .collect::<Vec<_>>(),
185 }
186 }
187 }
188
189 #[test(tokio::test)]
194 async fn test_basic() {
195 let state =
196 SharedState::new_populated(PopulationParams::builder().devices(50usize).build());
197 let server = LavaMock::new(
198 state.clone(),
199 PaginationLimits::builder().devices(Some(5)).build(),
200 )
201 .await;
202
203 let mut map = BTreeMap::new();
204 let start = state.access();
205 for device in start.get_iter::<lava_api_mock::Device<State>>() {
206 map.insert(device.hostname.clone(), device);
207 }
208
209 let lava = Lava::new(&server.uri(), None).expect("failed to make lava server");
210
211 let mut ld = lava.devices();
212
213 let mut seen = BTreeMap::new();
214 while let Some(device) = ld.try_next().await.expect("failed to get device") {
215 assert!(!seen.contains_key(&device.hostname));
216 assert!(map.contains_key(&device.hostname));
217 let dev = map.get(&device.hostname).unwrap();
218 assert_eq!(device.hostname, dev.hostname);
219 assert_eq!(device.worker_host, start.get(&dev.worker_host).hostname);
220 assert_eq!(device.device_type, start.get(&dev.device_type).name);
221 assert_eq!(device.description, dev.description);
222 assert_eq!(device.health.to_string(), dev.health.to_string());
223
224 assert_eq!(device.tags.len(), dev.tags.len());
225 for i in 0..device.tags.len() {
226 assert_eq!(device.tags[i].id, start.get(&dev.tags[i]).id);
227 assert_eq!(device.tags[i].name, start.get(&dev.tags[i]).name);
228 assert_eq!(
229 device.tags[i].description,
230 start.get(&dev.tags[i]).description
231 );
232 }
233
234 seen.insert(device.hostname.clone(), device.clone());
235 }
236 assert_eq!(seen.len(), 50);
237 }
238}