Skip to main content

atomr_streams_io/
http_poll.rs

1//! HTTP polling `Source` adapters.
2//!
3//! [`HttpPollSource`] turns a periodic HTTP `GET` into a
4//! [`Source<Result<HttpResponse, HttpError>>`]. Two flavours are provided:
5//!
6//! * [`HttpPollSource::new`] — fire a plain `GET` every `interval`.
7//! * [`HttpPollSource::with_etag`] — conditional `GET`: track the last
8//!   `ETag` response header and send it back as `If-None-Match` on the next
9//!   request, surfacing `304 Not Modified` as
10//!   [`HttpResponse { not_modified: true, .. }`].
11//!
12//! ## Backpressure
13//!
14//! The polling task and the consumer are decoupled by a **bounded** Tokio
15//! mpsc channel (capacity [`POLL_CHANNEL_CAPACITY`]). The producer task
16//! `send().await`s into that channel, so when the consumer falls behind the
17//! channel fills and the producer's send awaits — naturally pausing further
18//! polling until the consumer catches up. This is the simplest correct
19//! backpressure story: we never drop responses and never grow memory without
20//! bound. (Contrast with [`Source::from_receiver`], which takes an *unbounded*
21//! receiver; we deliberately bridge a bounded channel onto it via a small
22//! forwarding adaptor below so the bound is honoured end to end.)
23//!
24//! ## Rate limiting
25//!
26//! This adapter does **not** implement rate limiting itself. Compose it with
27//! the upstream limiter, e.g.
28//! `atomr_streams::rate::token_bucket(src, 10.0, 1)`, to cap request rate —
29//! see `examples/edgar_poller.rs`.
30
31use std::sync::{Arc, Mutex};
32use std::time::Duration;
33
34use atomr_streams::Source;
35use bytes::Bytes;
36
37/// Bounded capacity of the channel bridging the polling task to the stream.
38///
39/// Small on purpose: it bounds in-flight, un-consumed responses, which is
40/// what gives us backpressure onto the polling loop.
41pub const POLL_CHANNEL_CAPACITY: usize = 8;
42
43/// Describes a single HTTP request: target URL plus extra request headers.
44///
45/// Callers SHOULD always include a `User-Agent` header (many APIs — e.g. the
46/// SEC EDGAR system — reject requests without one). See the crate examples.
47#[derive(Debug, Clone)]
48pub struct RequestSpec {
49    /// Absolute URL to `GET`.
50    pub url: String,
51    /// Additional request headers as `(name, value)` pairs.
52    pub headers: Vec<(String, String)>,
53}
54
55impl RequestSpec {
56    /// Convenience constructor for a bare URL with no extra headers.
57    pub fn new(url: impl Into<String>) -> Self {
58        Self { url: url.into(), headers: Vec::new() }
59    }
60
61    /// Builder-style: append a request header.
62    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
63        self.headers.push((name.into(), value.into()));
64        self
65    }
66}
67
68/// A captured HTTP response.
69#[derive(Debug, Clone)]
70pub struct HttpResponse {
71    /// HTTP status code (e.g. `200`, `304`).
72    pub status: u16,
73    /// Response headers as `(name, value)` pairs (lower-cased names).
74    pub headers: Vec<(String, String)>,
75    /// Response body bytes (empty for `304 Not Modified`).
76    pub body: Bytes,
77    /// `true` when the server answered `304 Not Modified` to a conditional
78    /// `GET` — the body is empty and the previously fetched representation is
79    /// still current.
80    pub not_modified: bool,
81}
82
83/// Errors raised while polling.
84#[derive(Debug, thiserror::Error)]
85pub enum HttpError {
86    /// A transport / network / I/O error performing the request.
87    #[error("http transport error: {0}")]
88    Transport(String),
89    /// The request could not be constructed (bad URL, invalid header, …).
90    #[error("http request build error: {0}")]
91    Build(String),
92}
93
94/// HTTP polling source factory.
95pub struct HttpPollSource;
96
97impl HttpPollSource {
98    /// Poll `req` with a plain `GET` every `interval`, emitting each
99    /// `Result<HttpResponse, HttpError>` in order.
100    ///
101    /// The first poll happens after the first `interval` has elapsed. The
102    /// returned source is effectively infinite — terminate it downstream with
103    /// `take`, a [`atomr_streams::KillSwitch`], etc.
104    // This is a source *factory*; returning the constructed `Source` rather
105    // than `Self` is the whole point of the API.
106    #[allow(clippy::new_ret_no_self)]
107    pub fn new(req: RequestSpec, interval: Duration) -> Source<Result<HttpResponse, HttpError>> {
108        Self::spawn(req, interval, false)
109    }
110
111    /// Like [`new`](Self::new) but performs a *conditional* `GET`: after the
112    /// first successful response carrying an `ETag`, subsequent requests send
113    /// `If-None-Match: <etag>`. A `304 Not Modified` is surfaced as
114    /// [`HttpResponse`] with `not_modified == true` and an empty body.
115    pub fn with_etag(req: RequestSpec, interval: Duration) -> Source<Result<HttpResponse, HttpError>> {
116        Self::spawn(req, interval, true)
117    }
118
119    fn spawn(
120        req: RequestSpec,
121        interval: Duration,
122        use_etag: bool,
123    ) -> Source<Result<HttpResponse, HttpError>> {
124        // Bounded channel => backpressure onto the polling loop.
125        let (tx, mut rx) =
126            tokio::sync::mpsc::channel::<Result<HttpResponse, HttpError>>(POLL_CHANNEL_CAPACITY.max(1));
127
128        let client = reqwest::Client::new();
129        // Shared last-seen ETag (only mutated by the single polling task, but
130        // kept behind a Mutex so the type is simple & Send).
131        let etag: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
132
133        tokio::spawn(async move {
134            loop {
135                tokio::time::sleep(interval).await;
136
137                let conditional =
138                    if use_etag { etag.lock().unwrap_or_else(|p| p.into_inner()).clone() } else { None };
139                let result = perform_get(&client, &req, conditional, &etag, use_etag).await;
140
141                // Bounded send: awaits if the consumer is behind (backpressure).
142                if tx.send(result).await.is_err() {
143                    // Consumer dropped; stop polling.
144                    return;
145                }
146            }
147        });
148
149        // Bridge the bounded receiver onto an unbounded one for
150        // `Source::from_receiver`, preserving the upstream bound: this
151        // forwarder only pulls from the bounded `rx` when it has handed the
152        // previous item off, so the bounded channel remains the limiting
153        // buffer.
154        let (utx, urx) = tokio::sync::mpsc::unbounded_channel();
155        tokio::spawn(async move {
156            while let Some(item) = rx.recv().await {
157                if utx.send(item).is_err() {
158                    return;
159                }
160            }
161        });
162
163        Source::from_receiver(urx)
164    }
165}
166
167/// Perform one `GET`, optionally conditional, mapping the outcome to a
168/// `Result<HttpResponse, HttpError>` and updating the stored ETag.
169async fn perform_get(
170    client: &reqwest::Client,
171    req: &RequestSpec,
172    conditional_etag: Option<String>,
173    etag_store: &Arc<Mutex<Option<String>>>,
174    use_etag: bool,
175) -> Result<HttpResponse, HttpError> {
176    let mut builder = client.get(&req.url);
177    for (name, value) in &req.headers {
178        builder = builder.header(name.as_str(), value.as_str());
179    }
180    if let Some(tag) = conditional_etag {
181        builder = builder.header(reqwest::header::IF_NONE_MATCH, tag);
182    }
183
184    let resp = builder.send().await.map_err(|e| {
185        if e.is_builder() {
186            HttpError::Build(e.to_string())
187        } else {
188            HttpError::Transport(e.to_string())
189        }
190    })?;
191
192    let status = resp.status().as_u16();
193    let not_modified = status == 304;
194
195    let headers: Vec<(String, String)> = resp
196        .headers()
197        .iter()
198        .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or_default().to_string()))
199        .collect();
200
201    // Track the freshest ETag we've seen for the next conditional request.
202    if use_etag && !not_modified {
203        if let Some(tag) = resp.headers().get(reqwest::header::ETAG) {
204            if let Ok(s) = tag.to_str() {
205                *etag_store.lock().unwrap_or_else(|p| p.into_inner()) = Some(s.to_string());
206            }
207        }
208    }
209
210    let body = if not_modified {
211        Bytes::new()
212    } else {
213        resp.bytes().await.map_err(|e| HttpError::Transport(e.to_string()))?
214    };
215
216    Ok(HttpResponse { status, headers, body, not_modified })
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use atomr_streams::Sink;
223    use tokio::io::AsyncWriteExt;
224    use tokio::net::TcpListener;
225
226    /// Spawn a one-shot raw-HTTP responder that accepts a single connection,
227    /// drains the request, writes a canned `200 OK` with an ETag, and closes.
228    /// Returns the bound `http://127.0.0.1:<port>/` URL.
229    async fn canned_ok_server() -> String {
230        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
231        let addr = listener.local_addr().unwrap();
232        tokio::spawn(async move {
233            if let Ok((mut sock, _)) = listener.accept().await {
234                // Read the request bytes until the header terminator so reqwest
235                // considers the exchange well-formed; we don't parse them.
236                let mut buf = [0u8; 1024];
237                use tokio::io::AsyncReadExt;
238                let _ = sock.read(&mut buf).await;
239                let body = b"hi";
240                let resp = format!(
241                    "HTTP/1.1 200 OK\r\nETag: \"abc\"\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
242                    body.len()
243                );
244                let _ = sock.write_all(resp.as_bytes()).await;
245                let _ = sock.write_all(body).await;
246                let _ = sock.flush().await;
247                let _ = sock.shutdown().await;
248            }
249        });
250        format!("http://{addr}/")
251    }
252
253    #[tokio::test]
254    async fn first_emission_is_ok_200_against_canned_server() {
255        let url = canned_ok_server().await;
256        let req = RequestSpec::new(url).header("User-Agent", "atomr-test/0.1");
257        let src = HttpPollSource::new(req, Duration::from_millis(5));
258
259        let first = Sink::first(src).await.expect("expected one emission");
260        match first {
261            Ok(resp) => {
262                assert_eq!(resp.status, 200);
263                assert!(!resp.not_modified);
264                assert_eq!(resp.body.as_ref(), b"hi");
265                assert!(resp.headers.iter().any(|(k, v)| k.eq_ignore_ascii_case("etag") && v == "\"abc\""));
266            }
267            Err(e) => panic!("expected Ok(200), got Err: {e}"),
268        }
269    }
270
271    #[tokio::test]
272    async fn connection_refused_surfaces_transport_err() {
273        // Port 1 is reserved/unused; connect should be refused -> Transport err.
274        let req = RequestSpec::new("http://127.0.0.1:1/").header("User-Agent", "atomr-test/0.1");
275        let src = HttpPollSource::new(req, Duration::from_millis(5));
276
277        let first = Sink::first(src).await.expect("expected one emission");
278        match first {
279            Err(HttpError::Transport(_)) => {}
280            other => panic!("expected Err(Transport), got {other:?}"),
281        }
282    }
283}