#![allow(clippy::disallowed_types)]
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use http::StatusCode;
use tower::Service;
use crate::ffi::handles::ResponseHeadEntry;
use crate::ffi::pumps::pump_hyper_body_to_channel;
use crate::http::server::{
serve_with_events, ConnectionEventFn, RemoteNodeId, ServeHandle, ServeOptions,
};
use crate::{Body, CoreError, IrohEndpoint, RequestPayload};
fn internal_error(detail: &'static [u8]) -> hyper::Response<Body> {
hyper::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::full(Bytes::from_static(detail)))
.expect("static error response args are valid")
}
fn service_unavailable(detail: &'static [u8]) -> hyper::Response<Body> {
hyper::Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(Body::full(Bytes::from_static(detail)))
.expect("static error response args are valid")
}
pub fn respond(
handles: &crate::ffi::handles::HandleStore,
req_handle: u64,
status: u16,
headers: Vec<(String, String)>,
) -> Result<(), CoreError> {
StatusCode::from_u16(status)
.map_err(|_| CoreError::invalid_input(format!("invalid HTTP status code: {status}")))?;
for (name, value) in &headers {
http::HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
CoreError::invalid_input(format!("invalid response header name {:?}", name))
})?;
http::HeaderValue::from_str(value).map_err(|_| {
CoreError::invalid_input(format!("invalid response header value for {:?}", name))
})?;
}
let sender = handles
.take_req_sender(req_handle)
.ok_or_else(|| CoreError::invalid_handle(req_handle))?;
sender
.send(ResponseHeadEntry { status, headers })
.map_err(|_| CoreError::internal("serve task dropped before respond"))
}
struct ReqHeadGuard {
endpoint: IrohEndpoint,
req_handle: u64,
}
impl Drop for ReqHeadGuard {
fn drop(&mut self) {
self.endpoint.handles().take_req_sender(self.req_handle);
}
}
struct FfiDispatcher {
on_request: Arc<dyn Fn(RequestPayload) + Send + Sync>,
endpoint: IrohEndpoint,
own_node_id: Arc<String>,
max_header_size: Option<usize>,
}
#[derive(Clone)]
pub(crate) struct IrohHttpService {
dispatcher: Arc<FfiDispatcher>,
}
impl Service<hyper::Request<Body>> for IrohHttpService {
type Response = hyper::Response<Body>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
let dispatcher = self.dispatcher.clone();
let remote_node_id = req
.extensions()
.get::<RemoteNodeId>()
.map(|r| r.0.clone())
.unwrap_or_else(|| Arc::new(String::new()));
Box::pin(async move { Ok(dispatcher.dispatch(req, remote_node_id).await) })
}
}
impl FfiDispatcher {
async fn dispatch(
self: Arc<Self>,
req: hyper::Request<Body>,
remote_node_id: Arc<String>,
) -> hyper::Response<Body> {
let handles = self.endpoint.handles();
let own_node_id = &*self.own_node_id;
let max_header_size = self.max_header_size;
let method = req.method().to_string();
let path_and_query = req
.uri()
.path_and_query()
.map(|p| p.as_str())
.unwrap_or("/")
.to_string();
tracing::debug!(
method = %method,
path = %path_and_query,
peer = %remote_node_id,
"iroh-http: incoming request",
);
if let Some(limit) = max_header_size {
let header_bytes: usize = req
.headers()
.iter()
.filter(|(k, _)| !k.as_str().eq_ignore_ascii_case("peer-id"))
.map(|(k, v)| {
k.as_str()
.len()
.saturating_add(v.as_bytes().len())
.saturating_add(4)
}) .fold(0usize, |acc, x| acc.saturating_add(x))
.saturating_add("peer-id".len())
.saturating_add(remote_node_id.len())
.saturating_add(4)
.saturating_add(req.uri().to_string().len())
.saturating_add(method.len())
.saturating_add(12); if header_bytes > limit {
let resp = hyper::Response::builder()
.status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
.body(Body::empty())
.expect("static response args are valid");
return resp;
}
}
let mut req_headers: Vec<(String, String)> = Vec::new();
for (k, v) in req.headers().iter() {
if k.as_str().eq_ignore_ascii_case("peer-id") {
continue;
}
match v.to_str() {
Ok(s) => req_headers.push((k.as_str().to_string(), s.to_string())),
Err(_) => {
let resp = hyper::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::full(Bytes::from_static(b"non-UTF8 header value")))
.expect("static response args are valid");
return resp;
}
}
}
req_headers.push(("peer-id".to_string(), (*remote_node_id).clone()));
let url = format!("httpi://{own_node_id}{path_and_query}");
let mut guard = handles.insert_guard();
let (req_body_writer, req_body_reader) = handles.make_body_channel();
let req_body_handle = match guard.insert_reader(req_body_reader) {
Ok(h) => h,
Err(_) => return service_unavailable(b"server handle table full"),
};
let (res_body_writer, res_body_reader) = handles.make_body_channel();
let res_body_handle = match guard.insert_writer(res_body_writer) {
Ok(h) => h,
Err(_) => return service_unavailable(b"server handle table full"),
};
let (head_tx, head_rx) = tokio::sync::oneshot::channel::<ResponseHeadEntry>();
let req_handle = match guard.allocate_req_handle(head_tx) {
Ok(h) => h,
Err(_) => return service_unavailable(b"server handle table full"),
};
guard.commit();
let _req_head_guard = ReqHeadGuard {
endpoint: self.endpoint.clone(),
req_handle,
};
let body = req.into_body();
tokio::spawn(pump_hyper_body_to_channel(body, req_body_writer));
(self.on_request)(RequestPayload {
req_handle,
req_body_handle,
res_body_handle,
method,
url,
headers: req_headers,
remote_node_id: Arc::unwrap_or_clone(remote_node_id),
is_bidi: false,
});
let response_head = match head_rx.await {
Ok(h) => h,
Err(_) => return internal_error(b"JS handler dropped without responding"),
};
let mut resp_builder = hyper::Response::builder().status(response_head.status);
for (k, v) in &response_head.headers {
resp_builder = resp_builder.header(k.as_str(), v.as_str());
}
match resp_builder.body(Body::new(res_body_reader)) {
Ok(r) => r,
Err(_) => internal_error(b"failed to build response head from JS"),
}
}
}
pub fn ffi_serve_with_callback<F>(
endpoint: IrohEndpoint,
options: ServeOptions,
on_request: F,
on_connection_event: Option<ConnectionEventFn>,
) -> ServeHandle
where
F: Fn(RequestPayload) + Send + Sync + 'static,
{
let max_header_size = endpoint.max_header_size();
let own_node_id = Arc::new(endpoint.node_id().to_string());
let on_request = Arc::new(on_request) as Arc<dyn Fn(RequestPayload) + Send + Sync>;
let dispatcher = Arc::new(FfiDispatcher {
on_request,
endpoint: endpoint.clone(),
own_node_id,
max_header_size: if max_header_size == 0 {
None
} else {
Some(max_header_size)
},
});
let svc = IrohHttpService { dispatcher };
serve_with_events(endpoint, options, svc, on_connection_event)
}
pub fn ffi_serve<F>(endpoint: IrohEndpoint, options: ServeOptions, on_request: F) -> ServeHandle
where
F: Fn(RequestPayload) + Send + Sync + 'static,
{
ffi_serve_with_callback(endpoint, options, on_request, None)
}