1use bytes::{self, Buf, Bytes};
7use core::pin::Pin;
8use futures_util::{Stream, StreamExt};
9use reqwest::{Client, StatusCode};
10use std::io::{self, BufRead};
11use std::task::{ready, Context, Poll};
12use tokio::sync::mpsc;
13use tokio_util::sync::ReusableBoxFuture;
14
15pub mod proto;
16
17#[derive(Debug)]
18pub enum Error {
19 Reqwest(reqwest::Error),
20 Io(io::Error),
21 JSON(serde_json::Error),
22 String(StatusCode),
23}
24
25impl From<serde_json::Error> for Error {
26 fn from(value: serde_json::Error) -> Self {
27 Self::JSON(value)
28 }
29}
30
31impl From<reqwest::Error> for Error {
32 fn from(value: reqwest::Error) -> Self {
33 Self::Reqwest(value)
34 }
35}
36
37impl From<io::Error> for Error {
38 fn from(value: io::Error) -> Self {
39 Self::Io(value)
40 }
41}
42
43pub struct ResourceStream {
46 inner: ReusableBoxFuture<'static, (Option<Bytes>, mpsc::Receiver<Bytes>)>,
47 buf: Vec<u8>,
48 partial: Option<bytes::buf::Reader<Bytes>>,
49}
50
51impl ResourceStream {
52 pub fn new(rx: mpsc::Receiver<Bytes>) -> ResourceStream {
53 ResourceStream {
54 inner: ReusableBoxFuture::new(make_future(rx)),
55 buf: vec![],
56 partial: None,
57 }
58 }
59}
60
61async fn make_future(mut rx: mpsc::Receiver<Bytes>) -> (Option<Bytes>, mpsc::Receiver<Bytes>) {
62 let result = rx.recv().await;
63 (result, rx)
64}
65
66impl Stream for ResourceStream {
67 type Item = proto::ResourceDiff;
68 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
69 let parse = |buffer: &mut bytes::buf::Reader<Bytes>,
70 buf: &mut Vec<u8>|
71 -> Result<Option<Self::Item>, Error> {
72 match buffer.read_until(b'\r', buf) {
73 Ok(_) => match buf.pop() {
74 Some(b'\r') => match serde_json::from_slice(buf) {
75 Ok(diff) => {
76 buf.clear();
77 Ok(Some(diff))
78 }
79 Err(e) => Err(Error::JSON(e)),
80 },
81 Some(n) => {
82 buf.push(n);
83 Ok(None)
84 }
85 None => Ok(None),
86 },
87 Err(e) => Err(Error::Io(e)),
88 }
89 };
90 let mut buf = self.buf.clone();
94 if let Some(p) = &mut self.partial {
95 match parse(p, &mut buf) {
96 Ok(Some(diff)) => return Poll::Ready(Some(diff)),
97 Ok(None) => self.partial = None,
98 Err(_) => return Poll::Ready(None),
99 }
100 }
101 self.buf = buf;
102 loop {
103 let (result, rx) = ready!(self.inner.poll(cx));
104 self.inner.set(make_future(rx));
105 match result {
106 Some(chunk) => {
107 let mut buffer = chunk.reader();
108 match parse(&mut buffer, &mut self.buf) {
109 Ok(Some(diff)) => {
110 self.partial = Some(buffer);
111 return Poll::Ready(Some(diff));
112 }
113 Ok(None) => continue,
114 Err(_) => return Poll::Ready(None),
115 }
116 }
117 None => return Poll::Ready(None),
118 }
119 }
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[tokio::test]
128 async fn parse_resource() {
129 let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
130 let chunk = Bytes::from_static(
131 b"{\"new\": null,\"changed\": null,\"gone\": null,\"full_update\": true}\r",
132 );
133 let (tx, rx) = mpsc::channel(100);
134 tx.send(chunk).await.unwrap();
135 let mut diffs = ResourceStream::new(rx);
136 let res = Pin::new(&mut diffs).poll_next(&mut cx);
137 assert_ne!(res, Poll::Ready(None));
138 assert_ne!(res, Poll::Pending);
139 if let Poll::Ready(Some(diff)) = res {
140 assert_eq!(diff.new, None);
141 assert!(diff.full_update);
142 }
143 }
144
145 #[tokio::test]
146 async fn parse_across_chunks() {
147 let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
148 let chunk1 = Bytes::from_static(b"{\"new\": null,\"changed\": null,");
149 let chunk2 = Bytes::from_static(b"\"gone\": null,\"full_update\": true}\r");
150 let (tx, rx) = mpsc::channel(100);
151 tx.send(chunk1).await.unwrap();
152 tx.send(chunk2).await.unwrap();
153 let mut diffs = ResourceStream::new(rx);
154 let mut res = Pin::new(&mut diffs).poll_next(&mut cx);
155 while res.is_pending() {
156 res = Pin::new(&mut diffs).poll_next(&mut cx);
157 }
158 assert_ne!(res, Poll::Ready(None));
159 assert_ne!(res, Poll::Pending);
160 if let Poll::Ready(Some(diff)) = res {
161 assert_eq!(diff.new, None);
162 assert!(diff.full_update);
163 }
164 }
165
166 #[tokio::test]
167 async fn parse_multi_diff_partial_chunks() {
168 let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
169 let chunk1 = Bytes::from_static(b"{\"new\": null,\"changed\": null,");
170 let chunk2 =
171 Bytes::from_static(b"\"gone\": null,\"full_update\": true}\r{\"new\": null,\"changed");
172 let chunk3 = Bytes::from_static(b"\": null,\"gone\": null,\"full_update\": true}");
173 let chunk4 = Bytes::from_static(b"\r");
174 let (tx, rx) = mpsc::channel(100);
175 tx.send(chunk1).await.unwrap();
176 tx.send(chunk2).await.unwrap();
177 tx.send(chunk3).await.unwrap();
178 tx.send(chunk4).await.unwrap();
179 let mut diffs = ResourceStream::new(rx);
180 let mut res = Pin::new(&mut diffs).poll_next(&mut cx);
181 while res.is_pending() {
182 res = Pin::new(&mut diffs).poll_next(&mut cx);
183 }
184 assert_ne!(res, Poll::Ready(None));
185 assert_ne!(res, Poll::Pending);
186 if let Poll::Ready(Some(diff)) = res {
187 assert_eq!(diff.new, None);
188 assert!(diff.full_update);
189 }
190 res = Pin::new(&mut diffs).poll_next(&mut cx);
191 while res.is_pending() {
192 res = Pin::new(&mut diffs).poll_next(&mut cx);
193 }
194 assert_ne!(res, Poll::Ready(None));
195 assert_ne!(res, Poll::Pending);
196 if let Poll::Ready(Some(diff)) = res {
197 assert_eq!(diff.new, None);
198 assert!(diff.full_update);
199 }
200 }
201}
202
203pub async fn start_stream(
226 api_endpoint: String,
227 name: String,
228 token: String,
229 resource_types: Vec<String>,
230) -> Result<ResourceStream, Error> {
231 let (tx, rx) = mpsc::channel(100);
232
233 let req = proto::ResourceRequest {
234 request_origin: name,
235 resource_types,
236 };
237 let json = serde_json::to_string(&req)?;
238
239 let auth_value = format!("Bearer {}", token);
240
241 let client = Client::new();
242
243 let mut stream = client
244 .get(api_endpoint)
245 .header("Authorization", &auth_value)
246 .body(json)
247 .send()
248 .await?
249 .bytes_stream();
250
251 tokio::spawn(async move {
252 while let Some(chunk) = stream.next().await {
253 let bytes = match chunk {
254 Ok(b) => b,
255 Err(_e) => {
256 return;
257 }
258 };
259 tx.send(bytes).await.unwrap();
260 }
261 });
262 Ok(ResourceStream::new(rx))
263}
264
265pub async fn request_resources(
266 api_endpoint: String,
267 name: String,
268 token: String,
269 resource_types: Vec<String>,
270) -> Result<proto::ResourceState, Error> {
271 let fetched_resources: Result<proto::ResourceState, Error>;
272 let req = proto::ResourceRequest {
273 request_origin: name,
274 resource_types,
275 };
276 let json = serde_json::to_string(&req)?;
277
278 let auth_value = format!("Bearer {}", token);
279
280 let client = Client::new();
281
282 let response = client
283 .get(api_endpoint)
284 .header("Authorization", &auth_value)
285 .body(json)
286 .send()
287 .await
288 .unwrap();
289 match response.status() {
290 reqwest::StatusCode::OK => {
291 fetched_resources = match response.json::<proto::ResourceState>().await {
292 Ok(fetched_resources) => Ok(fetched_resources),
293 Err(e) => Err(Error::Reqwest(e)),
294 };
295 }
296 other => fetched_resources = Err(Error::String(other)),
297 };
298 fetched_resources
299}