1use axum::{Router, extract::Extension, http::header::HeaderName};
2use bytes::Bytes;
3use cargo_lambda_metadata::{
4 DEFAULT_PACKAGE_FUNCTION,
5 cargo::{
6 CargoMetadata, CargoPackage, filter_binary_targets_from_metadata, kind_bin_filter,
7 selected_bin_filter, watch::Watch,
8 },
9 lambda::Timeout,
10};
11use cargo_lambda_remote::tls::TlsOptions;
12use cargo_options::Run as CargoOptions;
13use http_body_util::{BodyExt, combinators::BoxBody};
14use hyper::{Request, Response, body::Incoming, client::conn::http1, service::service_fn};
15use hyper_util::{
16 rt::{TokioExecutor, TokioIo},
17 server::conn::auto::Builder,
18};
19use miette::{IntoDiagnostic, Result, WrapErr};
20use opentelemetry::{
21 global,
22 sdk::{export::trace::stdout, trace, trace::Tracer},
23};
24use opentelemetry_aws::trace::XrayPropagator;
25use rustls::ServerConfig;
26use std::{
27 collections::{HashMap, HashSet},
28 net::{IpAddr, SocketAddr},
29 path::Path,
30 str::FromStr,
31 sync::Arc,
32};
33use tokio::{
34 net::{TcpListener, TcpStream},
35 pin,
36 time::Duration,
37};
38use tokio_graceful_shutdown::{SubsystemBuilder, SubsystemHandle, Toplevel};
39use tokio_rustls::TlsAcceptor;
40use tokio_util::task::TaskTracker;
41use tower_http::{
42 catch_panic::CatchPanicLayer,
43 cors::CorsLayer,
44 request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
45 timeout::TimeoutLayer,
46 trace::TraceLayer,
47};
48use tracing::{Subscriber, error, info};
49use tracing_opentelemetry::OpenTelemetryLayer;
50use tracing_subscriber::registry::LookupSpan;
51
52mod error;
53mod requests;
54mod runtime;
55
56mod scheduler;
57use scheduler::*;
58mod state;
59use state::*;
60mod trigger_router;
61mod watcher;
62use watcher::WatcherConfig;
63
64use crate::{error::ServerError, requests::Action};
65
66pub(crate) const RUNTIME_EMULATOR_PATH: &str = "/.rt";
67
68#[tracing::instrument(target = "cargo_lambda")]
69pub async fn run(
70 config: &Watch,
71 base_env: &HashMap<String, String>,
72 metadata: &CargoMetadata,
73 color: &str,
74) -> Result<()> {
75 tracing::trace!("watching project");
76
77 let manifest_path = config.manifest_path();
78
79 let mut cargo_options = config.cargo_opts.clone();
80 cargo_options.color = Some(color.into());
81 if cargo_options.manifest_path.is_none() {
82 cargo_options.manifest_path = Some(manifest_path.clone());
83 }
84
85 let base = dunce::canonicalize(".").into_diagnostic()?;
86 let ignore_files = watcher::ignore::discover_files(&base).await;
87
88 let env = config.lambda_environment(base_env).into_diagnostic()?;
89
90 let package_filter = if !cargo_options.packages.is_empty() {
91 let packages = cargo_options.packages.clone();
92 Some(move |p: &&CargoPackage| packages.contains(&p.name))
93 } else {
94 None
95 };
96
97 let binary_filter = if config.cargo_opts.bin.is_empty() {
98 Box::new(kind_bin_filter)
99 } else {
100 selected_bin_filter(config.cargo_opts.bin.clone())
101 };
102
103 let binary_packages =
104 filter_binary_targets_from_metadata(metadata, binary_filter, package_filter);
105
106 if binary_packages.is_empty() {
107 Err(ServerError::NoBinaryPackages)?;
108 }
109
110 let watcher_config = WatcherConfig {
111 base,
112 ignore_files,
113 env,
114 ignore_changes: config.ignore_changes,
115 only_lambda_apis: config.only_lambda_apis,
116 manifest_path: manifest_path.clone(),
117 wait: config.wait,
118 ..Default::default()
119 };
120
121 let runtime_state = build_runtime_state(config, &manifest_path, binary_packages)?;
122
123 let disable_cors = config.disable_cors;
124 let timeout = config.timeout.clone();
125 let tls_options = config.tls_options.clone();
126
127 let _ = Toplevel::new(move |s| async move {
128 s.start(SubsystemBuilder::new("Lambda server", move |s| {
129 start_server(
130 s,
131 runtime_state,
132 cargo_options,
133 watcher_config,
134 tls_options,
135 disable_cors,
136 timeout,
137 )
138 }));
139 })
140 .catch_signals()
141 .handle_shutdown_requests(Duration::from_secs(1))
142 .await;
143
144 Ok(())
145}
146
147pub fn xray_layer<S>(config: &Watch) -> OpenTelemetryLayer<S, Tracer>
148where
149 S: Subscriber + for<'span> LookupSpan<'span>,
150{
151 global::set_text_map_propagator(XrayPropagator::default());
152
153 let builder = stdout::new_pipeline().with_trace_config(
154 trace::config()
155 .with_sampler(trace::Sampler::AlwaysOn)
156 .with_id_generator(trace::XrayIdGenerator::default()),
157 );
158 let tracer = if config.print_traces {
159 builder.install_simple()
160 } else {
161 builder.with_writer(std::io::sink()).install_simple()
162 };
163 tracing_opentelemetry::layer().with_tracer(tracer)
164}
165
166fn build_runtime_state(
167 config: &Watch,
168 manifest_path: &Path,
169 binary_packages: HashSet<String>,
170) -> Result<RuntimeState> {
171 let ip = IpAddr::from_str(&config.invoke_address)
172 .into_diagnostic()
173 .wrap_err("invalid invoke address")?;
174 let (runtime_port, proxy_addr) = if config.tls_options.is_secure() {
175 (
176 config.invoke_port + 1,
177 Some(SocketAddr::from((ip, config.invoke_port))),
178 )
179 } else {
180 (config.invoke_port, None)
181 };
182 let runtime_addr = SocketAddr::from((ip, runtime_port));
183
184 Ok(RuntimeState::new(
185 runtime_addr,
186 proxy_addr,
187 manifest_path.to_path_buf(),
188 binary_packages,
189 config.router.clone(),
190 ))
191}
192
193async fn start_server(
194 subsys: SubsystemHandle,
195 runtime_state: RuntimeState,
196 cargo_options: CargoOptions,
197 watcher_config: WatcherConfig,
198 tls_options: TlsOptions,
199 disable_cors: bool,
200 timeout: Option<Timeout>,
201) -> Result<()> {
202 let only_lambda_apis = watcher_config.only_lambda_apis;
203 let init_default_function =
204 runtime_state.is_default_function_enabled() && watcher_config.send_function_init();
205
206 let (runtime_addr, proxy_addr, runtime_url) = runtime_state.addresses();
207
208 let x_request_id = HeaderName::from_static("lambda-runtime-aws-request-id");
209 let req_tx = init_scheduler(
210 &subsys,
211 runtime_state.clone(),
212 cargo_options,
213 watcher_config,
214 );
215
216 let state_ref = Arc::new(runtime_state);
217 let mut app = Router::new()
218 .merge(trigger_router::routes().with_state(state_ref.clone()))
219 .nest(
220 RUNTIME_EMULATOR_PATH,
221 runtime::routes().with_state(state_ref.clone()),
222 )
223 .layer(SetRequestIdLayer::new(
224 x_request_id.clone(),
225 MakeRequestUuid,
226 ))
227 .layer(PropagateRequestIdLayer::new(x_request_id))
228 .layer(Extension(req_tx.clone()))
229 .layer(TraceLayer::new_for_http())
230 .layer(CatchPanicLayer::new());
231 if !disable_cors {
232 app = app.layer(CorsLayer::very_permissive());
233 }
234 if let Some(timeout) = timeout {
235 app = app.layer(TimeoutLayer::new(timeout.duration()));
236 }
237 let app = app.with_state(state_ref);
238
239 if only_lambda_apis {
240 info!("");
241 info!(
242 "the flag --only_lambda_apis is active, the lambda function will not be started by Cargo Lambda"
243 );
244 info!("the lambda function will depend on the following environment variables");
245 info!(
246 "you MUST set these variables in the environment where you're running your function:"
247 );
248 info!("AWS_LAMBDA_FUNCTION_VERSION=1");
249 info!("AWS_LAMBDA_FUNCTION_MEMORY_SIZE=4096");
250 info!("AWS_LAMBDA_RUNTIME_API={}", runtime_url);
251 info!("AWS_LAMBDA_FUNCTION_NAME={DEFAULT_PACKAGE_FUNCTION}");
252 } else {
253 let print_start_info = if init_default_function {
254 req_tx.send(Action::Init).await.is_err()
257 } else {
258 false
259 };
260
261 if print_start_info {
262 info!("");
263 info!("your function will start running when you send the first invoke request");
264 info!("read the invoke guide if you don't know how to continue:");
265 info!("https://www.cargo-lambda.info/commands/invoke.html");
266 }
267 }
268
269 let tls_config = tls_options.server_config()?;
270 let tls_tracker = TaskTracker::new();
271
272 if let (Some(tls_config), Some(proxy_addr)) = (tls_config, proxy_addr) {
273 let tls_tracker = tls_tracker.clone();
274
275 subsys.start(SubsystemBuilder::new("TLS proxy", move |s| async move {
276 start_tls_proxy(s, tls_tracker, tls_config, proxy_addr, runtime_addr).await
277 }));
278 }
279
280 info!(?runtime_addr, "starting Runtime server");
281 let out = axum::serve(
282 TcpListener::bind(runtime_addr).await.into_diagnostic()?,
283 app.into_make_service(),
284 )
285 .with_graceful_shutdown(async move {
286 subsys.on_shutdown_requested().await;
287 })
288 .await;
289
290 if let Err(error) = out {
291 error!(error = ?error, "failed to serve HTTP requests");
292 }
293
294 tls_tracker.close();
295 tls_tracker.wait().await;
296
297 Ok(())
298}
299
300async fn start_tls_proxy(
301 subsys: SubsystemHandle,
302 connection_tracker: TaskTracker,
303 tls_config: ServerConfig,
304 proxy_addr: SocketAddr,
305 runtime_addr: SocketAddr,
306) -> Result<()> {
307 info!(
308 ?proxy_addr,
309 "starting TLS server, use this address to send secure requests to the runtime"
310 );
311
312 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
313
314 let listener = TcpListener::bind(proxy_addr).await.into_diagnostic()?;
315
316 let addr = Arc::new(runtime_addr);
317
318 loop {
319 let (stream, _) = listener.accept().await.into_diagnostic()?;
320 let acceptor = acceptor.clone();
321
322 let addr = addr.clone();
323
324 connection_tracker.spawn({
325 let cancellation_token = subsys.create_cancellation_token();
326 let connection_tracker = connection_tracker.clone();
327
328 async move {
329 let hyper_service = service_fn(move |request: Request<Incoming>| {
330 proxy(connection_tracker.clone(), request, addr.clone())
331 });
332
333 let tls_stream = match acceptor.accept(stream).await {
334 Ok(tls_stream) => tls_stream,
335 Err(e) => {
336 error!(error = ?e, "Failed to accept TLS connection");
337 return Err(e).into_diagnostic();
338 }
339 };
340
341 let builder = Builder::new(TokioExecutor::new());
342 let conn = builder.serve_connection(TokioIo::new(tls_stream), hyper_service);
343
344 pin!(conn);
345
346 let result = tokio::select! {
347 res = conn.as_mut() => res,
348 _ = cancellation_token.cancelled() => {
349 conn.as_mut().graceful_shutdown();
350 conn.await
351 }
352 };
353
354 if let Err(e) = result {
355 error!(error = ?e, "Failed to serve connection");
356 }
357
358 Ok(())
359 }
360 });
361 }
362}
363
364async fn proxy(
365 connection_tracker: TaskTracker,
366 req: Request<Incoming>,
367 addr: Arc<SocketAddr>,
368) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
369 let stream = TcpStream::connect(&*addr).await.unwrap();
370 let io = TokioIo::new(stream);
371
372 let (mut sender, conn) = http1::Builder::new()
373 .preserve_header_case(true)
374 .title_case_headers(true)
375 .handshake(io)
376 .await?;
377
378 connection_tracker.spawn(async move {
379 if let Err(err) = conn.await {
380 println!("Connection failed: {:?}", err);
381 }
382 });
383
384 let resp = sender.send_request(req).await?;
385 Ok(resp.map(|b| b.boxed()))
386}