Skip to main content

awsim_lambda/
state.rs

1use awsim_core::{Body, BodyStore, Snapshottable};
2use dashmap::DashMap;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::atomic::{AtomicU32, Ordering};
6use std::sync::{Arc, OnceLock};
7
8/// Lambda state — per account and region.
9#[derive(Debug, Default)]
10pub struct LambdaState {
11    pub functions: DashMap<String, LambdaFunction>,
12    pub event_source_mappings: DashMap<String, EventSourceMapping>,
13    pub layers: DashMap<String, Vec<LayerVersion>>,
14    /// function_name → FunctionUrlConfig
15    pub url_configs: DashMap<String, FunctionUrlConfig>,
16    /// function_name[:qualifier] → EventInvokeConfig
17    pub event_invoke_configs: DashMap<String, EventInvokeConfig>,
18    pub body_store: OnceLock<Arc<BodyStore>>,
19    /// Per-function active-invocation counter. Used to enforce
20    /// `ReservedConcurrentExecutions` (and AWS's
21    /// `TooManyRequestsException` on overflow). Lazily populated on
22    /// first invoke against the function; entries are never removed
23    /// even when the function is deleted, but the counter resets when
24    /// the LambdaState is dropped or rebuilt from a snapshot.
25    pub active_invocations: DashMap<String, Arc<AtomicU32>>,
26}
27
28/// RAII guard around an active-invocation slot. Increments the
29/// per-function counter on creation; decrements when dropped. Used
30/// by `invoke()` so the slot is always released, even when the
31/// underlying executor panics.
32pub struct InvocationSlot {
33    counter: Arc<AtomicU32>,
34}
35
36impl InvocationSlot {
37    pub fn acquire(state: &LambdaState, function_name: &str) -> Arc<AtomicU32> {
38        let counter = state
39            .active_invocations
40            .entry(function_name.to_string())
41            .or_insert_with(|| Arc::new(AtomicU32::new(0)))
42            .clone();
43        counter.fetch_add(1, Ordering::SeqCst);
44        counter
45    }
46
47    /// Construct from an already-incremented counter handle. The
48    /// caller is responsible for ensuring the counter was bumped via
49    /// `acquire`; the guard handles the decrement.
50    pub fn from_acquired(counter: Arc<AtomicU32>) -> Self {
51        Self { counter }
52    }
53
54    /// Peek the current in-flight count for `function_name` without
55    /// incrementing. Returns 0 when no entry exists.
56    pub fn current(state: &LambdaState, function_name: &str) -> u32 {
57        state
58            .active_invocations
59            .get(function_name)
60            .map(|c| c.load(Ordering::SeqCst))
61            .unwrap_or(0)
62    }
63}
64
65impl Drop for InvocationSlot {
66    fn drop(&mut self) {
67        // Saturating sub so a programming bug that double-drops the
68        // guard can't wrap the counter to u32::MAX.
69        loop {
70            let cur = self.counter.load(Ordering::SeqCst);
71            if cur == 0 {
72                return;
73            }
74            if self
75                .counter
76                .compare_exchange(cur, cur - 1, Ordering::SeqCst, Ordering::SeqCst)
77                .is_ok()
78            {
79                return;
80            }
81        }
82    }
83}
84
85impl LambdaState {
86    pub fn body_store(&self) -> Option<&Arc<BodyStore>> {
87        self.body_store.get()
88    }
89
90    pub fn set_body_store(&self, store: Arc<BodyStore>) {
91        let _ = self.body_store.set(store);
92    }
93}
94
95impl Snapshottable for LambdaState {
96    type Snapshot = LambdaRegionSnapshot;
97
98    fn to_snapshot(&self, account_id: &str, region: &str) -> Self::Snapshot {
99        let functions = self
100            .functions
101            .iter()
102            .map(|entry| {
103                let f = entry.value();
104                FunctionSnapshot {
105                    account_id: account_id.to_string(),
106                    region: region.to_string(),
107                    name: f.name.clone(),
108                    arn: f.arn.clone(),
109                    runtime: f.runtime.clone(),
110                    role: f.role.clone(),
111                    handler: f.handler.clone(),
112                    description: f.description.clone(),
113                    timeout: f.timeout,
114                    memory_size: f.memory_size,
115                    code_sha256: f.code_sha256.clone(),
116                    code_size: f.code_size,
117                    environment: f.environment.clone(),
118                    version: f.version.clone(),
119                    versions: f
120                        .versions
121                        .iter()
122                        .map(|v| FunctionVersionSnapshot {
123                            version: v.version.clone(),
124                            description: v.description.clone(),
125                            code_sha256: v.code_sha256.clone(),
126                            code_size: v.code_size,
127                            last_modified: v.last_modified.clone(),
128                        })
129                        .collect(),
130                    aliases: f
131                        .aliases
132                        .iter()
133                        .map(|(k, a)| {
134                            (
135                                k.clone(),
136                                AliasSnapshot {
137                                    name: a.name.clone(),
138                                    arn: a.arn.clone(),
139                                    function_version: a.function_version.clone(),
140                                    description: a.description.clone(),
141                                    routing_config: a.routing_config.clone(),
142                                },
143                            )
144                        })
145                        .collect(),
146                    last_modified: f.last_modified.clone(),
147                    state: f.state.clone(),
148                    policy_statements: f.policy_statements.clone(),
149                    tags: f.tags.clone(),
150                    architectures: f.architectures.clone(),
151                    ephemeral_storage_size: f.ephemeral_storage_size,
152                    package_type: f.package_type.clone(),
153                    layers: f.layers.clone(),
154                    vpc_config: f.vpc_config.clone(),
155                    dead_letter_config: f.dead_letter_config.clone(),
156                    tracing_config: f.tracing_config.clone(),
157                    kms_key_arn: f.kms_key_arn.clone(),
158                    file_system_configs: f.file_system_configs.clone(),
159                    logging_config: f.logging_config.clone(),
160                    snap_start: f.snap_start.clone(),
161                    image_config: f.image_config.clone(),
162                    recursive_loop: Some(f.recursive_loop.clone()),
163                }
164            })
165            .collect();
166
167        LambdaRegionSnapshot {
168            account_id: account_id.to_string(),
169            region: region.to_string(),
170            functions,
171        }
172    }
173
174    fn from_snapshot(snapshot: Self::Snapshot) -> (String, String, Self) {
175        let state = LambdaState::default();
176        for fs in snapshot.functions {
177            let versions: Vec<FunctionVersion> = fs
178                .versions
179                .into_iter()
180                .map(|v| FunctionVersion {
181                    version: v.version,
182                    description: v.description,
183                    code_sha256: v.code_sha256,
184                    code_size: v.code_size,
185                    code: None,
186                    last_modified: v.last_modified,
187                })
188                .collect();
189
190            let aliases: HashMap<String, Alias> = fs
191                .aliases
192                .into_iter()
193                .map(|(k, a)| {
194                    (
195                        k,
196                        Alias {
197                            name: a.name,
198                            arn: a.arn,
199                            function_version: a.function_version,
200                            description: a.description,
201                            routing_config: a.routing_config,
202                        },
203                    )
204                })
205                .collect();
206
207            let func = LambdaFunction {
208                name: fs.name.clone(),
209                arn: fs.arn,
210                runtime: fs.runtime,
211                role: fs.role,
212                handler: fs.handler,
213                description: fs.description,
214                timeout: fs.timeout,
215                memory_size: fs.memory_size,
216                code_sha256: fs.code_sha256,
217                code_size: fs.code_size,
218                code: None,
219                environment: fs.environment,
220                version: fs.version,
221                versions,
222                aliases,
223                last_modified: fs.last_modified,
224                state: fs.state,
225                invocations: Vec::new(),
226                policy_statements: fs.policy_statements,
227                tags: fs.tags,
228                reserved_concurrent_executions: None,
229                provisioned_concurrency: HashMap::new(),
230                architectures: fs.architectures,
231                ephemeral_storage_size: fs.ephemeral_storage_size,
232                package_type: fs.package_type,
233                layers: fs.layers,
234                vpc_config: fs.vpc_config,
235                dead_letter_config: fs.dead_letter_config,
236                tracing_config: fs.tracing_config,
237                kms_key_arn: fs.kms_key_arn,
238                file_system_configs: fs.file_system_configs,
239                logging_config: fs.logging_config,
240                snap_start: fs.snap_start,
241                image_config: fs.image_config,
242                recursive_loop: fs.recursive_loop.unwrap_or_else(|| "Terminate".to_string()),
243            };
244            state.functions.insert(fs.name, func);
245        }
246        (snapshot.account_id, snapshot.region, state)
247    }
248}
249
250#[derive(Debug, Clone, Default)]
251pub struct EventInvokeConfig {
252    pub function_arn: String,
253    pub maximum_retry_attempts: Option<i32>,
254    pub maximum_event_age_in_seconds: Option<i32>,
255    pub destination_on_success: Option<String>,
256    pub destination_on_failure: Option<String>,
257    pub last_modified: f64,
258}
259
260#[derive(Debug, Clone)]
261pub struct LambdaFunction {
262    pub name: String,
263    pub arn: String,
264    pub runtime: Option<String>,
265    pub role: String,
266    pub handler: Option<String>,
267    pub description: String,
268    pub timeout: u32,
269    pub memory_size: u32,
270    pub code_sha256: String,
271    pub code_size: u64,
272    pub code: Option<Body>,
273    pub environment: HashMap<String, String>,
274    /// Always "$LATEST" for the live function.
275    pub version: String,
276    pub versions: Vec<FunctionVersion>,
277    pub aliases: HashMap<String, Alias>,
278    pub last_modified: String,
279    /// "Active", "Pending", "Failed", etc.
280    pub state: String,
281    /// Invocation records for debugging / admin console.
282    pub invocations: Vec<InvocationRecord>,
283    /// Resource-based policy statements (for AddPermission / RemovePermission).
284    pub policy_statements: HashMap<String, serde_json::Value>,
285    /// Tags attached to this function.
286    pub tags: HashMap<String, String>,
287    /// Reserved concurrent executions ceiling per PutFunctionConcurrency.
288    /// `None` means unreserved — the function shares the account pool.
289    pub reserved_concurrent_executions: Option<u32>,
290    /// Provisioned concurrency configurations keyed by qualifier (alias name
291    /// or function version). Each entry tracks the requested capacity along
292    /// with a simulated state machine that flips IN_PROGRESS -> READY.
293    pub provisioned_concurrency: HashMap<String, ProvisionedConcurrencyConfig>,
294    /// CPU architecture set: `["x86_64"]` or `["arm64"]`. Defaults to
295    /// `["x86_64"]` per AWS.
296    pub architectures: Vec<String>,
297    /// `/tmp` size in MiB. Defaults to 512; AWS allows 512..=10240.
298    pub ephemeral_storage_size: u32,
299    /// "Zip" or "Image". Defaults to "Zip".
300    pub package_type: String,
301    /// Optional layer-version ARNs attached to the function.
302    pub layers: Vec<String>,
303    /// VpcConfig as supplied by the caller plus the synthesized VpcId field.
304    pub vpc_config: Option<serde_json::Value>,
305    /// DeadLetterConfig (`{ TargetArn }`).
306    pub dead_letter_config: Option<serde_json::Value>,
307    /// TracingConfig (`{ Mode }`). Defaults to `{ Mode: "PassThrough" }`.
308    pub tracing_config: Option<serde_json::Value>,
309    /// KMS key ARN used to encrypt environment variables at rest.
310    pub kms_key_arn: Option<String>,
311    /// EFS mounts: array of `{ Arn, LocalMountPath }`.
312    pub file_system_configs: Option<serde_json::Value>,
313    /// LoggingConfig (`{ LogFormat, ApplicationLogLevel, SystemLogLevel, LogGroup }`).
314    pub logging_config: Option<serde_json::Value>,
315    /// SnapStart configuration. Stored as supplied; the serializer adds the
316    /// computed `OptimizationStatus` field.
317    pub snap_start: Option<serde_json::Value>,
318    /// ImageConfig for container-image functions.
319    pub image_config: Option<serde_json::Value>,
320    /// Self-invoke recursion control. AWS Lambda stamps this on the
321    /// function and exposes it via Get / Put FunctionRecursionConfig;
322    /// `Allow` lets the function call itself in a loop, `Terminate`
323    /// (default) stops recursive invocation chains.
324    pub recursive_loop: String,
325}
326
327/// Provisioned concurrency configuration for a single (function, qualifier)
328/// pair. Real Lambda transitions IN_PROGRESS -> READY asynchronously; we
329/// flip immediately because the emulator never has provisioning latency.
330#[derive(Debug, Clone)]
331pub struct ProvisionedConcurrencyConfig {
332    pub qualifier: String,
333    pub requested_provisioned_concurrent_executions: u32,
334    pub allocated_provisioned_concurrent_executions: u32,
335    pub available_provisioned_concurrent_executions: u32,
336    pub status: String, // IN_PROGRESS | READY | FAILED
337    pub status_reason: Option<String>,
338    pub last_modified: String,
339}
340
341/// A function URL configuration.
342#[derive(Debug, Clone)]
343pub struct FunctionUrlConfig {
344    /// Kept for potential admin console use.
345    #[allow(dead_code)]
346    pub function_name: String,
347    pub function_arn: String,
348    pub function_url: String,
349    pub auth_type: String,
350    pub cors: Option<serde_json::Value>,
351    pub creation_time: String,
352    pub last_modified_time: String,
353}
354
355#[derive(Debug, Clone)]
356pub struct FunctionVersion {
357    pub version: String,
358    pub description: String,
359    pub code_sha256: String,
360    pub code_size: u64,
361    pub code: Option<Body>,
362    pub last_modified: String,
363}
364
365#[derive(Debug, Clone)]
366pub struct Alias {
367    pub name: String,
368    pub arn: String,
369    pub function_version: String,
370    pub description: String,
371    /// Traffic-shifting weights: `version → fraction in [0, 1]`. When set,
372    /// invocations through the alias split between `function_version` and
373    /// the listed versions per their weights. Must total ≤ 1; the
374    /// implicit remainder is routed to `function_version`.
375    pub routing_config: HashMap<String, f64>,
376}
377
378/// Stored for debugging and the admin console — fields read externally.
379#[allow(dead_code)]
380#[derive(Debug, Clone)]
381pub struct InvocationRecord {
382    pub invocation_id: String,
383    pub invocation_type: String,
384    pub payload: serde_json::Value,
385    pub response: serde_json::Value,
386    pub status_code: u16,
387    pub timestamp: String,
388}
389
390#[derive(Debug, Clone)]
391pub struct EventSourceMapping {
392    pub uuid: String,
393    pub event_source_arn: String,
394    pub function_arn: String,
395    pub batch_size: u32,
396    /// Stored for potential future use / admin console.
397    #[allow(dead_code)]
398    pub enabled: bool,
399    pub state: String,
400    pub last_modified: String,
401    /// TRIM_HORIZON | LATEST | AT_TIMESTAMP — only meaningful for Kinesis/DDB streams.
402    pub starting_position: Option<String>,
403    pub starting_position_timestamp: Option<f64>,
404    pub maximum_batching_window_in_seconds: u32,
405    pub maximum_record_age_in_seconds: Option<i32>,
406    pub bisect_batch_on_function_error: bool,
407    pub maximum_retry_attempts: Option<i32>,
408    pub parallelization_factor: Option<u32>,
409    pub tumbling_window_in_seconds: Option<u32>,
410    /// Raw FilterCriteria JSON: { "Filters": [{ "Pattern": "..." }, ...] }.
411    pub filter_criteria: Option<serde_json::Value>,
412    /// DestinationConfig.OnFailure.Destination ARN — receives failed batches.
413    pub destination_on_failure: Option<String>,
414    pub function_response_types: Vec<String>,
415    /// Last poll result, surfaced via Get/List for diagnostics.
416    /// "OK", "PROBLEM: <message>", or "No records processed".
417    pub last_processing_result: String,
418    /// Per-shard iterator state for Kinesis/DDB-stream pollers so we don't
419    /// re-deliver records on every tick. Keyed by shard id.
420    pub shard_iterators: HashMap<String, String>,
421    /// Tags attached via `TagResource` against the ESM ARN.
422    pub tags: HashMap<String, String>,
423    /// `ScalingConfig.MaximumConcurrency`. Caps the number of
424    /// concurrent invocations Lambda will start for this event source
425    /// (SQS-only in AWS). AWS allows 2..=1000; `None` means no cap.
426    pub scaling_max_concurrency: Option<u32>,
427}
428
429#[derive(Debug, Serialize, Deserialize)]
430pub struct LambdaStateSnapshot {
431    pub functions: Vec<FunctionSnapshot>,
432}
433
434#[derive(Debug, Serialize, Deserialize)]
435pub struct LambdaRegionSnapshot {
436    pub account_id: String,
437    pub region: String,
438    pub functions: Vec<FunctionSnapshot>,
439}
440
441#[derive(Debug, Serialize, Deserialize)]
442pub struct FunctionSnapshot {
443    pub account_id: String,
444    pub region: String,
445    pub name: String,
446    pub arn: String,
447    pub runtime: Option<String>,
448    pub role: String,
449    pub handler: Option<String>,
450    pub description: String,
451    pub timeout: u32,
452    pub memory_size: u32,
453    pub code_sha256: String,
454    pub code_size: u64,
455    pub environment: HashMap<String, String>,
456    pub version: String,
457    pub versions: Vec<FunctionVersionSnapshot>,
458    pub aliases: HashMap<String, AliasSnapshot>,
459    pub last_modified: String,
460    pub state: String,
461    #[serde(default)]
462    pub policy_statements: HashMap<String, serde_json::Value>,
463    #[serde(default)]
464    pub tags: HashMap<String, String>,
465    #[serde(default = "default_architectures")]
466    pub architectures: Vec<String>,
467    #[serde(default = "default_ephemeral_storage_size")]
468    pub ephemeral_storage_size: u32,
469    #[serde(default = "default_package_type")]
470    pub package_type: String,
471    #[serde(default)]
472    pub layers: Vec<String>,
473    #[serde(default)]
474    pub vpc_config: Option<serde_json::Value>,
475    #[serde(default)]
476    pub dead_letter_config: Option<serde_json::Value>,
477    #[serde(default)]
478    pub tracing_config: Option<serde_json::Value>,
479    #[serde(default)]
480    pub kms_key_arn: Option<String>,
481    #[serde(default)]
482    pub file_system_configs: Option<serde_json::Value>,
483    #[serde(default)]
484    pub logging_config: Option<serde_json::Value>,
485    #[serde(default)]
486    pub snap_start: Option<serde_json::Value>,
487    #[serde(default)]
488    pub image_config: Option<serde_json::Value>,
489    #[serde(default)]
490    pub recursive_loop: Option<String>,
491}
492
493fn default_architectures() -> Vec<String> {
494    vec!["x86_64".to_string()]
495}
496
497fn default_ephemeral_storage_size() -> u32 {
498    512
499}
500
501fn default_package_type() -> String {
502    "Zip".to_string()
503}
504
505#[derive(Debug, Serialize, Deserialize)]
506pub struct FunctionVersionSnapshot {
507    pub version: String,
508    pub description: String,
509    pub code_sha256: String,
510    pub code_size: u64,
511    pub last_modified: String,
512}
513
514#[derive(Debug, Serialize, Deserialize)]
515pub struct AliasSnapshot {
516    pub name: String,
517    pub arn: String,
518    pub function_version: String,
519    pub description: String,
520    #[serde(default)]
521    pub routing_config: HashMap<String, f64>,
522}
523
524#[derive(Debug, Clone)]
525pub struct LayerVersion {
526    pub layer_name: String,
527    pub layer_arn: String,
528    pub version_arn: String,
529    pub version: u64,
530    pub description: String,
531    pub compatible_runtimes: Vec<String>,
532    pub code_sha256: String,
533    pub code_size: u64,
534    /// Raw zip bytes stored for future execution support.
535    #[allow(dead_code)]
536    pub code_data: Option<Vec<u8>>,
537    pub created_date: String,
538    /// Tags attached via `TagResource` against the layer-version ARN.
539    pub tags: HashMap<String, String>,
540}