use bytes::Bytes;
use http::{HeaderName, HeaderValue, Method, StatusCode};
use http_body_util::{BodyExt, StreamBody};
use hyper::body::Frame;
use hyper_util::rt::TokioIo;
use crate::{
io::IrohStream,
parse_node_addr,
stream::{BodyReader, BodyWriter, HandleStore},
CoreError, FfiDuplexStream, FfiResponse, IrohEndpoint, ALPN, ALPN_DUPLEX,
};
use crate::BoxBody;
#[cfg(feature = "compression")]
struct HyperClientSvc(hyper::client::conn::http1::SendRequest<BoxBody>);
#[cfg(feature = "compression")]
impl tower::Service<hyper::Request<BoxBody>> for HyperClientSvc {
type Response = hyper::Response<hyper::body::Incoming>;
type Error = hyper::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, req: hyper::Request<BoxBody>) -> Self::Future {
Box::pin(self.0.send_request(req))
}
}
#[allow(clippy::too_many_arguments)]
pub async fn fetch(
endpoint: &IrohEndpoint,
remote_node_id: &str,
url: &str,
method: &str,
headers: &[(String, String)],
req_body_reader: Option<BodyReader>,
req_trailer_sender_handle: Option<u64>,
fetch_token: Option<u64>,
direct_addrs: Option<&[std::net::SocketAddr]>,
) -> Result<FfiResponse, CoreError> {
{
let lower = url.to_ascii_lowercase();
if lower.starts_with("https://") || lower.starts_with("http://") {
let scheme_end = lower
.find("://")
.map(|i| i.saturating_add(3))
.unwrap_or(lower.len());
return Err(CoreError::invalid_input(format!(
"iroh-http URLs must use the \"httpi://\" scheme, not \"{}\". \
Example: httpi://nodeId/path",
&url[..scheme_end]
)));
}
}
let http_method = Method::from_bytes(method.as_bytes())
.map_err(|_| CoreError::invalid_input(format!("invalid HTTP method {:?}", method)))?;
for (name, value) in headers {
HeaderName::from_bytes(name.as_bytes())
.map_err(|_| CoreError::invalid_input(format!("invalid header name {:?}", name)))?;
HeaderValue::from_str(value).map_err(|_| {
CoreError::invalid_input(format!("invalid header value for {:?}", name))
})?;
}
let cancel_notify = fetch_token.and_then(|t| endpoint.handles().get_fetch_cancel_notify(t));
let handles = endpoint.handles();
let req_trailer_rx = req_trailer_sender_handle.and_then(|h| {
if h == 0 {
None
} else {
handles.claim_pending_trailer_rx(h)
}
});
let parsed = parse_node_addr(remote_node_id)?;
let node_id = parsed.node_id;
let mut addr = iroh::EndpointAddr::new(node_id);
for a in &parsed.direct_addrs {
addr = addr.with_ip_addr(*a);
}
if let Some(addrs) = direct_addrs {
for a in addrs {
addr = addr.with_ip_addr(*a);
}
}
let ep_raw = endpoint.raw().clone();
let addr_clone = addr.clone();
let max_header_size = endpoint.max_header_size();
let pooled = endpoint
.pool()
.get_or_connect(node_id, ALPN, || async move {
ep_raw
.connect(addr_clone, ALPN)
.await
.map_err(|e| format!("connect: {e}"))
})
.await
.map_err(CoreError::connection_failed)?;
let conn = pooled.conn.clone();
let remote_str = pooled.remote_id_str.clone();
let result = do_fetch(
handles,
conn,
&remote_str,
url,
http_method,
headers,
req_body_reader,
req_trailer_rx,
max_header_size,
);
let out = if let Some(notify) = cancel_notify {
tokio::select! {
_ = notify.notified() => Err(CoreError::cancelled()),
r = result => r,
}
} else {
result.await
};
if let Some(token) = fetch_token {
endpoint.handles().remove_fetch_token(token);
}
out
}
#[allow(clippy::too_many_arguments)]
async fn do_fetch(
handles: &HandleStore,
conn: iroh::endpoint::Connection,
remote_str: &str,
url: &str,
method: Method,
headers: &[(String, String)],
req_body_reader: Option<BodyReader>,
req_trailer_rx: Option<crate::stream::TrailerRx>,
max_header_size: usize,
) -> Result<FfiResponse, CoreError> {
let (send, recv) = conn
.open_bi()
.await
.map_err(|e| CoreError::connection_failed(format!("open_bi: {e}")))?;
let io = TokioIo::new(IrohStream::new(send, recv));
#[allow(unused_mut)] let (mut sender, conn_task) = hyper::client::conn::http1::Builder::new()
.max_buf_size(max_header_size.max(8192))
.max_headers(128)
.handshake::<_, BoxBody>(io)
.await
.map_err(|e| CoreError::connection_failed(format!("hyper handshake: {e}")))?;
tokio::spawn(conn_task);
let path = extract_path(url);
let mut req_builder = hyper::Request::builder()
.method(method)
.uri(&path)
.header(hyper::header::HOST, remote_str)
.header("te", "trailers");
#[cfg(feature = "compression")]
{
let has_accept_encoding = headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("accept-encoding"));
if !has_accept_encoding {
req_builder = req_builder.header("accept-encoding", "zstd");
}
}
for (k, v) in headers {
req_builder = req_builder.header(k.as_str(), v.as_str());
}
let req_body: BoxBody = if let Some(reader) = req_body_reader {
crate::box_body(body_from_reader(reader, req_trailer_rx))
} else {
crate::box_body(http_body_util::Empty::new())
};
let req = req_builder
.body(req_body)
.map_err(|e| CoreError::internal(format!("build request: {e}")))?;
#[cfg(feature = "compression")]
let resp = {
use tower::ServiceExt;
let svc = tower::ServiceBuilder::new()
.layer(tower_http::decompression::DecompressionLayer::new())
.service(HyperClientSvc(sender));
svc.oneshot(req)
.await
.map_err(|e| CoreError::connection_failed(format!("send_request: {e}")))?
};
#[cfg(not(feature = "compression"))]
let resp = sender
.send_request(req)
.await
.map_err(|e| CoreError::connection_failed(format!("send_request: {e}")))?;
let status = resp.status().as_u16();
let header_bytes: usize = resp
.headers()
.iter()
.map(|(k, v)| {
k.as_str()
.len()
.saturating_add(v.as_bytes().len())
.saturating_add(4) })
.fold(16usize, |acc, x| acc.saturating_add(x)); if header_bytes > max_header_size {
return Err(CoreError::header_too_large(format!(
"response header size {header_bytes} exceeds limit {max_header_size}"
)));
}
let mut resp_headers: Vec<(String, String)> = Vec::new();
for (k, v) in resp.headers().iter() {
match v.to_str() {
Ok(s) => resp_headers.push((k.as_str().to_string(), s.to_string())),
Err(_) => {
return Err(CoreError::invalid_input(format!(
"non-UTF8 response header value for '{}'",
k.as_str()
)));
}
}
}
let mut guard = handles.insert_guard();
let (trailer_tx, trailer_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
let trailer_handle = guard.insert_trailer_receiver(trailer_rx)?;
let (res_writer, res_reader) = handles.make_body_channel();
let body = resp.into_body();
tokio::spawn(pump_hyper_body_to_channel(body, res_writer, trailer_tx));
let body_handle = guard.insert_reader(res_reader)?;
let response_url = format!("httpi://{remote_str}{path}");
guard.commit();
Ok(FfiResponse {
status,
headers: resp_headers,
body_handle,
url: response_url,
trailers_handle: trailer_handle,
})
}
pub(crate) async fn pump_hyper_body_to_channel<B>(
body: B,
writer: BodyWriter,
trailer_tx: tokio::sync::oneshot::Sender<Vec<(String, String)>>,
) where
B: http_body::Body<Data = Bytes>,
B::Error: std::fmt::Debug,
{
let timeout = writer.drain_timeout;
pump_hyper_body_to_channel_limited(body, writer, trailer_tx, None, timeout, None).await;
}
pub(crate) async fn pump_hyper_body_to_channel_limited<B>(
body: B,
writer: BodyWriter,
trailer_tx: tokio::sync::oneshot::Sender<Vec<(String, String)>>,
max_bytes: Option<usize>,
frame_timeout: std::time::Duration,
overflow_tx: Option<tokio::sync::oneshot::Sender<()>>,
) where
B: http_body::Body<Data = Bytes>,
B::Error: std::fmt::Debug,
{
let mut body = Box::pin(body);
let mut total = 0usize;
let mut trailers_vec: Vec<(String, String)> = Vec::new();
loop {
let frame_result = match tokio::time::timeout(frame_timeout, body.frame()).await {
Err(_elapsed) => {
tracing::warn!("iroh-http: body frame read timed out after {frame_timeout:?}");
break;
}
Ok(None) => break,
Ok(Some(r)) => r,
};
match frame_result {
Err(e) => {
tracing::warn!("iroh-http: body frame error: {e:?}");
break;
}
Ok(frame) => {
if frame.is_data() {
let data = frame.into_data().expect("is_data checked above");
total = total.saturating_add(data.len());
if let Some(limit) = max_bytes {
if total > limit {
tracing::warn!("iroh-http: request body exceeded {limit} bytes");
if let Some(tx) = overflow_tx {
let _ = tx.send(());
}
break;
}
}
if writer.send_chunk(data).await.is_err() {
return; }
} else if frame.is_trailers() {
let hdrs = frame.into_trailers().expect("is_trailers checked above");
trailers_vec = hdrs
.iter()
.filter_map(|(k, v)| match v.to_str() {
Ok(s) => Some((k.as_str().to_string(), s.to_string())),
Err(_) => {
tracing::warn!(
"iroh-http: dropping non-UTF8 trailer value for '{}'",
k.as_str()
);
None
}
})
.collect();
}
}
}
}
drop(writer);
let _ = trailer_tx.send(trailers_vec);
}
pub(crate) fn body_from_reader(
reader: BodyReader,
trailer_rx: Option<tokio::sync::oneshot::Receiver<Vec<(String, String)>>>,
) -> StreamBody<impl futures::Stream<Item = Result<Frame<Bytes>, std::convert::Infallible>>> {
use futures::stream;
let s = stream::unfold(
(reader, trailer_rx, false),
|(reader, trailer_rx, done)| async move {
if done {
return None;
}
match reader.next_chunk().await {
Some(data) => Some((Ok(Frame::data(data)), (reader, trailer_rx, false))),
None => {
if let Some(rx) = trailer_rx {
let timeout = reader.drain_timeout;
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(trailers)) => {
let mut map = http::HeaderMap::new();
for (k, v) in trailers {
if let (Ok(name), Ok(val)) = (
HeaderName::from_bytes(k.as_bytes()),
HeaderValue::from_str(&v),
) {
map.append(name, val);
}
}
if !map.is_empty() {
return Some((Ok(Frame::trailers(map)), (reader, None, true)));
}
}
Ok(Err(_)) => {
}
Err(_) => {
tracing::warn!(
"iroh-http: trailer wait timed out after {timeout:?}; \
completing body without trailers"
);
}
}
}
None
}
}
},
);
StreamBody::new(s)
}
pub(crate) fn extract_path(url: &str) -> String {
let raw = if let Some(idx) = url.find("://") {
let after_scheme = url.get(idx.saturating_add(3)..).unwrap_or("");
if let Some(slash) = after_scheme.find('/') {
after_scheme[slash..].to_string()
} else if let Some(q) = after_scheme.find('?') {
format!("/{}", &after_scheme[q..])
} else {
"/".to_string()
}
} else if url.starts_with('/') {
url.to_string()
} else {
format!("/{url}")
};
match raw.find('#') {
Some(pos) => raw[..pos].to_string(),
None => raw,
}
}
pub async fn raw_connect(
endpoint: &IrohEndpoint,
remote_node_id: &str,
path: &str,
headers: &[(String, String)],
) -> Result<FfiDuplexStream, CoreError> {
for (name, value) in headers {
HeaderName::from_bytes(name.as_bytes())
.map_err(|_| CoreError::invalid_input(format!("invalid header name {:?}", name)))?;
HeaderValue::from_str(value).map_err(|_| {
CoreError::invalid_input(format!("invalid header value for {:?}", name))
})?;
}
let parsed = parse_node_addr(remote_node_id)?;
let node_id = parsed.node_id;
let mut addr = iroh::EndpointAddr::new(node_id);
for a in &parsed.direct_addrs {
addr = addr.with_ip_addr(*a);
}
let ep_raw = endpoint.raw().clone();
let addr_clone = addr.clone();
let max_header_size = endpoint.max_header_size();
let handles = endpoint.handles();
let pooled = endpoint
.pool()
.get_or_connect(node_id, ALPN_DUPLEX, || async move {
ep_raw
.connect(addr_clone, ALPN_DUPLEX)
.await
.map_err(|e| format!("connect duplex: {e}"))
})
.await
.map_err(CoreError::connection_failed)?;
let (send, recv) = pooled
.conn
.open_bi()
.await
.map_err(|e| CoreError::connection_failed(format!("open_bi: {e}")))?;
let io = TokioIo::new(IrohStream::new(send, recv));
let (mut sender, conn_task) = hyper::client::conn::http1::Builder::new()
.max_buf_size(max_header_size.max(8192))
.handshake::<_, BoxBody>(io)
.await
.map_err(|e| CoreError::connection_failed(format!("hyper handshake (duplex): {e}")))?;
tokio::spawn(conn_task);
let mut req_builder = hyper::Request::builder()
.method(Method::from_bytes(b"CONNECT").expect("CONNECT is a valid HTTP method"))
.uri(path)
.header(hyper::header::CONNECTION, "upgrade")
.header(hyper::header::UPGRADE, "iroh-duplex");
for (k, v) in headers {
req_builder = req_builder.header(k.as_str(), v.as_str());
}
let req = req_builder
.body(crate::box_body(http_body_util::Empty::new()))
.map_err(|e| CoreError::internal(format!("build duplex request: {e}")))?;
let resp = sender
.send_request(req)
.await
.map_err(|e| CoreError::connection_failed(format!("send duplex request: {e}")))?;
let status = resp.status();
if status != StatusCode::SWITCHING_PROTOCOLS {
return Err(CoreError::peer_rejected(format!(
"server rejected duplex: expected 101, got {status}"
)));
}
let upgraded = hyper::upgrade::on(resp)
.await
.map_err(|e| CoreError::connection_failed(format!("upgrade error: {e}")))?;
let (server_write, server_read) = handles.make_body_channel();
let (client_write, client_read) = handles.make_body_channel();
let read_handle = handles.insert_reader(server_read)?;
let write_handle = handles.insert_writer(client_write)?;
let io = TokioIo::new(upgraded);
tokio::spawn(crate::stream::pump_duplex(io, server_write, client_read));
Ok(FfiDuplexStream {
read_handle,
write_handle,
})
}
#[cfg(test)]
mod tests {
use super::extract_path;
#[test]
fn extract_path_basic() {
assert_eq!(extract_path("httpi://node/foo/bar"), "/foo/bar");
assert_eq!(extract_path("httpi://node/"), "/");
assert_eq!(extract_path("httpi://node"), "/");
}
#[test]
fn extract_path_query_string() {
assert_eq!(extract_path("httpi://node/path?x=1"), "/path?x=1");
assert_eq!(extract_path("httpi://node?x=1"), "/?x=1");
}
#[test]
fn extract_path_fragment() {
assert_eq!(extract_path("httpi://node/path#frag"), "/path");
assert_eq!(extract_path("httpi://node/path?q=1#frag"), "/path?q=1");
assert_eq!(extract_path("/local#frag"), "/local");
}
#[test]
fn extract_path_bare_path() {
assert_eq!(extract_path("/already"), "/already");
assert_eq!(extract_path("no-slash"), "/no-slash");
}
}