1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use tokio::sync::{Mutex, Notify};
10use uuid::Uuid;
11
12pub type ExtensionId = String;
14
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "UPPERCASE")]
18pub enum EventType {
19 Invoke,
21
22 Shutdown,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
31pub enum ShutdownReason {
32 #[serde(rename = "spindown")]
34 Spindown,
35
36 #[serde(rename = "timeout")]
38 Timeout,
39
40 #[serde(rename = "failure")]
42 Failure,
43}
44
45impl<'de> Deserialize<'de> for ShutdownReason {
46 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47 where
48 D: serde::Deserializer<'de>,
49 {
50 let s = String::deserialize(deserializer)?;
51 match s.to_lowercase().as_str() {
52 "spindown" => Ok(ShutdownReason::Spindown),
53 "timeout" => Ok(ShutdownReason::Timeout),
54 "failure" => Ok(ShutdownReason::Failure),
55 _ => Err(serde::de::Error::unknown_variant(
56 &s,
57 &["spindown", "timeout", "failure"],
58 )),
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65#[serde(tag = "eventType")]
66pub enum LifecycleEvent {
67 #[serde(rename = "INVOKE")]
69 Invoke {
70 #[serde(rename = "deadlineMs")]
72 deadline_ms: i64,
73
74 #[serde(rename = "requestId")]
76 request_id: String,
77
78 #[serde(rename = "invokedFunctionArn")]
80 invoked_function_arn: String,
81
82 #[serde(rename = "tracing")]
84 tracing: TracingInfo,
85 },
86
87 #[serde(rename = "SHUTDOWN")]
89 Shutdown {
90 #[serde(rename = "shutdownReason")]
92 shutdown_reason: ShutdownReason,
93
94 #[serde(rename = "deadlineMs")]
96 deadline_ms: i64,
97 },
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct TracingInfo {
103 #[serde(rename = "type")]
105 pub trace_type: String,
106
107 pub value: String,
109}
110
111#[derive(Debug, Clone, Deserialize)]
113pub struct RegisterRequest {
114 pub events: Vec<EventType>,
116}
117
118#[derive(Debug, Clone)]
120pub struct RegisteredExtension {
121 pub id: ExtensionId,
123
124 pub name: String,
126
127 pub events: Vec<EventType>,
129
130 pub registered_at: DateTime<Utc>,
132}
133
134impl RegisteredExtension {
135 pub fn new(name: String, events: Vec<EventType>) -> Self {
137 Self {
138 id: Uuid::new_v4().to_string(),
139 name,
140 events,
141 registered_at: Utc::now(),
142 }
143 }
144
145 pub fn is_subscribed_to(&self, event_type: &EventType) -> bool {
147 self.events.contains(event_type)
148 }
149}
150
151#[derive(Debug)]
156pub struct ExtensionState {
157 extensions: Mutex<HashMap<ExtensionId, RegisteredExtension>>,
159
160 event_queues: Mutex<HashMap<ExtensionId, VecDeque<LifecycleEvent>>>,
162
163 event_notifiers: Mutex<HashMap<ExtensionId, std::sync::Arc<Notify>>>,
165
166 shutdown_acknowledged: Mutex<std::collections::HashSet<ExtensionId>>,
168
169 shutdown_notify: Notify,
171}
172
173impl ExtensionState {
174 pub fn new() -> Self {
176 Self {
177 extensions: Mutex::new(HashMap::new()),
178 event_queues: Mutex::new(HashMap::new()),
179 event_notifiers: Mutex::new(HashMap::new()),
180 shutdown_acknowledged: Mutex::new(std::collections::HashSet::new()),
181 shutdown_notify: Notify::new(),
182 }
183 }
184
185 pub async fn register(&self, name: String, events: Vec<EventType>) -> RegisteredExtension {
196 let extension = RegisteredExtension::new(name, events);
197 let id = extension.id.clone();
198
199 self.extensions
200 .lock()
201 .await
202 .insert(id.clone(), extension.clone());
203 self.event_queues
204 .lock()
205 .await
206 .insert(id.clone(), VecDeque::new());
207 self.event_notifiers
208 .lock()
209 .await
210 .insert(id.clone(), std::sync::Arc::new(Notify::new()));
211
212 extension
213 }
214
215 pub async fn broadcast_event(&self, event: LifecycleEvent) {
221 let event_type = match &event {
222 LifecycleEvent::Invoke { .. } => EventType::Invoke,
223 LifecycleEvent::Shutdown { .. } => EventType::Shutdown,
224 };
225
226 let extensions = self.extensions.lock().await;
227 let mut queues = self.event_queues.lock().await;
228 let notifiers = self.event_notifiers.lock().await;
229
230 for (id, ext) in extensions.iter() {
231 if ext.is_subscribed_to(&event_type) {
232 if let Some(queue) = queues.get_mut(id) {
233 queue.push_back(event.clone());
234 }
235 if let Some(notifier) = notifiers.get(id) {
236 notifier.notify_one();
237 }
238 }
239 }
240 }
241
242 pub async fn next_event(&self, extension_id: &str) -> Option<LifecycleEvent> {
254 loop {
255 {
256 let mut queues = self.event_queues.lock().await;
257 if let Some(queue) = queues.get_mut(extension_id) {
258 if let Some(event) = queue.pop_front() {
259 return Some(event);
260 }
261 } else {
262 return None;
263 }
264 }
265
266 let notifiers = self.event_notifiers.lock().await;
267 if let Some(notifier) = notifiers.get(extension_id) {
268 let notifier = std::sync::Arc::clone(notifier);
269 drop(notifiers);
270 notifier.notified().await;
271 } else {
272 return None;
273 }
274 }
275 }
276
277 pub async fn get_extension(&self, extension_id: &str) -> Option<RegisteredExtension> {
287 self.extensions.lock().await.get(extension_id).cloned()
288 }
289
290 pub async fn get_all_extensions(&self) -> Vec<RegisteredExtension> {
292 self.extensions.lock().await.values().cloned().collect()
293 }
294
295 pub async fn extension_count(&self) -> usize {
297 self.extensions.lock().await.len()
298 }
299
300 pub async fn get_invoke_subscribers(&self) -> Vec<ExtensionId> {
305 self.extensions
306 .lock()
307 .await
308 .values()
309 .filter(|ext| ext.is_subscribed_to(&EventType::Invoke))
310 .map(|ext| ext.id.clone())
311 .collect()
312 }
313
314 pub async fn get_shutdown_subscribers(&self) -> Vec<ExtensionId> {
319 self.extensions
320 .lock()
321 .await
322 .values()
323 .filter(|ext| ext.is_subscribed_to(&EventType::Shutdown))
324 .map(|ext| ext.id.clone())
325 .collect()
326 }
327
328 pub async fn wake_all_extensions(&self) {
332 let notifiers = self.event_notifiers.lock().await;
333 for notifier in notifiers.values() {
334 notifier.notify_one();
335 }
336 }
337
338 #[allow(dead_code)]
343 pub async fn is_queue_empty(&self, extension_id: &str) -> bool {
344 let queues = self.event_queues.lock().await;
345 queues
346 .get(extension_id)
347 .is_none_or(|queue| queue.is_empty())
348 }
349
350 pub async fn mark_shutdown_acknowledged(&self, extension_id: &str) {
355 self.shutdown_acknowledged
356 .lock()
357 .await
358 .insert(extension_id.to_string());
359 self.shutdown_notify.notify_waiters();
360 }
361
362 pub async fn is_shutdown_acknowledged(&self, extension_id: &str) -> bool {
364 self.shutdown_acknowledged
365 .lock()
366 .await
367 .contains(extension_id)
368 }
369
370 pub async fn wait_for_shutdown_acknowledged(&self, extension_ids: &[String]) {
374 loop {
375 let acknowledged = self.shutdown_acknowledged.lock().await;
376 if extension_ids.iter().all(|id| acknowledged.contains(id)) {
377 return;
378 }
379 drop(acknowledged);
380 self.shutdown_notify.notified().await;
381 }
382 }
383
384 pub async fn clear_shutdown_acknowledged(&self) {
388 self.shutdown_acknowledged.lock().await.clear();
389 }
390}
391
392impl Default for ExtensionState {
393 fn default() -> Self {
394 Self::new()
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[test]
403 fn test_shutdown_reason_serializes_lowercase() {
404 assert_eq!(
405 serde_json::to_string(&ShutdownReason::Spindown).unwrap(),
406 "\"spindown\""
407 );
408 assert_eq!(
409 serde_json::to_string(&ShutdownReason::Timeout).unwrap(),
410 "\"timeout\""
411 );
412 assert_eq!(
413 serde_json::to_string(&ShutdownReason::Failure).unwrap(),
414 "\"failure\""
415 );
416 }
417
418 #[test]
419 fn test_shutdown_reason_deserializes_case_insensitive() {
420 assert_eq!(
421 serde_json::from_str::<ShutdownReason>("\"spindown\"").unwrap(),
422 ShutdownReason::Spindown
423 );
424 assert_eq!(
425 serde_json::from_str::<ShutdownReason>("\"SPINDOWN\"").unwrap(),
426 ShutdownReason::Spindown
427 );
428 assert_eq!(
429 serde_json::from_str::<ShutdownReason>("\"Spindown\"").unwrap(),
430 ShutdownReason::Spindown
431 );
432 assert_eq!(
433 serde_json::from_str::<ShutdownReason>("\"SpInDoWn\"").unwrap(),
434 ShutdownReason::Spindown
435 );
436
437 assert_eq!(
438 serde_json::from_str::<ShutdownReason>("\"timeout\"").unwrap(),
439 ShutdownReason::Timeout
440 );
441 assert_eq!(
442 serde_json::from_str::<ShutdownReason>("\"TIMEOUT\"").unwrap(),
443 ShutdownReason::Timeout
444 );
445
446 assert_eq!(
447 serde_json::from_str::<ShutdownReason>("\"failure\"").unwrap(),
448 ShutdownReason::Failure
449 );
450 assert_eq!(
451 serde_json::from_str::<ShutdownReason>("\"FAILURE\"").unwrap(),
452 ShutdownReason::Failure
453 );
454 }
455
456 #[test]
457 fn test_shutdown_reason_deserialize_invalid() {
458 let result = serde_json::from_str::<ShutdownReason>("\"invalid\"");
459 assert!(result.is_err());
460 }
461}