use crate::pinboard;
use futures::future::{BoxFuture, FutureExt};
use pin_project::pin_project;
use reqwest::Url;
use serde::Deserialize;
use serde_json::Deserializer;
use snafu::{prelude::*, Backtrace, IntoError};
use tokio::sync::Mutex; use tracing::{debug, error, trace, warn};
use std::{
cmp::min,
collections::VecDeque,
future::Future,
pin::Pin,
str::FromStr,
sync::Arc,
task::{Context, Poll},
};
#[non_exhaustive]
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("HTPP error {source}"))]
Http {
source: reqwest::Error,
backtrace: Backtrace,
},
#[snafu(display("Pinboard error {source}"))]
Pinboard {
source: pinboard::Error,
backtrace: Backtrace,
},
#[snafu(display("Bad URL {source}"))]
Url {
source: url::ParseError,
backtrace: Backtrace,
},
}
type Result<T> = std::result::Result<T, Error>;
enum Pending {
None,
AwaitingResponse(BoxFuture<'static, pinboard::Result<reqwest::Response>>),
ProcessingResponseHeaders(
(
Arc<Mutex<reqwest::Response>>,
BoxFuture<'static, reqwest::Result<Option<bytes::Bytes>>>,
),
),
#[allow(clippy::type_complexity)] ProcessingResponseBody(
(
Arc<Mutex<reqwest::Response>>,
BoxFuture<'static, reqwest::Result<Option<bytes::Bytes>>>,
Vec<u8>,
usize,
),
),
}
#[derive(Debug, Deserialize)]
struct Entry {
href: String,
}
#[pin_project]
pub struct GreedyUrlStream<I>
where
I: Iterator<Item = String>,
{
client: pinboard::Client,
url_or_tags: I,
banked_urls: VecDeque<Result<Url>>,
current_request: Pending,
}
impl<I> GreedyUrlStream<I>
where
I: Iterator<Item = String>,
{
fn parse_urls_to_tag(
client: &pinboard::Client,
url_or_tags: &mut I,
) -> (VecDeque<Result<Url>>, Pending) {
let mut banked_urls = VecDeque::new();
let mut current_request = Pending::None;
for item in url_or_tags {
match Url::parse(&item) {
Ok(url) => banked_urls.push_back(Ok(url)),
Err(_) => {
match item
.split("+")
.map(pinboard::Tag::from_str)
.collect::<pinboard::Result<Vec<pinboard::Tag>>>()
{
Ok(tags) => {
let client_clone = client.clone();
current_request = Pending::AwaitingResponse(
async move { client_clone.all_posts(tags.iter().cloned()).await }
.boxed(),
);
break;
}
Err(err) => banked_urls.push_back(Err(PinboardSnafu.into_error(err))),
}
}
}
}
(banked_urls, current_request)
}
pub fn new(client: pinboard::Client, mut url_or_tags: I) -> Result<GreedyUrlStream<I>> {
let (banked_urls, current_request) =
GreedyUrlStream::parse_urls_to_tag(&client, &mut url_or_tags);
Ok(GreedyUrlStream {
client,
url_or_tags,
banked_urls,
current_request,
})
}
fn consume_more_urls_or_tags(&mut self) {
let (banked_urls, current_request) =
GreedyUrlStream::parse_urls_to_tag(&self.client, &mut self.url_or_tags);
self.banked_urls.extend(banked_urls);
self.current_request = current_request;
}
}
fn short_string(b: &[u8]) -> &str {
std::str::from_utf8(&b[0..min(b.len(), 32)]).unwrap()
}
fn parse_urls_from_chunk(buf: &mut Vec<u8>) -> Result<(Vec<u8>, VecDeque<Url>, usize)> {
trace!(
"Starting to parse URLs: {}:{}...",
buf.len(),
short_string(buf)
);
let mut urls = VecDeque::new();
let mut bytes_to_consume = 0;
loop {
if buf.len() == 1 && buf[0] == b']' {
break;
}
let mut deser = Deserializer::from_slice(buf).into_iter::<Entry>();
match deser.next() {
None => {
break;
}
Some(result) => match result {
Ok(entry) => {
let url = Url::parse(&entry.href).map_err(|err| UrlSnafu.into_error(err))?;
trace!("Parsed URL: {}", url.as_str());
urls.push_back(url);
let offset = deser.byte_offset();
drop(deser);
let mut split_at = offset + 2;
if offset == buf.len() {
bytes_to_consume = 2;
split_at = offset;
} else if offset == buf.len() - 1 {
bytes_to_consume = 1;
split_at = offset + 1;
}
*buf = buf.split_off(split_at);
}
Err(err) => {
warn!("While deserializing URLs: {:#?}", err);
break;
}
},
}
}
trace!("The buffer is now: {}:{}...", buf.len(), short_string(buf));
Ok((buf.to_vec(), urls, bytes_to_consume))
}
fn poll_for_banked(banked_urls: &mut VecDeque<Result<Url>>) -> Poll<Option<Result<Url>>> {
match banked_urls.pop_front() {
Some(url) => Poll::Ready(Some(url)),
None => Poll::Pending,
}
}
fn handle_awaiting_response(res: pinboard::Result<reqwest::Response>) -> Result<Pending> {
let rsp = res.map_err(|err| PinboardSnafu.into_error(err))?;
let rsp = Arc::new(Mutex::new(rsp));
let rsp_clone = rsp.clone();
Ok(Pending::ProcessingResponseHeaders((
rsp,
async move { rsp_clone.lock().await.chunk().await }.boxed(),
)))
}
fn handle_processing_response_headers(
rsp: Arc<Mutex<reqwest::Response>>,
mut chunk: bytes::Bytes,
) -> Result<(Pending, VecDeque<Url>)> {
let mut new_chunk: Vec<u8> = chunk.split_off(1).into_iter().collect();
let (buf, urls, bytes_to_consume) = parse_urls_from_chunk(&mut new_chunk)?;
let rsp_clone1 = rsp.clone();
let rsp_clone2 = rsp.clone();
Ok((
Pending::ProcessingResponseBody((
rsp_clone1,
async move { rsp_clone2.lock().await.chunk().await }.boxed(),
buf,
bytes_to_consume,
)),
urls,
))
}
fn handle_processing_response_body(
rsp: Arc<Mutex<reqwest::Response>>,
bytes: bytes::Bytes,
buf: &mut Vec<u8>,
mut bytes_to_consume: usize,
) -> Result<(Pending, VecDeque<Url>)> {
buf.extend_from_slice(&bytes);
let bytes_to_consume_this_time = min(bytes_to_consume, buf.len());
if bytes_to_consume_this_time != 0 {
*buf = buf.split_off(bytes_to_consume_this_time);
bytes_to_consume -= bytes_to_consume_this_time;
}
let (buf, urls, bytes_to_consume_next_time) = parse_urls_from_chunk(buf)?;
bytes_to_consume += bytes_to_consume_next_time;
let rsp_clone1 = rsp.clone();
let rsp_clone2 = rsp.clone();
Ok((
Pending::ProcessingResponseBody((
rsp_clone1,
async move { rsp_clone2.lock().await.chunk().await }.boxed(),
buf,
bytes_to_consume,
)),
urls,
))
}
impl<I> futures::stream::Stream for GreedyUrlStream<I>
where
I: Iterator<Item = String>,
{
type Item = Result<Url>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match &mut self.current_request {
Pending::None => {
return Poll::Ready(self.banked_urls.pop_front());
}
Pending::AwaitingResponse(box_fut) => {
match Pin::new(box_fut).poll(cx) {
Poll::Pending => {
return poll_for_banked(&mut self.banked_urls);
}
Poll::Ready(res) => {
match handle_awaiting_response(res) {
Ok(curr_req) => {
self.current_request = curr_req;
}
Err(err) => {
return Poll::Ready(Some(Err(err)));
}
}
}
}
}
Pending::ProcessingResponseHeaders((rsp, box_fut)) => {
debug!("ProcessingResponseHeaders");
match Pin::new(box_fut).poll(cx) {
Poll::Pending => {
return poll_for_banked(&mut self.banked_urls);
}
Poll::Ready(res) => {
let opt_chunk = match res {
Ok(opt_chunk) => opt_chunk,
Err(err) => {
error!("While processing response headers: {:?}", err);
self.consume_more_urls_or_tags();
return Poll::Ready(Some(Err(HttpSnafu.into_error(err))));
}
};
match opt_chunk {
None => {
warn!("No posts for this tag");
self.consume_more_urls_or_tags();
}
Some(chunk) => {
let (req, urls) = match handle_processing_response_headers(
rsp.clone(),
chunk,
) {
Ok((req, urls)) => (req, urls),
Err(err) => {
error!("While parsing the response body: {:?}", err);
self.consume_more_urls_or_tags();
return Poll::Ready(Some(Err(err)));
}
};
self.current_request = req;
self.banked_urls.extend(urls.into_iter().map(Ok));
}
}
}
}
}
Pending::ProcessingResponseBody((rsp, box_fut, buf, bytes_to_consume)) => {
debug!("ProcessingResponseBody");
match Pin::new(box_fut).poll(cx) {
Poll::Pending => {
return poll_for_banked(&mut self.banked_urls);
}
Poll::Ready(res) => {
let res = match res {
Ok(res) => res,
Err(err) => {
error!("While processing response body: {:?}", err);
self.consume_more_urls_or_tags();
return Poll::Ready(Some(Err(HttpSnafu.into_error(err))));
}
};
match res {
Some(bytes) => {
debug!("Got a chunk of {} bytes.", bytes.len());
let (req, urls) = match handle_processing_response_body(
rsp.clone(),
bytes,
buf,
*bytes_to_consume,
) {
Ok((req, urls)) => (req, urls),
Err(err) => {
error!("While parsing the response body: {:?}", err);
self.consume_more_urls_or_tags();
return Poll::Ready(Some(Err(err)));
}
};
self.current_request = req;
self.banked_urls.extend(urls.into_iter().map(Ok));
}
None => {
debug!("Finished parsing response body.");
self.consume_more_urls_or_tags();
}
}
}
}
}
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
use itertools::intersperse;
use test_log::test;
use tracing::{error, info};
use std::collections::HashMap;
struct TestServer {
addr: Url,
tags: HashMap<String, Vec<String>>,
deleted_urls: Vec<String>,
}
impl TestServer {
pub async fn new<'a, T, U>(table: T) -> Arc<Mutex<TestServer>>
where
T: IntoIterator<Item = &'a (&'static str, U)>,
U: IntoIterator<Item = &'a &'static str>,
U: 'a + Copy,
{
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let mut tags = HashMap::new();
for (tag, urls) in table {
tags.insert(
String::from(*tag),
urls.into_iter()
.map(|x: &&str| String::from(*x))
.collect::<Vec<String>>(),
);
}
let base_url =
Url::parse(&format!("http://{}", listener.local_addr().unwrap())).unwrap();
let server = Arc::new(Mutex::new(TestServer {
addr: base_url.clone(),
tags: tags,
deleted_urls: Vec::new(),
}));
let server_clone = server.clone();
tokio::spawn(async move {
loop {
let (mut stream, _) = listener.accept().await.unwrap();
let inner_server = server_clone.clone();
let inner_url = base_url.clone();
tokio::spawn(async move {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut incoming = vec![];
loop {
let mut buf = vec![0u8; 1024];
let read = stream.read(&mut buf).await.unwrap();
incoming.extend_from_slice(&buf[..read]);
if incoming.len() > 4 && &incoming[incoming.len() - 4..] == b"\r\n\r\n"
{
break;
}
}
let request = std::str::from_utf8(&incoming).expect("Non-UTF8 request");
debug!("TestServer got a request: {}", request);
if request.starts_with("GET /v1/posts/all") {
let idx = request.find("\r\n").unwrap();
let full_url = inner_url.join(&request[4..idx - 9]).unwrap();
let mut tag = None;
for pair in full_url.query_pairs() {
if pair.0 == "tag" {
info!("listing all URLs with tag {}", pair.1);
tag = Some(pair.1);
}
}
stream.write_all(b"HTTP/1.1 200 OK\r\n\r\n[").await.unwrap();
let tag = tag.unwrap().into_owned();
let body = intersperse(inner_server.lock().await.tags.get(&tag).unwrap().iter()
.map(|url|
format!("{{\"href\":\"{}\",\"description\":\"some description\",\"extended\":\"\",\"meta\":\"7f1cc4538d3047b90452e5792ce650df\",\"hash\":\"0af2a7ee48a8b92d8ca4bc4340739097\",\"time\":\"2021-11-17T20:20:08Z\",\"shared\":\"no\",\"toread\":\"no\",\"tags\":\"{} 2022\"}}", url, tag)
), String::from(",\n"))
.fold(String::new(), |mut a, b| {
a.push_str(&b);
a
});
stream.write_all(body.as_bytes()).await.unwrap();
stream.write_all(b"]").await.unwrap();
} else if request.starts_with("GET /v1/posts/delete") {
let idx = request.find("\r\n").unwrap();
let full_url = inner_url.join(&request[4..idx - 9]).unwrap();
for pair in full_url.query_pairs() {
if pair.0 == "url" {
info!("deleting Pinboard post {}", pair.1);
inner_server.lock().await.deleted_urls.push(pair.1.into());
}
}
stream
.write_all(b"HTTP/1.1 200 OK\r\n\r\n{\"result_code\":\"done\"}")
.await
.unwrap();
} else {
error!("TestServer 404!");
stream
.write_all(b"HTTP/1.1 404 Not Found\r\n")
.await
.unwrap();
}
});
}
});
server
}
pub fn server_url(&self) -> Url {
self.addr.clone()
}
pub fn deleted_urls(&self) -> Vec<String> {
self.deleted_urls.clone()
}
}
#[test(tokio::test)]
async fn smoke() {
let server = TestServer::new(&[
("foo", &["http://foo.com", "https://fooish.com"]),
("bar", &["http://bar.com", "https://barbinator.com"]),
])
.await;
let client;
{
let guard = server.lock().await;
client =
pinboard::Client::new(guard.server_url(), "sp1ff:FFFFFFFFFFFFFFFFFFFF").unwrap();
}
let mut my_stream = GreedyUrlStream::new(
client.clone(),
vec!["http://www.unwoundstack.com".to_string(), "foo".to_string()].into_iter(),
)
.unwrap();
use futures::stream::StreamExt;
while let Some(url) = my_stream.next().await {
info!("My stream yielded {:?}", url);
client.delete_post(url.unwrap()).await.unwrap();
}
assert_eq!(
server.lock().await.deleted_urls(),
vec![
String::from("http://www.unwoundstack.com/"),
String::from("http://foo.com/"),
String::from("https://fooish.com/")
]
);
}
}
#[cfg(test)]
#[cfg(feature = "personal-link-tests")]
mod link_tests {
use super::*;
use tracing::{debug, error};
use std::path::Path;
pub async fn test_server<P: AsRef<Path>>(pth: P, mut chunk_size: usize) -> Url {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let base_url = Url::parse(&format!("http://{}", listener.local_addr().unwrap())).unwrap();
let body = std::fs::read_to_string(pth).unwrap();
tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut incoming = vec![];
loop {
let mut buf = vec![0u8; 1024];
let read = stream.read(&mut buf).await.unwrap();
incoming.extend_from_slice(&buf[..read]);
if incoming.len() > 4 && &incoming[incoming.len() - 4..] == b"\r\n\r\n" {
break;
}
}
let request = std::str::from_utf8(&incoming).expect("Non-UTF8 request");
debug!("TestServer got a request: {}", request);
if request.starts_with("GET /v1/posts/all") {
stream
.write_all(b"HTTP/1.1 200 OK\r\nContent-Type: text/json; charset=utf8\r\n")
.await
.unwrap();
let nbytes = body.len();
if nbytes < chunk_size {
stream
.write_all(format!("Content-Length: {}\r\n\r\n{}", nbytes, body).as_bytes())
.await
.unwrap();
debug!(
"Response body was less than chunk_size ({}); wrote in one shot & exiting.",
chunk_size
);
} else {
stream
.write_all(b"Transfer-Encoding: chunked\r\n\r\n")
.await
.unwrap();
let mut nwritten = 0;
let mut bytes: &[u8] = body.as_bytes();
while nwritten < nbytes {
stream
.write_all(format!("{:X}\r\n", chunk_size).as_bytes())
.await
.unwrap();
stream.write_all(&bytes[0..chunk_size]).await.unwrap();
stream.write_all(b"\r\n").await.unwrap();
nwritten += chunk_size;
let (_, remaining) = bytes.split_at(chunk_size);
bytes = remaining;
if chunk_size > nbytes - nwritten {
chunk_size = nbytes - nwritten;
}
}
stream.write_all(b"0\r\n\r\n").await.unwrap();
debug!("All {} bytes written-- exiting.", nbytes);
}
} else {
error!("TestServer 404!");
stream
.write_all(b"HTTP/1.1 404 Not Found\r\n")
.await
.unwrap();
}
});
base_url
}
#[tokio::test]
async fn linkedin_and_jira() {
let url = test_server(&Path::new("linkedin-and-jira.json"), 4096).await;
let client = pinboard::Client::new(url, "sp1ff:FFFFFFFFFFFFFFFFFFFF").unwrap();
let stream =
GreedyUrlStream::new(client, vec!["linkedin+jira".to_string()].into_iter()).unwrap();
use futures::StreamExt;
assert_eq!(stream.count().await, 4);
}
#[tokio::test]
async fn linkedin() {
let url = test_server(&Path::new("linkedin.json"), 4096).await;
let client = pinboard::Client::new(url, "sp1ff:FFFFFFFFFFFFFFFFFFFF").unwrap();
let mut stream =
GreedyUrlStream::new(client, vec!["linkedin".to_string()].into_iter()).unwrap();
use futures::StreamExt;
let mut count = 0;
while let Some(url) = stream.next().await {
debug!("My stream yielded {:?}", url);
count += 1;
}
assert_eq!(count, 331);
}
}