fakecloud 0.13.1

Local AWS cloud emulator — free, open-source LocalStack alternative
//! Kinesis -> Lambda event source mapping poller.
//!
//! Honors:
//! - `FilterCriteria` — non-matching records are dropped (advanced past).
//! - `StartingPosition` — `TRIM_HORIZON` (default), `LATEST`, or
//!   `AT_TIMESTAMP` paired with `StartingPositionTimestamp` to seed
//!   the per-shard checkpoint on first poll.

use std::sync::Arc;
use std::time::Duration;

use base64::Engine;
use chrono::Utc;
use serde_json::{json, Value};

use fakecloud_core::delivery::LambdaDelivery;
use fakecloud_kinesis::state::SharedKinesisState;
use fakecloud_lambda::filter::FilterSet;
use fakecloud_lambda::state::{LambdaInvocation, SharedLambdaState};

#[derive(Clone)]
struct Mapping {
    uuid: String,
    function_arn: String,
    stream_arn: String,
    batch_size: i64,
    filter: FilterSet,
    starting_position: Option<String>,
    starting_position_timestamp: Option<f64>,
}

pub struct KinesisLambdaPoller {
    kinesis_state: SharedKinesisState,
    lambda_state: SharedLambdaState,
    lambda_delivery: Option<Arc<dyn LambdaDelivery>>,
}

impl KinesisLambdaPoller {
    pub fn new(kinesis_state: SharedKinesisState, lambda_state: SharedLambdaState) -> Self {
        Self {
            kinesis_state,
            lambda_state,
            lambda_delivery: None,
        }
    }

    pub fn with_lambda_delivery(mut self, delivery: Arc<dyn LambdaDelivery>) -> Self {
        self.lambda_delivery = Some(delivery);
        self
    }

    pub async fn run(self) {
        let mut interval = tokio::time::interval(Duration::from_millis(500));
        loop {
            interval.tick().await;
            self.poll().await;
        }
    }

    async fn poll(&self) {
        let mappings: Vec<Mapping> = {
            let lambda_accounts = self.lambda_state.read();
            lambda_accounts
                .iter()
                .flat_map(|(_, lambda)| {
                    lambda
                        .event_source_mappings
                        .values()
                        .filter(|m| m.enabled && m.event_source_arn.contains(":kinesis:"))
                        .map(|m| Mapping {
                            uuid: m.uuid.clone(),
                            function_arn: m.function_arn.clone(),
                            stream_arn: m.event_source_arn.clone(),
                            batch_size: m.batch_size,
                            filter: FilterSet::from_strings(m.filter_patterns.iter()),
                            starting_position: m.starting_position.clone(),
                            starting_position_timestamp: m.starting_position_timestamp,
                        })
                        .collect::<Vec<_>>()
                })
                .collect()
        };

        if mappings.is_empty() {
            return;
        }

        for mapping in mappings {
            self.process_mapping(&mapping).await;
        }
    }

