use std::rc::Rc;
use std::{cell::RefCell, convert::TryFrom};
use crate::event::{ExchangeComplete, FiniteEvent};
use crate::grpc::transport::{GrpcCallId, GrpcCallResponse, GrpcResponseParts};
use crate::host::grpc::GrpcHost;
use crate::proxy_wasm::{
traits::{Context, HttpContext},
types::Action,
};
use crate::{
client::HttpCallResponse,
event::{
Event, EventData, Exchange, RequestBody, RequestHeaders, RequestTrailers, ResponseBody,
ResponseHeaders, ResponseTrailers,
},
host::Host,
middleware::{EventHandlerDispatch, EventHandlerStack},
reactor::{http::HttpReactor, root::RootReactor},
types::{Cid, HttpCid},
BoxError,
};
use futures::executor::LocalPool;
#[cfg(feature = "experimental_websocket")]
use crate::reactor::websocket::WebSocketReactor;
#[cfg(feature = "experimental_websocket")]
use pdk_websockets_lib::UpgradeTracker;
#[cfg(feature = "experimental_websocket")]
use futures::task::LocalSpawnExt;
#[cfg(feature = "experimental_websocket")]
use crate::extract::context::WebSocketHandlerFn;
pub(crate) struct AsyncHttpContext {
context_id: HttpCid,
executor: Rc<RefCell<LocalPool>>,
host: Rc<dyn Host>,
grpc_host: Rc<dyn GrpcHost>,
config_reactor: Rc<RootReactor>,
reactor: Rc<HttpReactor>,
#[cfg(feature = "experimental_websocket")]
websocket_reactor: Rc<WebSocketReactor>,
#[cfg(feature = "experimental_websocket")]
upgrade_tracker: RefCell<UpgradeTracker>,
event_handlers: Rc<RefCell<EventHandlerStack>>,
#[cfg(feature = "experimental_websocket")]
websocket_upstream_handler: Rc<RefCell<Option<WebSocketHandlerFn>>>,
#[cfg(feature = "experimental_websocket")]
websocket_downstream_handler: Rc<RefCell<Option<WebSocketHandlerFn>>>,
#[cfg(feature = "experimental_websocket")]
websocket_upstream_spawned: RefCell<bool>,
#[cfg(feature = "experimental_websocket")]
websocket_downstream_spawned: RefCell<bool>,
}
impl AsyncHttpContext {
#[allow(clippy::too_many_arguments)]
pub fn new(
context_id: HttpCid,
executor: Rc<RefCell<LocalPool>>,
host: Rc<dyn Host>,
grpc_host: Rc<dyn GrpcHost>,
config_reactor: Rc<RootReactor>,
reactor: Rc<HttpReactor>,
event_handlers: Rc<RefCell<EventHandlerStack>>,
#[cfg(feature = "experimental_websocket")] websocket_upstream_handler: Rc<
RefCell<Option<WebSocketHandlerFn>>,
>,
#[cfg(feature = "experimental_websocket")] websocket_downstream_handler: Rc<
RefCell<Option<WebSocketHandlerFn>>,
>,
) -> Self {
#[cfg(feature = "experimental_websocket")]
let websocket_reactor = Rc::new(WebSocketReactor::new(context_id));
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::new");
Self {
context_id,
executor,
host,
grpc_host,
config_reactor,
reactor,
#[cfg(feature = "experimental_websocket")]
websocket_reactor,
#[cfg(feature = "experimental_websocket")]
upgrade_tracker: RefCell::new(UpgradeTracker::default()),
event_handlers,
#[cfg(feature = "experimental_websocket")]
websocket_upstream_handler,
#[cfg(feature = "experimental_websocket")]
websocket_downstream_handler,
#[cfg(feature = "experimental_websocket")]
websocket_upstream_spawned: RefCell::new(false),
#[cfg(feature = "experimental_websocket")]
websocket_downstream_spawned: RefCell::new(false),
}
}
#[cfg(feature = "experimental_websocket")]
fn spawn_websocket_upstream_handlers(&mut self) {
if *self.websocket_upstream_spawned.borrow() {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::spawn_websocket_upstream_handlers Handlers already spawned, skipping");
return;
}
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::spawn_websocket_upstream_handlers WebSocket upgrade detected, spawning handlers NOW");
if let Some(ref handler) = *self.websocket_upstream_handler.borrow() {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::spawn_websocket_upstream_handlers Spawning upstream (client→server) handler");
let future = handler(Rc::clone(&self.websocket_reactor));
if let Err(e) = self.executor.borrow().spawner().spawn_local(async move {
#[cfg(feature = "debug-logs")]
log::debug!(
"AsyncHttpContext::spawn_websocket_upstream_handlers Handler task started"
);
if let Err(e) = future.await {
log::error!(
"AsyncHttpContext::spawn_websocket_upstream_handlers Handler error: {e:?}"
);
}
#[cfg(feature = "debug-logs")]
log::debug!(
"AsyncHttpContext::spawn_websocket_upstream_handlers Handler task finished"
);
}) {
log::error!("AsyncHttpContext::spawn_websocket_upstream_handlers Failed to spawn WebSocket upstream handler: {e}");
}
}
*self.websocket_upstream_spawned.borrow_mut() = true;
}
#[cfg(feature = "experimental_websocket")]
fn spawn_websocket_downstream_handlers(&mut self) {
if *self.websocket_downstream_spawned.borrow() {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::spawn_websocket_downstream_handlers Downstream handler already spawned, skipping");
return;
}
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::spawn_websocket_downstream_handlers WebSocket upgrade detected, spawning downstream handler NOW");
if let Some(ref handler) = *self.websocket_downstream_handler.borrow() {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::spawn_websocket_downstream_handlers Spawning downstream (server→client) handler");
let future = handler(Rc::clone(&self.websocket_reactor));
if let Err(e) = self.executor.borrow().spawner().spawn_local(async move {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::spawn_websocket_downstream_handlers Handler task started");
if let Err(e) = future.await {
log::error!("AsyncHttpContext::spawn_websocket_downstream_handlers Handler error: {e:?}");
}
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::spawn_websocket_downstream_handlers Handler task finished");
}) {
log::error!("AsyncHttpContext::spawn_websocket_downstream_handlers Failed to spawn WebSocket downstream handler: {e}");
}
}
*self.websocket_downstream_spawned.borrow_mut() = true;
}
#[cfg(feature = "experimental_websocket")]
fn handle_websocket_request_body(&mut self, _body_size: usize, end_of_stream: bool) -> Action {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::handle_websocket_request_body on_http_request_body called: body_size={_body_size}, end_of_stream={end_of_stream}");
self.config_reactor
.set_active_cid(Cid::Http(self.context_id));
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::handle_websocket_request_body Calling wake_upstream() to wake upstream handler");
self.websocket_reactor.wake_upstream();
#[cfg(feature = "debug-logs")]
log::debug!(
"AsyncHttpContext::handle_websocket_request_body Running executor.run_until_stalled()"
);
self.executor.borrow_mut().run_until_stalled();
if self.websocket_reactor.upstream_paused() {
Action::Pause
} else {
Action::Continue
}
}
#[cfg(feature = "experimental_websocket")]
fn handle_websocket_response_body(&mut self, _body_size: usize, end_of_stream: bool) -> Action {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::handle_websocket_response_body on_http_response_body called: body_size={_body_size}, end_of_stream={end_of_stream}");
self.config_reactor
.set_active_cid(Cid::Http(self.context_id));
self.websocket_reactor.wake_downstream();
self.executor.borrow_mut().run_until_stalled();
if self.websocket_reactor.downstream_paused() {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::handle_websocket_response_body PAUSE");
Action::Pause
} else {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::handle_websocket_response_body CONTINUE");
Action::Continue
}
}
fn dispatch<S>(&self, event: S) -> Result<(), BoxError>
where
S: Event,
EventHandlerStack: EventHandlerDispatch<S>,
{
let exchange: Exchange<S> =
Exchange::new(self.reactor.clone(), self.host.clone(), Some(event.clone()));
let event = EventData::new(&exchange, event, 0);
self.event_handlers.borrow_mut().dispatch(&event)
}
fn notify_finite_event(&mut self, finite_event: FiniteEvent) -> Action {
let end_of_stream = finite_event.end_of_stream();
self.config_reactor
.set_active_cid(Cid::Http(self.context_id));
self.reactor.notify(finite_event.clone());
let kind = finite_event.kind();
let event_handler_result = match RequestHeaders::try_from(finite_event) {
Ok(event) => self.dispatch(event),
Err(finite_event) => match ResponseHeaders::try_from(finite_event) {
Ok(event) => self.dispatch(event),
Err(event) => {
log::trace!("No handler dispatched for event '{event:?}'.");
Ok(())
}
},
};
if let Err(err) = event_handler_result {
log::error!("Failed event handler for {kind:?}: {err:?}");
}
self.executor.borrow_mut().run_until_stalled();
let action = if self.reactor.paused() || self.reactor.headers_paused() {
self.reactor.set_eos_paused(end_of_stream);
Action::Pause
} else {
Action::Continue
};
self.reactor
.set_http_context_paused(action == Action::Pause);
action
}
fn notify<E: Event>(&mut self, event: E) -> Action {
self.notify_finite_event(event.into())
}
}
impl Context for AsyncHttpContext {
fn on_http_call_response(
&mut self,
token_id: u32,
num_headers: usize,
body_size: usize,
num_trailers: usize,
) {
#[cfg(feature = "debug-logs")]
log::debug!(
"on_http_call_response: {token_id}, {num_headers}, {body_size}, {num_trailers}"
);
self.config_reactor.notify_response(HttpCallResponse {
request_id: token_id.into(),
num_headers,
body_size,
num_trailers,
});
self.config_reactor.set_active_cid(self.context_id.into());
self.executor.borrow_mut().run_until_stalled();
}
fn on_grpc_call_response(&mut self, token_id: u32, status_code: u32, response_size: usize) {
#[cfg(feature = "debug-logs")]
log::debug!("on_grpc_call_response: {token_id}, {status_code}, {response_size}");
let event = GrpcCallResponse {
call_id: GrpcCallId::new(token_id),
status_code,
response_size,
};
let content = self.grpc_host.get_grpc_call_response_body(0, response_size);
let status = self.grpc_host.get_grpc_status();
self.config_reactor
.notify_grpc_call_response(GrpcResponseParts {
event,
content,
status,
});
self.config_reactor.set_active_cid(self.context_id.into());
self.executor.borrow_mut().run_until_stalled();
}
fn on_done(&mut self) -> bool {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::on_done");
self.notify(ExchangeComplete {});
self.executor.borrow_mut().run_until_stalled();
self.config_reactor.set_http_context_done(self.context_id);
true
}
}
impl HttpContext for AsyncHttpContext {
fn on_http_request_headers(&mut self, num_headers: usize, end_of_stream: bool) -> Action {
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::on_http_request_headers: num_headers={num_headers}, end_of_stream={end_of_stream}");
#[cfg(feature = "debug-logs")]
{
let upgrade_header = self.host.get_http_request_header("upgrade");
let connection_header = self.host.get_http_request_header("connection");
if upgrade_header.is_some() || connection_header.is_some() {
log::debug!("AsyncHttpContext::on_http_request_headers: WebSocket upgrade detected: upgrade={upgrade_header:?}, connection={connection_header:?}");
}
}
let result = self.notify(RequestHeaders {
_num_headers: num_headers,
end_of_stream,
});
#[cfg(feature = "debug-logs")]
log::debug!("AsyncHttpContext::on_http_request_headers: returning {result:?}");
result
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
#[cfg(feature = "debug-logs")]
log::debug!("on_http_request_body: {body_size}, {end_of_stream}");
#[cfg(feature = "experimental_websocket")]
if self.upgrade_tracker.borrow().ready() {
self.spawn_websocket_upstream_handlers(); return self.handle_websocket_request_body(body_size, end_of_stream);
}
let result = self.notify(RequestBody {
body_size,
end_of_stream,
});
#[cfg(feature = "debug-logs")]
log::debug!("on_http_request_body -> {result:?}");
result
}
fn on_http_request_trailers(&mut self, num_trailers: usize) -> Action {
#[cfg(feature = "debug-logs")]
log::debug!("on_http_request_trailers: {num_trailers}");
let result = self.notify(RequestTrailers {
_num_trailers: num_trailers,
});
#[cfg(feature = "debug-logs")]
log::debug!("on_http_request_trailers -> {result:?}");
result
}
fn on_http_response_headers(&mut self, num_headers: usize, end_of_stream: bool) -> Action {
#[cfg(feature = "debug-logs")]
log::debug!("on_http_response_headers: {num_headers}, {end_of_stream}");
#[cfg(feature = "experimental_websocket")]
{
let status = self.host.get_http_response_header(":status");
let upgrade = self.host.get_http_response_header("upgrade");
#[cfg(feature = "debug-logs")]
log::debug!(
"AsyncHttpContext::on_http_response_headers: WebSocket upgrade check: status={status:?}, upgrade={upgrade:?}"
);
self.upgrade_tracker
.borrow_mut()
.track_upgrade_headers(status.as_deref(), upgrade.as_deref());
let is_ready = self.upgrade_tracker.borrow().ready();
#[cfg(feature = "debug-logs")]
log::debug!(
"AsyncHttpContext::on_http_response_headers: After upgrade check: ready={is_ready}"
);
if is_ready {
self.spawn_websocket_downstream_handlers();
self.executor.borrow_mut().run_until_stalled();
}
}
let result = self.notify(ResponseHeaders {
_num_headers: num_headers,
end_of_stream,
});
#[cfg(feature = "debug-logs")]
log::debug!("on_http_response_headers -> {result:?}");
result
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
#[cfg(feature = "debug-logs")]
log::debug!("on_http_response_body: {body_size}, {end_of_stream}");
#[cfg(feature = "experimental_websocket")]
if self.upgrade_tracker.borrow().ready() {
self.spawn_websocket_downstream_handlers(); return self.handle_websocket_response_body(body_size, end_of_stream);
}
let result = self.notify(ResponseBody {
body_size,
end_of_stream,
});
#[cfg(feature = "debug-logs")]
log::debug!("on_http_response_body -> {result:?}");
result
}
fn on_http_response_trailers(&mut self, num_trailers: usize) -> Action {
#[cfg(feature = "debug-logs")]
log::debug!("on_http_response_trailers: {num_trailers}");
let result = self.notify(ResponseTrailers {
_num_trailers: num_trailers,
});
#[cfg(feature = "debug-logs")]
log::debug!("on_http_response_trailers -> {result:?}");
result
}
}
impl Drop for AsyncHttpContext {
fn drop(&mut self) {
self.config_reactor.set_http_context_done(self.context_id);
}
}
#[cfg(all(test, feature = "experimental_websocket"))]
mod tests {
use super::*;
use crate::host::DefaultHost;
use crate::middleware::EventHandlerStack;
use crate::reactor::http::HttpReactor;
use crate::reactor::root::RootReactor;
use std::cell::RefCell;
use std::rc::Rc;
fn create_test_context() -> AsyncHttpContext {
let context_id = HttpCid::from(1);
let executor = Rc::new(RefCell::new(futures::executor::LocalPool::new()));
let host: Rc<dyn Host> = Rc::new(DefaultHost);
let grpc_host: Rc<dyn crate::host::grpc::GrpcHost> =
Rc::new(crate::host::grpc::DefaultGrpcHost);
let root_cid = crate::types::RootCid::from(0);
let config_reactor = Rc::new(RootReactor::new(root_cid));
let reactor = Rc::new(HttpReactor::new(context_id));
let event_handlers = Rc::new(RefCell::new(EventHandlerStack::default()));
let websocket_upstream_handler = Rc::new(RefCell::new(None));
let websocket_downstream_handler = Rc::new(RefCell::new(None));
AsyncHttpContext::new(
context_id,
executor,
host,
grpc_host,
config_reactor,
reactor,
event_handlers,
websocket_upstream_handler,
websocket_downstream_handler,
)
}
#[test]
fn websocket_request_body_returns_pause_when_upstream_paused() {
let mut ctx = create_test_context();
ctx.websocket_reactor.set_upstream_paused(true);
let action = ctx.handle_websocket_request_body(100, false);
assert_eq!(
action,
Action::Pause,
"Should return Pause when upstream is paused"
);
}
#[test]
fn websocket_request_body_returns_continue_when_upstream_not_paused() {
let mut ctx = create_test_context();
ctx.websocket_reactor.set_upstream_paused(false);
let action = ctx.handle_websocket_request_body(100, false);
assert_eq!(
action,
Action::Continue,
"Should return Continue when upstream is not paused"
);
}
#[test]
fn websocket_response_body_returns_pause_when_downstream_paused() {
let mut ctx = create_test_context();
ctx.websocket_reactor.set_downstream_paused(true);
let action = ctx.handle_websocket_response_body(100, false);
assert_eq!(
action,
Action::Pause,
"Should return Pause when downstream is paused"
);
}
#[test]
fn websocket_response_body_returns_continue_when_downstream_not_paused() {
let mut ctx = create_test_context();
ctx.websocket_reactor.set_downstream_paused(false);
let action = ctx.handle_websocket_response_body(100, false);
assert_eq!(
action,
Action::Continue,
"Should return Continue when downstream is not paused"
);
}
#[test]
fn spawn_websocket_handlers_only_spawns_once() {
let mut ctx = create_test_context();
ctx.spawn_websocket_upstream_handlers();
assert!(
*ctx.websocket_upstream_spawned.borrow(),
"Should be marked as spawned after first call"
);
ctx.spawn_websocket_upstream_handlers();
assert!(
*ctx.websocket_upstream_spawned.borrow(),
"Should still be marked as spawned, not re-spawned"
);
}
#[test]
fn spawn_websocket_handlers_starts_false() {
let ctx = create_test_context();
assert!(
!*ctx.websocket_upstream_spawned.borrow(),
"Should start as not spawned"
);
}
#[test]
fn spawn_downstream_handlers_only_spawns_once() {
let mut ctx = create_test_context();
ctx.spawn_websocket_downstream_handlers();
assert!(
*ctx.websocket_downstream_spawned.borrow(),
"Should be marked as spawned after first call"
);
ctx.spawn_websocket_downstream_handlers();
assert!(
*ctx.websocket_downstream_spawned.borrow(),
"Should still be marked as spawned, not re-spawned"
);
}
#[test]
fn spawn_downstream_handlers_starts_false() {
let ctx = create_test_context();
assert!(
!*ctx.websocket_downstream_spawned.borrow(),
"Should start as not spawned for downstream"
);
}
}