use std::{ops::Deref, sync::Arc};
use http::Request;
use tokio::sync::watch;
use super::Connected;
#[derive(Debug, Clone)]
pub struct CaptureConnection {
rx: watch::Receiver<Option<Connected>>,
}
pub fn capture_connection<B>(request: &mut Request<B>) -> CaptureConnection {
let (tx, rx) = CaptureConnection::new();
request.extensions_mut().insert(tx);
rx
}
#[derive(Clone)]
pub(crate) struct CaptureConnectionExtension {
tx: Arc<watch::Sender<Option<Connected>>>,
}
impl CaptureConnectionExtension {
pub(crate) fn set(&self, connected: &Connected) {
self.tx.send_replace(Some(connected.clone()));
}
}
impl CaptureConnection {
pub(crate) fn new() -> (CaptureConnectionExtension, Self) {
let (tx, rx) = watch::channel(None);
(
CaptureConnectionExtension { tx: Arc::new(tx) },
CaptureConnection { rx },
)
}
pub fn connection_metadata(&self) -> impl Deref<Target = Option<Connected>> + '_ {
self.rx.borrow()
}
pub async fn wait_for_connection_metadata(
&mut self,
) -> impl Deref<Target = Option<Connected>> + '_ {
if self.rx.borrow().is_some() {
return self.rx.borrow();
}
let _ = self.rx.changed().await;
self.rx.borrow()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_sync_capture_connection() {
let (tx, rx) = CaptureConnection::new();
assert!(
rx.connection_metadata().is_none(),
"connection has not been set"
);
tx.set(&Connected::new().proxy(true));
assert!(rx
.connection_metadata()
.as_ref()
.expect("connected should be set")
.is_proxied());
assert!(rx
.connection_metadata()
.as_ref()
.expect("connected should be set")
.is_proxied());
}
#[tokio::test]
async fn async_capture_connection() {
let (tx, mut rx) = CaptureConnection::new();
assert!(
rx.connection_metadata().is_none(),
"connection has not been set"
);
let test_task = tokio::spawn(async move {
assert!(rx
.wait_for_connection_metadata()
.await
.as_ref()
.expect("connection should be set")
.is_proxied());
assert!(
rx.wait_for_connection_metadata().await.is_some(),
"should be awaitable multiple times"
);
assert!(rx.connection_metadata().is_some());
});
assert!(!test_task.is_finished());
tx.set(&Connected::new().proxy(true));
assert!(test_task.await.is_ok());
}
#[tokio::test]
async fn capture_connection_sender_side_dropped() {
let (tx, mut rx) = CaptureConnection::new();
assert!(
rx.connection_metadata().is_none(),
"connection has not been set"
);
drop(tx);
assert!(rx.wait_for_connection_metadata().await.is_none());
}
}