cargo_lambda_watch/
lib.rs

1use axum::{Router, extract::Extension, http::header::HeaderName};
2use bytes::Bytes;
3use cargo_lambda_metadata::{
4    DEFAULT_PACKAGE_FUNCTION,
5    cargo::{
6        CargoMetadata, CargoPackage, filter_binary_targets_from_metadata, kind_bin_filter,
7        selected_bin_filter, watch::Watch,
8    },
9    lambda::Timeout,
10};
11use cargo_lambda_remote::tls::TlsOptions;
12use cargo_options::Run as CargoOptions;
13use http_body_util::{BodyExt, combinators::BoxBody};
14use hyper::{Request, Response, body::Incoming, client::conn::http1, service::service_fn};
15use hyper_util::{
16    rt::{TokioExecutor, TokioIo},
17    server::conn::auto::Builder,
18};
19use miette::{IntoDiagnostic, Result, WrapErr};
20use opentelemetry::{
21    global,
22    sdk::{export::trace::stdout, trace, trace::Tracer},
23};
24use opentelemetry_aws::trace::XrayPropagator;
25use rustls::ServerConfig;
26use std::{
27    collections::{HashMap, HashSet},
28    net::{IpAddr, SocketAddr},
29    path::Path,
30    str::FromStr,
31    sync::Arc,
32};
33use tokio::{
34    net::{TcpListener, TcpStream},
35    pin,
36    time::Duration,
37};
38use tokio_graceful_shutdown::{SubsystemBuilder, SubsystemHandle, Toplevel};
39use tokio_rustls::TlsAcceptor;
40use tokio_util::task::TaskTracker;
41use tower_http::{
42    catch_panic::CatchPanicLayer,
43    cors::CorsLayer,
44    request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
45    timeout::TimeoutLayer,
46    trace::TraceLayer,
47};
48use tracing::{Subscriber, error, info};
49use tracing_opentelemetry::OpenTelemetryLayer;
50use tracing_subscriber::registry::LookupSpan;
51
52mod error;
53mod requests;
54mod runtime;
55
56mod scheduler;
57use scheduler::*;
58mod state;
59use state::*;
60mod trigger_router;
61mod watcher;
62use watcher::WatcherConfig;
63
64use crate::{error::ServerError, requests::Action};
65
66pub(crate) const RUNTIME_EMULATOR_PATH: &str = "/.rt";
67
68#[tracing::instrument(target = "cargo_lambda")]
69pub async fn run(
70    config: &Watch,
71    base_env: &HashMap<String, String>,
72    metadata: &CargoMetadata,
73    color: &str,
74) -> Result<()> {
75    tracing::trace!("watching project");
76
77    let manifest_path = config.manifest_path();
78
79    let mut cargo_options = config.cargo_opts.clone();
80    cargo_options.color = Some(color.into());
81    if cargo_options.manifest_path.is_none() {
82        cargo_options.manifest_path = Some(manifest_path.clone());
83    }
84
85    let base = dunce::canonicalize(".").into_diagnostic()?;
86    let ignore_files = watcher::ignore::discover_files(&base).await;
87
88    let env = config.lambda_environment(base_env).into_diagnostic()?;
89
90    let package_filter = if !cargo_options.packages.is_empty() {
91        let packages = cargo_options.packages.clone();
92        Some(move |p: &&CargoPackage| packages.contains(&p.name))
93    } else {
94        None
95    };
96
97    let binary_filter = if config.cargo_opts.bin.is_empty() {
98        Box::new(kind_bin_filter)
99    } else {
100        selected_bin_filter(config.cargo_opts.bin.clone())
101    };
102
103    let binary_packages =
104        filter_binary_targets_from_metadata(metadata, binary_filter, package_filter);
105
106    if binary_packages.is_empty() {
107        Err(ServerError::NoBinaryPackages)?;
108    }
109
110    let watcher_config = WatcherConfig {
111        base,
112        ignore_files,
113        env,
114        ignore_changes: config.ignore_changes,
115        only_lambda_apis: config.only_lambda_apis,
116        manifest_path: manifest_path.clone(),
117        wait: config.wait,
118        ..Default::default()
119    };
120
121    let runtime_state = build_runtime_state(config, &manifest_path, binary_packages)?;
122
123    let disable_cors = config.disable_cors;
124    let timeout = config.timeout.clone();
125    let tls_options = config.tls_options.clone();
126
127    let _ = Toplevel::new(move |s| async move {
128        s.start(SubsystemBuilder::new("Lambda server", move |s| {
129            start_server(
130                s,
131                runtime_state,
132                cargo_options,
133                watcher_config,
134                tls_options,
135                disable_cors,
136                timeout,
137            )
138        }));
139    })
140    .catch_signals()
141    .handle_shutdown_requests(Duration::from_secs(1))
142    .await;
143
144    Ok(())
145}
146
147pub fn xray_layer<S>(config: &Watch) -> OpenTelemetryLayer<S, Tracer>
148where
149    S: Subscriber + for<'span> LookupSpan<'span>,
150{
151    global::set_text_map_propagator(XrayPropagator::default());
152
153    let builder = stdout::new_pipeline().with_trace_config(
154        trace::config()
155            .with_sampler(trace::Sampler::AlwaysOn)
156            .with_id_generator(trace::XrayIdGenerator::default()),
157    );
158    let tracer = if config.print_traces {
159        builder.install_simple()
160    } else {
161        builder.with_writer(std::io::sink()).install_simple()
162    };
163    tracing_opentelemetry::layer().with_tracer(tracer)
164}
165
166fn build_runtime_state(
167    config: &Watch,
168    manifest_path: &Path,
169    binary_packages: HashSet<String>,
170) -> Result<RuntimeState> {
171    let ip = IpAddr::from_str(&config.invoke_address)
172        .into_diagnostic()
173        .wrap_err("invalid invoke address")?;
174    let (runtime_port, proxy_addr) = if config.tls_options.is_secure() {
175        (
176            config.invoke_port + 1,
177            Some(SocketAddr::from((ip, config.invoke_port))),
178        )
179    } else {
180        (config.invoke_port, None)
181    };
182    let runtime_addr = SocketAddr::from((ip, runtime_port));
183
184    Ok(RuntimeState::new(
185        runtime_addr,
186        proxy_addr,
187        manifest_path.to_path_buf(),
188        config.only_lambda_apis,
189        binary_packages,
190        config.router.clone(),
191    ))
192}
193
194async fn start_server(
195    subsys: SubsystemHandle,
196    runtime_state: RuntimeState,
197    cargo_options: CargoOptions,
198    watcher_config: WatcherConfig,
199    tls_options: TlsOptions,
200    disable_cors: bool,
201    timeout: Option<Timeout>,
202) -> Result<()> {
203    let only_lambda_apis = watcher_config.only_lambda_apis;
204    let init_default_function =
205        runtime_state.is_default_function_enabled() && watcher_config.send_function_init();
206
207    let (runtime_addr, proxy_addr, runtime_url) = runtime_state.addresses();
208
209    let x_request_id = HeaderName::from_static("lambda-runtime-aws-request-id");
210    let req_tx = init_scheduler(
211        &subsys,
212        runtime_state.clone(),
213        cargo_options,
214        watcher_config,
215    );
216
217    let state_ref = Arc::new(runtime_state);
218    let mut app = Router::new()
219        .merge(trigger_router::routes().with_state(state_ref.clone()))
220        .nest(
221            RUNTIME_EMULATOR_PATH,
222            runtime::routes().with_state(state_ref.clone()),
223        )
224        .layer(SetRequestIdLayer::new(
225            x_request_id.clone(),
226            MakeRequestUuid,
227        ))
228        .layer(PropagateRequestIdLayer::new(x_request_id))
229        .layer(Extension(req_tx.clone()))
230        .layer(TraceLayer::new_for_http())
231        .layer(CatchPanicLayer::new());
232    if !disable_cors {
233        app = app.layer(CorsLayer::very_permissive());
234    }
235    if let Some(timeout) = timeout {
236        app = app.layer(TimeoutLayer::new(timeout.duration()));
237    }
238    let app = app.with_state(state_ref);
239
240    if only_lambda_apis {
241        info!("");
242        info!(
243            "the flag --only_lambda_apis is active, the lambda function will not be started by Cargo Lambda"
244        );
245        info!("the lambda function will depend on the following environment variables");
246        info!(
247            "you MUST set these variables in the environment where you're running your function:"
248        );
249        info!("AWS_LAMBDA_FUNCTION_VERSION=1");
250        info!("AWS_LAMBDA_FUNCTION_MEMORY_SIZE=4096");
251        info!("AWS_LAMBDA_RUNTIME_API={}", runtime_url);
252        info!("AWS_LAMBDA_FUNCTION_NAME={DEFAULT_PACKAGE_FUNCTION}");
253    } else {
254        let print_start_info = if init_default_function {
255            // This call ignores any error sending the action.
256            // The function can still be lazy loaded later if there is any error.
257            req_tx.send(Action::Init).await.is_err()
258        } else {
259            false
260        };
261
262        if print_start_info {
263            info!("");
264            info!("your function will start running when you send the first invoke request");
265            info!("read the invoke guide if you don't know how to continue:");
266            info!("https://www.cargo-lambda.info/commands/invoke.html");
267        }
268    }
269
270    let tls_config = tls_options.server_config()?;
271    let tls_tracker = TaskTracker::new();
272
273    if let (Some(tls_config), Some(proxy_addr)) = (tls_config, proxy_addr) {
274        let tls_tracker = tls_tracker.clone();
275
276        subsys.start(SubsystemBuilder::new("TLS proxy", move |s| async move {
277            start_tls_proxy(s, tls_tracker, tls_config, proxy_addr, runtime_addr).await
278        }));
279    }
280
281    info!(?runtime_addr, "starting Runtime server");
282    let out = axum::serve(
283        TcpListener::bind(runtime_addr).await.into_diagnostic()?,
284        app.into_make_service(),
285    )
286    .with_graceful_shutdown(async move {
287        subsys.on_shutdown_requested().await;
288    })
289    .await;
290
291    if let Err(error) = out {
292        error!(error = ?error, "failed to serve HTTP requests");
293    }
294
295    tls_tracker.close();
296    tls_tracker.wait().await;
297
298    Ok(())
299}
300
301async fn start_tls_proxy(
302    subsys: SubsystemHandle,
303    connection_tracker: TaskTracker,
304    tls_config: ServerConfig,
305    proxy_addr: SocketAddr,
306    runtime_addr: SocketAddr,
307) -> Result<()> {
308    info!(
309        ?proxy_addr,
310        "starting TLS server, use this address to send secure requests to the runtime"
311    );
312
313    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
314
315    let listener = TcpListener::bind(proxy_addr).await.into_diagnostic()?;
316
317    let addr = Arc::new(runtime_addr);
318
319    loop {
320        let (stream, _) = listener.accept().await.into_diagnostic()?;
321        let acceptor = acceptor.clone();
322
323        let addr = addr.clone();
324
325        connection_tracker.spawn({
326            let cancellation_token = subsys.create_cancellation_token();
327            let connection_tracker = connection_tracker.clone();
328
329            async move {
330                let hyper_service = service_fn(move |request: Request<Incoming>| {
331                    proxy(connection_tracker.clone(), request, addr.clone())
332                });
333
334                let tls_stream = match acceptor.accept(stream).await {
335                    Ok(tls_stream) => tls_stream,
336                    Err(e) => {
337                        error!(error = ?e, "Failed to accept TLS connection");
338                        return Err(e).into_diagnostic();
339                    }
340                };
341
342                let builder = Builder::new(TokioExecutor::new());
343                let conn = builder.serve_connection(TokioIo::new(tls_stream), hyper_service);
344
345                pin!(conn);
346
347                let result = tokio::select! {
348                    res = conn.as_mut() => res,
349                    _ = cancellation_token.cancelled() => {
350                        conn.as_mut().graceful_shutdown();
351                        conn.await
352                    }
353                };
354
355                if let Err(e) = result {
356                    error!(error = ?e, "Failed to serve connection");
357                }
358
359                Ok(())
360            }
361        });
362    }
363}
364
365async fn proxy(
366    connection_tracker: TaskTracker,
367    req: Request<Incoming>,
368    addr: Arc<SocketAddr>,
369) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
370    let stream = TcpStream::connect(&*addr).await.unwrap();
371    let io = TokioIo::new(stream);
372
373    let (mut sender, conn) = http1::Builder::new()
374        .preserve_header_case(true)
375        .title_case_headers(true)
376        .handshake(io)
377        .await?;
378
379    connection_tracker.spawn(async move {
380        if let Err(err) = conn.await {
381            println!("Connection failed: {:?}", err);
382        }
383    });
384
385    let resp = sender.send_request(req).await?;
386    Ok(resp.map(|b| b.boxed()))
387}