Skip to main content

cargo_lambda_watch/
lib.rs

1use axum::{Router, extract::Extension, http::header::HeaderName};
2use bytes::Bytes;
3use cargo_lambda_metadata::{
4    DEFAULT_PACKAGE_FUNCTION,
5    cargo::{
6        CargoMetadata, CargoPackage, filter_binary_targets_from_metadata, kind_bin_filter,
7        selected_bin_filter, watch::Watch,
8    },
9    env::SystemEnvExtractor,
10    lambda::Timeout,
11};
12use cargo_lambda_remote::tls::TlsOptions;
13use cargo_options::Run as CargoOptions;
14use http_body_util::{BodyExt, combinators::BoxBody};
15use hyper::{Request, Response, body::Incoming, client::conn::http1, service::service_fn};
16use hyper_util::{
17    rt::{TokioExecutor, TokioIo},
18    server::conn::auto::Builder,
19};
20use miette::{IntoDiagnostic, Result, WrapErr};
21use opentelemetry::{
22    global,
23    sdk::{export::trace::stdout, trace, trace::Tracer},
24};
25use opentelemetry_aws::trace::XrayPropagator;
26use rustls::ServerConfig;
27use std::{
28    collections::{HashMap, HashSet},
29    net::{IpAddr, SocketAddr},
30    path::Path,
31    str::FromStr,
32    sync::Arc,
33};
34use tokio::{
35    net::{TcpListener, TcpStream},
36    pin,
37    time::Duration,
38};
39use tokio_graceful_shutdown::{SubsystemBuilder, SubsystemHandle, Toplevel};
40use tokio_rustls::TlsAcceptor;
41use tokio_util::task::TaskTracker;
42use tower_http::{
43    catch_panic::CatchPanicLayer,
44    cors::CorsLayer,
45    request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
46    timeout::TimeoutLayer,
47    trace::TraceLayer,
48};
49use tracing::{Subscriber, error, info};
50use tracing_opentelemetry::OpenTelemetryLayer;
51use tracing_subscriber::registry::LookupSpan;
52
53mod error;
54pub mod eventstream;
55mod instance_pool;
56mod requests;
57mod runtime;
58
59mod scheduler;
60use scheduler::*;
61mod state;
62use state::*;
63mod trigger_router;
64mod watcher;
65use watcher::WatcherConfig;
66
67use crate::{error::ServerError, requests::Action};
68
69pub(crate) const RUNTIME_EMULATOR_PATH: &str = "/.rt";
70
71#[tracing::instrument(target = "cargo_lambda")]
72pub async fn run(
73    config: &Watch,
74    base_env: &HashMap<String, String>,
75    metadata: &CargoMetadata,
76    color: &str,
77) -> Result<()> {
78    tracing::trace!("watching project");
79
80    let manifest_path = config.manifest_path();
81
82    let mut cargo_options = config.cargo_opts.clone();
83    cargo_options.color = Some(color.into());
84    if cargo_options.manifest_path.is_none() {
85        cargo_options.manifest_path = Some(manifest_path.clone());
86    }
87
88    let base = dunce::canonicalize(".").into_diagnostic()?;
89    let ignore_files = watcher::ignore::discover_files(&base, SystemEnvExtractor).await;
90
91    let env = config.lambda_environment(base_env).into_diagnostic()?;
92
93    let package_filter = if !cargo_options.packages.is_empty() {
94        let packages = cargo_options.packages.clone();
95        Some(move |p: &&CargoPackage| packages.contains(&p.name))
96    } else {
97        None
98    };
99
100    let binary_filter = if config.cargo_opts.bin.is_empty() {
101        Box::new(kind_bin_filter)
102    } else {
103        selected_bin_filter(config.cargo_opts.bin.clone())
104    };
105
106    let binary_packages =
107        filter_binary_targets_from_metadata(metadata, binary_filter, package_filter);
108
109    if binary_packages.is_empty() {
110        Err(ServerError::NoBinaryPackages)?;
111    }
112
113    let watcher_config = WatcherConfig {
114        base,
115        ignore_files,
116        env,
117        ignore_changes: config.ignore_changes,
118        only_lambda_apis: config.only_lambda_apis,
119        manifest_path: manifest_path.clone(),
120        wait: config.wait,
121        ..Default::default()
122    };
123
124    let runtime_state = build_runtime_state(config, &manifest_path, binary_packages)?;
125
126    let disable_cors = config.disable_cors;
127    let timeout = config.timeout.clone();
128    let tls_options = config.tls_options.clone();
129
130    let _ = Toplevel::new(move |s| async move {
131        s.start(SubsystemBuilder::new("Lambda server", move |s| {
132            start_server(
133                s,
134                runtime_state,
135                cargo_options,
136                watcher_config,
137                tls_options,
138                disable_cors,
139                timeout,
140            )
141        }));
142    })
143    .catch_signals()
144    .handle_shutdown_requests(Duration::from_secs(1))
145    .await;
146
147    Ok(())
148}
149
150pub fn xray_layer<S>(config: &Watch) -> OpenTelemetryLayer<S, Tracer>
151where
152    S: Subscriber + for<'span> LookupSpan<'span>,
153{
154    global::set_text_map_propagator(XrayPropagator::default());
155
156    let builder = stdout::new_pipeline().with_trace_config(
157        trace::config()
158            .with_sampler(trace::Sampler::AlwaysOn)
159            .with_id_generator(trace::XrayIdGenerator::default()),
160    );
161    let tracer = if config.print_traces {
162        builder.install_simple()
163    } else {
164        builder.with_writer(std::io::sink()).install_simple()
165    };
166    tracing_opentelemetry::layer().with_tracer(tracer)
167}
168
169fn build_runtime_state(
170    config: &Watch,
171    manifest_path: &Path,
172    binary_packages: HashSet<String>,
173) -> Result<RuntimeState> {
174    let ip = IpAddr::from_str(&config.invoke_address)
175        .into_diagnostic()
176        .wrap_err("invalid invoke address")?;
177    let (runtime_port, proxy_addr) = if config.tls_options.is_secure() {
178        (
179            config.invoke_port + 1,
180            Some(SocketAddr::from((ip, config.invoke_port))),
181        )
182    } else {
183        (config.invoke_port, None)
184    };
185    let runtime_addr = SocketAddr::from((ip, runtime_port));
186
187    Ok(RuntimeState::new(
188        runtime_addr,
189        proxy_addr,
190        manifest_path.to_path_buf(),
191        config.only_lambda_apis,
192        binary_packages,
193        config.router.clone(),
194        config.concurrency,
195    ))
196}
197
198async fn start_server(
199    subsys: SubsystemHandle,
200    runtime_state: RuntimeState,
201    cargo_options: CargoOptions,
202    watcher_config: WatcherConfig,
203    tls_options: TlsOptions,
204    disable_cors: bool,
205    timeout: Option<Timeout>,
206) -> Result<()> {
207    let only_lambda_apis = watcher_config.only_lambda_apis;
208    let init_default_function =
209        runtime_state.is_default_function_enabled() && watcher_config.send_function_init();
210
211    let (runtime_addr, proxy_addr, runtime_url) = runtime_state.addresses();
212
213    let x_request_id = HeaderName::from_static("lambda-runtime-aws-request-id");
214    let req_tx = init_scheduler(
215        &subsys,
216        runtime_state.clone(),
217        cargo_options,
218        watcher_config,
219    );
220
221    let state_ref = Arc::new(runtime_state);
222    let mut app = Router::new()
223        .merge(trigger_router::routes().with_state(state_ref.clone()))
224        .nest(
225            RUNTIME_EMULATOR_PATH,
226            runtime::routes().with_state(state_ref.clone()),
227        )
228        .layer(SetRequestIdLayer::new(
229            x_request_id.clone(),
230            MakeRequestUuid,
231        ))
232        .layer(PropagateRequestIdLayer::new(x_request_id))
233        .layer(Extension(req_tx.clone()))
234        .layer(TraceLayer::new_for_http())
235        .layer(CatchPanicLayer::new());
236    if !disable_cors {
237        app = app.layer(CorsLayer::very_permissive());
238    }
239    if let Some(timeout) = timeout {
240        app = app.layer(TimeoutLayer::new(timeout.duration()));
241    }
242    let app = app.with_state(state_ref);
243
244    if only_lambda_apis {
245        info!("");
246        info!(
247            "the flag --only_lambda_apis is active, the lambda function will not be started by Cargo Lambda"
248        );
249        info!("the lambda function will depend on the following environment variables");
250        info!(
251            "you MUST set these variables in the environment where you're running your function:"
252        );
253        info!("AWS_LAMBDA_FUNCTION_VERSION=1");
254        info!("AWS_LAMBDA_FUNCTION_MEMORY_SIZE=4096");
255        info!("AWS_LAMBDA_RUNTIME_API={}", runtime_url);
256        info!("AWS_LAMBDA_FUNCTION_NAME={DEFAULT_PACKAGE_FUNCTION}");
257    } else {
258        let print_start_info = if init_default_function {
259            // This call ignores any error sending the action.
260            // The function can still be lazy loaded later if there is any error.
261            req_tx.send(Action::Init).await.is_err()
262        } else {
263            false
264        };
265
266        if print_start_info {
267            info!("");
268            info!("your function will start running when you send the first invoke request");
269            info!("read the invoke guide if you don't know how to continue:");
270            info!("https://www.cargo-lambda.info/commands/invoke.html");
271        }
272    }
273
274    let tls_config = tls_options.server_config()?;
275    let tls_tracker = TaskTracker::new();
276
277    if let (Some(tls_config), Some(proxy_addr)) = (tls_config, proxy_addr) {
278        let tls_tracker = tls_tracker.clone();
279
280        subsys.start(SubsystemBuilder::new("TLS proxy", move |s| async move {
281            start_tls_proxy(s, tls_tracker, tls_config, proxy_addr, runtime_addr).await
282        }));
283    }
284
285    info!(?runtime_addr, "starting Runtime server");
286    let out = axum::serve(
287        TcpListener::bind(runtime_addr).await.into_diagnostic()?,
288        app.into_make_service_with_connect_info::<SocketAddr>(),
289    )
290    .with_graceful_shutdown(async move {
291        subsys.on_shutdown_requested().await;
292    })
293    .await;
294
295    if let Err(error) = out {
296        error!(error = ?error, "failed to serve HTTP requests");
297    }
298
299    tls_tracker.close();
300    tls_tracker.wait().await;
301
302    Ok(())
303}
304
305async fn start_tls_proxy(
306    subsys: SubsystemHandle,
307    connection_tracker: TaskTracker,
308    tls_config: ServerConfig,
309    proxy_addr: SocketAddr,
310    runtime_addr: SocketAddr,
311) -> Result<()> {
312    info!(
313        ?proxy_addr,
314        "starting TLS server, use this address to send secure requests to the runtime"
315    );
316
317    let acceptor = TlsAcceptor::from(Arc::new(tls_config));
318
319    let listener = TcpListener::bind(proxy_addr).await.into_diagnostic()?;
320
321    let addr = Arc::new(runtime_addr);
322
323    loop {
324        let (stream, _) = listener.accept().await.into_diagnostic()?;
325        let acceptor = acceptor.clone();
326
327        let addr = addr.clone();
328
329        connection_tracker.spawn({
330            let cancellation_token = subsys.create_cancellation_token();
331            let connection_tracker = connection_tracker.clone();
332
333            async move {
334                let hyper_service = service_fn(move |request: Request<Incoming>| {
335                    proxy(connection_tracker.clone(), request, addr.clone())
336                });
337
338                let tls_stream = match acceptor.accept(stream).await {
339                    Ok(tls_stream) => tls_stream,
340                    Err(e) => {
341                        error!(error = ?e, "Failed to accept TLS connection");
342                        return Err(e).into_diagnostic();
343                    }
344                };
345
346                let builder = Builder::new(TokioExecutor::new());
347                let conn = builder.serve_connection(TokioIo::new(tls_stream), hyper_service);
348
349                pin!(conn);
350
351                let result = tokio::select! {
352                    res = conn.as_mut() => res,
353                    _ = cancellation_token.cancelled() => {
354                        conn.as_mut().graceful_shutdown();
355                        conn.await
356                    }
357                };
358
359                if let Err(e) = result {
360                    error!(error = ?e, "Failed to serve connection");
361                }
362
363                Ok(())
364            }
365        });
366    }
367}
368
369async fn proxy(
370    connection_tracker: TaskTracker,
371    req: Request<Incoming>,
372    addr: Arc<SocketAddr>,
373) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
374    let stream = TcpStream::connect(&*addr).await.unwrap();
375    let io = TokioIo::new(stream);
376
377    let (mut sender, conn) = http1::Builder::new()
378        .preserve_header_case(true)
379        .title_case_headers(true)
380        .handshake(io)
381        .await?;
382
383    connection_tracker.spawn(async move {
384        if let Err(err) = conn.await {
385            println!("Connection failed: {err:?}");
386        }
387    });
388
389    let resp = sender.send_request(req).await?;
390    Ok(resp.map(|b| b.boxed()))
391}