1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
use crate::LambdaRuntimeApiClient;
use hyper::{
  body::{Body, Incoming},
  server::conn::http1,
  service::service_fn,
  Request, Response,
};
use hyper_util::rt::TokioIo;
use std::{future::Future, net::SocketAddr};
use tokio::net::TcpListener;

/// A mock server for the Lambda Runtime API.
/// Use [`Self::bind`] to create a new server, and [`Self::serve`] to start serving requests.
///
/// If you want to handle each connection manually, use [`Self::handle_next`].
/// If you want to forward requests to the real Lambda Runtime API, use [`Self::passthrough`].
pub struct MockLambdaRuntimeApiServer(TcpListener);

impl MockLambdaRuntimeApiServer {
  /// Create a new server bound to the provided port.
  pub async fn bind(port: u16) -> Self {
    let addr = SocketAddr::from(([127, 0, 0, 1], port));

    Self(
      TcpListener::bind(addr)
        .await
        .expect("Failed to bind for proxy server"),
    )
  }

  /// Handle the next incoming connection with the provided processor.
  pub async fn handle_next<ResBody, Fut>(
    &self,
    processor: impl Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
  ) where
    ResBody: hyper::body::Body + Send + 'static,
    <ResBody as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
    Fut: Future<Output = hyper::Result<Response<ResBody>>> + Send,
    <ResBody as Body>::Data: Send,
  {
    let (stream, _) = self.0.accept().await.expect("Failed to accept connection");
    let io = TokioIo::new(stream);

    // in lambda's execution environment there is usually only one connection
    // but we can't rely on that, so spawn a task for each connection
    tokio::spawn(async move {
      if let Err(err) = http1::Builder::new()
        .serve_connection(io, service_fn(|req| async { processor(req).await }))
        .await
      {
        println!("Error serving connection: {:?}", err);
      }
    });
  }

  /// Block the current thread and handle connections with the processor in a loop.
  pub async fn serve<ResBody, Fut>(
    &self,
    processor: impl Fn(Request<Incoming>) -> Fut + Send + Sync + Clone + 'static,
  ) where
    Fut: Future<Output = hyper::Result<Response<ResBody>>> + Send,
    ResBody: hyper::body::Body + Send + 'static,
    <ResBody as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
    <ResBody as Body>::Data: Send,
  {
    loop {
      self.handle_next(processor.clone()).await
    }
  }

  /// Block the current thread and handle connections in a loop,
  /// forwarding requests to a new [`LambdaRuntimeApiClient`], and responding with the client's response.
  pub async fn passthrough(&self) {
    self
      .serve(|req| async { LambdaRuntimeApiClient::forward(req).await })
      .await
  }
}