lambda_simulator/
runtime_api.rs

1//! Lambda Runtime API HTTP endpoints implementation.
2//!
3//! Implements the Lambda Runtime API as documented at:
4//! <https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html>
5
6use crate::extension_readiness::ExtensionReadinessTracker;
7use crate::freeze::FreezeState;
8use crate::invocation::{InvocationError, InvocationResponse};
9use crate::simulator::SimulatorConfig;
10use crate::state::{RecordResult, RuntimeState};
11use crate::telemetry::{
12    InitReportMetrics, InitializationType, Phase, PlatformInitReport, PlatformInitRuntimeDone,
13    PlatformReport, PlatformRuntimeDone, PlatformStart, ReportMetrics, RuntimeDoneMetrics,
14    RuntimeStatus, TelemetryEvent, TelemetryEventType, TraceContext,
15};
16use crate::telemetry_state::TelemetryState;
17use axum::{
18    Json, Router,
19    extract::{DefaultBodyLimit, Path, State},
20    http::{HeaderMap, HeaderValue, StatusCode},
21    response::{IntoResponse, Response},
22    routing::{get, post},
23};
24use chrono::Utc;
25use serde_json::{Value, json};
26use std::sync::Arc;
27
28/// Maximum Lambda response payload size (6 MB).
29///
30/// AWS Lambda enforces this limit for synchronous invocations. Responses
31/// exceeding this limit will be rejected with a 413 Payload Too Large error.
32///
33/// See: <https://docs.aws.amazon.com/lambda/latest/dg/gettingstarted-limits.html>
34const MAX_RESPONSE_PAYLOAD_BYTES: usize = 6 * 1024 * 1024;
35
36/// Shared state for Runtime API endpoints.
37#[derive(Clone)]
38pub(crate) struct RuntimeApiState {
39    pub runtime: Arc<RuntimeState>,
40    pub telemetry: Arc<TelemetryState>,
41    pub freeze: Arc<FreezeState>,
42    pub readiness: Arc<ExtensionReadinessTracker>,
43    pub config: Arc<SimulatorConfig>,
44}
45
46/// Creates the Runtime API router.
47///
48/// # Arguments
49///
50/// * `state` - Shared runtime API state
51///
52/// # Returns
53///
54/// An axum router configured with all Runtime API endpoints.
55pub(crate) fn create_runtime_api_router(state: RuntimeApiState) -> Router {
56    Router::new()
57        .route("/2018-06-01/runtime/invocation/next", get(next_invocation))
58        .route(
59            "/2018-06-01/runtime/invocation/{request_id}/response",
60            post(invocation_response),
61        )
62        .route(
63            "/2018-06-01/runtime/invocation/{request_id}/error",
64            post(invocation_error),
65        )
66        .route("/2018-06-01/runtime/init/error", post(init_error))
67        .layer(DefaultBodyLimit::max(MAX_RESPONSE_PAYLOAD_BYTES + 1024))
68        .with_state(state)
69}
70
71/// Helper function to safely insert a header value.
72#[allow(clippy::result_large_err)]
73fn safe_header_insert(
74    headers: &mut HeaderMap,
75    name: &'static str,
76    value: impl AsRef<str>,
77) -> Result<(), Response> {
78    match HeaderValue::from_str(value.as_ref()) {
79        Ok(header_value) => {
80            headers.insert(name, header_value);
81            Ok(())
82        }
83        Err(_) => Err((
84            StatusCode::INTERNAL_SERVER_ERROR,
85            format!("Failed to create header {}", name),
86        )
87            .into_response()),
88    }
89}
90
91/// GET /2018-06-01/runtime/invocation/next
92///
93/// Retrieves the next invocation. This is a long-poll endpoint that blocks
94/// until an invocation is available.
95///
96/// On the first call, this endpoint:
97/// - Marks the runtime as initialized, ending the extension registration phase
98/// - Emits `platform.initRuntimeDone` and `platform.initReport` telemetry events
99///
100/// Process freezing happens after all extensions signal readiness (via polling
101/// their /next endpoint) following the runtime's response submission.
102async fn next_invocation(State(state): State<RuntimeApiState>) -> Response {
103    // Mark initialized on first call to /next - this ends the extension registration phase
104    let was_first_call = !state.runtime.is_initialized().await;
105    state.runtime.mark_initialized().await;
106
107    if was_first_call {
108        tracing::info!(target: "lambda_lifecycle", "🚀 Runtime ready (first /next call)");
109        tracing::info!(target: "lambda_lifecycle", "⏳ Runtime polling /next (waiting for invocation)");
110    }
111
112    // Emit init telemetry on first call to /next
113    if !state.runtime.mark_init_telemetry_emitted() {
114        let now = Utc::now();
115        let init_started_at = state.runtime.init_started_at();
116        let init_duration_ms = (now - init_started_at).num_milliseconds() as f64;
117
118        // Emit platform.initRuntimeDone
119        let init_runtime_done = PlatformInitRuntimeDone {
120            initialization_type: InitializationType::OnDemand,
121            phase: Phase::Init,
122            status: RuntimeStatus::Success,
123            spans: None,
124            tracing: None,
125        };
126
127        let init_runtime_done_event = TelemetryEvent {
128            time: now,
129            event_type: "platform.initRuntimeDone".to_string(),
130            record: serde_json::json!(init_runtime_done),
131        };
132
133        state
134            .telemetry
135            .broadcast_event(init_runtime_done_event, TelemetryEventType::Platform)
136            .await;
137
138        // Emit platform.initReport
139        let init_report = PlatformInitReport {
140            initialization_type: InitializationType::OnDemand,
141            phase: Phase::Init,
142            status: RuntimeStatus::Success,
143            metrics: InitReportMetrics {
144                duration_ms: init_duration_ms,
145            },
146            spans: None,
147            tracing: None,
148        };
149
150        let init_report_event = TelemetryEvent {
151            time: now,
152            event_type: "platform.initReport".to_string(),
153            record: serde_json::json!(init_report),
154        };
155
156        state
157            .telemetry
158            .broadcast_event(init_report_event, TelemetryEventType::Platform)
159            .await;
160
161        tracing::info!(target: "lambda_lifecycle", "📋 platform.initRuntimeDone (duration: {:.1}ms)", init_duration_ms);
162        tracing::info!(target: "lambda_lifecycle", "📋 platform.initReport");
163    }
164
165    let invocation = state.runtime.next_invocation().await;
166
167    tracing::info!(target: "lambda_lifecycle", "📥 Runtime received invocation (request_id: {})", &invocation.aws_request_id[..8]);
168
169    // Emit platform.start when the runtime receives an invocation
170    let trace_context = TraceContext {
171        trace_type: "X-Amzn-Trace-Id".to_string(),
172        value: invocation.trace_id.clone(),
173        span_id: None,
174    };
175
176    let platform_start = PlatformStart {
177        request_id: invocation.aws_request_id.clone(),
178        version: Some(state.config.function_version.clone()),
179        tracing: Some(trace_context),
180    };
181
182    let platform_start_event = TelemetryEvent {
183        time: Utc::now(),
184        event_type: "platform.start".to_string(),
185        record: serde_json::json!(platform_start),
186    };
187
188    state
189        .telemetry
190        .broadcast_event(platform_start_event, TelemetryEventType::Platform)
191        .await;
192
193    let mut headers = HeaderMap::new();
194
195    if let Err(e) = safe_header_insert(
196        &mut headers,
197        "Lambda-Runtime-Aws-Request-Id",
198        &invocation.aws_request_id,
199    ) {
200        return e;
201    }
202
203    if let Err(e) = safe_header_insert(
204        &mut headers,
205        "Lambda-Runtime-Deadline-Ms",
206        invocation.deadline_ms().to_string(),
207    ) {
208        return e;
209    }
210
211    if let Err(e) = safe_header_insert(
212        &mut headers,
213        "Lambda-Runtime-Invoked-Function-Arn",
214        &invocation.invoked_function_arn,
215    ) {
216        return e;
217    }
218
219    if let Err(e) = safe_header_insert(
220        &mut headers,
221        "Lambda-Runtime-Trace-Id",
222        &invocation.trace_id,
223    ) {
224        return e;
225    }
226
227    if let Some(client_context) = &invocation.client_context
228        && let Err(e) = safe_header_insert(
229            &mut headers,
230            "Lambda-Runtime-Client-Context",
231            client_context,
232        )
233    {
234        return e;
235    }
236
237    if let Some(cognito_identity) = &invocation.cognito_identity
238        && let Err(e) = safe_header_insert(
239            &mut headers,
240            "Lambda-Runtime-Cognito-Identity",
241            cognito_identity,
242        )
243    {
244        return e;
245    }
246
247    let body_str = match serde_json::to_string(&invocation.payload) {
248        Ok(s) => s,
249        Err(e) => {
250            tracing::error!("Failed to serialize invocation payload: {}", e);
251            return (
252                StatusCode::INTERNAL_SERVER_ERROR,
253                "Failed to serialize invocation payload",
254            )
255                .into_response();
256        }
257    };
258
259    (StatusCode::OK, headers, body_str).into_response()
260}
261
262/// POST /2018-06-01/runtime/invocation/:request_id/response
263///
264/// Reports a successful invocation response.
265///
266/// After recording the response and emitting `platform.runtimeDone`, this
267/// spawns a background task to wait for all extensions to signal readiness
268/// before emitting `platform.report`. The HTTP response is returned immediately.
269///
270/// Returns 413 if the response payload exceeds 6 MB.
271/// Returns 404 if the request ID is not found.
272/// Returns 400 if a response or error has already been recorded for this invocation.
273async fn invocation_response(
274    State(state): State<RuntimeApiState>,
275    Path(request_id): Path<String>,
276    body: String,
277) -> Response {
278    if body.len() > MAX_RESPONSE_PAYLOAD_BYTES {
279        return (
280            StatusCode::PAYLOAD_TOO_LARGE,
281            format!(
282                "Response payload size ({} bytes) exceeds Lambda's 6 MB limit",
283                body.len()
284            ),
285        )
286            .into_response();
287    }
288
289    // Check if the invocation exists
290    let inv_state = match state.runtime.get_invocation_state(&request_id).await {
291        Some(s) => s,
292        None => {
293            return (
294                StatusCode::NOT_FOUND,
295                format!("Unknown request ID: {}", request_id),
296            )
297                .into_response();
298        }
299    };
300
301    let payload: Value = match serde_json::from_str(&body) {
302        Ok(p) => p,
303        Err(e) => {
304            return (
305                StatusCode::BAD_REQUEST,
306                format!("Invalid JSON payload: {}", e),
307            )
308                .into_response();
309        }
310    };
311
312    let received_at = Utc::now();
313    let response = InvocationResponse {
314        request_id: request_id.clone(),
315        payload,
316        received_at,
317    };
318
319    match state.runtime.record_response(response).await {
320        RecordResult::Recorded => {}
321        RecordResult::AlreadyCompleted => {
322            return (
323                StatusCode::BAD_REQUEST,
324                "Response already submitted for this invocation",
325            )
326                .into_response();
327        }
328        RecordResult::NotFound => {
329            return (StatusCode::NOT_FOUND, "Unknown request ID").into_response();
330        }
331    }
332
333    // Proceed with telemetry emission since we successfully recorded
334    {
335        let duration_ms = if let Some(started_at) = inv_state.started_at {
336            (received_at - started_at).num_milliseconds() as f64
337        } else {
338            0.0
339        };
340
341        let trace_context = TraceContext {
342            trace_type: "X-Amzn-Trace-Id".to_string(),
343            value: inv_state.invocation.trace_id.clone(),
344            span_id: None,
345        };
346
347        let runtime_done = PlatformRuntimeDone {
348            request_id: request_id.clone(),
349            status: RuntimeStatus::Success,
350            metrics: Some(RuntimeDoneMetrics {
351                duration_ms,
352                produced_bytes: None,
353            }),
354            spans: None,
355            tracing: Some(trace_context.clone()),
356        };
357
358        let runtime_done_event = TelemetryEvent {
359            time: Utc::now(),
360            event_type: "platform.runtimeDone".to_string(),
361            record: json!(runtime_done),
362        };
363
364        state
365            .telemetry
366            .broadcast_event(runtime_done_event, TelemetryEventType::Platform)
367            .await;
368
369        tracing::info!(target: "lambda_lifecycle", "✅ platform.runtimeDone (status: success, duration: {:.1}ms)", duration_ms);
370
371        state.readiness.mark_runtime_done(&request_id).await;
372
373        spawn_report_task(
374            state.clone(),
375            request_id.clone(),
376            inv_state.invocation.created_at,
377            received_at,
378            RuntimeStatus::Success,
379            trace_context,
380        );
381    }
382
383    StatusCode::ACCEPTED.into_response()
384}
385
386/// Spawns a background task to wait for extension readiness, emit platform.report,
387/// and freeze the process.
388fn spawn_report_task(
389    state: RuntimeApiState,
390    request_id: String,
391    invocation_created_at: chrono::DateTime<Utc>,
392    runtime_done_at: chrono::DateTime<Utc>,
393    status: RuntimeStatus,
394    trace_context: TraceContext,
395) {
396    let timeout_ms = state.config.extension_ready_timeout_ms;
397    let freeze_epoch = state.freeze.current_epoch();
398
399    tokio::spawn(async move {
400        let timeout = std::time::Duration::from_millis(timeout_ms);
401
402        tokio::select! {
403            _ = state.readiness.wait_for_all_ready(&request_id) => {
404                tracing::debug!("All extensions ready for {}", request_id);
405            }
406            _ = tokio::time::sleep(timeout) => {
407                tracing::warn!(
408                    "Extension readiness timeout for {}; proceeding with report",
409                    request_id
410                );
411            }
412        }
413
414        let extensions_ready_at = Utc::now();
415        let extension_overhead_ms = state
416            .readiness
417            .get_extension_overhead_ms(&request_id)
418            .await
419            .unwrap_or_else(|| (extensions_ready_at - runtime_done_at).num_milliseconds() as f64);
420
421        let total_duration_ms =
422            (extensions_ready_at - invocation_created_at).num_milliseconds() as f64;
423        let billed_duration_ms = total_duration_ms.ceil() as u64;
424
425        let report = PlatformReport {
426            request_id: request_id.clone(),
427            status,
428            metrics: ReportMetrics {
429                duration_ms: total_duration_ms,
430                billed_duration_ms,
431                memory_size_mb: state.config.memory_size_mb as u64,
432                max_memory_used_mb: (state.config.memory_size_mb / 2) as u64,
433                init_duration_ms: None,
434                restore_duration_ms: None,
435                billed_restore_duration_ms: None,
436            },
437            spans: None,
438            tracing: Some(trace_context),
439        };
440
441        if extension_overhead_ms >= 1.0 {
442            tracing::info!(
443                target: "lambda_lifecycle",
444                "📊 platform.report (billed: {}ms, extension overhead: {:.0}ms)",
445                billed_duration_ms,
446                extension_overhead_ms
447            );
448        } else {
449            tracing::info!(
450                target: "lambda_lifecycle",
451                "📊 platform.report (billed: {}ms)",
452                billed_duration_ms
453            );
454        }
455
456        let report_event = TelemetryEvent {
457            time: Utc::now(),
458            event_type: "platform.report".to_string(),
459            record: json!(report),
460        };
461
462        state
463            .telemetry
464            .broadcast_event(report_event, TelemetryEventType::Platform)
465            .await;
466
467        state.readiness.cleanup_invocation(&request_id).await;
468
469        match state.freeze.freeze_at_epoch(freeze_epoch) {
470            Ok(true) => {
471                tracing::info!(target: "lambda_lifecycle", "🧊 Environment frozen (SIGSTOP)");
472            }
473            Ok(false) => {
474                // Epoch mismatch - new work arrived before freeze, which is expected behaviour
475            }
476            Err(e) => {
477                tracing::error!(
478                    "Failed to freeze processes after invocation: {}. \
479                     Freeze simulation may be inaccurate.",
480                    e
481                );
482            }
483        }
484    });
485}
486
487/// POST /2018-06-01/runtime/invocation/:request_id/error
488///
489/// Reports an invocation error.
490///
491/// After recording the error and emitting `platform.runtimeDone`, this
492/// spawns a background task to wait for all extensions to signal readiness
493/// before emitting `platform.report`. The HTTP response is returned immediately.
494///
495/// Returns 404 if the request ID is not found.
496/// Returns 400 if a response or error has already been recorded for this invocation.
497async fn invocation_error(
498    State(state): State<RuntimeApiState>,
499    Path(request_id): Path<String>,
500    body: String,
501) -> Response {
502    // Parse the error payload manually since lambda_runtime doesn't send Content-Type header
503    let error_payload: Value = match serde_json::from_str(&body) {
504        Ok(v) => v,
505        Err(e) => {
506            return (StatusCode::BAD_REQUEST, format!("Invalid JSON: {}", e)).into_response();
507        }
508    };
509    // Check if the invocation exists
510    let inv_state = match state.runtime.get_invocation_state(&request_id).await {
511        Some(s) => s,
512        None => {
513            return (
514                StatusCode::NOT_FOUND,
515                format!("Unknown request ID: {}", request_id),
516            )
517                .into_response();
518        }
519    };
520
521    let error_type = error_payload
522        .get("errorType")
523        .and_then(|v| v.as_str())
524        .unwrap_or("UnknownError")
525        .to_string();
526
527    let error_message = error_payload
528        .get("errorMessage")
529        .and_then(|v| v.as_str())
530        .unwrap_or("Unknown error")
531        .to_string();
532
533    let stack_trace = error_payload
534        .get("stackTrace")
535        .and_then(|v| v.as_array())
536        .map(|arr| {
537            arr.iter()
538                .filter_map(|v| v.as_str().map(|s| s.to_string()))
539                .collect()
540        });
541
542    let received_at = Utc::now();
543    let error = InvocationError {
544        request_id: request_id.clone(),
545        error_type: error_type.clone(),
546        error_message,
547        stack_trace,
548        received_at,
549    };
550
551    match state.runtime.record_error(error).await {
552        RecordResult::Recorded => {}
553        RecordResult::AlreadyCompleted => {
554            return (
555                StatusCode::BAD_REQUEST,
556                "Response already submitted for this invocation",
557            )
558                .into_response();
559        }
560        RecordResult::NotFound => {
561            return (StatusCode::NOT_FOUND, "Unknown request ID").into_response();
562        }
563    }
564
565    // Proceed with telemetry emission since we successfully recorded
566    {
567        let duration_ms = if let Some(started_at) = inv_state.started_at {
568            (received_at - started_at).num_milliseconds() as f64
569        } else {
570            0.0
571        };
572
573        let trace_context = TraceContext {
574            trace_type: "X-Amzn-Trace-Id".to_string(),
575            value: inv_state.invocation.trace_id.clone(),
576            span_id: None,
577        };
578
579        let runtime_done = PlatformRuntimeDone {
580            request_id: request_id.clone(),
581            status: RuntimeStatus::Error,
582            metrics: Some(RuntimeDoneMetrics {
583                duration_ms,
584                produced_bytes: None,
585            }),
586            spans: None,
587            tracing: Some(trace_context.clone()),
588        };
589
590        let runtime_done_event = TelemetryEvent {
591            time: Utc::now(),
592            event_type: "platform.runtimeDone".to_string(),
593            record: json!(runtime_done),
594        };
595
596        state
597            .telemetry
598            .broadcast_event(runtime_done_event, TelemetryEventType::Platform)
599            .await;
600
601        tracing::info!(target: "lambda_lifecycle", "❌ platform.runtimeDone (status: error, type: {})", error_type);
602
603        state.readiness.mark_runtime_done(&request_id).await;
604
605        spawn_report_task(
606            state.clone(),
607            request_id.clone(),
608            inv_state.invocation.created_at,
609            received_at,
610            RuntimeStatus::Error,
611            trace_context,
612        );
613    }
614
615    StatusCode::ACCEPTED.into_response()
616}
617
618/// POST /2018-06-01/runtime/init/error
619///
620/// Reports an initialization error.
621async fn init_error(
622    State(state): State<RuntimeApiState>,
623    Json(error_payload): Json<Value>,
624) -> Response {
625    let error_type = error_payload
626        .get("errorType")
627        .and_then(|v| v.as_str())
628        .unwrap_or("UnknownError");
629
630    let error_message = error_payload
631        .get("errorMessage")
632        .and_then(|v| v.as_str())
633        .unwrap_or("Unknown error");
634
635    let error_string = format!("{}: {}", error_type, error_message);
636    state.runtime.record_init_error(error_string).await;
637
638    StatusCode::OK.into_response()
639}