use axum::{
Router, body::Bytes, extract::State, http::StatusCode, response::IntoResponse, routing::post,
};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TelemetryType {
Platform,
Function,
Extension,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BufferingConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_items: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_bytes: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u32>,
}
impl Default for BufferingConfig {
fn default() -> Self {
Self {
max_items: Some(1000),
max_bytes: Some(256 * 1024),
timeout_ms: Some(25),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DestinationConfig {
pub protocol: String,
#[serde(rename = "URI")]
pub uri: String,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct TelemetrySubscription {
pub schema_version: String,
pub types: Vec<TelemetryType>,
pub buffering: BufferingConfig,
pub destination: DestinationConfig,
}
impl TelemetrySubscription {
pub fn platform_events(listener_uri: impl Into<String>) -> Self {
Self {
schema_version: "2022-12-13".to_string(),
types: vec![TelemetryType::Platform],
buffering: BufferingConfig::default(),
destination: DestinationConfig {
protocol: "HTTP".to_string(),
uri: listener_uri.into(),
},
}
}
pub fn all_events(listener_uri: impl Into<String>) -> Self {
Self {
schema_version: "2022-12-13".to_string(),
types: vec![
TelemetryType::Platform,
TelemetryType::Function,
TelemetryType::Extension,
],
buffering: BufferingConfig::default(),
destination: DestinationConfig {
protocol: "HTTP".to_string(),
uri: listener_uri.into(),
},
}
}
pub fn with_buffering(mut self, config: BufferingConfig) -> Self {
self.buffering = config;
self
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum TelemetryEvent {
#[serde(rename = "platform.initStart")]
InitStart {
time: String,
record: InitStartRecord,
},
#[serde(rename = "platform.initRuntimeDone")]
InitRuntimeDone {
time: String,
record: InitRuntimeDoneRecord,
},
#[serde(rename = "platform.start")]
Start {
time: String,
record: StartRecord,
},
#[serde(rename = "platform.runtimeDone")]
RuntimeDone {
time: String,
record: RuntimeDoneRecord,
},
#[serde(rename = "platform.report")]
Report {
time: String,
record: ReportRecord,
},
#[serde(rename = "platform.fault")]
Fault {
time: String,
record: FaultRecord,
},
#[serde(rename = "platform.extension")]
Extension {
time: String,
record: ExtensionRecord,
},
#[serde(rename = "function")]
Function {
time: String,
record: String,
},
#[serde(rename = "extension")]
ExtensionLog {
time: String,
record: String,
},
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitStartRecord {
pub initialization_type: String,
#[serde(default)]
pub phase: String,
#[serde(default)]
pub runtime_version: Option<String>,
#[serde(default)]
pub runtime_version_arn: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitRuntimeDoneRecord {
pub initialization_type: String,
#[serde(default)]
pub status: String,
#[serde(default)]
pub phase: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StartRecord {
pub request_id: String,
#[serde(default)]
pub version: Option<String>,
#[serde(default)]
pub tracing: Option<TracingRecord>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TracingRecord {
#[serde(default)]
pub span_id: Option<String>,
#[serde(rename = "type", default)]
pub trace_type: Option<String>,
#[serde(default)]
pub value: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RuntimeDoneRecord {
pub request_id: String,
pub status: String,
#[serde(default)]
pub metrics: Option<RuntimeMetrics>,
#[serde(default)]
pub tracing: Option<TracingRecord>,
#[serde(default)]
pub spans: Vec<SpanRecord>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RuntimeMetrics {
pub duration_ms: f64,
#[serde(default)]
pub produced_bytes: Option<u64>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SpanRecord {
pub name: String,
pub start: f64,
pub duration_ms: f64,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ReportRecord {
pub request_id: String,
pub status: String,
pub metrics: ReportMetrics,
#[serde(default)]
pub tracing: Option<TracingRecord>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ReportMetrics {
pub duration_ms: f64,
pub billed_duration_ms: u64,
#[serde(rename = "memorySizeMB")]
pub memory_size_mb: u64,
#[serde(rename = "maxMemoryUsedMB")]
pub max_memory_used_mb: u64,
#[serde(default)]
pub init_duration_ms: Option<f64>,
#[serde(default)]
pub restore_duration_ms: Option<f64>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FaultRecord {
#[serde(default)]
pub request_id: Option<String>,
#[serde(default)]
pub fault_message: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ExtensionRecord {
pub name: String,
pub state: String,
#[serde(default)]
pub events: Vec<String>,
}
#[non_exhaustive]
#[derive(Debug)]
pub enum TelemetryError {
Parse(String),
}
impl std::fmt::Display for TelemetryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TelemetryError::Parse(msg) => write!(f, "Parse error: {}", msg),
}
}
}
impl std::error::Error for TelemetryError {}
pub struct TelemetryListener {
port: u16,
event_tx: mpsc::Sender<Vec<TelemetryEvent>>,
cancel_token: CancellationToken,
}
impl TelemetryListener {
pub fn new(
port: u16,
event_tx: mpsc::Sender<Vec<TelemetryEvent>>,
cancel_token: CancellationToken,
) -> Self {
Self {
port,
event_tx,
cancel_token,
}
}
pub fn listener_uri(&self) -> String {
if std::env::var("AWS_LAMBDA_FUNCTION_NAME").is_ok() {
format!("http://sandbox.localdomain:{}", self.port)
} else {
format!("http://127.0.0.1:{}", self.port)
}
}
pub async fn run(self) -> Result<(), std::io::Error> {
let state = ListenerState {
event_tx: self.event_tx,
};
let app = Router::new()
.route("/", post(handle_telemetry))
.with_state(Arc::new(state));
let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
let listener = TcpListener::bind(addr).await?;
tracing::info!(port = self.port, "Telemetry API listener started");
axum::serve(listener, app)
.with_graceful_shutdown(self.cancel_token.cancelled_owned())
.await
}
}
struct ListenerState {
event_tx: mpsc::Sender<Vec<TelemetryEvent>>,
}
async fn handle_telemetry(
State(state): State<Arc<ListenerState>>,
body: Bytes,
) -> impl IntoResponse {
let events: Vec<TelemetryEvent> = match serde_json::from_slice(&body) {
Ok(events) => events,
Err(e) => {
tracing::warn!(error = %e, "Failed to parse telemetry events");
return StatusCode::BAD_REQUEST;
}
};
tracing::debug!(count = events.len(), "Received telemetry events");
match state.event_tx.try_send(events) {
Ok(()) => StatusCode::OK,
Err(mpsc::error::TrySendError::Full(_)) => {
tracing::warn!("Telemetry event channel full");
StatusCode::SERVICE_UNAVAILABLE
}
Err(mpsc::error::TrySendError::Closed(_)) => {
tracing::error!("Telemetry event channel closed");
StatusCode::INTERNAL_SERVER_ERROR
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_telemetry_subscription_platform() {
let sub = TelemetrySubscription::platform_events("http://localhost:9999");
assert_eq!(sub.schema_version, "2022-12-13");
assert_eq!(sub.types, vec![TelemetryType::Platform]);
assert_eq!(sub.destination.uri, "http://localhost:9999");
}
#[test]
fn test_telemetry_subscription_all() {
let sub = TelemetrySubscription::all_events("http://localhost:9999");
assert_eq!(sub.types.len(), 3);
assert!(sub.types.contains(&TelemetryType::Platform));
assert!(sub.types.contains(&TelemetryType::Function));
assert!(sub.types.contains(&TelemetryType::Extension));
}
#[test]
fn test_parse_start_event() {
let json = r#"[{
"type": "platform.start",
"time": "2022-10-12T00:00:00.000Z",
"record": {
"requestId": "test-request-id",
"version": "$LATEST"
}
}]"#;
let events: Vec<TelemetryEvent> = serde_json::from_str(json).unwrap();
assert_eq!(events.len(), 1);
match &events[0] {
TelemetryEvent::Start { record, .. } => {
assert_eq!(record.request_id, "test-request-id");
assert_eq!(record.version, Some("$LATEST".to_string()));
}
_ => panic!("Expected Start event"),
}
}
#[test]
fn test_parse_report_event() {
let json = r#"[{
"type": "platform.report",
"time": "2022-10-12T00:00:00.000Z",
"record": {
"requestId": "test-request-id",
"status": "success",
"metrics": {
"durationMs": 100.5,
"billedDurationMs": 200,
"memorySizeMB": 128,
"maxMemoryUsedMB": 64
}
}
}]"#;
let events: Vec<TelemetryEvent> = serde_json::from_str(json).unwrap();
assert_eq!(events.len(), 1);
match &events[0] {
TelemetryEvent::Report { record, .. } => {
assert_eq!(record.request_id, "test-request-id");
assert_eq!(record.status, "success");
assert_eq!(record.metrics.duration_ms, 100.5);
assert_eq!(record.metrics.billed_duration_ms, 200);
}
_ => panic!("Expected Report event"),
}
}
#[test]
fn test_parse_runtime_done_event() {
let json = r#"[{
"type": "platform.runtimeDone",
"time": "2022-10-12T00:00:00.000Z",
"record": {
"requestId": "test-request-id",
"status": "success",
"metrics": {
"durationMs": 50.0
},
"spans": [
{"name": "responseLatency", "start": 0.0, "durationMs": 10.0}
]
}
}]"#;
let events: Vec<TelemetryEvent> = serde_json::from_str(json).unwrap();
assert_eq!(events.len(), 1);
match &events[0] {
TelemetryEvent::RuntimeDone { record, .. } => {
assert_eq!(record.request_id, "test-request-id");
assert_eq!(record.spans.len(), 1);
assert_eq!(record.spans[0].name, "responseLatency");
}
_ => panic!("Expected RuntimeDone event"),
}
}
#[test]
fn test_parse_init_events() {
let json = r#"[
{
"type": "platform.initStart",
"time": "2022-10-12T00:00:00.000Z",
"record": {
"initializationType": "on-demand",
"phase": "init"
}
},
{
"type": "platform.initRuntimeDone",
"time": "2022-10-12T00:00:01.000Z",
"record": {
"initializationType": "on-demand",
"status": "success",
"phase": "init"
}
}
]"#;
let events: Vec<TelemetryEvent> = serde_json::from_str(json).unwrap();
assert_eq!(events.len(), 2);
match &events[0] {
TelemetryEvent::InitStart { record, .. } => {
assert_eq!(record.initialization_type, "on-demand");
}
_ => panic!("Expected InitStart event"),
}
match &events[1] {
TelemetryEvent::InitRuntimeDone { record, .. } => {
assert_eq!(record.status, "success");
}
_ => panic!("Expected InitRuntimeDone event"),
}
}
#[test]
fn test_parse_function_log() {
let json = r#"[{
"type": "function",
"time": "2022-10-12T00:00:00.000Z",
"record": "Hello from Lambda!"
}]"#;
let events: Vec<TelemetryEvent> = serde_json::from_str(json).unwrap();
assert_eq!(events.len(), 1);
match &events[0] {
TelemetryEvent::Function { record, .. } => {
assert_eq!(record, "Hello from Lambda!");
}
_ => panic!("Expected Function event"),
}
}
#[test]
fn test_listener_uri() {
let (tx, _rx) = mpsc::channel(10);
let listener = TelemetryListener::new(9999, tx, CancellationToken::new());
assert_eq!(listener.listener_uri(), "http://127.0.0.1:9999");
}
#[test]
fn test_telemetry_error_display() {
let err = TelemetryError::Parse("parse error".to_string());
assert!(format!("{}", err).contains("parse error"));
}
#[test]
fn test_buffering_config_default() {
let config = BufferingConfig::default();
assert_eq!(config.max_items, Some(1000));
assert_eq!(config.max_bytes, Some(256 * 1024));
assert_eq!(config.timeout_ms, Some(25));
}
}