1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
use axum::{extract::Extension, http::header::HeaderName, Router};
use clap::{Args, ValueHint};
use miette::Result;
use std::{net::SocketAddr, path::PathBuf};
use tokio::time::Duration;
use tokio_graceful_shutdown::{SubsystemHandle, Toplevel};
use tower_http::{
    catch_panic::CatchPanicLayer,
    request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
    trace::TraceLayer,
};
use tracing::info;

mod requests;
mod runtime_router;

mod scheduler;
use scheduler::*;
mod trace;
use trace::*;
mod trigger_router;
mod watch_installer;

const RUNTIME_EMULATOR_PATH: &str = "/.rt";

#[derive(Args, Clone, Debug)]
#[clap(name = "watch", visible_alias = "start")]
pub struct Watch {
    /// Avoid hot-reload
    #[clap(long)]
    no_reload: bool,

    /// Address port where users send invoke requests
    #[clap(short = 'p', long, default_value = "9000")]
    invoke_port: u16,

    /// Print OpenTelemetry traces after each function invocation
    #[clap(long)]
    print_traces: bool,

    /// Path to Cargo.toml
    #[clap(long, value_name = "PATH", parse(from_os_str), value_hint = ValueHint::FilePath)]
    #[clap(default_value = "Cargo.toml")]
    pub manifest_path: PathBuf,
}

impl Watch {
    pub async fn run(&self) -> Result<()> {
        if !self.no_reload && which::which("cargo-watch").is_err() {
            watch_installer::install().await?;
        }

        let port = self.invoke_port;
        let print_traces = self.print_traces;
        let manifest_path = self.manifest_path.clone();
        let no_reload = self.no_reload;

        Toplevel::new()
            .start("Lambda server", move |s| {
                start_server(s, port, print_traces, manifest_path, no_reload)
            })
            .catch_signals()
            .handle_shutdown_requests(Duration::from_millis(1000))
            .await
            .map_err(|e| miette::miette!("{}", e))
    }
}

async fn start_server(
    subsys: SubsystemHandle,
    invoke_port: u16,
    print_traces: bool,
    manifest_path: PathBuf,
    no_reload: bool,
) -> Result<(), axum::Error> {
    init_tracing(print_traces);

    let addr = SocketAddr::from(([127, 0, 0, 1], invoke_port));
    let runtime_addr = format!("http://{addr}{RUNTIME_EMULATOR_PATH}");

    let req_cache = RequestCache::new(runtime_addr);
    let req_tx = init_scheduler(&subsys, req_cache.clone(), manifest_path, no_reload).await;
    let resp_cache = ResponseCache::new();
    let x_request_id = HeaderName::from_static("lambda-runtime-aws-request-id");

    let app = Router::new()
        .merge(trigger_router::routes())
        .nest(RUNTIME_EMULATOR_PATH, runtime_router::routes())
        .layer(SetRequestIdLayer::new(
            x_request_id.clone(),
            RequestUuidService,
        ))
        .layer(PropagateRequestIdLayer::new(x_request_id))
        .layer(Extension(req_tx.clone()))
        .layer(Extension(req_cache))
        .layer(Extension(resp_cache))
        .layer(TraceLayer::new_for_http())
        .layer(CatchPanicLayer::new());

    info!("invoke server listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .with_graceful_shutdown(subsys.on_shutdown_requested())
        .await
        .map_err(axum::Error::new)
}