mod scheduler;
mod service;
use std::convert::{TryFrom, TryInto};
use std::fmt::Debug;
use std::marker::{PhantomData, Unpin};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use diesel::dsl::*;
use diesel::prelude::*;
use diesel::r2d2::{ConnectionManager, Pool};
use futures::channel::mpsc;
use futures::{Stream, StreamExt, TryStream};
use http::uri::{PathAndQuery, Uri};
use hyper::server::conn::Http;
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use twitter_client::traits::HttpService;
use crate::feed::Feed;
use crate::manifest;
use crate::query;
use crate::schema::*;
use crate::util::ArcService;
use self::scheduler::Scheduler;
use self::service::Service;
#[pin_project]
pub struct Subscriber<S, B, I> {
#[pin]
incoming: I,
server: Http,
rx: mpsc::Receiver<(String, Feed)>,
service: Arc<Service<S, B>>,
}
impl<S, B, I> Subscriber<S, B, I>
where
S: HttpService<B> + Clone + Send + Sync + 'static,
S::Error: Debug,
S::Future: Send,
S::ResponseBody: Send,
B: Default + From<Vec<u8>> + Send + 'static,
I: TryStream,
I::Ok: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
pub fn new(
manifest: &manifest::WebSub,
incoming: I,
client: S,
pool: Pool<ConnectionManager<SqliteConnection>>,
) -> Self {
let renewal_margin = manifest.renewal_margin.as_secs();
let first_tick = query::expires_at()
.first::<i64>(&pool.get().unwrap())
.optional()
.unwrap()
.map(|expires_at| {
u64::try_from(expires_at).map_or(0, |expires_at| expires_at - renewal_margin)
});
let callback = prepare_callback_prefix(manifest.callback.clone());
let (tx, rx) = mpsc::channel(0);
let service = Arc::new(service::Service {
callback,
renewal_margin,
client,
pool,
tx,
handle: scheduler::Handle::new(first_tick),
_marker: PhantomData,
});
tokio::spawn(Scheduler::new(&service, |service| {
let conn = &*service.pool.get().unwrap();
service.renew_subscriptions(conn);
query::expires_at()
.filter(not(
websub_active_subscriptions::id.eq_any(query::renewing_subs())
))
.first::<i64>(conn)
.optional()
.unwrap()
.map(|expires_at| {
expires_at
.try_into()
.map_or(0, |expires_at| service.refresh_time(expires_at))
})
}));
Subscriber {
incoming,
server: Http::new(),
rx,
service,
}
}
fn accept_all(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), I::Error>> {
let mut this = self.project();
while let Poll::Ready(option) = this.incoming.as_mut().try_poll_next(cx)? {
match option {
None => return Poll::Ready(Ok(())),
Some(sock) => {
let service = ArcService(this.service.clone());
tokio::spawn(this.server.serve_connection(sock, service));
}
}
}
Poll::Pending
}
}
impl<S, B, I> Subscriber<S, B, I> {
pub fn service(&self) -> &Service<S, B> {
&self.service
}
}
impl<S, B, I> Stream for Subscriber<S, B, I>
where
S: HttpService<B> + Clone + Send + Sync + 'static,
S::Error: Debug,
S::Future: Send,
S::ResponseBody: Send,
B: Default + From<Vec<u8>> + Send + 'static,
I: TryStream,
I::Ok: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
type Item = Result<(String, Feed), I::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
log::trace!("Subscriber::poll_next");
let incoming_done = matches!(self.as_mut().accept_all(cx)?, Poll::Ready(()));
match self.as_mut().project().rx.poll_next_unpin(cx) {
Poll::Ready(option) => Poll::Ready(option.map(Ok)),
Poll::Pending => {
if incoming_done && Arc::strong_count(&self.service) == 1 {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
}
}
fn prepare_callback_prefix(prefix: Uri) -> Uri {
let path = prefix.path_and_query().unwrap().as_str();
if path.ends_with('/') {
prefix
} else {
let path = {
let mut buf = String::with_capacity(path.len() + 1);
buf.push_str(path);
buf.push('/');
PathAndQuery::try_from(buf).unwrap()
};
let mut parts = prefix.into_parts();
parts.path_and_query = Some(path);
Uri::from_parts(parts).unwrap()
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use std::str;
use std::time::Duration;
use bytes::Bytes;
use diesel::connection::SimpleConnection;
use diesel::r2d2::ConnectionManager;
use futures::channel::oneshot;
use futures::future;
use hmac::{Hmac, Mac, NewMac};
use http::header::CONTENT_TYPE;
use http::Uri;
use http::{Method, Request, Response, StatusCode};
use http_body::Full;
use hyper::server::conn::Http;
use hyper::{service, Client};
use sha1::Sha1;
use crate::feed;
use crate::util::connection::{Connector, Listener};
use crate::util::consts::{
APPLICATION_ATOM_XML, APPLICATION_WWW_FORM_URLENCODED, HUB_SIGNATURE,
};
use crate::util::{self, EitherUnwrapExt, FutureTimeoutExt};
use super::super::hub;
use super::*;
const TOPIC: &str = "http://example.com/feed.xml";
const HUB: &str = "http://example.com/hub";
const FEED: &str = include_str!("testcases/feed.xml");
const MARGIN: Duration = Duration::from_secs(42);
const DELAY: Duration = Duration::from_millis(1);
#[test]
fn callback_prefix() {
let uri = prepare_callback_prefix("http://example.com/".try_into().unwrap());
assert_eq!(uri, "http://example.com/");
let uri = prepare_callback_prefix("http://example.com/websub/".try_into().unwrap());
assert_eq!(uri, "http://example.com/websub/");
let uri = prepare_callback_prefix("http://example.com".try_into().unwrap());
assert_eq!(uri, "http://example.com/");
let uri = prepare_callback_prefix("http://example.com/websub".try_into().unwrap());
assert_eq!(uri, "http://example.com/websub/");
}
#[tokio::test]
async fn renew_multi_subs() {
tokio::time::pause();
let begin = i64::try_from(util::now_unix().as_secs()).unwrap();
let pool = util::r2d2::pool_with_builder(
Pool::builder().max_size(1),
ConnectionManager::new(":memory:"),
)
.unwrap();
let conn = pool.get().unwrap();
crate::migrations::run(&*conn).unwrap();
conn.batch_execute(
"INSERT INTO websub_subscriptions (hub, topic, secret)
VALUES
('http://example.com/hub', 'http://example.com/topic/1', 'secret1'),
('http://example.com/hub', 'http://example.com/topic/2', 'secret2');",
)
.unwrap();
let expiry1 = begin + MARGIN.as_secs() as i64 + 1;
let expiry2 = expiry1 + MARGIN.as_secs() as i64;
let values = [
websub_active_subscriptions::expires_at.eq(expiry1),
websub_active_subscriptions::expires_at.eq(expiry2),
];
insert_into(websub_active_subscriptions::table)
.values(&values[..])
.execute(&*conn)
.unwrap();
drop(conn);
let (mut subscriber, client, listener) = prepare_subscriber_with_pool(pool);
let mut listener = tokio_test::task::spawn(listener);
let hub = Http::new();
listener.enter(|cx, listener| assert!(listener.poll_next(cx).is_pending()));
tokio::time::advance(Duration::from_secs(2)).await;
let form = accept_request(&mut listener, &hub, "/hub").timeout().await;
let callback1 = match form {
hub::Form::Subscribe {
callback, topic, ..
} => {
assert_eq!(topic, "http://example.com/topic/1");
callback
}
_ => panic!(),
};
listener.enter(|cx, listener| assert!(listener.poll_next(cx).is_pending()));
let task = verify_intent(
&client,
&callback1,
&hub::Verify::Subscribe {
topic: "http://example.com/topic/1",
challenge: "subscription_challenge1",
lease_seconds: 2 * MARGIN.as_secs() + 2,
},
);
util::first(task.timeout(), subscriber.next())
.await
.unwrap_left();
let id = super::super::encode_callback_id(&1_i64.to_le_bytes()).to_string();
let old_callback1 = Uri::try_from(format!("http://example.com/{}", id)).unwrap();
let form = accept_request(&mut listener, &hub, "/hub").timeout().await;
match form {
hub::Form::Unsubscribe {
callback, topic, ..
} => {
assert_eq!(callback, old_callback1);
assert_eq!(topic, "http://example.com/topic/1");
}
_ => panic!(),
}
listener.enter(|cx, listener| assert!(listener.poll_next(cx).is_pending()));
let task = verify_intent(
&client,
&old_callback1,
&hub::Verify::Unsubscribe {
topic: "http://example.com/topic/1",
challenge: "unsubscription_challenge1",
},
);
util::first(task.timeout(), subscriber.next())
.await
.unwrap_left();
tokio::time::advance(MARGIN).await;
let form = accept_request(&mut listener, &hub, "/hub").timeout().await;
let callback2 = match form {
hub::Form::Subscribe {
callback, topic, ..
} => {
assert_eq!(topic, "http://example.com/topic/2");
callback
}
_ => panic!(),
};
listener.enter(|cx, listener| assert!(listener.poll_next(cx).is_pending()));
let task = verify_intent(
&client,
&callback2,
&hub::Verify::Subscribe {
topic: "http://example.com/topic/2",
challenge: "subscription_challenge2",
lease_seconds: MARGIN.as_secs() + 1,
},
);
util::first(task.timeout(), subscriber.next())
.await
.unwrap_left();
let id = super::super::encode_callback_id(&2_i64.to_le_bytes()).to_string();
let old_callback2 = Uri::try_from(format!("http://example.com/{}", id)).unwrap();
let form = accept_request(&mut listener, &hub, "/hub").timeout().await;
match form {
hub::Form::Unsubscribe {
callback, topic, ..
} => {
assert_eq!(callback, old_callback2);
assert_eq!(topic, "http://example.com/topic/2");
}
_ => panic!(),
}
listener.enter(|cx, listener| assert!(listener.poll_next(cx).is_pending()));
let task = verify_intent(
&client,
&old_callback1,
&hub::Verify::Unsubscribe {
topic: "http://example.com/topic/2",
challenge: "unsubscription_challenge2",
},
);
util::first(task.timeout(), subscriber.next())
.await
.unwrap_left();
}
#[tokio::test]
async fn entire_flow() {
tokio::time::pause();
let (subscriber, client, listener) = prepare_subscriber();
let mut subscriber = tokio_test::task::spawn(subscriber);
let mut listener = tokio_test::task::spawn(listener);
let hub = Http::new();
listener.enter(|cx, listener| assert!(listener.poll_next(cx).is_pending()));
let task = tokio::spawn(subscriber.service().discover(TOPIC.to_owned()));
tokio::time::advance(DELAY).await;
let sock = listener.next().timeout().await.unwrap().unwrap();
hub.serve_connection(
sock,
service::service_fn(|req| async move {
assert_eq!(req.method(), Method::GET);
assert_eq!(req.uri().path_and_query().unwrap(), "/feed.xml");
let res = Response::builder()
.header(CONTENT_TYPE, APPLICATION_ATOM_XML)
.body(Full::new(Bytes::from_static(FEED.as_bytes())))
.unwrap();
tokio::time::advance(DELAY).await;
Ok::<_, Infallible>(res)
}),
)
.timeout()
.await
.unwrap();
listener.enter(|cx, listener| assert!(listener.poll_next(cx).is_pending()));
assert_eq!(
subscriber.next().await.unwrap().unwrap(),
(
TOPIC.to_owned(),
Feed::parse(feed::MediaType::Xml, FEED.as_bytes()).unwrap()
)
);
let (topic, hubs) = task.timeout().await.unwrap().unwrap();
assert_eq!(topic, TOPIC);
let hubs = hubs.unwrap().collect::<Vec<_>>();
assert_eq!(hubs, [HUB]);
let req_task = subscriber.service().subscribe(
HUB.to_owned(),
topic,
&subscriber.service().pool.get().unwrap(),
);
tokio::time::advance(DELAY).await;
let accept_task = accept_request(&mut listener, &hub, "/hub");
let (result, form) = future::join(req_task, accept_task).timeout().await;
result.unwrap();
let (callback, topic, secret) = match form {
hub::Form::Subscribe {
callback,
topic,
secret,
} => (callback, topic, secret),
_ => panic!(),
};
assert_eq!(topic, TOPIC);
listener.enter(|cx, listener| assert!(listener.poll_next(cx).is_pending()));
let task = verify_intent(
&client,
&callback,
&hub::Verify::Subscribe {
topic: TOPIC,
challenge: "subscription_challenge",
lease_seconds: (2 * MARGIN).as_secs(),
},
);
util::first(task.timeout(), subscriber.next())
.await
.unwrap_left();
let mut mac = Hmac::<Sha1>::new_from_slice(secret.as_bytes()).unwrap();
mac.update(FEED.as_bytes());
let signature = mac.finalize().into_bytes();
let req = Request::post(&callback)
.header(CONTENT_TYPE, APPLICATION_ATOM_XML)
.header(HUB_SIGNATURE, format!("sha1={}", hex::encode(&*signature)))
.body(Full::new(Bytes::from_static(FEED.as_bytes())))
.unwrap();
let task = client.request(req);
let (res, update) = future::join(task, subscriber.next()).timeout().await;
let res = res.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let (topic, feed) = update.unwrap().unwrap();
assert_eq!(topic, TOPIC);
assert_eq!(
feed,
Feed::parse(feed::MediaType::Xml, FEED.as_bytes()).unwrap()
);
tokio::time::advance(MARGIN).await;
tokio::time::advance(DELAY).await;
let task = accept_request(&mut listener, &hub, "/hub").timeout();
let form = util::first(task, subscriber.next()).await.unwrap_left();
let (new_callback, topic) = match form {
hub::Form::Subscribe {
callback,
topic,
secret: _,
} => (callback, topic),
_ => panic!(),
};
assert_ne!(new_callback, callback);
assert_eq!(topic, TOPIC);
listener.enter(|cx, listener| assert!(listener.poll_next(cx).is_pending()));
let task = verify_intent(
&client,
&new_callback,
&hub::Verify::Subscribe {
topic: TOPIC,
challenge: "renewal_challenge",
lease_seconds: (2 * MARGIN).as_secs(),
},
);
util::first(task.timeout(), subscriber.next())
.await
.unwrap_left();
tokio::time::advance(DELAY).await;
let task = accept_request(&mut listener, &hub, "/hub").timeout();
let form = util::first(task, subscriber.next()).await.unwrap_left();
let (unsubscribed, topic) = match form {
hub::Form::Unsubscribe { callback, topic } => (callback, topic),
_ => panic!(),
};
assert_eq!(unsubscribed, callback);
assert_eq!(topic, TOPIC);
listener.enter(|cx, listener| assert!(listener.poll_next(cx).is_pending()));
let task = verify_intent(
&client,
&callback,
&hub::Verify::Unsubscribe {
topic: TOPIC,
challenge: "unsubscription_challenge",
},
);
util::first(task.timeout(), subscriber.next())
.await
.unwrap_left();
subscriber.enter(|cx, subscriber| assert!(subscriber.poll_next(cx).is_pending()));
let id = callback.path().rsplit('/').next().unwrap();
let id = crate::websub::decode_callback_id(id).unwrap() as i64;
let conn = subscriber.service().pool.get().unwrap();
let row = websub_subscriptions::table.find(id);
assert!(!select(exists(row)).get_result::<bool>(&conn).unwrap());
let row = websub_active_subscriptions::table.find(id);
assert!(!select(exists(row)).get_result::<bool>(&conn).unwrap());
let row =
websub_renewing_subscriptions::table.filter(websub_renewing_subscriptions::old.eq(id));
assert!(!select(exists(row)).get_result::<bool>(&conn).unwrap());
}
fn prepare_subscriber() -> (
Subscriber<Client<Connector, Full<Bytes>>, Full<Bytes>, Listener>,
Client<Connector, Full<Bytes>>,
Listener,
) {
let pool = util::r2d2::pool_with_builder(
Pool::builder().max_size(1),
ConnectionManager::new(":memory:"),
)
.unwrap();
crate::migrations::run(&*pool.get().unwrap()).unwrap();
prepare_subscriber_with_pool(pool)
}
fn prepare_subscriber_with_pool(
pool: Pool<ConnectionManager<SqliteConnection>>,
) -> (
Subscriber<Client<Connector, Full<Bytes>>, Full<Bytes>, Listener>,
Client<Connector, Full<Bytes>>,
Listener,
) {
let (hub_conn, sub_listener) = util::connection();
let (sub_conn, hub_listener) = util::connection();
let sub_client = Client::builder()
.pool_idle_timeout(Duration::from_secs(0))
.pool_max_idle_per_host(0)
.build::<_, Full<Bytes>>(sub_conn);
let hub_client = Client::builder()
.pool_idle_timeout(Duration::from_secs(0))
.pool_max_idle_per_host(0)
.build::<_, Full<Bytes>>(hub_conn);
let manifest = manifest::WebSub {
callback: Uri::from_static("http://example.com/"),
bind: None,
renewal_margin: MARGIN,
};
let subscriber = Subscriber::new(&manifest, sub_listener, sub_client, pool);
(subscriber, hub_client, hub_listener)
}
async fn accept_request<'a>(
listener: &'a mut Listener,
http: &'a Http,
path: &'static str,
) -> hub::Form {
let (tx, rx) = oneshot::channel();
let mut tx = Some(tx);
let sock = listener.next().await;
http.serve_connection(
sock.unwrap().unwrap(),
service::service_fn(move |req| {
let tx = tx.take().unwrap();
#[allow(clippy::borrow_interior_mutable_const)]
async move {
assert_eq!(req.uri().path_and_query().unwrap(), path);
assert_eq!(
req.headers().get(CONTENT_TYPE).unwrap(),
APPLICATION_WWW_FORM_URLENCODED
);
let body = hyper::body::to_bytes(req).await.unwrap();
let form: hub::Form = serde_urlencoded::from_bytes(&body).unwrap();
tx.send(form).unwrap();
let res = Response::builder()
.status(StatusCode::ACCEPTED)
.body(http_body::Empty::<Bytes>::new())
.unwrap();
tokio::time::advance(DELAY).await;
Ok::<_, Infallible>(res)
}
}),
)
.await
.unwrap();
rx.await.unwrap()
}
fn verify_intent<'a>(
client: &Client<Connector, Full<Bytes>>,
callback: &Uri,
query: &hub::Verify<&'a str>,
) -> impl std::future::Future<Output = ()> + 'a {
let challenge = match *query {
hub::Verify::Subscribe { challenge, .. }
| hub::Verify::Unsubscribe { challenge, .. } => challenge,
};
let query = serde_urlencoded::to_string(query).unwrap();
let res = client.get(format!("{}?{}", callback, query).try_into().unwrap());
async move {
let res = res.await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = hyper::body::to_bytes(res).timeout().await.unwrap();
assert_eq!(body, challenge.as_bytes());
}
}
}