lambda_simulator/
extension.rs

1//! Lambda Extensions API types and state management.
2//!
3//! Implements the Lambda Extensions API as documented at:
4//! <https://docs.aws.amazon.com/lambda/latest/dg/runtimes-extensions-api.html>
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use tokio::sync::{Mutex, Notify};
10use uuid::Uuid;
11
12/// Extension identifier returned upon registration.
13pub type ExtensionId = String;
14
15/// Lifecycle events that extensions can subscribe to.
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "UPPERCASE")]
18pub enum EventType {
19    /// Invocation event - sent when a new invocation starts.
20    Invoke,
21
22    /// Shutdown event - sent when the Lambda environment is shutting down.
23    Shutdown,
24}
25
26/// Reasons for a shutdown event.
27///
28/// Serializes to lowercase (`spindown`, `timeout`, `failure`) per the Extensions API spec.
29/// Deserializes case-insensitively to handle variations like `SPINDOWN` or `Spindown`.
30#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
31pub enum ShutdownReason {
32    /// Normal spindown of the Lambda environment.
33    #[serde(rename = "spindown")]
34    Spindown,
35
36    /// Timeout occurred during execution.
37    #[serde(rename = "timeout")]
38    Timeout,
39
40    /// Failure in the Lambda environment.
41    #[serde(rename = "failure")]
42    Failure,
43}
44
45impl<'de> Deserialize<'de> for ShutdownReason {
46    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47    where
48        D: serde::Deserializer<'de>,
49    {
50        let s = String::deserialize(deserializer)?;
51        match s.to_lowercase().as_str() {
52            "spindown" => Ok(ShutdownReason::Spindown),
53            "timeout" => Ok(ShutdownReason::Timeout),
54            "failure" => Ok(ShutdownReason::Failure),
55            _ => Err(serde::de::Error::unknown_variant(
56                &s,
57                &["spindown", "timeout", "failure"],
58            )),
59        }
60    }
61}
62
63/// A lifecycle event sent to extensions.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65#[serde(tag = "eventType")]
66pub enum LifecycleEvent {
67    /// Invocation event with request ID.
68    #[serde(rename = "INVOKE")]
69    Invoke {
70        /// Timestamp when the invocation was received.
71        #[serde(rename = "deadlineMs")]
72        deadline_ms: i64,
73
74        /// Request ID for this invocation.
75        #[serde(rename = "requestId")]
76        request_id: String,
77
78        /// ARN of the invoked function.
79        #[serde(rename = "invokedFunctionArn")]
80        invoked_function_arn: String,
81
82        /// X-Ray tracing ID.
83        #[serde(rename = "tracing")]
84        tracing: TracingInfo,
85    },
86
87    /// Shutdown event with reason.
88    #[serde(rename = "SHUTDOWN")]
89    Shutdown {
90        /// Reason for the shutdown.
91        #[serde(rename = "shutdownReason")]
92        shutdown_reason: ShutdownReason,
93
94        /// Timestamp of the shutdown event.
95        #[serde(rename = "deadlineMs")]
96        deadline_ms: i64,
97    },
98}
99
100/// X-Ray tracing information.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct TracingInfo {
103    /// X-Ray trace ID type (always "X-Amzn-Trace-Id" for AWS).
104    #[serde(rename = "type")]
105    pub trace_type: String,
106
107    /// The actual trace ID value.
108    pub value: String,
109}
110
111/// Registration request from an extension.
112#[derive(Debug, Clone, Deserialize)]
113pub struct RegisterRequest {
114    /// Events the extension wants to receive.
115    pub events: Vec<EventType>,
116}
117
118/// Information about a registered extension.
119#[derive(Debug, Clone)]
120pub struct RegisteredExtension {
121    /// Unique identifier for this extension.
122    pub id: ExtensionId,
123
124    /// Name of the extension.
125    pub name: String,
126
127    /// Events this extension subscribed to.
128    pub events: Vec<EventType>,
129
130    /// Timestamp when the extension registered.
131    pub registered_at: DateTime<Utc>,
132}
133
134impl RegisteredExtension {
135    /// Creates a new registered extension.
136    pub fn new(name: String, events: Vec<EventType>) -> Self {
137        Self {
138            id: Uuid::new_v4().to_string(),
139            name,
140            events,
141            registered_at: Utc::now(),
142        }
143    }
144
145    /// Checks if this extension is subscribed to a specific event type.
146    pub fn is_subscribed_to(&self, event_type: &EventType) -> bool {
147        self.events.contains(event_type)
148    }
149}
150
151/// Shared state for extension management.
152///
153/// This manages all registered extensions, their event queues, and
154/// coordination for event delivery.
155#[derive(Debug)]
156pub struct ExtensionState {
157    /// Map of extension IDs to their registration information.
158    extensions: Mutex<HashMap<ExtensionId, RegisteredExtension>>,
159
160    /// Event queues for each extension (extension_id -> events).
161    event_queues: Mutex<HashMap<ExtensionId, VecDeque<LifecycleEvent>>>,
162
163    /// Notifiers for each extension to signal new events.
164    event_notifiers: Mutex<HashMap<ExtensionId, std::sync::Arc<Notify>>>,
165
166    /// Extensions that have acknowledged shutdown by polling /next after receiving SHUTDOWN.
167    shutdown_acknowledged: Mutex<std::collections::HashSet<ExtensionId>>,
168
169    /// Notifier for shutdown acknowledgment changes.
170    shutdown_notify: Notify,
171}
172
173impl ExtensionState {
174    /// Creates a new extension state.
175    pub fn new() -> Self {
176        Self {
177            extensions: Mutex::new(HashMap::new()),
178            event_queues: Mutex::new(HashMap::new()),
179            event_notifiers: Mutex::new(HashMap::new()),
180            shutdown_acknowledged: Mutex::new(std::collections::HashSet::new()),
181            shutdown_notify: Notify::new(),
182        }
183    }
184
185    /// Registers a new extension.
186    ///
187    /// # Arguments
188    ///
189    /// * `name` - Name of the extension (from Lambda-Extension-Name header)
190    /// * `events` - List of events the extension wants to receive
191    ///
192    /// # Returns
193    ///
194    /// The registered extension information including its ID.
195    pub async fn register(&self, name: String, events: Vec<EventType>) -> RegisteredExtension {
196        let extension = RegisteredExtension::new(name, events);
197        let id = extension.id.clone();
198
199        self.extensions
200            .lock()
201            .await
202            .insert(id.clone(), extension.clone());
203        self.event_queues
204            .lock()
205            .await
206            .insert(id.clone(), VecDeque::new());
207        self.event_notifiers
208            .lock()
209            .await
210            .insert(id.clone(), std::sync::Arc::new(Notify::new()));
211
212        extension
213    }
214
215    /// Broadcasts an event to all extensions subscribed to that event type.
216    ///
217    /// # Arguments
218    ///
219    /// * `event` - The lifecycle event to broadcast
220    pub async fn broadcast_event(&self, event: LifecycleEvent) {
221        let event_type = match &event {
222            LifecycleEvent::Invoke { .. } => EventType::Invoke,
223            LifecycleEvent::Shutdown { .. } => EventType::Shutdown,
224        };
225
226        let extensions = self.extensions.lock().await;
227        let mut queues = self.event_queues.lock().await;
228        let notifiers = self.event_notifiers.lock().await;
229
230        for (id, ext) in extensions.iter() {
231            if ext.is_subscribed_to(&event_type) {
232                if let Some(queue) = queues.get_mut(id) {
233                    queue.push_back(event.clone());
234                }
235                if let Some(notifier) = notifiers.get(id) {
236                    notifier.notify_one();
237                }
238            }
239        }
240    }
241
242    /// Waits for and retrieves the next event for a specific extension.
243    ///
244    /// This is a long-poll operation that blocks until an event is available.
245    ///
246    /// # Arguments
247    ///
248    /// * `extension_id` - The ID of the extension requesting the next event
249    ///
250    /// # Returns
251    ///
252    /// The next lifecycle event for this extension.
253    pub async fn next_event(&self, extension_id: &str) -> Option<LifecycleEvent> {
254        loop {
255            {
256                let mut queues = self.event_queues.lock().await;
257                if let Some(queue) = queues.get_mut(extension_id) {
258                    if let Some(event) = queue.pop_front() {
259                        return Some(event);
260                    }
261                } else {
262                    return None;
263                }
264            }
265
266            let notifiers = self.event_notifiers.lock().await;
267            if let Some(notifier) = notifiers.get(extension_id) {
268                let notifier = std::sync::Arc::clone(notifier);
269                drop(notifiers);
270                notifier.notified().await;
271            } else {
272                return None;
273            }
274        }
275    }
276
277    /// Gets information about a registered extension.
278    ///
279    /// # Arguments
280    ///
281    /// * `extension_id` - The ID of the extension
282    ///
283    /// # Returns
284    ///
285    /// The extension information if it exists.
286    pub async fn get_extension(&self, extension_id: &str) -> Option<RegisteredExtension> {
287        self.extensions.lock().await.get(extension_id).cloned()
288    }
289
290    /// Gets all registered extensions.
291    pub async fn get_all_extensions(&self) -> Vec<RegisteredExtension> {
292        self.extensions.lock().await.values().cloned().collect()
293    }
294
295    /// Returns the number of registered extensions.
296    pub async fn extension_count(&self) -> usize {
297        self.extensions.lock().await.len()
298    }
299
300    /// Returns the IDs of all extensions subscribed to INVOKE events.
301    ///
302    /// This is used to determine which extensions need to signal readiness
303    /// after each invocation.
304    pub async fn get_invoke_subscribers(&self) -> Vec<ExtensionId> {
305        self.extensions
306            .lock()
307            .await
308            .values()
309            .filter(|ext| ext.is_subscribed_to(&EventType::Invoke))
310            .map(|ext| ext.id.clone())
311            .collect()
312    }
313
314    /// Returns the IDs of all extensions subscribed to SHUTDOWN events.
315    ///
316    /// This is used to determine which extensions need to receive the
317    /// SHUTDOWN event during graceful shutdown.
318    pub async fn get_shutdown_subscribers(&self) -> Vec<ExtensionId> {
319        self.extensions
320            .lock()
321            .await
322            .values()
323            .filter(|ext| ext.is_subscribed_to(&EventType::Shutdown))
324            .map(|ext| ext.id.clone())
325            .collect()
326    }
327
328    /// Wakes all extensions waiting on /next.
329    ///
330    /// This is used during shutdown to unblock extensions waiting for events.
331    pub async fn wake_all_extensions(&self) {
332        let notifiers = self.event_notifiers.lock().await;
333        for notifier in notifiers.values() {
334            notifier.notify_one();
335        }
336    }
337
338    /// Checks if an extension's event queue is empty.
339    ///
340    /// This is used during shutdown to determine if an extension has
341    /// consumed all events (including the SHUTDOWN event).
342    #[allow(dead_code)]
343    pub async fn is_queue_empty(&self, extension_id: &str) -> bool {
344        let queues = self.event_queues.lock().await;
345        queues
346            .get(extension_id)
347            .is_none_or(|queue| queue.is_empty())
348    }
349
350    /// Marks an extension as having acknowledged shutdown.
351    ///
352    /// This is called when an extension polls `/next` after receiving the
353    /// SHUTDOWN event, signaling it has completed its cleanup work.
354    pub async fn mark_shutdown_acknowledged(&self, extension_id: &str) {
355        self.shutdown_acknowledged
356            .lock()
357            .await
358            .insert(extension_id.to_string());
359        self.shutdown_notify.notify_waiters();
360    }
361
362    /// Checks if an extension has acknowledged shutdown.
363    pub async fn is_shutdown_acknowledged(&self, extension_id: &str) -> bool {
364        self.shutdown_acknowledged
365            .lock()
366            .await
367            .contains(extension_id)
368    }
369
370    /// Waits for all specified extensions to acknowledge shutdown.
371    ///
372    /// Returns when all extensions have polled `/next` after receiving SHUTDOWN.
373    pub async fn wait_for_shutdown_acknowledged(&self, extension_ids: &[String]) {
374        loop {
375            let acknowledged = self.shutdown_acknowledged.lock().await;
376            if extension_ids.iter().all(|id| acknowledged.contains(id)) {
377                return;
378            }
379            drop(acknowledged);
380            self.shutdown_notify.notified().await;
381        }
382    }
383
384    /// Clears shutdown acknowledgment state.
385    ///
386    /// Called when starting a new shutdown sequence.
387    pub async fn clear_shutdown_acknowledged(&self) {
388        self.shutdown_acknowledged.lock().await.clear();
389    }
390}
391
392impl Default for ExtensionState {
393    fn default() -> Self {
394        Self::new()
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[test]
403    fn test_shutdown_reason_serializes_lowercase() {
404        assert_eq!(
405            serde_json::to_string(&ShutdownReason::Spindown).unwrap(),
406            "\"spindown\""
407        );
408        assert_eq!(
409            serde_json::to_string(&ShutdownReason::Timeout).unwrap(),
410            "\"timeout\""
411        );
412        assert_eq!(
413            serde_json::to_string(&ShutdownReason::Failure).unwrap(),
414            "\"failure\""
415        );
416    }
417
418    #[test]
419    fn test_shutdown_reason_deserializes_case_insensitive() {
420        assert_eq!(
421            serde_json::from_str::<ShutdownReason>("\"spindown\"").unwrap(),
422            ShutdownReason::Spindown
423        );
424        assert_eq!(
425            serde_json::from_str::<ShutdownReason>("\"SPINDOWN\"").unwrap(),
426            ShutdownReason::Spindown
427        );
428        assert_eq!(
429            serde_json::from_str::<ShutdownReason>("\"Spindown\"").unwrap(),
430            ShutdownReason::Spindown
431        );
432        assert_eq!(
433            serde_json::from_str::<ShutdownReason>("\"SpInDoWn\"").unwrap(),
434            ShutdownReason::Spindown
435        );
436
437        assert_eq!(
438            serde_json::from_str::<ShutdownReason>("\"timeout\"").unwrap(),
439            ShutdownReason::Timeout
440        );
441        assert_eq!(
442            serde_json::from_str::<ShutdownReason>("\"TIMEOUT\"").unwrap(),
443            ShutdownReason::Timeout
444        );
445
446        assert_eq!(
447            serde_json::from_str::<ShutdownReason>("\"failure\"").unwrap(),
448            ShutdownReason::Failure
449        );
450        assert_eq!(
451            serde_json::from_str::<ShutdownReason>("\"FAILURE\"").unwrap(),
452            ShutdownReason::Failure
453        );
454    }
455
456    #[test]
457    fn test_shutdown_reason_deserialize_invalid() {
458        let result = serde_json::from_str::<ShutdownReason>("\"invalid\"");
459        assert!(result.is_err());
460    }
461}