use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use tokio::sync::{Mutex, Notify};
use uuid::Uuid;
pub type ExtensionId = String;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum EventType {
Invoke,
Shutdown,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum ShutdownReason {
#[serde(rename = "spindown")]
Spindown,
#[serde(rename = "timeout")]
Timeout,
#[serde(rename = "failure")]
Failure,
}
impl<'de> Deserialize<'de> for ShutdownReason {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.to_lowercase().as_str() {
"spindown" => Ok(ShutdownReason::Spindown),
"timeout" => Ok(ShutdownReason::Timeout),
"failure" => Ok(ShutdownReason::Failure),
_ => Err(serde::de::Error::unknown_variant(
&s,
&["spindown", "timeout", "failure"],
)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "eventType")]
pub enum LifecycleEvent {
#[serde(rename = "INVOKE")]
Invoke {
#[serde(rename = "deadlineMs")]
deadline_ms: i64,
#[serde(rename = "requestId")]
request_id: String,
#[serde(rename = "invokedFunctionArn")]
invoked_function_arn: String,
#[serde(rename = "tracing")]
tracing: TracingInfo,
},
#[serde(rename = "SHUTDOWN")]
Shutdown {
#[serde(rename = "shutdownReason")]
shutdown_reason: ShutdownReason,
#[serde(rename = "deadlineMs")]
deadline_ms: i64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TracingInfo {
#[serde(rename = "type")]
pub trace_type: String,
pub value: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RegisterRequest {
pub events: Vec<EventType>,
}
#[derive(Debug, Clone)]
pub struct RegisteredExtension {
pub id: ExtensionId,
pub name: String,
pub events: Vec<EventType>,
pub registered_at: DateTime<Utc>,
}
impl RegisteredExtension {
pub fn new(name: String, events: Vec<EventType>) -> Self {
Self {
id: Uuid::new_v4().to_string(),
name,
events,
registered_at: Utc::now(),
}
}
pub fn is_subscribed_to(&self, event_type: &EventType) -> bool {
self.events.contains(event_type)
}
}
#[derive(Debug)]
pub struct ExtensionState {
extensions: Mutex<HashMap<ExtensionId, RegisteredExtension>>,
event_queues: Mutex<HashMap<ExtensionId, VecDeque<LifecycleEvent>>>,
event_notifiers: Mutex<HashMap<ExtensionId, std::sync::Arc<Notify>>>,
shutdown_acknowledged: Mutex<std::collections::HashSet<ExtensionId>>,
shutdown_notify: Notify,
}
impl ExtensionState {
pub fn new() -> Self {
Self {
extensions: Mutex::new(HashMap::new()),
event_queues: Mutex::new(HashMap::new()),
event_notifiers: Mutex::new(HashMap::new()),
shutdown_acknowledged: Mutex::new(std::collections::HashSet::new()),
shutdown_notify: Notify::new(),
}
}
pub async fn register(&self, name: String, events: Vec<EventType>) -> RegisteredExtension {
let extension = RegisteredExtension::new(name, events);
let id = extension.id.clone();
self.extensions
.lock()
.await
.insert(id.clone(), extension.clone());
self.event_queues
.lock()
.await
.insert(id.clone(), VecDeque::new());
self.event_notifiers
.lock()
.await
.insert(id.clone(), std::sync::Arc::new(Notify::new()));
extension
}
pub async fn broadcast_event(&self, event: LifecycleEvent) {
let event_type = match &event {
LifecycleEvent::Invoke { .. } => EventType::Invoke,
LifecycleEvent::Shutdown { .. } => EventType::Shutdown,
};
let extensions = self.extensions.lock().await;
let mut queues = self.event_queues.lock().await;
let notifiers = self.event_notifiers.lock().await;
for (id, ext) in extensions.iter() {
if ext.is_subscribed_to(&event_type) {
if let Some(queue) = queues.get_mut(id) {
queue.push_back(event.clone());
}
if let Some(notifier) = notifiers.get(id) {
notifier.notify_one();
}
}
}
}
pub async fn next_event(&self, extension_id: &str) -> Option<LifecycleEvent> {
loop {
{
let mut queues = self.event_queues.lock().await;
if let Some(queue) = queues.get_mut(extension_id) {
if let Some(event) = queue.pop_front() {
return Some(event);
}
} else {
return None;
}
}
let notifiers = self.event_notifiers.lock().await;
if let Some(notifier) = notifiers.get(extension_id) {
let notifier = std::sync::Arc::clone(notifier);
drop(notifiers);
notifier.notified().await;
} else {
return None;
}
}
}
pub async fn get_extension(&self, extension_id: &str) -> Option<RegisteredExtension> {
self.extensions.lock().await.get(extension_id).cloned()
}
pub async fn get_all_extensions(&self) -> Vec<RegisteredExtension> {
self.extensions.lock().await.values().cloned().collect()
}
pub async fn extension_count(&self) -> usize {
self.extensions.lock().await.len()
}
pub async fn get_invoke_subscribers(&self) -> Vec<ExtensionId> {
self.extensions
.lock()
.await
.values()
.filter(|ext| ext.is_subscribed_to(&EventType::Invoke))
.map(|ext| ext.id.clone())
.collect()
}
pub async fn get_shutdown_subscribers(&self) -> Vec<ExtensionId> {
self.extensions
.lock()
.await
.values()
.filter(|ext| ext.is_subscribed_to(&EventType::Shutdown))
.map(|ext| ext.id.clone())
.collect()
}
pub async fn wake_all_extensions(&self) {
let notifiers = self.event_notifiers.lock().await;
for notifier in notifiers.values() {
notifier.notify_one();
}
}
#[allow(dead_code)]
pub async fn is_queue_empty(&self, extension_id: &str) -> bool {
let queues = self.event_queues.lock().await;
queues
.get(extension_id)
.is_none_or(|queue| queue.is_empty())
}
pub async fn mark_shutdown_acknowledged(&self, extension_id: &str) {
self.shutdown_acknowledged
.lock()
.await
.insert(extension_id.to_string());
self.shutdown_notify.notify_waiters();
}
pub async fn is_shutdown_acknowledged(&self, extension_id: &str) -> bool {
self.shutdown_acknowledged
.lock()
.await
.contains(extension_id)
}
pub async fn wait_for_shutdown_acknowledged(&self, extension_ids: &[String]) {
loop {
let acknowledged = self.shutdown_acknowledged.lock().await;
if extension_ids.iter().all(|id| acknowledged.contains(id)) {
return;
}
drop(acknowledged);
self.shutdown_notify.notified().await;
}
}
pub async fn clear_shutdown_acknowledged(&self) {
self.shutdown_acknowledged.lock().await.clear();
}
}
impl Default for ExtensionState {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shutdown_reason_serializes_lowercase() {
assert_eq!(
serde_json::to_string(&ShutdownReason::Spindown).unwrap(),
"\"spindown\""
);
assert_eq!(
serde_json::to_string(&ShutdownReason::Timeout).unwrap(),
"\"timeout\""
);
assert_eq!(
serde_json::to_string(&ShutdownReason::Failure).unwrap(),
"\"failure\""
);
}
#[test]
fn test_shutdown_reason_deserializes_case_insensitive() {
assert_eq!(
serde_json::from_str::<ShutdownReason>("\"spindown\"").unwrap(),
ShutdownReason::Spindown
);
assert_eq!(
serde_json::from_str::<ShutdownReason>("\"SPINDOWN\"").unwrap(),
ShutdownReason::Spindown
);
assert_eq!(
serde_json::from_str::<ShutdownReason>("\"Spindown\"").unwrap(),
ShutdownReason::Spindown
);
assert_eq!(
serde_json::from_str::<ShutdownReason>("\"SpInDoWn\"").unwrap(),
ShutdownReason::Spindown
);
assert_eq!(
serde_json::from_str::<ShutdownReason>("\"timeout\"").unwrap(),
ShutdownReason::Timeout
);
assert_eq!(
serde_json::from_str::<ShutdownReason>("\"TIMEOUT\"").unwrap(),
ShutdownReason::Timeout
);
assert_eq!(
serde_json::from_str::<ShutdownReason>("\"failure\"").unwrap(),
ShutdownReason::Failure
);
assert_eq!(
serde_json::from_str::<ShutdownReason>("\"FAILURE\"").unwrap(),
ShutdownReason::Failure
);
}
#[test]
fn test_shutdown_reason_deserialize_invalid() {
let result = serde_json::from_str::<ShutdownReason>("\"invalid\"");
assert!(result.is_err());
}
}