    async fn process_mapping(&self, mapping: &Mapping) {
        // Compute per-shard deliveries: snapshot current shard
        // contents, seed missing checkpoints based on StartingPosition,
        // then collect a batch from each shard up to batch_size.
        let deliveries = {
            let mut kinesis_accounts = self.kinesis_state.write();
            let account_id = mapping.stream_arn.split(':').nth(4).unwrap_or("");
            let kinesis = match kinesis_accounts.get_mut(account_id) {
                Some(k) => k,
                None => return,
            };
            let stream_idx = kinesis
                .streams
                .iter()
                .find(|(_, s)| s.stream_arn == mapping.stream_arn)
                .map(|(name, _)| name.clone());
            let Some(stream_name) = stream_idx else {
                return;
            };

            // Initialize per-shard checkpoints once based on
            // StartingPosition. Subsequent polls just read what's already
            // there.
            let init_pairs: Vec<(String, usize)> = {
                let stream = kinesis
                    .streams
                    .get(&stream_name)
                    .expect("stream exists, just looked up");
                stream
                    .shards
                    .iter()
                    .filter_map(|shard| {
                        let key = format!("{}:{}", mapping.uuid, shard.shard_id);
                        if kinesis.lambda_checkpoints.contains_key(&key) {
                            return None;
                        }
                        let init = match mapping
                            .starting_position
                            .as_deref()
                            .unwrap_or("TRIM_HORIZON")
                        {
                            "LATEST" => shard.records.len(),
                            "AT_TIMESTAMP" => {
                                let target = mapping
                                    .starting_position_timestamp
                                    .map(|t| t as i64)
                                    .unwrap_or(0);
                                shard
                                    .records
                                    .iter()
                                    .position(|r| {
                                        r.approximate_arrival_timestamp.timestamp() >= target
                                    })
                                    .unwrap_or(shard.records.len())
                            }
                            _ => 0, // TRIM_HORIZON
                        };
                        Some((shard.shard_id.clone(), init))
                    })
                    .collect()
            };
            for (shard_id, init) in init_pairs {
                kinesis.set_lambda_checkpoint(&mapping.uuid, &shard_id, init);
            }

            let stream = kinesis
                .streams
                .get(&stream_name)
                .expect("stream exists, just looked up");
            let limit = mapping.batch_size.max(1) as usize;
            stream
                .shards
                .iter()
                .filter_map(|shard| {
                    let start = kinesis.lambda_checkpoint(&mapping.uuid, &shard.shard_id);
                    if start >= shard.records.len() {
                        return None;
                    }
                    let end = shard.records.len().min(start.saturating_add(limit));
                    let records = shard.records[start..end].to_vec();
                    Some((shard.shard_id.clone(), end, records))
                })
                .collect::<Vec<_>>()
        };

        for (shard_id, end, records) in deliveries {
            // Build per-record JSON, then split into matched + dropped
            // by FilterCriteria. Dropped records still advance the
            // checkpoint — AWS docs say filtered-out records "do not
            // count toward batch size and are discarded".
            let record_jsons: Vec<Value> = records
                .iter()
                .map(|record| {
                    json!({
                        "awsRegion": "us-east-1",
                        "eventID": format!("{}:{}", shard_id, record.sequence_number),
                        "eventName": "aws:kinesis:record",
                        "eventSource": "aws:kinesis",
                        "eventSourceARN": mapping.stream_arn,
                        "eventVersion": "1.0",
                        "invokeIdentityArn": "arn:aws:iam::123456789012:role/lambda-role",
                        "kinesis": {
                            "approximateArrivalTimestamp": record.approximate_arrival_timestamp.timestamp_millis() as f64 / 1000.0,
                            "data": base64::engine::general_purpose::STANDARD.encode(&record.data),
                            "kinesisSchemaVersion": "1.0",
                            "partitionKey": record.partition_key,
                            "sequenceNumber": record.sequence_number,
                        }
                    })
                })
                .collect();

            let matched: Vec<Value> = if mapping.filter.is_empty() {
                record_jsons
            } else {
                record_jsons
                    .into_iter()
                    .filter(|r| mapping.filter.matches(r))
                    .collect()
            };

            // If the filter dropped every record, advance the
            // checkpoint past them — AWS treats filtered-out records
            // as consumed and never retries them.
            if matched.is_empty() {
                let account_id = mapping.stream_arn.split(':').nth(4).unwrap_or("");
                let mut kinesis_accounts = self.kinesis_state.write();
                let kinesis = kinesis_accounts.get_or_create(account_id);
                kinesis.set_lambda_checkpoint(&mapping.uuid, &shard_id, end);
                continue;
            }

            let payload = json!({ "Records": matched }).to_string();

            let used_real_delivery = self.lambda_delivery.is_some();
            let delivered = if let Some(ref delivery) = self.lambda_delivery {
                match delivery
                    .invoke_lambda(&mapping.function_arn, &payload)
                    .await
                {
                    Ok(_) => true,
                    Err(error) => {
                        tracing::warn!(
                            function_arn = %mapping.function_arn,
                            stream_arn = %mapping.stream_arn,
                            shard_id = %shard_id,
                            error = %error,
                            "Kinesis->Lambda: function invocation failed"
                        );
                        false
                    }
                }
            } else {
                true
            };

            // Only advance the checkpoint after a successful invoke.
            // A failed invoke leaves the records pending so the next
            // poll retries them — matches AWS's at-least-once guarantee.
            if !delivered {
                continue;
            }

            {
                let account_id = mapping.stream_arn.split(':').nth(4).unwrap_or("");
                let mut kinesis_accounts = self.kinesis_state.write();
                let kinesis = kinesis_accounts.get_or_create(account_id);
                kinesis.set_lambda_checkpoint(&mapping.uuid, &shard_id, end);
            }

            if !used_real_delivery {
                let fn_account = mapping.function_arn.split(':').nth(4).unwrap_or("");
                let mut lambda_accounts = self.lambda_state.write();
                let lambda = lambda_accounts.get_or_create(fn_account);
                lambda.invocations.push(LambdaInvocation {
                    function_arn: mapping.function_arn.clone(),
                    payload,
                    timestamp: Utc::now(),
                    source: "aws:kinesis".to_string(),
                });
            }
        }
    }
}