use crate::host::implementation::ProxyWasmStub;
use crate::tester::io::{RequestResponse, UnitHttpResponse};
use crate::tester::unit_test::{add_request_properties, respond_call, Backends};
use crate::StopIterationMode;
use proxy_wasm_stub::traits::HttpContext;
use proxy_wasm_stub::types::{Action, BufferType, MapType};
use std::cell::RefCell;
use std::rc::Rc;
use std::task::Poll;
#[derive(PartialOrd, PartialEq, Copy, Clone, Debug)]
pub enum State {
RequestHeaders,
RequestHeadersPaused,
RequestHeadersLimbo,
RequestBody,
RequestBodyPaused,
ResponseHeaders,
ResponseHeadersPaused,
ResponseHeadersLimbo,
ResponseBody,
ResponseBodyPaused,
Done,
Exit,
}
pub struct InnerUnitTestRequest {
state: State,
context_id: u32,
chunk_size: usize,
http_context: Box<dyn HttpContext>,
backends: Rc<RefCell<Backends>>,
host: Rc<RefCell<ProxyWasmStub>>,
stop_mode: Option<StopIterationMode>,
request_response: RequestResponse,
body_buffer: Vec<u8>,
}
impl InnerUnitTestRequest {
#[allow(clippy::too_many_arguments)]
pub fn new(
context_id: u32,
request: RequestResponse,
http_context: Box<dyn HttpContext>,
backends: Rc<RefCell<Backends>>,
host: Rc<RefCell<ProxyWasmStub>>,
stop_mode: Option<StopIterationMode>,
chunk_size: usize,
) -> Self {
Self {
state: State::RequestHeaders,
context_id,
chunk_size,
request_response: request,
http_context,
backends,
host,
stop_mode,
body_buffer: vec![],
}
}
pub fn poll(&mut self) -> Poll<UnitHttpResponse> {
let context_id = self.context_id;
self.host.borrow_mut().set_context(context_id);
self.resume();
if self.state == State::RequestHeaders {
self.set_map(MapType::HttpRequestHeaders);
let headers_len = self.request_response.headers().len();
let eos = self.eos();
let action = self.http_context.on_http_request_headers(headers_len, eos);
if action == Action::Pause && self.clear_send_response() {
self.load_response();
self.state = State::ResponseHeaders;
} else if action == Action::Pause {
if self.stop_mode.is_some() && !eos {
self.state = State::RequestHeadersLimbo;
} else {
self.state = State::RequestHeadersPaused;
}
} else {
self.state = State::RequestBody;
}
if self.stop_mode != Some(StopIterationMode::BodyThenRequests) || eos {
self.respond_calls();
}
}
if self.state == State::RequestBody && self.eos() {
self.call_backend();
self.state = State::ResponseHeaders;
}
if self.state == State::RequestBody || self.state == State::RequestHeadersLimbo {
self.body_buffer.clear();
let mut buffered_chunks: Vec<u8> = vec![];
for chunk in self.request_response.body.clone().chunks(self.chunk_size) {
buffered_chunks.extend_from_slice(chunk);
self.set_buffer(BufferType::HttpRequestBody, buffered_chunks.clone());
let action = self
.http_context
.on_http_request_body(buffered_chunks.len(), false);
if action == Action::Pause {
if self.clear_send_response() {
self.body_buffer.clear();
self.load_response();
self.state = State::ResponseHeaders;
break;
}
} else {
let continued_chunk = self.read_and_clear_buffer(BufferType::HttpRequestBody);
self.body_buffer.extend(continued_chunk);
buffered_chunks.clear();
}
}
if self.state == State::RequestBody || self.state == State::RequestHeadersLimbo {
let action = self
.http_context
.on_http_request_body(buffered_chunks.len(), true);
if action == Action::Pause && self.clear_send_response() {
self.body_buffer.clear();
self.load_response();
self.state = State::ResponseHeaders;
} else if action == Action::Pause {
self.state = State::RequestBodyPaused;
} else {
let continued_chunk = self.read_and_clear_buffer(BufferType::HttpRequestBody);
self.body_buffer.extend(continued_chunk);
self.call_backend();
self.state = State::ResponseHeaders;
}
}
self.respond_calls();
}
if self.state == State::ResponseHeaders {
let headers_len = self.request_response.headers.len();
let eos = self.eos();
self.set_map(MapType::HttpResponseHeaders);
let action = self.http_context.on_http_response_headers(headers_len, eos);
if self.host.borrow_mut().clear_send_response(context_id) {
self.body_buffer.clear();
self.load_response();
self.state = State::Done;
} else if action == Action::Pause {
if self.stop_mode.is_some() && !eos {
self.state = State::ResponseHeadersLimbo;
} else {
self.state = State::ResponseHeadersPaused;
}
} else {
self.state = State::ResponseBody;
}
if self.stop_mode != Some(StopIterationMode::BodyThenRequests) || eos {
self.respond_calls();
}
}
if self.state == State::ResponseBody && self.eos() {
self.body_buffer.clear();
self.load_response();
self.state = State::Done;
}
if self.state == State::ResponseBody || self.state == State::ResponseHeadersLimbo {
self.body_buffer.clear();
let mut buffered_chunks = vec![];
for chunk in self.request_response.body.clone().chunks(self.chunk_size) {
buffered_chunks.extend_from_slice(chunk);
self.set_buffer(BufferType::HttpResponseBody, buffered_chunks.clone());
let action = self
.http_context
.on_http_response_body(buffered_chunks.len(), false);
if action == Action::Continue {
let continued_chunk = self.read_and_clear_buffer(BufferType::HttpResponseBody);
self.body_buffer.extend(continued_chunk);
buffered_chunks.clear();
}
}
let action = self
.http_context
.on_http_response_body(buffered_chunks.len(), true);
if action == Action::Pause {
self.state = State::ResponseBodyPaused;
} else {
let continued_chunk = self.read_and_clear_buffer(BufferType::HttpResponseBody);
self.body_buffer.extend(continued_chunk);
self.load_response();
self.state = State::Done;
}
self.respond_calls();
}
if self.state == State::Done {
self.http_context.on_log();
self.http_context.on_done();
self.state = State::Exit;
}
if self.state == State::Exit {
return Poll::Ready(self.request_response.clone().into());
}
Poll::Pending
}
fn respond_calls(&mut self) {
if self.state >= State::Done {
return;
}
let mut pending_calls = self.host.borrow_mut().pending_calls(self.context_id);
while !pending_calls.is_empty() {
for (id, upstream, call) in pending_calls.into_iter() {
respond_call(
self.http_context.as_mut(),
&self.host,
&self.backends,
self.context_id,
id,
upstream,
call,
);
if self.resume() {
return;
}
}
pending_calls = self.host.borrow_mut().pending_calls(self.context_id);
}
}
fn resume(&mut self) -> bool {
let resume = self.host.borrow_mut().clear_resume(self.context_id);
let send_response = self.clear_send_response();
assert!(!(resume && send_response));
if resume {
self.state = match self.state {
State::RequestHeadersPaused | State::RequestHeadersLimbo => State::RequestBody,
State::RequestBodyPaused => {
self.body_buffer
.extend(self.read_and_clear_buffer(BufferType::HttpRequestBody));
self.call_backend();
State::ResponseHeaders
}
State::ResponseHeadersPaused | State::ResponseHeadersLimbo => State::ResponseBody,
State::ResponseBodyPaused => {
self.body_buffer
.extend(self.read_and_clear_buffer(BufferType::HttpResponseBody));
self.load_response();
State::Done
}
state => panic!("Called resume on non paused state {state:?}"),
};
true
} else if send_response {
self.body_buffer.clear();
self.load_response();
self.state = match self.state {
State::RequestHeadersPaused | State::RequestBodyPaused => State::ResponseHeaders,
State::ResponseHeadersPaused => State::Done,
state => state,
};
true
} else {
false
}
}
fn clear_send_response(&mut self) -> bool {
self.host.borrow_mut().clear_send_response(self.context_id)
}
fn set_map(&self, map_type: MapType) {
let context_id = self.context_id;
self.host.borrow_mut().create_map(
context_id,
map_type,
self.request_response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_string().into_bytes()))
.collect(),
);
}
fn set_buffer(&self, buffer_type: BufferType, buffer: Vec<u8>) {
let context_id = self.context_id;
self.host
.borrow_mut()
.create_buffer(context_id, buffer_type, buffer);
}
fn read_and_clear_buffer(&self, buffer_type: BufferType) -> Vec<u8> {
let context_id = self.context_id;
let buffer = self.host.borrow().read_buffer(context_id, buffer_type);
self.host
.borrow_mut()
.create_buffer(context_id, buffer_type, vec![]);
buffer
}
fn read_request_response(&self, map_type: MapType, buffer_type: BufferType) -> RequestResponse {
let context_id = self.context_id;
let headers = self
.host
.borrow()
.read_map(context_id, map_type)
.into_iter()
.map(|(k, v)| (k, String::from_utf8(v).unwrap()))
.collect();
let mut body = self.body_buffer.clone();
body.extend(self.host.borrow().read_buffer(context_id, buffer_type));
let properties = self.host.borrow().get_properties(context_id);
RequestResponse::create(headers, body, properties)
}
fn read_request(&mut self) -> RequestResponse {
self.read_request_response(MapType::HttpRequestHeaders, BufferType::HttpRequestBody)
}
fn load_response(&mut self) {
self.request_response =
self.read_request_response(MapType::HttpResponseHeaders, BufferType::HttpResponseBody);
}
fn call_backend(&mut self) {
let context_id = self.context_id;
let request = self.read_request();
self.body_buffer.clear();
let response = self.backends.borrow().backend.call(request.into()).inner;
let response = add_request_properties(response, context_id);
self.request_response = response;
self.host
.borrow_mut()
.set_properties(context_id, self.request_response.properties());
}
fn eos(&self) -> bool {
self.request_response.body.is_empty()
}
}