use alloc::{format, string::String, vec::Vec};
use core::time::Duration;
use future_form::{FutureForm, Local, Sendable, future_form};
use futures::{
future::{Either, select},
pin_mut,
};
use subduction_core::{
handshake::{self, HandshakeMessage, audience::Audience},
peer::id::PeerId,
timestamp::TimestampSeconds,
};
use subduction_crypto::{nonce::Nonce, signer::Signer};
use crate::{
SESSION_ID_HEADER, error::ClientError, http_client::HttpClient, session::SessionId,
transport::HttpLongPollTransport,
};
pub struct ConnectResult<K: FutureForm> {
pub authenticated: subduction_core::authenticated::Authenticated<HttpLongPollTransport, K>,
pub session_id: SessionId,
pub poll_task: K::Future<'static, ()>,
pub send_task: K::Future<'static, ()>,
}
impl<K: FutureForm> core::fmt::Debug for ConnectResult<K> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ConnectResult")
.field("session_id", &self.session_id)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct HttpLongPollClient<H> {
base_url: String,
http: H,
}
impl<H> HttpLongPollClient<H> {
#[must_use]
pub fn new(base_url: &str, http: H) -> Self {
Self {
base_url: base_url.trim_end_matches('/').to_string(),
http,
}
}
}
pub trait Connect<K: FutureForm, Sig: Signer<K>> {
#[allow(clippy::type_complexity)]
fn connect_with_audience<'a>(
&'a self,
signer: &'a Sig,
audience: Audience,
now: TimestampSeconds,
) -> K::Future<'a, Result<ConnectResult<K>, ClientError>>;
}
#[future_form(Sendable where H: Send + Sync, Sig: Sync, H::Error: Send, Local)]
impl<K: FutureForm, Sig: Signer<K>, H: HttpClient<K> + 'static> Connect<K, Sig>
for HttpLongPollClient<H>
{
fn connect_with_audience<'a>(
&'a self,
signer: &'a Sig,
audience: Audience,
now: TimestampSeconds,
) -> K::Future<'a, Result<ConnectResult<K>, ClientError>> {
let http = self.http.clone();
let base_url = self.base_url.clone();
K::from_future(async move {
let nonce = Nonce::random();
let mut client_handshake = ClientHttpHandshake::<K, H> {
http: http.clone(),
base_url: base_url.clone(),
session_id: None,
response_bytes: None,
_k: core::marker::PhantomData,
};
#[allow(clippy::expect_used)]
let (authenticated, session_id) = handshake::initiate::<K, _, _, _, _>(
&mut client_handshake,
|handshake, peer_id| {
let session_id = handshake
.session_id
.expect("session_id set during handshake send");
let conn = HttpLongPollTransport::new(peer_id);
(conn, session_id)
},
signer,
audience,
now,
nonce,
)
.await
.map_err(|e| ClientError::Authentication(e.to_string()))?;
let conn = authenticated.inner().clone();
let (cancel_tx, cancel_rx) = async_channel::bounded::<()>(1);
let send_cancel_rx = cancel_rx.clone();
conn.set_cancel_guard(cancel_tx).await;
let poll_url = format!("{base_url}/lp/recv");
let poll_http = http.clone();
let poll_conn = conn.clone();
let poll_task = K::from_future(async move {
poll_loop::<K, H>(poll_http, poll_url, session_id, poll_conn, cancel_rx).await;
});
let send_url = format!("{base_url}/lp/send");
let send_http = http;
let send_conn = conn;
let send_task = K::from_future(async move {
send_loop::<K, H>(send_http, send_url, session_id, send_conn, send_cancel_rx).await;
});
Ok(ConnectResult {
authenticated,
session_id,
poll_task,
send_task,
})
})
}
}
impl<H> HttpLongPollClient<H> {
pub fn connect<'a, K: FutureForm, Sig: Signer<K>>(
&'a self,
signer: &'a Sig,
expected_peer_id: PeerId,
now: TimestampSeconds,
) -> K::Future<'a, Result<ConnectResult<K>, ClientError>>
where
Self: Connect<K, Sig>,
{
Connect::<K, Sig>::connect_with_audience(
self,
signer,
Audience::known(expected_peer_id),
now,
)
}
pub fn connect_discover<'a, K: FutureForm, Sig: Signer<K>>(
&'a self,
signer: &'a Sig,
service_name: &str,
now: TimestampSeconds,
) -> K::Future<'a, Result<ConnectResult<K>, ClientError>>
where
Self: Connect<K, Sig>,
{
Connect::<K, Sig>::connect_with_audience(
self,
signer,
Audience::discover(service_name.as_bytes()),
now,
)
}
}
async fn poll_loop<K: FutureForm, H: HttpClient<K>>(
http: H,
url: String,
session_id: SessionId,
conn: HttpLongPollTransport,
cancel: async_channel::Receiver<()>,
) {
loop {
let recv_fut = http.post(
&url,
&[(SESSION_ID_HEADER, &session_id.to_hex())],
Vec::new(),
);
let cancel_fut = cancel.recv();
pin_mut!(recv_fut, cancel_fut);
match select(recv_fut, cancel_fut).await {
Either::Right(_) => {
tracing::debug!("recv poll loop cancelled");
break;
}
Either::Left((result, _)) => match result {
Ok(resp) => match resp.status {
200 => {
if conn.push_inbound(resp.body).await.is_err() {
tracing::error!("inbound channel closed");
break;
}
}
204 => {
}
410 => {
tracing::info!("session closed by server (410 Gone)");
break;
}
status => {
tracing::error!("unexpected recv status: {status}");
futures_timer::Delay::new(Duration::from_secs(1)).await;
}
},
Err(e) => {
tracing::error!("recv request error: {e}");
futures_timer::Delay::new(Duration::from_secs(1)).await;
}
},
}
}
}
async fn send_loop<K: FutureForm, H: HttpClient<K>>(
http: H,
url: String,
session_id: SessionId,
conn: HttpLongPollTransport,
cancel: async_channel::Receiver<()>,
) {
loop {
let outbound_fut = conn.pull_outbound();
let cancel_fut = cancel.recv();
pin_mut!(outbound_fut, cancel_fut);
match select(outbound_fut, cancel_fut).await {
Either::Right(_) => {
tracing::debug!("send loop cancelled");
break;
}
Either::Left((result, _)) => {
if let Ok(bytes) = result {
match http
.post(
&url,
&[
(SESSION_ID_HEADER, &session_id.to_hex()),
("content-type", "application/octet-stream"),
],
bytes,
)
.await
{
Ok(resp) if resp.status < 300 => {}
Ok(resp) => {
tracing::error!("send returned status {}", resp.status);
}
Err(e) => {
tracing::error!("send request error: {e}");
}
}
} else {
tracing::debug!("outbound channel closed");
break;
}
}
}
}
}
struct ClientHttpHandshake<K: FutureForm, H: HttpClient<K>> {
http: H,
base_url: String,
session_id: Option<SessionId>,
response_bytes: Option<Vec<u8>>,
_k: core::marker::PhantomData<K>,
}
#[future_form(Sendable where H: Send, Local)]
impl<K: FutureForm, H: HttpClient<K>> handshake::Handshake<K> for &mut ClientHttpHandshake<K, H> {
type Error = ClientError;
fn send(&mut self, bytes: Vec<u8>) -> K::Future<'_, Result<(), Self::Error>> {
let url = format!("{}/lp/handshake", self.base_url);
let http = self.http.clone();
K::from_future(async move {
let resp = http
.post(&url, &[("content-type", "application/octet-stream")], bytes)
.await
.map_err(|e| ClientError::Request(e.to_string()))?;
if let Some(sid_str) = resp.header(SESSION_ID_HEADER) {
self.session_id = SessionId::from_hex(sid_str);
}
if resp.status == 401 {
return Err(ClientError::HandshakeRejected {
reason: match HandshakeMessage::try_decode(&resp.body) {
Ok(HandshakeMessage::Rejection(r)) => {
format!("{:?}", r.reason)
}
_ => String::from_utf8_lossy(&resp.body).into_owned(),
},
});
}
if resp.status != 200 {
return Err(ClientError::UnexpectedStatus {
status: resp.status,
body: String::from_utf8_lossy(&resp.body).into_owned(),
});
}
self.response_bytes = Some(resp.body);
Ok(())
})
}
fn recv(&mut self) -> K::Future<'_, Result<Vec<u8>, Self::Error>> {
let bytes = self.response_bytes.take();
K::from_future(async move {
bytes.ok_or_else(|| {
ClientError::HandshakeDecode(
"no response bytes available (send not called?)".into(),
)
})
})
}
}