use futures::Stream;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use url::Url;
use uuid::Uuid;
use async_channel::{unbounded, Receiver, Sender, TryRecvError};
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response};
use thiserror::Error;
const QUERY_PARAM_NAME: &'static str = "_async_webhook_id";
#[derive(Debug, Error)]
enum Error {
#[error("Webhook receiver expected a unique UID but none was found")]
WebhookMissingUid,
}
struct Registry {
drop: Sender<Uuid>,
requests: Arc<Mutex<HashMap<Uuid, (Arc<Mutex<Option<Waker>>>, Sender<Request<Body>>)>>>,
}
impl Clone for Registry {
fn clone(&self) -> Self {
Registry {
drop: self.drop.clone(),
requests: Arc::clone(&self.requests),
}
}
}
impl Registry {
fn new() -> Self {
let (drop, mut handle_drop) = unbounded::<Uuid>();
let requests = Arc::new(Mutex::new(HashMap::new()));
{
let requests = requests.clone();
use futures::stream::StreamExt;
tokio::task::spawn(async move {
while let Some(id) = handle_drop.next().await {
requests.lock().unwrap().remove(&id);
}
});
}
Self { drop, requests }
}
fn register(&self, uid: Uuid) -> (Arc<Mutex<Option<Waker>>>, Receiver<Request<Body>>) {
let (sender, receiver) = unbounded();
let mut g = self.requests.lock().unwrap();
let waker = Arc::new(Mutex::new(None));
g.insert(uid, (waker.clone(), sender));
(waker, receiver)
}
fn notify(&self, uid: Uuid, t: Request<Body>) {
let mut g = self.requests.lock().unwrap();
if let Some((waker, sender)) = g.remove(&uid) {
sender.try_send(t).expect(
"Webhook couldn't have been
droped, because, otherwise, requests would have been locked
and the uid removed!",
);
waker.lock().unwrap().as_ref().map(|waker| {
waker.wake_by_ref();
});
}
}
}
struct WebhookInner {
url: Url,
uid: Uuid,
value: Receiver<Request<Body>>,
dropper: Sender<Uuid>,
waker: Arc<Mutex<Option<Waker>>>,
}
impl Drop for WebhookInner {
fn drop(&mut self) {
self.dropper
.try_send(self.uid.clone())
.expect("Registry should have an open channel to receive dropped requests");
}
}
pub struct Webhook {
inner: WebhookInner,
}
impl Webhook {
pub fn url(&self) -> &Url {
&self.inner.url
}
}
impl Stream for Webhook {
type Item = Request<Body>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = Pin::into_inner(self);
*this.inner.waker.lock().unwrap() = Some(_cx.waker().clone());
match this.inner.value.try_recv() {
Ok(l) => Poll::Ready(Some(l)),
Err(TryRecvError::Empty) => Poll::Pending,
Err(TryRecvError::Closed) => Poll::Ready(None),
}
}
}
#[derive(Clone)]
pub struct Server {
inner: Arc<Inner>,
}
struct Inner {
endpoint: url::Url,
registry: Registry,
stop: Sender<()>,
}
impl Drop for Inner {
fn drop(&mut self) {
self.stop.try_send(()).ok();
}
}
impl Server {
pub fn start<A: Into<SocketAddr>>(addr: A) -> Self {
let addr = addr.into();
let endpoint = Url::parse(format!("http://{}", addr).as_str()).unwrap();
Self::with_endpoint(addr, endpoint)
}
pub fn start_with_proxy<A: Into<SocketAddr>>(addr: A, proxy: Url) -> Self {
Self::with_endpoint(addr.into(), proxy)
}
fn with_endpoint(addr: SocketAddr, endpoint: Url) -> Self {
let registry = Registry::new();
type BoxError = Box<dyn std::error::Error + Send + Sync>;
async fn handle_request(
registry: Registry,
request: Request<Body>,
) -> Result<Response<Body>, BoxError> {
let registry = registry.clone();
let url = format!("http://localhost{}", request.uri());
let url = Url::parse(url.as_str()).expect("hyper::Uri is url::Url parsable");
let uid = url
.query_pairs()
.find(|(k, _)| k.as_ref() == QUERY_PARAM_NAME)
.map(|(_, uid)| {
Uuid::parse_str(uid.as_ref()).map_err(|_| Error::WebhookMissingUid)
})
.ok_or_else(|| Error::WebhookMissingUid)??;
registry.notify(uid, request);
Ok(Response::new(Body::from("Hello World")))
}
let keep_registry = registry.clone();
let make_service = make_service_fn(move |_conn| {
let registry = registry.clone();
async move {
Ok::<_, Error>(service_fn(move |request| {
let registry = registry.clone();
handle_request(registry, request)
}))
}
});
let (stop, on_stop) = async_channel::bounded::<()>(16);
let server = hyper::server::Server::bind(&addr)
.serve(make_service)
.with_graceful_shutdown(async move {
on_stop.recv().await.ok();
});
tokio::spawn(async { server.await });
Server {
inner: Arc::new(Inner {
registry: keep_registry,
stop,
endpoint,
}),
}
}
pub fn spawn(&self) -> Webhook {
let uid = Uuid::new_v4();
let url = {
let mut url = self.inner.endpoint.clone();
url.query_pairs_mut()
.append_pair(QUERY_PARAM_NAME, uid.to_string().as_str());
url
};
let (waker, value) = self.inner.registry.register(uid.clone());
let dropper = self.inner.registry.drop.clone();
Webhook {
inner: WebhookInner {
url,
uid,
value,
dropper,
waker,
},
}
}
}
#[cfg(test)]
mod tests {
use crate::{Server, Webhook};
use futures::{StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
#[tokio::test]
async fn test_callback_pattern() {
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct Request {
message: String,
callback: String,
}
#[derive(Serialize, Deserialize)]
struct Response {
message: String,
}
eprintln!("Spawning API Server");
let api_server = {
use warp::Filter;
let api_route =
warp::any()
.and(warp::body::json::<Request>())
.and_then(|req: Request| async {
let client = reqwest::Client::new();
eprintln!(
"API received message, sending callback to: {}",
req.callback.as_str()
);
match client
.post(req.callback.as_str())
.json(&Response {
message: req.message,
})
.send()
.await
{
Ok(r) => {
eprintln!("Webhook server responded with status: {}", r.status());
Ok(warp::reply())
}
Err(_) => Err(warp::reject::reject()),
}
});
tokio::spawn(warp::serve(api_route).run(([127, 0, 0, 1], 3032)))
};
eprintln!("Starting callback server");
let server: Server = Server::start(([127, 0, 0, 1], 3031));
eprintln!("Sending API request");
let mut webhook: Webhook = server.spawn();
let client = reqwest::Client::new();
client
.post("http://localhost:3032")
.json(&Request {
callback: webhook.url().to_string(),
message: "hey".to_string(),
})
.send()
.await;
let res: Option<Response> = if let Some(response) = webhook.next().await {
let bytes = response
.into_body()
.try_fold(Vec::new(), |mut data, chunk| async move {
data.extend_from_slice(&chunk);
Ok(data)
})
.await
.unwrap();
serde_json::from_slice(bytes.as_slice()).unwrap()
} else {
None
};
assert_eq!(res.unwrap().message.as_str(), "hey");
}
}