use std::sync::{Arc, Mutex};
use std::time::Duration;
use atomr_streams::Source;
use bytes::Bytes;
pub const POLL_CHANNEL_CAPACITY: usize = 8;
#[derive(Debug, Clone)]
pub struct RequestSpec {
pub url: String,
pub headers: Vec<(String, String)>,
}
impl RequestSpec {
pub fn new(url: impl Into<String>) -> Self {
Self { url: url.into(), headers: Vec::new() }
}
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.push((name.into(), value.into()));
self
}
}
#[derive(Debug, Clone)]
pub struct HttpResponse {
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: Bytes,
pub not_modified: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum HttpError {
#[error("http transport error: {0}")]
Transport(String),
#[error("http request build error: {0}")]
Build(String),
}
pub struct HttpPollSource;
impl HttpPollSource {
#[allow(clippy::new_ret_no_self)]
pub fn new(req: RequestSpec, interval: Duration) -> Source<Result<HttpResponse, HttpError>> {
Self::spawn(req, interval, false)
}
pub fn with_etag(req: RequestSpec, interval: Duration) -> Source<Result<HttpResponse, HttpError>> {
Self::spawn(req, interval, true)
}
fn spawn(
req: RequestSpec,
interval: Duration,
use_etag: bool,
) -> Source<Result<HttpResponse, HttpError>> {
let (tx, mut rx) =
tokio::sync::mpsc::channel::<Result<HttpResponse, HttpError>>(POLL_CHANNEL_CAPACITY.max(1));
let client = reqwest::Client::new();
let etag: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let conditional =
if use_etag { etag.lock().unwrap_or_else(|p| p.into_inner()).clone() } else { None };
let result = perform_get(&client, &req, conditional, &etag, use_etag).await;
if tx.send(result).await.is_err() {
return;
}
}
});
let (utx, urx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
while let Some(item) = rx.recv().await {
if utx.send(item).is_err() {
return;
}
}
});
Source::from_receiver(urx)
}
}
async fn perform_get(
client: &reqwest::Client,
req: &RequestSpec,
conditional_etag: Option<String>,
etag_store: &Arc<Mutex<Option<String>>>,
use_etag: bool,
) -> Result<HttpResponse, HttpError> {
let mut builder = client.get(&req.url);
for (name, value) in &req.headers {
builder = builder.header(name.as_str(), value.as_str());
}
if let Some(tag) = conditional_etag {
builder = builder.header(reqwest::header::IF_NONE_MATCH, tag);
}
let resp = builder.send().await.map_err(|e| {
if e.is_builder() {
HttpError::Build(e.to_string())
} else {
HttpError::Transport(e.to_string())
}
})?;
let status = resp.status().as_u16();
let not_modified = status == 304;
let headers: Vec<(String, String)> = resp
.headers()
.iter()
.map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or_default().to_string()))
.collect();
if use_etag && !not_modified {
if let Some(tag) = resp.headers().get(reqwest::header::ETAG) {
if let Ok(s) = tag.to_str() {
*etag_store.lock().unwrap_or_else(|p| p.into_inner()) = Some(s.to_string());
}
}
}
let body = if not_modified {
Bytes::new()
} else {
resp.bytes().await.map_err(|e| HttpError::Transport(e.to_string()))?
};
Ok(HttpResponse { status, headers, body, not_modified })
}
#[cfg(test)]
mod tests {
use super::*;
use atomr_streams::Sink;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
async fn canned_ok_server() -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
if let Ok((mut sock, _)) = listener.accept().await {
let mut buf = [0u8; 1024];
use tokio::io::AsyncReadExt;
let _ = sock.read(&mut buf).await;
let body = b"hi";
let resp = format!(
"HTTP/1.1 200 OK\r\nETag: \"abc\"\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
let _ = sock.write_all(resp.as_bytes()).await;
let _ = sock.write_all(body).await;
let _ = sock.flush().await;
let _ = sock.shutdown().await;
}
});
format!("http://{addr}/")
}
#[tokio::test]
async fn first_emission_is_ok_200_against_canned_server() {
let url = canned_ok_server().await;
let req = RequestSpec::new(url).header("User-Agent", "atomr-test/0.1");
let src = HttpPollSource::new(req, Duration::from_millis(5));
let first = Sink::first(src).await.expect("expected one emission");
match first {
Ok(resp) => {
assert_eq!(resp.status, 200);
assert!(!resp.not_modified);
assert_eq!(resp.body.as_ref(), b"hi");
assert!(resp.headers.iter().any(|(k, v)| k.eq_ignore_ascii_case("etag") && v == "\"abc\""));
}
Err(e) => panic!("expected Ok(200), got Err: {e}"),
}
}
#[tokio::test]
async fn connection_refused_surfaces_transport_err() {
let req = RequestSpec::new("http://127.0.0.1:1/").header("User-Agent", "atomr-test/0.1");
let src = HttpPollSource::new(req, Duration::from_millis(5));
let first = Sink::first(src).await.expect("expected one emission");
match first {
Err(HttpError::Transport(_)) => {}
other => panic!("expected Err(Transport), got {other:?}"),
}
}
}