1use crate::host::implementation::{FlowType, ProxyWasmStub};
6use crate::tester::io::{RequestResponse, UnitFrame, UnitHttpResponse};
7use crate::tester::unit_test::{add_request_properties, respond_call, Backends};
8use pdk_websockets_lib::{Decoder, Encoder, Frame, SinkResult};
9use proxy_wasm_stub::stub::Host;
10use proxy_wasm_stub::traits::HttpContext;
11use proxy_wasm_stub::types::{Action, BufferType, MapType};
12use std::cell::RefCell;
13use std::collections::VecDeque;
14use std::rc::{Rc, Weak};
15use std::task::Poll;
16#[derive(Clone)]
23pub struct UnitTestUpgrade {
24 pub(crate) inner: Rc<RefCell<InnerUnitUpgrade>>,
25}
26
27impl UnitTestUpgrade {
28 pub(crate) fn new(inner: InnerUnitUpgrade) -> Self {
29 Self {
30 inner: Rc::new(RefCell::new(inner)),
31 }
32 }
33
34 pub fn poll(&mut self) -> Poll<Result<UpgradeConnection, UnitHttpResponse>> {
40 self.inner.borrow_mut().poll()
41 }
42}
43
44#[derive(PartialOrd, PartialEq, Copy, Clone, Debug)]
45enum UpgradeState {
46 RequestHeaders,
47 RequestHeadersPaused,
48 ResponseHeaders,
49 ResponseHeadersPaused,
50 Done,
51 Rejected,
52}
53
54pub(crate) struct InnerUnitUpgrade {
55 state: UpgradeState,
56 context_id: u32,
57 chunk_size: usize,
58 request: RequestResponse,
59 http_context: Box<dyn HttpContext>,
60 backends: Rc<RefCell<Backends>>,
61 host: Rc<RefCell<ProxyWasmStub>>,
62 cached_connection: Option<Rc<RefCell<ConnectionInner>>>,
63 cached_response: Option<UnitHttpResponse>,
64}
65
66impl InnerUnitUpgrade {
67 pub(crate) fn new(
68 context_id: u32,
69 request: RequestResponse,
70 http_context: Box<dyn HttpContext>,
71 backends: Rc<RefCell<Backends>>,
72 host: Rc<RefCell<ProxyWasmStub>>,
73 chunk_size: usize,
74 ) -> Self {
75 Self {
76 state: UpgradeState::RequestHeaders,
77 context_id,
78 chunk_size,
79 request,
80 http_context,
81 backends,
82 host,
83 cached_connection: None,
84 cached_response: None,
85 }
86 }
87
88 pub(crate) fn poll(&mut self) -> Poll<Result<UpgradeConnection, UnitHttpResponse>> {
89 if let Poll::Ready(result) = self.build_result() {
90 return Poll::Ready(result);
91 }
92
93 let context_id = self.context_id;
94 self.host.borrow_mut().set_context(context_id);
95
96 self.host.borrow_mut().set_flow_mode(FlowType::Upstream);
97 self.resume();
98
99 if self.state == UpgradeState::RequestHeaders {
100 self.host.borrow_mut().create_map(
101 context_id,
102 MapType::HttpRequestHeaders,
103 self.request
104 .headers()
105 .iter()
106 .map(|(k, v)| (k.clone(), v.as_bytes().to_vec()))
107 .collect(),
108 );
109
110 let action = self
111 .http_context
112 .on_http_request_headers(self.request.headers().len(), false);
113
114 if action == Action::Pause && self.clear_send_response() {
115 self.state = UpgradeState::Rejected;
116 } else if action == Action::Pause {
117 self.state = UpgradeState::RequestHeadersPaused;
118 } else {
119 self.state = UpgradeState::ResponseHeaders;
120 }
121
122 self.respond_calls();
123 }
124
125 self.host.borrow_mut().set_flow_mode(FlowType::Downstream);
126
127 if self.state == UpgradeState::ResponseHeaders {
128 let backend_response = self
129 .backends
130 .borrow()
131 .backend
132 .call(self.request.clone().into())
133 .inner;
134 let backend_response = add_request_properties(backend_response, context_id);
135
136 self.host.borrow_mut().create_map(
137 context_id,
138 MapType::HttpResponseHeaders,
139 backend_response
140 .headers()
141 .iter()
142 .map(|(k, v)| (k.clone(), v.as_bytes().to_vec()))
143 .collect(),
144 );
145
146 let num_headers = backend_response.headers().len();
147 self.cached_response = Some(UnitHttpResponse::from(backend_response));
148
149 let action = self
150 .http_context
151 .on_http_response_headers(num_headers, false);
152
153 if action == Action::Pause && self.clear_send_response() {
154 self.state = UpgradeState::Rejected;
155 } else if action == Action::Pause {
156 self.state = UpgradeState::ResponseHeadersPaused;
157 } else {
158 self.state = UpgradeState::Done;
159 }
160
161 self.respond_calls();
162 }
163
164 self.build_result()
165 }
166
167 fn build_result(&mut self) -> Poll<Result<UpgradeConnection, UnitHttpResponse>> {
168 match self.state {
169 UpgradeState::Done => Poll::Ready(Ok(self.build_connection())),
170 UpgradeState::Rejected => Poll::Ready(Err(self.read_response().into())),
171 _ => Poll::Pending,
172 }
173 }
174 fn build_connection(&mut self) -> UpgradeConnection {
175 if self.cached_connection.is_none() {
176 let context_id = self.context_id;
177 self.cached_connection = Some(Rc::new(RefCell::new(ConnectionInner {
178 context_id,
179 chunk_size: self.chunk_size,
180 http_context: self.take_http_context(),
181 backends: Rc::clone(&self.backends),
182 host: Rc::clone(&self.host),
183 connection_state: Default::default(),
184 })));
185 }
186 let response = self
187 .cached_response
188 .clone()
189 .unwrap_or_else(UnitHttpResponse::upgrade);
190 UpgradeConnection::from_inner(
191 Rc::clone(self.cached_connection.as_ref().unwrap()),
192 response,
193 )
194 }
195
196 fn resume(&mut self) -> bool {
197 let resume = self.host.borrow_mut().clear_resume(self.context_id);
198 let send_response = self.clear_send_response();
199 assert!(!(resume && send_response));
200
201 if resume {
202 self.state = match self.state {
203 UpgradeState::RequestHeadersPaused => UpgradeState::ResponseHeaders,
204 UpgradeState::ResponseHeadersPaused => UpgradeState::Done,
205 state => panic!("Called resume on non-paused upgrade state {state:?}"),
206 };
207 true
208 } else if send_response {
209 self.state = UpgradeState::Rejected;
210 true
211 } else {
212 false
213 }
214 }
215
216 fn clear_send_response(&mut self) -> bool {
217 self.host.borrow_mut().clear_send_response(self.context_id)
218 }
219
220 fn respond_calls(&mut self) {
221 let prev_flow = self.host.borrow_mut().set_flow_mode(FlowType::Async);
222 let mut pending = self.host.borrow_mut().pending_calls(self.context_id);
223 while !pending.is_empty() {
224 for (id, upstream, call) in pending {
225 respond_call(
226 self.http_context.as_mut(),
227 &self.host,
228 &self.backends,
229 self.context_id,
230 id,
231 upstream,
232 call,
233 );
234 if self.resume() {
235 self.host.borrow_mut().set_flow_mode(prev_flow);
236 return;
237 }
238 }
239 pending = self.host.borrow_mut().pending_calls(self.context_id);
240 }
241 self.host.borrow_mut().set_flow_mode(prev_flow);
242 }
243
244 fn take_http_context(&mut self) -> Box<dyn HttpContext> {
245 struct NoopHttp;
246 impl proxy_wasm_stub::traits::Context for NoopHttp {}
247 impl HttpContext for NoopHttp {}
248
249 let mut placeholder: Box<dyn HttpContext> = Box::new(NoopHttp);
250 std::mem::swap(&mut self.http_context, &mut placeholder);
251 placeholder
252 }
253
254 fn read_response(&self) -> RequestResponse {
255 let headers = self
256 .host
257 .borrow()
258 .read_map(self.context_id, MapType::HttpResponseHeaders)
259 .into_iter()
260 .map(|(k, v)| (k, String::from_utf8(v).unwrap()))
261 .collect();
262 let body = self
263 .host
264 .borrow()
265 .read_buffer(self.context_id, BufferType::HttpResponseBody);
266 RequestResponse::create(headers, body, Default::default())
267 }
268}
269
270pub struct UpgradeConnection {
279 inner: Rc<RefCell<ConnectionInner>>,
280 response: UnitHttpResponse,
281}
282
283impl Drop for UpgradeConnection {
284 fn drop(&mut self) {
285 self.inner.borrow_mut().http_context.on_log();
286 self.inner.borrow_mut().http_context.on_done();
287 }
288}
289
290impl UpgradeConnection {
291 fn from_inner(inner: Rc<RefCell<ConnectionInner>>, response: UnitHttpResponse) -> Self {
292 Self { inner, response }
293 }
294
295 pub fn response(&self) -> &UnitHttpResponse {
297 &self.response
298 }
299
300 pub fn client(&self) -> ClientHandle {
302 ClientHandle {
303 inner: Rc::clone(&self.inner),
304 }
305 }
306
307 pub fn server(&self) -> ServerHandle {
309 ServerHandle {
310 inner: Rc::clone(&self.inner),
311 }
312 }
313
314 pub(crate) fn weak_inner(&self) -> Weak<RefCell<ConnectionInner>> {
315 Rc::downgrade(&self.inner)
316 }
317}
318
319pub struct ClientHandle {
323 inner: Rc<RefCell<ConnectionInner>>,
324}
325
326impl ClientHandle {
327 pub fn send_to_server(&self, frame: UnitFrame) {
331 self.inner.borrow_mut().send_upstream(vec![frame.frame])
332 }
333
334 #[cfg(feature = "experimental_websocket_bytes")]
335 pub fn send_bytes_to_server(&self, bytes: Vec<u8>) {
337 self.inner.borrow_mut().send_upstream_bytes(bytes)
338 }
339
340 pub fn next(&self) -> Option<UnitFrame> {
342 self.inner
343 .borrow_mut()
344 .connection_state
345 .client_ready_frames
346 .pop_front()
347 .map(|frame| UnitFrame { frame })
348 }
349
350 #[cfg(feature = "experimental_websocket_bytes")]
351 pub fn bytes(&self) -> Vec<u8> {
353 self.inner
354 .borrow_mut()
355 .connection_state
356 .client_ready_bytes
357 .split_off(0)
358 }
359}
360
361pub struct ServerHandle {
363 inner: Rc<RefCell<ConnectionInner>>,
364}
365
366impl ServerHandle {
367 pub fn send_to_client(&self, frame: UnitFrame) {
371 self.inner.borrow_mut().send_downstream(vec![frame.frame])
372 }
373
374 #[cfg(feature = "experimental_websocket_bytes")]
375 pub fn send_bytes_to_client(&self, bytes: Vec<u8>) {
377 self.inner.borrow_mut().send_downstream_bytes(bytes)
378 }
379
380 pub fn next(&self) -> Option<UnitFrame> {
382 self.inner
383 .borrow_mut()
384 .connection_state
385 .server_ready_frames
386 .pop_front()
387 .map(|frame| UnitFrame { frame })
388 }
389
390 #[cfg(feature = "experimental_websocket_bytes")]
391 pub fn bytes(&self) -> Vec<u8> {
393 self.inner
394 .borrow_mut()
395 .connection_state
396 .server_ready_bytes
397 .split_off(0)
398 }
399}
400
401pub(crate) struct ConnectionInner {
404 context_id: u32,
405 chunk_size: usize,
406 http_context: Box<dyn HttpContext>,
407 backends: Rc<RefCell<Backends>>,
408 host: Rc<RefCell<ProxyWasmStub>>,
409 connection_state: ConnectionState,
410}
411
412#[derive(Copy, Clone)]
413pub(crate) enum Direction {
414 Upstream,
415 Downstream,
416}
417
418#[derive(Default)]
419struct ConnectionState {
420 #[cfg(feature = "experimental_websocket_bytes")]
421 server_ready_bytes: Vec<u8>,
422 server_ready_decoder: Decoder,
423 pub(crate) server_ready_frames: VecDeque<Frame>,
424 pub(crate) upstream_paused: bool,
425 #[cfg(feature = "experimental_websocket_bytes")]
426 client_ready_bytes: Vec<u8>,
427 client_ready_decoder: Decoder,
428 pub(crate) client_ready_frames: VecDeque<Frame>,
429 pub(crate) downstream_paused: bool,
430}
431impl ConnectionState {
432 fn set_paused(&mut self, direction: Direction, value: bool) {
433 match direction {
434 Direction::Upstream => {
435 self.upstream_paused = value;
436 }
437 Direction::Downstream => {
438 self.downstream_paused = value;
439 }
440 }
441 }
442
443 fn on_body(
444 &mut self,
445 context_id: u32,
446 context: &mut dyn HttpContext,
447 host: Rc<RefCell<ProxyWasmStub>>,
448 direction: Direction,
449 bytes: Vec<u8>,
450 ) -> Action {
451 match direction {
452 Direction::Upstream => {
453 let mut buffer = host
454 .borrow()
455 .get_buffer(BufferType::HttpRequestBody, 0, usize::MAX)
456 .unwrap_or_default()
457 .unwrap_or_default();
458
459 buffer.extend_from_slice(&bytes);
460
461 let len = buffer.len();
462 host.borrow_mut()
463 .create_buffer(context_id, BufferType::HttpRequestBody, buffer);
464
465 context.on_http_request_body(len, false)
466 }
467 Direction::Downstream => {
468 let mut buffer = host
469 .borrow()
470 .get_buffer(BufferType::HttpResponseBody, 0, usize::MAX)
471 .unwrap_or_default()
472 .unwrap_or_default();
473
474 buffer.extend_from_slice(&bytes);
475
476 let len = buffer.len();
477 host.borrow_mut()
478 .create_buffer(context_id, BufferType::HttpResponseBody, buffer);
479 context.on_http_response_body(len, false)
480 }
481 }
482 }
483
484 fn clean_buffer(&mut self, context_id: u32, host: &mut ProxyWasmStub, direction: Direction) {
485 match direction {
486 Direction::Upstream => {
487 let bytes = host.read_buffer(context_id, BufferType::HttpRequestBody);
488 #[cfg(feature = "experimental_websocket_bytes")]
489 self.server_ready_bytes.extend_from_slice(&bytes);
490 if let SinkResult::Complete(frames) = self.server_ready_decoder.sink(bytes) {
491 frames
492 .into_iter()
493 .for_each(|frame| self.server_ready_frames.push_back(frame))
494 }
495 host.create_buffer(context_id, BufferType::HttpRequestBody, vec![]);
496 }
497 Direction::Downstream => {
498 let bytes = host.read_buffer(context_id, BufferType::HttpResponseBody);
499 #[cfg(feature = "experimental_websocket_bytes")]
500 self.client_ready_bytes.extend_from_slice(&bytes);
501 if let SinkResult::Complete(frames) = self.client_ready_decoder.sink(bytes) {
502 frames
503 .into_iter()
504 .for_each(|frame| self.client_ready_frames.push_back(frame))
505 }
506 host.create_buffer(context_id, BufferType::HttpResponseBody, vec![]);
507 }
508 }
509 }
510
511 fn resume(&mut self, context_id: u32, host: &mut ProxyWasmStub, direction: Direction) {
512 match direction {
513 Direction::Upstream => {
514 let resume = host.clear_resume_request(context_id);
515 if resume && !self.upstream_paused {
516 panic!("Called resume on non-paused request state")
517 }
518 if resume {
519 self.clean_buffer(context_id, host, direction)
520 }
521 }
522 Direction::Downstream => {
523 let resume = host.clear_resume_response(context_id);
524 if resume && !self.downstream_paused {
525 panic!("Called resume on non-paused response state")
526 }
527 if resume {
528 self.clean_buffer(context_id, host, direction)
529 }
530 }
531 }
532 }
533}
534
535impl ConnectionInner {
536 pub(crate) fn send_upstream(&mut self, frames: Vec<Frame>) {
537 let encoded = Encoder::default().encode_client(frames);
538 self.send_upstream_bytes(encoded);
539 }
540
541 pub(crate) fn send_upstream_bytes(&mut self, bytes: Vec<u8>) {
542 self.host.borrow_mut().set_flow_mode(FlowType::Upstream);
543 self.drive(Direction::Upstream, bytes);
544 }
545
546 pub(crate) fn send_downstream(&mut self, frames: Vec<Frame>) {
547 let encoded = Encoder::default().encode_server(frames);
548 self.send_downstream_bytes(encoded);
549 }
550
551 pub(crate) fn send_downstream_bytes(&mut self, bytes: Vec<u8>) {
552 self.host.borrow_mut().set_flow_mode(FlowType::Downstream);
553 self.drive(Direction::Downstream, bytes);
554 }
555
556 fn drive(&mut self, direction: Direction, bytes: Vec<u8>) {
557 self.host.borrow_mut().set_context(self.context_id);
558
559 let mut start = 0;
560 let mut end = 0;
561
562 while end < bytes.len() {
563 end += self.chunk_size;
564 if end > bytes.len() {
565 end = bytes.len();
566 }
567
568 let buffer = bytes[start..end].to_vec();
569 let ctx = self.http_context.as_mut();
570 let action = self.connection_state.on_body(
571 self.context_id,
572 ctx,
573 Rc::clone(&self.host),
574 direction,
575 buffer,
576 );
577 match action {
578 Action::Continue => {
579 self.connection_state.clean_buffer(
580 self.context_id,
581 &mut self.host.borrow_mut(),
582 direction,
583 );
584 self.connection_state.set_paused(direction, false);
585 }
586 Action::Pause => {
587 self.connection_state.set_paused(direction, true);
588 }
589 _ => {
590 panic!("unexpected action: {action:?}");
591 }
592 }
593 start = end
594 }
595
596 self.respond_calls();
597 self.resume(direction);
598 }
599
600 pub(crate) fn resume(&mut self, direction: Direction) {
601 if self.host.borrow_mut().clear_send_response(self.context_id) {
602 panic!("Called send response on websocket flow.")
603 }
604
605 self.connection_state
606 .resume(self.context_id, &mut self.host.borrow_mut(), direction);
607 }
608
609 fn respond_calls(&mut self) {
610 self.host.borrow_mut().set_context(self.context_id);
611 let prev_flow = self.host.borrow_mut().set_flow_mode(FlowType::Async);
612 let mut pending = self.host.borrow_mut().pending_calls(self.context_id);
613 while !pending.is_empty() {
614 for (id, upstream, call) in pending {
615 respond_call(
616 self.http_context.as_mut(),
617 &self.host,
618 &self.backends,
619 self.context_id,
620 id,
621 upstream,
622 call,
623 );
624 }
625 pending = self.host.borrow_mut().pending_calls(self.context_id);
626 }
627 self.host.borrow_mut().set_flow_mode(prev_flow);
628 }
629}
630
631impl Drop for ConnectionInner {
632 fn drop(&mut self) {
633 self.host.borrow_mut().set_context(self.context_id);
634 self.http_context.on_done();
635 }
636}