1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use crate::LambdaRuntimeApiClient;
use hyper::{
  body::{Body, Incoming},
  server::conn::http1,
  service::service_fn,
  Request, Response,
};
use hyper_util::rt::TokioIo;
use std::{future::Future, net::SocketAddr};
use tokio::{net::TcpListener, sync::Mutex};

pub struct MockLambdaRuntimeApiServer(TcpListener);

impl MockLambdaRuntimeApiServer {
  /// Create a new server bound to the provided port.
  pub async fn bind(port: u16) -> Self {
    let addr = SocketAddr::from(([127, 0, 0, 1], port));

    Self(
      TcpListener::bind(addr)
        .await
        .expect("Failed to bind for proxy server"),
    )
  }

  /// Handle the next incoming request.
  pub async fn handle_next<ResBody, Fut>(&self, processor: impl Fn(Request<Incoming>) -> Fut)
  where
    ResBody: hyper::body::Body + 'static,
    <ResBody as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
    Fut: Future<Output = hyper::Result<Response<ResBody>>>,
  {
    let (stream, _) = self.0.accept().await.expect("Failed to accept connection");
    let io = TokioIo::new(stream);

    if let Err(err) = http1::Builder::new()
      .serve_connection(io, service_fn(|req| async { processor(req).await }))
      .await
    {
      println!("Error serving connection: {:?}", err);
    }
  }

  /// Block the current thread and handle requests with the processor in a loop.
  pub async fn serve<ResBody, Fut>(&self, processor: impl Fn(Request<Incoming>) -> Fut)
  where
    Fut: Future<Output = hyper::Result<Response<ResBody>>>,
    ResBody: hyper::body::Body + 'static,
    <ResBody as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
  {
    loop {
      self.handle_next(&processor).await
    }
  }

  /// Block the current thread and handle requests in a loop,
  /// forwarding them to the provided client, and responding with the client's response.
  pub async fn passthrough(&self, client: LambdaRuntimeApiClient<Incoming>) {
    // TODO: how to avoid creating the Mutex here?
    let client = Mutex::new(client);
    self
      .serve(|req| async { client.lock().await.send_request(req).await })
      .await
  }
}