use crate::host::implementation::{FlowType, ProxyWasmStub};
use crate::tester::io::{RequestResponse, UnitFrame, UnitHttpResponse};
use crate::tester::unit_test::{add_request_properties, respond_call, Backends};
use pdk_websockets_lib::{Decoder, Encoder, Frame, SinkResult};
use proxy_wasm_stub::stub::Host;
use proxy_wasm_stub::traits::HttpContext;
use proxy_wasm_stub::types::{Action, BufferType, MapType};
use std::cell::RefCell;
use std::collections::VecDeque;
use std::rc::{Rc, Weak};
use std::task::Poll;
#[derive(Clone)]
pub struct UnitTestUpgrade {
pub(crate) inner: Rc<RefCell<InnerUnitUpgrade>>,
}
impl UnitTestUpgrade {
pub(crate) fn new(inner: InnerUnitUpgrade) -> Self {
Self {
inner: Rc::new(RefCell::new(inner)),
}
}
pub fn poll(&mut self) -> Poll<Result<UpgradeConnection, UnitHttpResponse>> {
self.inner.borrow_mut().poll()
}
}
#[derive(PartialOrd, PartialEq, Copy, Clone, Debug)]
enum UpgradeState {
RequestHeaders,
RequestHeadersPaused,
ResponseHeaders,
ResponseHeadersPaused,
Done,
Rejected,
}
pub(crate) struct InnerUnitUpgrade {
state: UpgradeState,
context_id: u32,
chunk_size: usize,
request: RequestResponse,
http_context: Box<dyn HttpContext>,
backends: Rc<RefCell<Backends>>,
host: Rc<RefCell<ProxyWasmStub>>,
cached_connection: Option<Rc<RefCell<ConnectionInner>>>,
cached_response: Option<UnitHttpResponse>,
}
impl InnerUnitUpgrade {
pub(crate) fn new(
context_id: u32,
request: RequestResponse,
http_context: Box<dyn HttpContext>,
backends: Rc<RefCell<Backends>>,
host: Rc<RefCell<ProxyWasmStub>>,
chunk_size: usize,
) -> Self {
Self {
state: UpgradeState::RequestHeaders,
context_id,
chunk_size,
request,
http_context,
backends,
host,
cached_connection: None,
cached_response: None,
}
}
pub(crate) fn poll(&mut self) -> Poll<Result<UpgradeConnection, UnitHttpResponse>> {
if let Poll::Ready(result) = self.build_result() {
return Poll::Ready(result);
}
let context_id = self.context_id;
self.host.borrow_mut().set_context(context_id);
self.host.borrow_mut().set_flow_mode(FlowType::Upstream);
self.resume();
if self.state == UpgradeState::RequestHeaders {
self.host.borrow_mut().create_map(
context_id,
MapType::HttpRequestHeaders,
self.request
.headers()
.iter()
.map(|(k, v)| (k.clone(), v.as_bytes().to_vec()))
.collect(),
);
let action = self
.http_context
.on_http_request_headers(self.request.headers().len(), false);
if action == Action::Pause && self.clear_send_response() {
self.state = UpgradeState::Rejected;
} else if action == Action::Pause {
self.state = UpgradeState::RequestHeadersPaused;
} else {
self.state = UpgradeState::ResponseHeaders;
}
self.respond_calls();
}
self.host.borrow_mut().set_flow_mode(FlowType::Downstream);
if self.state == UpgradeState::ResponseHeaders {
let backend_response = self
.backends
.borrow()
.backend
.call(self.request.clone().into())
.inner;
let backend_response = add_request_properties(backend_response, context_id);
self.host.borrow_mut().create_map(
context_id,
MapType::HttpResponseHeaders,
backend_response
.headers()
.iter()
.map(|(k, v)| (k.clone(), v.as_bytes().to_vec()))
.collect(),
);
let num_headers = backend_response.headers().len();
self.cached_response = Some(UnitHttpResponse::from(backend_response));
let action = self
.http_context
.on_http_response_headers(num_headers, false);
if action == Action::Pause && self.clear_send_response() {
self.state = UpgradeState::Rejected;
} else if action == Action::Pause {
self.state = UpgradeState::ResponseHeadersPaused;
} else {
self.state = UpgradeState::Done;
}
self.respond_calls();
}
self.build_result()
}
fn build_result(&mut self) -> Poll<Result<UpgradeConnection, UnitHttpResponse>> {
match self.state {
UpgradeState::Done => Poll::Ready(Ok(self.build_connection())),
UpgradeState::Rejected => Poll::Ready(Err(self.read_response().into())),
_ => Poll::Pending,
}
}
fn build_connection(&mut self) -> UpgradeConnection {
if self.cached_connection.is_none() {
let context_id = self.context_id;
self.cached_connection = Some(Rc::new(RefCell::new(ConnectionInner {
context_id,
chunk_size: self.chunk_size,
http_context: self.take_http_context(),
backends: Rc::clone(&self.backends),
host: Rc::clone(&self.host),
connection_state: Default::default(),
})));
}
let response = self
.cached_response
.clone()
.unwrap_or_else(UnitHttpResponse::upgrade);
UpgradeConnection::from_inner(
Rc::clone(self.cached_connection.as_ref().unwrap()),
response,
)
}
fn resume(&mut self) -> bool {
let resume = self.host.borrow_mut().clear_resume(self.context_id);
let send_response = self.clear_send_response();
assert!(!(resume && send_response));
if resume {
self.state = match self.state {
UpgradeState::RequestHeadersPaused => UpgradeState::ResponseHeaders,
UpgradeState::ResponseHeadersPaused => UpgradeState::Done,
state => panic!("Called resume on non-paused upgrade state {state:?}"),
};
true
} else if send_response {
self.state = UpgradeState::Rejected;
true
} else {
false
}
}
fn clear_send_response(&mut self) -> bool {
self.host.borrow_mut().clear_send_response(self.context_id)
}
fn respond_calls(&mut self) {
let prev_flow = self.host.borrow_mut().set_flow_mode(FlowType::Async);
let mut pending = self.host.borrow_mut().pending_calls(self.context_id);
while !pending.is_empty() {
for (id, upstream, call) in pending {
respond_call(
self.http_context.as_mut(),
&self.host,
&self.backends,
self.context_id,
id,
upstream,
call,
);
if self.resume() {
self.host.borrow_mut().set_flow_mode(prev_flow);
return;
}
}
pending = self.host.borrow_mut().pending_calls(self.context_id);
}
self.host.borrow_mut().set_flow_mode(prev_flow);
}
fn take_http_context(&mut self) -> Box<dyn HttpContext> {
struct NoopHttp;
impl proxy_wasm_stub::traits::Context for NoopHttp {}
impl HttpContext for NoopHttp {}
let mut placeholder: Box<dyn HttpContext> = Box::new(NoopHttp);
std::mem::swap(&mut self.http_context, &mut placeholder);
placeholder
}
fn read_response(&self) -> RequestResponse {
let headers = self
.host
.borrow()
.read_map(self.context_id, MapType::HttpResponseHeaders)
.into_iter()
.map(|(k, v)| (k, String::from_utf8(v).unwrap()))
.collect();
let body = self
.host
.borrow()
.read_buffer(self.context_id, BufferType::HttpResponseBody);
RequestResponse::create(headers, body, Default::default())
}
}
pub struct UpgradeConnection {
inner: Rc<RefCell<ConnectionInner>>,
response: UnitHttpResponse,
}
impl Drop for UpgradeConnection {
fn drop(&mut self) {
self.inner.borrow_mut().http_context.on_log();
self.inner.borrow_mut().http_context.on_done();
}
}
impl UpgradeConnection {
fn from_inner(inner: Rc<RefCell<ConnectionInner>>, response: UnitHttpResponse) -> Self {
Self { inner, response }
}
pub fn response(&self) -> &UnitHttpResponse {
&self.response
}
pub fn client(&self) -> ClientHandle {
ClientHandle {
inner: Rc::clone(&self.inner),
}
}
pub fn server(&self) -> ServerHandle {
ServerHandle {
inner: Rc::clone(&self.inner),
}
}
pub(crate) fn weak_inner(&self) -> Weak<RefCell<ConnectionInner>> {
Rc::downgrade(&self.inner)
}
}
pub struct ClientHandle {
inner: Rc<RefCell<ConnectionInner>>,
}
impl ClientHandle {
pub fn send_to_server(&self, frame: UnitFrame) {
self.inner.borrow_mut().send_upstream(vec![frame.frame])
}
#[cfg(feature = "experimental_websocket_bytes")]
pub fn send_bytes_to_server(&self, bytes: Vec<u8>) {
self.inner.borrow_mut().send_upstream_bytes(bytes)
}
pub fn next(&self) -> Option<UnitFrame> {
self.inner
.borrow_mut()
.connection_state
.client_ready_frames
.pop_front()
.map(|frame| UnitFrame { frame })
}
#[cfg(feature = "experimental_websocket_bytes")]
pub fn bytes(&self) -> Vec<u8> {
self.inner
.borrow_mut()
.connection_state
.client_ready_bytes
.split_off(0)
}
}
pub struct ServerHandle {
inner: Rc<RefCell<ConnectionInner>>,
}
impl ServerHandle {
pub fn send_to_client(&self, frame: UnitFrame) {
self.inner.borrow_mut().send_downstream(vec![frame.frame])
}
#[cfg(feature = "experimental_websocket_bytes")]
pub fn send_bytes_to_client(&self, bytes: Vec<u8>) {
self.inner.borrow_mut().send_downstream_bytes(bytes)
}
pub fn next(&self) -> Option<UnitFrame> {
self.inner
.borrow_mut()
.connection_state
.server_ready_frames
.pop_front()
.map(|frame| UnitFrame { frame })
}
#[cfg(feature = "experimental_websocket_bytes")]
pub fn bytes(&self) -> Vec<u8> {
self.inner
.borrow_mut()
.connection_state
.server_ready_bytes
.split_off(0)
}
}
pub(crate) struct ConnectionInner {
context_id: u32,
chunk_size: usize,
http_context: Box<dyn HttpContext>,
backends: Rc<RefCell<Backends>>,
host: Rc<RefCell<ProxyWasmStub>>,
connection_state: ConnectionState,
}
#[derive(Copy, Clone)]
pub(crate) enum Direction {
Upstream,
Downstream,
}
#[derive(Default)]
struct ConnectionState {
#[cfg(feature = "experimental_websocket_bytes")]
server_ready_bytes: Vec<u8>,
server_ready_decoder: Decoder,
pub(crate) server_ready_frames: VecDeque<Frame>,
pub(crate) upstream_paused: bool,
#[cfg(feature = "experimental_websocket_bytes")]
client_ready_bytes: Vec<u8>,
client_ready_decoder: Decoder,
pub(crate) client_ready_frames: VecDeque<Frame>,
pub(crate) downstream_paused: bool,
}
impl ConnectionState {
fn set_paused(&mut self, direction: Direction, value: bool) {
match direction {
Direction::Upstream => {
self.upstream_paused = value;
}
Direction::Downstream => {
self.downstream_paused = value;
}
}
}
fn on_body(
&mut self,
context_id: u32,
context: &mut dyn HttpContext,
host: Rc<RefCell<ProxyWasmStub>>,
direction: Direction,
bytes: Vec<u8>,
) -> Action {
match direction {
Direction::Upstream => {
let mut buffer = host
.borrow()
.get_buffer(BufferType::HttpRequestBody, 0, usize::MAX)
.unwrap_or_default()
.unwrap_or_default();
buffer.extend_from_slice(&bytes);
let len = buffer.len();
host.borrow_mut()
.create_buffer(context_id, BufferType::HttpRequestBody, buffer);
context.on_http_request_body(len, false)
}
Direction::Downstream => {
let mut buffer = host
.borrow()
.get_buffer(BufferType::HttpResponseBody, 0, usize::MAX)
.unwrap_or_default()
.unwrap_or_default();
buffer.extend_from_slice(&bytes);
let len = buffer.len();
host.borrow_mut()
.create_buffer(context_id, BufferType::HttpResponseBody, buffer);
context.on_http_response_body(len, false)
}
}
}
fn clean_buffer(&mut self, context_id: u32, host: &mut ProxyWasmStub, direction: Direction) {
match direction {
Direction::Upstream => {
let bytes = host.read_buffer(context_id, BufferType::HttpRequestBody);
#[cfg(feature = "experimental_websocket_bytes")]
self.server_ready_bytes.extend_from_slice(&bytes);
if let SinkResult::Complete(frames) = self.server_ready_decoder.sink(bytes) {
frames
.into_iter()
.for_each(|frame| self.server_ready_frames.push_back(frame))
}
host.create_buffer(context_id, BufferType::HttpRequestBody, vec![]);
}
Direction::Downstream => {
let bytes = host.read_buffer(context_id, BufferType::HttpResponseBody);
#[cfg(feature = "experimental_websocket_bytes")]
self.client_ready_bytes.extend_from_slice(&bytes);
if let SinkResult::Complete(frames) = self.client_ready_decoder.sink(bytes) {
frames
.into_iter()
.for_each(|frame| self.client_ready_frames.push_back(frame))
}
host.create_buffer(context_id, BufferType::HttpResponseBody, vec![]);
}
}
}
fn resume(&mut self, context_id: u32, host: &mut ProxyWasmStub, direction: Direction) {
match direction {
Direction::Upstream => {
let resume = host.clear_resume_request(context_id);
if resume && !self.upstream_paused {
panic!("Called resume on non-paused request state")
}
if resume {
self.clean_buffer(context_id, host, direction)
}
}
Direction::Downstream => {
let resume = host.clear_resume_response(context_id);
if resume && !self.downstream_paused {
panic!("Called resume on non-paused response state")
}
if resume {
self.clean_buffer(context_id, host, direction)
}
}
}
}
}
impl ConnectionInner {
pub(crate) fn send_upstream(&mut self, frames: Vec<Frame>) {
let encoded = Encoder::default().encode_client(frames);
self.send_upstream_bytes(encoded);
}
pub(crate) fn send_upstream_bytes(&mut self, bytes: Vec<u8>) {
self.host.borrow_mut().set_flow_mode(FlowType::Upstream);
self.drive(Direction::Upstream, bytes);
}
pub(crate) fn send_downstream(&mut self, frames: Vec<Frame>) {
let encoded = Encoder::default().encode_server(frames);
self.send_downstream_bytes(encoded);
}
pub(crate) fn send_downstream_bytes(&mut self, bytes: Vec<u8>) {
self.host.borrow_mut().set_flow_mode(FlowType::Downstream);
self.drive(Direction::Downstream, bytes);
}
fn drive(&mut self, direction: Direction, bytes: Vec<u8>) {
self.host.borrow_mut().set_context(self.context_id);
let mut start = 0;
let mut end = 0;
while end < bytes.len() {
end += self.chunk_size;
if end > bytes.len() {
end = bytes.len();
}
let buffer = bytes[start..end].to_vec();
let ctx = self.http_context.as_mut();
let action = self.connection_state.on_body(
self.context_id,
ctx,
Rc::clone(&self.host),
direction,
buffer,
);
match action {
Action::Continue => {
self.connection_state.clean_buffer(
self.context_id,
&mut self.host.borrow_mut(),
direction,
);
self.connection_state.set_paused(direction, false);
}
Action::Pause => {
self.connection_state.set_paused(direction, true);
}
_ => {
panic!("unexpected action: {action:?}");
}
}
start = end
}
self.respond_calls();
self.resume(direction);
}
pub(crate) fn resume(&mut self, direction: Direction) {
if self.host.borrow_mut().clear_send_response(self.context_id) {
panic!("Called send response on websocket flow.")
}
self.connection_state
.resume(self.context_id, &mut self.host.borrow_mut(), direction);
}
fn respond_calls(&mut self) {
self.host.borrow_mut().set_context(self.context_id);
let prev_flow = self.host.borrow_mut().set_flow_mode(FlowType::Async);
let mut pending = self.host.borrow_mut().pending_calls(self.context_id);
while !pending.is_empty() {
for (id, upstream, call) in pending {
respond_call(
self.http_context.as_mut(),
&self.host,
&self.backends,
self.context_id,
id,
upstream,
call,
);
}
pending = self.host.borrow_mut().pending_calls(self.context_id);
}
self.host.borrow_mut().set_flow_mode(prev_flow);
}
}
impl Drop for ConnectionInner {
fn drop(&mut self) {
self.host.borrow_mut().set_context(self.context_id);
self.http_context.on_done();
}
}