1use axum::{Router, extract::Extension, http::header::HeaderName};
2use bytes::Bytes;
3use cargo_lambda_metadata::{
4 DEFAULT_PACKAGE_FUNCTION,
5 cargo::{
6 CargoMetadata, CargoPackage, filter_binary_targets_from_metadata, kind_bin_filter,
7 selected_bin_filter, watch::Watch,
8 },
9 env::SystemEnvExtractor,
10 lambda::Timeout,
11};
12use cargo_lambda_remote::tls::TlsOptions;
13use cargo_options::Run as CargoOptions;
14use http_body_util::{BodyExt, combinators::BoxBody};
15use hyper::{Request, Response, body::Incoming, client::conn::http1, service::service_fn};
16use hyper_util::{
17 rt::{TokioExecutor, TokioIo},
18 server::conn::auto::Builder,
19};
20use miette::{IntoDiagnostic, Result, WrapErr};
21use opentelemetry::{
22 global,
23 sdk::{export::trace::stdout, trace, trace::Tracer},
24};
25use opentelemetry_aws::trace::XrayPropagator;
26use rustls::ServerConfig;
27use std::{
28 collections::{HashMap, HashSet},
29 net::{IpAddr, SocketAddr},
30 path::Path,
31 str::FromStr,
32 sync::Arc,
33};
34use tokio::{
35 net::{TcpListener, TcpStream},
36 pin,
37 time::Duration,
38};
39use tokio_graceful_shutdown::{SubsystemBuilder, SubsystemHandle, Toplevel};
40use tokio_rustls::TlsAcceptor;
41use tokio_util::task::TaskTracker;
42use tower_http::{
43 catch_panic::CatchPanicLayer,
44 cors::CorsLayer,
45 request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
46 timeout::TimeoutLayer,
47 trace::TraceLayer,
48};
49use tracing::{Subscriber, error, info};
50use tracing_opentelemetry::OpenTelemetryLayer;
51use tracing_subscriber::registry::LookupSpan;
52
53mod error;
54mod requests;
55mod runtime;
56
57mod scheduler;
58use scheduler::*;
59mod state;
60use state::*;
61mod trigger_router;
62mod watcher;
63use watcher::WatcherConfig;
64
65use crate::{error::ServerError, requests::Action};
66
67pub(crate) const RUNTIME_EMULATOR_PATH: &str = "/.rt";
68
69#[tracing::instrument(target = "cargo_lambda")]
70pub async fn run(
71 config: &Watch,
72 base_env: &HashMap<String, String>,
73 metadata: &CargoMetadata,
74 color: &str,
75) -> Result<()> {
76 tracing::trace!("watching project");
77
78 let manifest_path = config.manifest_path();
79
80 let mut cargo_options = config.cargo_opts.clone();
81 cargo_options.color = Some(color.into());
82 if cargo_options.manifest_path.is_none() {
83 cargo_options.manifest_path = Some(manifest_path.clone());
84 }
85
86 let base = dunce::canonicalize(".").into_diagnostic()?;
87 let ignore_files = watcher::ignore::discover_files(&base, SystemEnvExtractor).await;
88
89 let env = config.lambda_environment(base_env).into_diagnostic()?;
90
91 let package_filter = if !cargo_options.packages.is_empty() {
92 let packages = cargo_options.packages.clone();
93 Some(move |p: &&CargoPackage| packages.contains(&p.name))
94 } else {
95 None
96 };
97
98 let binary_filter = if config.cargo_opts.bin.is_empty() {
99 Box::new(kind_bin_filter)
100 } else {
101 selected_bin_filter(config.cargo_opts.bin.clone())
102 };
103
104 let binary_packages =
105 filter_binary_targets_from_metadata(metadata, binary_filter, package_filter);
106
107 if binary_packages.is_empty() {
108 Err(ServerError::NoBinaryPackages)?;
109 }
110
111 let watcher_config = WatcherConfig {
112 base,
113 ignore_files,
114 env,
115 ignore_changes: config.ignore_changes,
116 only_lambda_apis: config.only_lambda_apis,
117 manifest_path: manifest_path.clone(),
118 wait: config.wait,
119 ..Default::default()
120 };
121
122 let runtime_state = build_runtime_state(config, &manifest_path, binary_packages)?;
123
124 let disable_cors = config.disable_cors;
125 let timeout = config.timeout.clone();
126 let tls_options = config.tls_options.clone();
127
128 let _ = Toplevel::new(move |s| async move {
129 s.start(SubsystemBuilder::new("Lambda server", move |s| {
130 start_server(
131 s,
132 runtime_state,
133 cargo_options,
134 watcher_config,
135 tls_options,
136 disable_cors,
137 timeout,
138 )
139 }));
140 })
141 .catch_signals()
142 .handle_shutdown_requests(Duration::from_secs(1))
143 .await;
144
145 Ok(())
146}
147
148pub fn xray_layer<S>(config: &Watch) -> OpenTelemetryLayer<S, Tracer>
149where
150 S: Subscriber + for<'span> LookupSpan<'span>,
151{
152 global::set_text_map_propagator(XrayPropagator::default());
153
154 let builder = stdout::new_pipeline().with_trace_config(
155 trace::config()
156 .with_sampler(trace::Sampler::AlwaysOn)
157 .with_id_generator(trace::XrayIdGenerator::default()),
158 );
159 let tracer = if config.print_traces {
160 builder.install_simple()
161 } else {
162 builder.with_writer(std::io::sink()).install_simple()
163 };
164 tracing_opentelemetry::layer().with_tracer(tracer)
165}
166
167fn build_runtime_state(
168 config: &Watch,
169 manifest_path: &Path,
170 binary_packages: HashSet<String>,
171) -> Result<RuntimeState> {
172 let ip = IpAddr::from_str(&config.invoke_address)
173 .into_diagnostic()
174 .wrap_err("invalid invoke address")?;
175 let (runtime_port, proxy_addr) = if config.tls_options.is_secure() {
176 (
177 config.invoke_port + 1,
178 Some(SocketAddr::from((ip, config.invoke_port))),
179 )
180 } else {
181 (config.invoke_port, None)
182 };
183 let runtime_addr = SocketAddr::from((ip, runtime_port));
184
185 Ok(RuntimeState::new(
186 runtime_addr,
187 proxy_addr,
188 manifest_path.to_path_buf(),
189 config.only_lambda_apis,
190 binary_packages,
191 config.router.clone(),
192 ))
193}
194
195async fn start_server(
196 subsys: SubsystemHandle,
197 runtime_state: RuntimeState,
198 cargo_options: CargoOptions,
199 watcher_config: WatcherConfig,
200 tls_options: TlsOptions,
201 disable_cors: bool,
202 timeout: Option<Timeout>,
203) -> Result<()> {
204 let only_lambda_apis = watcher_config.only_lambda_apis;
205 let init_default_function =
206 runtime_state.is_default_function_enabled() && watcher_config.send_function_init();
207
208 let (runtime_addr, proxy_addr, runtime_url) = runtime_state.addresses();
209
210 let x_request_id = HeaderName::from_static("lambda-runtime-aws-request-id");
211 let req_tx = init_scheduler(
212 &subsys,
213 runtime_state.clone(),
214 cargo_options,
215 watcher_config,
216 );
217
218 let state_ref = Arc::new(runtime_state);
219 let mut app = Router::new()
220 .merge(trigger_router::routes().with_state(state_ref.clone()))
221 .nest(
222 RUNTIME_EMULATOR_PATH,
223 runtime::routes().with_state(state_ref.clone()),
224 )
225 .layer(SetRequestIdLayer::new(
226 x_request_id.clone(),
227 MakeRequestUuid,
228 ))
229 .layer(PropagateRequestIdLayer::new(x_request_id))
230 .layer(Extension(req_tx.clone()))
231 .layer(TraceLayer::new_for_http())
232 .layer(CatchPanicLayer::new());
233 if !disable_cors {
234 app = app.layer(CorsLayer::very_permissive());
235 }
236 if let Some(timeout) = timeout {
237 app = app.layer(TimeoutLayer::new(timeout.duration()));
238 }
239 let app = app.with_state(state_ref);
240
241 if only_lambda_apis {
242 info!("");
243 info!(
244 "the flag --only_lambda_apis is active, the lambda function will not be started by Cargo Lambda"
245 );
246 info!("the lambda function will depend on the following environment variables");
247 info!(
248 "you MUST set these variables in the environment where you're running your function:"
249 );
250 info!("AWS_LAMBDA_FUNCTION_VERSION=1");
251 info!("AWS_LAMBDA_FUNCTION_MEMORY_SIZE=4096");
252 info!("AWS_LAMBDA_RUNTIME_API={}", runtime_url);
253 info!("AWS_LAMBDA_FUNCTION_NAME={DEFAULT_PACKAGE_FUNCTION}");
254 } else {
255 let print_start_info = if init_default_function {
256 req_tx.send(Action::Init).await.is_err()
259 } else {
260 false
261 };
262
263 if print_start_info {
264 info!("");
265 info!("your function will start running when you send the first invoke request");
266 info!("read the invoke guide if you don't know how to continue:");
267 info!("https://www.cargo-lambda.info/commands/invoke.html");
268 }
269 }
270
271 let tls_config = tls_options.server_config()?;
272 let tls_tracker = TaskTracker::new();
273
274 if let (Some(tls_config), Some(proxy_addr)) = (tls_config, proxy_addr) {
275 let tls_tracker = tls_tracker.clone();
276
277 subsys.start(SubsystemBuilder::new("TLS proxy", move |s| async move {
278 start_tls_proxy(s, tls_tracker, tls_config, proxy_addr, runtime_addr).await
279 }));
280 }
281
282 info!(?runtime_addr, "starting Runtime server");
283 let out = axum::serve(
284 TcpListener::bind(runtime_addr).await.into_diagnostic()?,
285 app.into_make_service(),
286 )
287 .with_graceful_shutdown(async move {
288 subsys.on_shutdown_requested().await;
289 })
290 .await;
291
292 if let Err(error) = out {
293 error!(error = ?error, "failed to serve HTTP requests");
294 }
295
296 tls_tracker.close();
297 tls_tracker.wait().await;
298
299 Ok(())
300}
301
302async fn start_tls_proxy(
303 subsys: SubsystemHandle,
304 connection_tracker: TaskTracker,
305 tls_config: ServerConfig,
306 proxy_addr: SocketAddr,
307 runtime_addr: SocketAddr,
308) -> Result<()> {
309 info!(
310 ?proxy_addr,
311 "starting TLS server, use this address to send secure requests to the runtime"
312 );
313
314 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
315
316 let listener = TcpListener::bind(proxy_addr).await.into_diagnostic()?;
317
318 let addr = Arc::new(runtime_addr);
319
320 loop {
321 let (stream, _) = listener.accept().await.into_diagnostic()?;
322 let acceptor = acceptor.clone();
323
324 let addr = addr.clone();
325
326 connection_tracker.spawn({
327 let cancellation_token = subsys.create_cancellation_token();
328 let connection_tracker = connection_tracker.clone();
329
330 async move {
331 let hyper_service = service_fn(move |request: Request<Incoming>| {
332 proxy(connection_tracker.clone(), request, addr.clone())
333 });
334
335 let tls_stream = match acceptor.accept(stream).await {
336 Ok(tls_stream) => tls_stream,
337 Err(e) => {
338 error!(error = ?e, "Failed to accept TLS connection");
339 return Err(e).into_diagnostic();
340 }
341 };
342
343 let builder = Builder::new(TokioExecutor::new());
344 let conn = builder.serve_connection(TokioIo::new(tls_stream), hyper_service);
345
346 pin!(conn);
347
348 let result = tokio::select! {
349 res = conn.as_mut() => res,
350 _ = cancellation_token.cancelled() => {
351 conn.as_mut().graceful_shutdown();
352 conn.await
353 }
354 };
355
356 if let Err(e) = result {
357 error!(error = ?e, "Failed to serve connection");
358 }
359
360 Ok(())
361 }
362 });
363 }
364}
365
366async fn proxy(
367 connection_tracker: TaskTracker,
368 req: Request<Incoming>,
369 addr: Arc<SocketAddr>,
370) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
371 let stream = TcpStream::connect(&*addr).await.unwrap();
372 let io = TokioIo::new(stream);
373
374 let (mut sender, conn) = http1::Builder::new()
375 .preserve_header_case(true)
376 .title_case_headers(true)
377 .handshake(io)
378 .await?;
379
380 connection_tracker.spawn(async move {
381 if let Err(err) = conn.await {
382 println!("Connection failed: {err:?}");
383 }
384 });
385
386 let resp = sender.send_request(req).await?;
387 Ok(resp.map(|b| b.boxed()))
388}