use crate::extension_readiness::ExtensionReadinessTracker;
use crate::freeze::FreezeState;
use crate::invocation::{InvocationError, InvocationResponse};
use crate::simulator::SimulatorConfig;
use crate::state::{RecordResult, RuntimeState};
use crate::telemetry::{
InitReportMetrics, InitializationType, Phase, PlatformInitReport, PlatformInitRuntimeDone,
PlatformReport, PlatformRuntimeDone, PlatformStart, ReportMetrics, RuntimeDoneMetrics,
RuntimeStatus, TelemetryEvent, TelemetryEventType, TraceContext,
};
use crate::telemetry_state::TelemetryState;
use axum::{
Json, Router,
extract::{DefaultBodyLimit, Path, State},
http::{HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
routing::{get, post},
};
use chrono::Utc;
use serde_json::{Value, json};
use std::sync::Arc;
const MAX_RESPONSE_PAYLOAD_BYTES: usize = 6 * 1024 * 1024;
#[derive(Clone)]
pub(crate) struct RuntimeApiState {
pub runtime: Arc<RuntimeState>,
pub telemetry: Arc<TelemetryState>,
pub freeze: Arc<FreezeState>,
pub readiness: Arc<ExtensionReadinessTracker>,
pub config: Arc<SimulatorConfig>,
}
pub(crate) fn create_runtime_api_router(state: RuntimeApiState) -> Router {
Router::new()
.route("/2018-06-01/runtime/invocation/next", get(next_invocation))
.route(
"/2018-06-01/runtime/invocation/{request_id}/response",
post(invocation_response),
)
.route(
"/2018-06-01/runtime/invocation/{request_id}/error",
post(invocation_error),
)
.route("/2018-06-01/runtime/init/error", post(init_error))
.layer(DefaultBodyLimit::max(MAX_RESPONSE_PAYLOAD_BYTES + 1024))
.with_state(state)
}
#[allow(clippy::result_large_err)]
fn safe_header_insert(
headers: &mut HeaderMap,
name: &'static str,
value: impl AsRef<str>,
) -> Result<(), Response> {
match HeaderValue::from_str(value.as_ref()) {
Ok(header_value) => {
headers.insert(name, header_value);
Ok(())
}
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to create header {}", name),
)
.into_response()),
}
}
async fn next_invocation(State(state): State<RuntimeApiState>) -> Response {
let was_first_call = !state.runtime.is_initialized().await;
state.runtime.mark_initialized().await;
if was_first_call {
tracing::info!(target: "lambda_lifecycle", "🚀 Runtime ready (first /next call)");
tracing::info!(target: "lambda_lifecycle", "⏳ Runtime polling /next (waiting for invocation)");
}
if !state.runtime.mark_init_telemetry_emitted() {
let now = Utc::now();
let init_started_at = state.runtime.init_started_at();
let init_duration_ms = (now - init_started_at).num_milliseconds() as f64;
let init_runtime_done = PlatformInitRuntimeDone {
initialization_type: InitializationType::OnDemand,
phase: Phase::Init,
status: RuntimeStatus::Success,
spans: None,
tracing: None,
};
let init_runtime_done_event = TelemetryEvent {
time: now,
event_type: "platform.initRuntimeDone".to_string(),
record: serde_json::json!(init_runtime_done),
};
state
.telemetry
.broadcast_event(init_runtime_done_event, TelemetryEventType::Platform)
.await;
let init_report = PlatformInitReport {
initialization_type: InitializationType::OnDemand,
phase: Phase::Init,
status: RuntimeStatus::Success,
metrics: InitReportMetrics {
duration_ms: init_duration_ms,
},
spans: None,
tracing: None,
};
let init_report_event = TelemetryEvent {
time: now,
event_type: "platform.initReport".to_string(),
record: serde_json::json!(init_report),
};
state
.telemetry
.broadcast_event(init_report_event, TelemetryEventType::Platform)
.await;
tracing::info!(target: "lambda_lifecycle", "📋 platform.initRuntimeDone (duration: {:.1}ms)", init_duration_ms);
tracing::info!(target: "lambda_lifecycle", "📋 platform.initReport");
}
let invocation = state.runtime.next_invocation().await;
tracing::info!(target: "lambda_lifecycle", "📥 Runtime received invocation (request_id: {})", &invocation.aws_request_id[..8]);
let trace_context = TraceContext {
trace_type: "X-Amzn-Trace-Id".to_string(),
value: invocation.trace_id.clone(),
span_id: None,
};
let platform_start = PlatformStart {
request_id: invocation.aws_request_id.clone(),
version: Some(state.config.function_version.clone()),
tracing: Some(trace_context),
};
let platform_start_event = TelemetryEvent {
time: Utc::now(),
event_type: "platform.start".to_string(),
record: serde_json::json!(platform_start),
};
state
.telemetry
.broadcast_event(platform_start_event, TelemetryEventType::Platform)
.await;
let mut headers = HeaderMap::new();
if let Err(e) = safe_header_insert(
&mut headers,
"Lambda-Runtime-Aws-Request-Id",
&invocation.aws_request_id,
) {
return e;
}
if let Err(e) = safe_header_insert(
&mut headers,
"Lambda-Runtime-Deadline-Ms",
invocation.deadline_ms().to_string(),
) {
return e;
}
if let Err(e) = safe_header_insert(
&mut headers,
"Lambda-Runtime-Invoked-Function-Arn",
&invocation.invoked_function_arn,
) {
return e;
}
if let Err(e) = safe_header_insert(
&mut headers,
"Lambda-Runtime-Trace-Id",
&invocation.trace_id,
) {
return e;
}
if let Some(client_context) = &invocation.client_context
&& let Err(e) = safe_header_insert(
&mut headers,
"Lambda-Runtime-Client-Context",
client_context,
)
{
return e;
}
if let Some(cognito_identity) = &invocation.cognito_identity
&& let Err(e) = safe_header_insert(
&mut headers,
"Lambda-Runtime-Cognito-Identity",
cognito_identity,
)
{
return e;
}
let body_str = match serde_json::to_string(&invocation.payload) {
Ok(s) => s,
Err(e) => {
tracing::error!("Failed to serialize invocation payload: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to serialize invocation payload",
)
.into_response();
}
};
(StatusCode::OK, headers, body_str).into_response()
}
async fn invocation_response(
State(state): State<RuntimeApiState>,
Path(request_id): Path<String>,
body: String,
) -> Response {
if body.len() > MAX_RESPONSE_PAYLOAD_BYTES {
return (
StatusCode::PAYLOAD_TOO_LARGE,
format!(
"Response payload size ({} bytes) exceeds Lambda's 6 MB limit",
body.len()
),
)
.into_response();
}
let inv_state = match state.runtime.get_invocation_state(&request_id).await {
Some(s) => s,
None => {
return (
StatusCode::NOT_FOUND,
format!("Unknown request ID: {}", request_id),
)
.into_response();
}
};
let payload: Value = match serde_json::from_str(&body) {
Ok(p) => p,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Invalid JSON payload: {}", e),
)
.into_response();
}
};
let received_at = Utc::now();
let response = InvocationResponse {
request_id: request_id.clone(),
payload,
received_at,
};
match state.runtime.record_response(response).await {
RecordResult::Recorded => {}
RecordResult::AlreadyCompleted => {
return (
StatusCode::BAD_REQUEST,
"Response already submitted for this invocation",
)
.into_response();
}
RecordResult::NotFound => {
return (StatusCode::NOT_FOUND, "Unknown request ID").into_response();
}
}
{
let duration_ms = if let Some(started_at) = inv_state.started_at {
(received_at - started_at).num_milliseconds() as f64
} else {
0.0
};
let trace_context = TraceContext {
trace_type: "X-Amzn-Trace-Id".to_string(),
value: inv_state.invocation.trace_id.clone(),
span_id: None,
};
let runtime_done = PlatformRuntimeDone {
request_id: request_id.clone(),
status: RuntimeStatus::Success,
metrics: Some(RuntimeDoneMetrics {
duration_ms,
produced_bytes: None,
}),
spans: None,
tracing: Some(trace_context.clone()),
};
let runtime_done_event = TelemetryEvent {
time: Utc::now(),
event_type: "platform.runtimeDone".to_string(),
record: json!(runtime_done),
};
state
.telemetry
.broadcast_event(runtime_done_event, TelemetryEventType::Platform)
.await;
tracing::info!(target: "lambda_lifecycle", "✅ platform.runtimeDone (status: success, duration: {:.1}ms)", duration_ms);
state.readiness.mark_runtime_done(&request_id).await;
spawn_report_task(
state.clone(),
request_id.clone(),
inv_state.invocation.created_at,
received_at,
RuntimeStatus::Success,
trace_context,
);
}
StatusCode::ACCEPTED.into_response()
}
fn spawn_report_task(
state: RuntimeApiState,
request_id: String,
invocation_created_at: chrono::DateTime<Utc>,
runtime_done_at: chrono::DateTime<Utc>,
status: RuntimeStatus,
trace_context: TraceContext,
) {
let timeout_ms = state.config.extension_ready_timeout_ms;
let freeze_epoch = state.freeze.current_epoch();
tokio::spawn(async move {
let timeout = std::time::Duration::from_millis(timeout_ms);
tokio::select! {
_ = state.readiness.wait_for_all_ready(&request_id) => {
tracing::debug!("All extensions ready for {}", request_id);
}
_ = tokio::time::sleep(timeout) => {
tracing::warn!(
"Extension readiness timeout for {}; proceeding with report",
request_id
);
}
}
let extensions_ready_at = Utc::now();
let extension_overhead_ms = state
.readiness
.get_extension_overhead_ms(&request_id)
.await
.unwrap_or_else(|| (extensions_ready_at - runtime_done_at).num_milliseconds() as f64);
let total_duration_ms =
(extensions_ready_at - invocation_created_at).num_milliseconds() as f64;
let billed_duration_ms = total_duration_ms.ceil() as u64;
let report = PlatformReport {
request_id: request_id.clone(),
status,
metrics: ReportMetrics {
duration_ms: total_duration_ms,
billed_duration_ms,
memory_size_mb: state.config.memory_size_mb as u64,
max_memory_used_mb: (state.config.memory_size_mb / 2) as u64,
init_duration_ms: None,
restore_duration_ms: None,
billed_restore_duration_ms: None,
},
spans: None,
tracing: Some(trace_context),
};
if extension_overhead_ms >= 1.0 {
tracing::info!(
target: "lambda_lifecycle",
"📊 platform.report (billed: {}ms, extension overhead: {:.0}ms)",
billed_duration_ms,
extension_overhead_ms
);
} else {
tracing::info!(
target: "lambda_lifecycle",
"📊 platform.report (billed: {}ms)",
billed_duration_ms
);
}
let report_event = TelemetryEvent {
time: Utc::now(),
event_type: "platform.report".to_string(),
record: json!(report),
};
state
.telemetry
.broadcast_event(report_event, TelemetryEventType::Platform)
.await;
state.readiness.cleanup_invocation(&request_id).await;
match state.freeze.freeze_at_epoch(freeze_epoch) {
Ok(true) => {
tracing::info!(target: "lambda_lifecycle", "🧊 Environment frozen (SIGSTOP)");
}
Ok(false) => {
}
Err(e) => {
tracing::error!(
"Failed to freeze processes after invocation: {}. \
Freeze simulation may be inaccurate.",
e
);
}
}
});
}
async fn invocation_error(
State(state): State<RuntimeApiState>,
Path(request_id): Path<String>,
body: String,
) -> Response {
let error_payload: Value = match serde_json::from_str(&body) {
Ok(v) => v,
Err(e) => {
return (StatusCode::BAD_REQUEST, format!("Invalid JSON: {}", e)).into_response();
}
};
let inv_state = match state.runtime.get_invocation_state(&request_id).await {
Some(s) => s,
None => {
return (
StatusCode::NOT_FOUND,
format!("Unknown request ID: {}", request_id),
)
.into_response();
}
};
let error_type = error_payload
.get("errorType")
.and_then(|v| v.as_str())
.unwrap_or("UnknownError")
.to_string();
let error_message = error_payload
.get("errorMessage")
.and_then(|v| v.as_str())
.unwrap_or("Unknown error")
.to_string();
let stack_trace = error_payload
.get("stackTrace")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
});
let received_at = Utc::now();
let error = InvocationError {
request_id: request_id.clone(),
error_type: error_type.clone(),
error_message,
stack_trace,
received_at,
};
match state.runtime.record_error(error).await {
RecordResult::Recorded => {}
RecordResult::AlreadyCompleted => {
return (
StatusCode::BAD_REQUEST,
"Response already submitted for this invocation",
)
.into_response();
}
RecordResult::NotFound => {
return (StatusCode::NOT_FOUND, "Unknown request ID").into_response();
}
}
{
let duration_ms = if let Some(started_at) = inv_state.started_at {
(received_at - started_at).num_milliseconds() as f64
} else {
0.0
};
let trace_context = TraceContext {
trace_type: "X-Amzn-Trace-Id".to_string(),
value: inv_state.invocation.trace_id.clone(),
span_id: None,
};
let runtime_done = PlatformRuntimeDone {
request_id: request_id.clone(),
status: RuntimeStatus::Error,
metrics: Some(RuntimeDoneMetrics {
duration_ms,
produced_bytes: None,
}),
spans: None,
tracing: Some(trace_context.clone()),
};
let runtime_done_event = TelemetryEvent {
time: Utc::now(),
event_type: "platform.runtimeDone".to_string(),
record: json!(runtime_done),
};
state
.telemetry
.broadcast_event(runtime_done_event, TelemetryEventType::Platform)
.await;
tracing::info!(target: "lambda_lifecycle", "❌ platform.runtimeDone (status: error, type: {})", error_type);
state.readiness.mark_runtime_done(&request_id).await;
spawn_report_task(
state.clone(),
request_id.clone(),
inv_state.invocation.created_at,
received_at,
RuntimeStatus::Error,
trace_context,
);
}
StatusCode::ACCEPTED.into_response()
}
async fn init_error(
State(state): State<RuntimeApiState>,
Json(error_payload): Json<Value>,
) -> Response {
let error_type = error_payload
.get("errorType")
.and_then(|v| v.as_str())
.unwrap_or("UnknownError");
let error_message = error_payload
.get("errorMessage")
.and_then(|v| v.as_str())
.unwrap_or("Unknown error");
let error_string = format!("{}: {}", error_type, error_message);
state.runtime.record_init_error(error_string).await;
StatusCode::OK.into_response()
}