rdsys_backend/
lib.rs

1//! # Rdsys Backend Distributor API
2//!
3//! `rdsys_backend` is an implementation of the rdsys backend API
4//! https://gitlab.torproject.org/tpo/anti-censorship/rdsys/-/blob/main/doc/backend-api.md
5
6use bytes::{self, Buf, Bytes};
7use core::pin::Pin;
8use futures_util::{Stream, StreamExt};
9use reqwest::{Client, StatusCode};
10use std::io::{self, BufRead};
11use std::task::{ready, Context, Poll};
12use tokio::sync::mpsc;
13use tokio_util::sync::ReusableBoxFuture;
14
15pub mod proto;
16
17#[derive(Debug)]
18pub enum Error {
19    Reqwest(reqwest::Error),
20    Io(io::Error),
21    JSON(serde_json::Error),
22    String(StatusCode),
23}
24
25impl From<serde_json::Error> for Error {
26    fn from(value: serde_json::Error) -> Self {
27        Self::JSON(value)
28    }
29}
30
31impl From<reqwest::Error> for Error {
32    fn from(value: reqwest::Error) -> Self {
33        Self::Reqwest(value)
34    }
35}
36
37impl From<io::Error> for Error {
38    fn from(value: io::Error) -> Self {
39        Self::Io(value)
40    }
41}
42
43/// An iterable wrapper of ResourceDiff items for the streamed chunks of Bytes
44/// received from the connection to the rdsys backend
45pub struct ResourceStream {
46    inner: ReusableBoxFuture<'static, (Option<Bytes>, mpsc::Receiver<Bytes>)>,
47    buf: Vec<u8>,
48    partial: Option<bytes::buf::Reader<Bytes>>,
49}
50
51impl ResourceStream {
52    pub fn new(rx: mpsc::Receiver<Bytes>) -> ResourceStream {
53        ResourceStream {
54            inner: ReusableBoxFuture::new(make_future(rx)),
55            buf: vec![],
56            partial: None,
57        }
58    }
59}
60
61async fn make_future(mut rx: mpsc::Receiver<Bytes>) -> (Option<Bytes>, mpsc::Receiver<Bytes>) {
62    let result = rx.recv().await;
63    (result, rx)
64}
65
66impl Stream for ResourceStream {
67    type Item = proto::ResourceDiff;
68    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
69        let parse = |buffer: &mut bytes::buf::Reader<Bytes>,
70                     buf: &mut Vec<u8>|
71         -> Result<Option<Self::Item>, Error> {
72            match buffer.read_until(b'\r', buf) {
73                Ok(_) => match buf.pop() {
74                    Some(b'\r') => match serde_json::from_slice(buf) {
75                        Ok(diff) => {
76                            buf.clear();
77                            Ok(Some(diff))
78                        }
79                        Err(e) => Err(Error::JSON(e)),
80                    },
81                    Some(n) => {
82                        buf.push(n);
83                        Ok(None)
84                    }
85                    None => Ok(None),
86                },
87                Err(e) => Err(Error::Io(e)),
88            }
89        };
90        // This clone is here to avoid having multiple mutable references to self
91        // it's not optimal performance-wise but given that these resource streams aren't large
92        // this feels like an acceptable trade-off to the complexity of interior mutability
93        let mut buf = self.buf.clone();
94        if let Some(p) = &mut self.partial {
95            match parse(p, &mut buf) {
96                Ok(Some(diff)) => return Poll::Ready(Some(diff)),
97                Ok(None) => self.partial = None,
98                Err(_) => return Poll::Ready(None),
99            }
100        }
101        self.buf = buf;
102        loop {
103            let (result, rx) = ready!(self.inner.poll(cx));
104            self.inner.set(make_future(rx));
105            match result {
106                Some(chunk) => {
107                    let mut buffer = chunk.reader();
108                    match parse(&mut buffer, &mut self.buf) {
109                        Ok(Some(diff)) => {
110                            self.partial = Some(buffer);
111                            return Poll::Ready(Some(diff));
112                        }
113                        Ok(None) => continue,
114                        Err(_) => return Poll::Ready(None),
115                    }
116                }
117                None => return Poll::Ready(None),
118            }
119        }
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[tokio::test]
128    async fn parse_resource() {
129        let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
130        let chunk = Bytes::from_static(
131            b"{\"new\": null,\"changed\": null,\"gone\": null,\"full_update\": true}\r",
132        );
133        let (tx, rx) = mpsc::channel(100);
134        tx.send(chunk).await.unwrap();
135        let mut diffs = ResourceStream::new(rx);
136        let res = Pin::new(&mut diffs).poll_next(&mut cx);
137        assert_ne!(res, Poll::Ready(None));
138        assert_ne!(res, Poll::Pending);
139        if let Poll::Ready(Some(diff)) = res {
140            assert_eq!(diff.new, None);
141            assert!(diff.full_update);
142        }
143    }
144
145    #[tokio::test]
146    async fn parse_across_chunks() {
147        let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
148        let chunk1 = Bytes::from_static(b"{\"new\": null,\"changed\": null,");
149        let chunk2 = Bytes::from_static(b"\"gone\": null,\"full_update\": true}\r");
150        let (tx, rx) = mpsc::channel(100);
151        tx.send(chunk1).await.unwrap();
152        tx.send(chunk2).await.unwrap();
153        let mut diffs = ResourceStream::new(rx);
154        let mut res = Pin::new(&mut diffs).poll_next(&mut cx);
155        while res.is_pending() {
156            res = Pin::new(&mut diffs).poll_next(&mut cx);
157        }
158        assert_ne!(res, Poll::Ready(None));
159        assert_ne!(res, Poll::Pending);
160        if let Poll::Ready(Some(diff)) = res {
161            assert_eq!(diff.new, None);
162            assert!(diff.full_update);
163        }
164    }
165
166    #[tokio::test]
167    async fn parse_multi_diff_partial_chunks() {
168        let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
169        let chunk1 = Bytes::from_static(b"{\"new\": null,\"changed\": null,");
170        let chunk2 =
171            Bytes::from_static(b"\"gone\": null,\"full_update\": true}\r{\"new\": null,\"changed");
172        let chunk3 = Bytes::from_static(b"\": null,\"gone\": null,\"full_update\": true}");
173        let chunk4 = Bytes::from_static(b"\r");
174        let (tx, rx) = mpsc::channel(100);
175        tx.send(chunk1).await.unwrap();
176        tx.send(chunk2).await.unwrap();
177        tx.send(chunk3).await.unwrap();
178        tx.send(chunk4).await.unwrap();
179        let mut diffs = ResourceStream::new(rx);
180        let mut res = Pin::new(&mut diffs).poll_next(&mut cx);
181        while res.is_pending() {
182            res = Pin::new(&mut diffs).poll_next(&mut cx);
183        }
184        assert_ne!(res, Poll::Ready(None));
185        assert_ne!(res, Poll::Pending);
186        if let Poll::Ready(Some(diff)) = res {
187            assert_eq!(diff.new, None);
188            assert!(diff.full_update);
189        }
190        res = Pin::new(&mut diffs).poll_next(&mut cx);
191        while res.is_pending() {
192            res = Pin::new(&mut diffs).poll_next(&mut cx);
193        }
194        assert_ne!(res, Poll::Ready(None));
195        assert_ne!(res, Poll::Pending);
196        if let Poll::Ready(Some(diff)) = res {
197            assert_eq!(diff.new, None);
198            assert!(diff.full_update);
199        }
200    }
201}
202
203/// Makes an http connection to the rdsys backend api endpoint and returns a ResourceStream
204/// if successful
205///
206/// # Examples
207///
208/// ```ignore
209/// use rdsys_backend::start_stream;
210///
211/// let endpoint = String::from("http://127.0.0.1:7100/resource-stream");
212/// let name = String::from("https");
213/// let token = String::from("HttpsApiTokenPlaceholder");
214/// let types = vec![String::from("obfs2"), String::from("scramblesuit")];
215/// let stream = start_stream(endpoint, name, token, types).await.unwrap();
216/// loop {
217///     match Pin::new(&mut stream).poll_next(&mut cx) {
218///         Poll::Ready(Some(diff)) => println!("Received diff: {:?}", diff),
219///         Poll::Ready(None) => break,
220///         Poll::Pending => continue,
221///     }
222/// }
223/// ```
224
225pub async fn start_stream(
226    api_endpoint: String,
227    name: String,
228    token: String,
229    resource_types: Vec<String>,
230) -> Result<ResourceStream, Error> {
231    let (tx, rx) = mpsc::channel(100);
232
233    let req = proto::ResourceRequest {
234        request_origin: name,
235        resource_types,
236    };
237    let json = serde_json::to_string(&req)?;
238
239    let auth_value = format!("Bearer {}", token);
240
241    let client = Client::new();
242
243    let mut stream = client
244        .get(api_endpoint)
245        .header("Authorization", &auth_value)
246        .body(json)
247        .send()
248        .await?
249        .bytes_stream();
250
251    tokio::spawn(async move {
252        while let Some(chunk) = stream.next().await {
253            let bytes = match chunk {
254                Ok(b) => b,
255                Err(_e) => {
256                    return;
257                }
258            };
259            tx.send(bytes).await.unwrap();
260        }
261    });
262    Ok(ResourceStream::new(rx))
263}
264
265pub async fn request_resources(
266    api_endpoint: String,
267    name: String,
268    token: String,
269    resource_types: Vec<String>,
270) -> Result<proto::ResourceState, Error> {
271    let fetched_resources: Result<proto::ResourceState, Error>;
272    let req = proto::ResourceRequest {
273        request_origin: name,
274        resource_types,
275    };
276    let json = serde_json::to_string(&req)?;
277
278    let auth_value = format!("Bearer {}", token);
279
280    let client = Client::new();
281
282    let response = client
283        .get(api_endpoint)
284        .header("Authorization", &auth_value)
285        .body(json)
286        .send()
287        .await
288        .unwrap();
289    match response.status() {
290        reqwest::StatusCode::OK => {
291            fetched_resources = match response.json::<proto::ResourceState>().await {
292                Ok(fetched_resources) => Ok(fetched_resources),
293                Err(e) => Err(Error::Reqwest(e)),
294            };
295        }
296        other => fetched_resources = Err(Error::String(other)),
297    };
298    fetched_resources
299}