use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use base64::Engine;
use chrono::{DateTime, Utc};
use serde_json::{json, Value};
use fakecloud_core::delivery::{KmsHook, LambdaDelivery};
use fakecloud_lambda::filter::FilterSet;
use fakecloud_lambda::{LambdaInvocation, SharedLambdaState};
use fakecloud_sqs::SharedSqsState;
#[derive(Clone)]
struct Mapping {
uuid: String,
function_arn: String,
queue_arn: String,
batch_size: i64,
filter: FilterSet,
report_batch_item_failures: bool,
max_batching_window_seconds: i64,
}
#[derive(Default)]
struct BatchingState {
window_started_at: HashMap<String, DateTime<Utc>>,
}
pub struct SqsLambdaPoller {
sqs_state: SharedSqsState,
lambda_state: SharedLambdaState,
lambda_delivery: Option<Arc<dyn LambdaDelivery>>,
kms_hook: Option<Arc<dyn KmsHook>>,
}
impl SqsLambdaPoller {
pub fn new(sqs_state: SharedSqsState, lambda_state: SharedLambdaState) -> Self {
Self {
sqs_state,
lambda_state,
lambda_delivery: None,
kms_hook: None,
}
}
pub fn with_lambda_delivery(mut self, delivery: Arc<dyn LambdaDelivery>) -> Self {
self.lambda_delivery = Some(delivery);
self
}
pub fn with_kms_hook(mut self, hook: Arc<dyn KmsHook>) -> Self {
self.kms_hook = Some(hook);
self
}
pub async fn run(self) {
let mut interval = tokio::time::interval(Duration::from_millis(500));
let mut batching = BatchingState::default();
loop {
interval.tick().await;
self.poll(&mut batching).await;
}
}
async fn poll(&self, batching: &mut BatchingState) {
let mappings: Vec<Mapping> = {
let lambda_accounts = self.lambda_state.read();
lambda_accounts
.iter()
.flat_map(|(_, lambda)| {
lambda
.event_source_mappings
.values()
.filter(|m| m.enabled && m.event_source_arn.contains(":sqs:"))
.map(|m| Mapping {
uuid: m.uuid.clone(),
function_arn: m.function_arn.clone(),
queue_arn: m.event_source_arn.clone(),
batch_size: m.batch_size,
filter: FilterSet::from_strings(m.filter_patterns.iter()),
report_batch_item_failures: m
.function_response_types
.iter()
.any(|s| s == "ReportBatchItemFailures"),
max_batching_window_seconds: m
.maximum_batching_window_in_seconds
.unwrap_or(0),
})
.collect::<Vec<_>>()
})
.collect()
};
if mappings.is_empty() {
return;
}
let now = Utc::now();
for mapping in mappings {
self.process_mapping(&mapping, batching, now).await;
}
}
async fn process_mapping(
&self,
mapping: &Mapping,
batching: &mut BatchingState,
now: DateTime<Utc>,
) {
let (messages, account_id, queue_encrypted) = {
let mut sqs_mas = self.sqs_state.write();
let default_acct = sqs_mas.default_account_id().to_string();
let acct = mapping
.queue_arn
.split(':')
.nth(4)
.unwrap_or(&default_acct)
.to_string();
let sqs = sqs_mas.get_or_create(&acct);
let queue = sqs.queues.values_mut().find(|q| q.arn == mapping.queue_arn);
let queue = match queue {
Some(q) => q,
None => return,
};
let visibility_timeout: i64 = queue
.attributes
.get("VisibilityTimeout")
.and_then(|s| s.parse().ok())
.unwrap_or(30);
let visible_at = now + chrono::Duration::seconds(visibility_timeout);
let encrypted = queue
.attributes
.get("KmsMasterKeyId")
.map(|k| !k.is_empty())
.unwrap_or(false)
|| queue
.attributes
.get("SqsManagedSseEnabled")
.map(String::as_str)
== Some("true");
let mut batch = Vec::new();
let limit = mapping.batch_size.min(10) as usize;
for msg in queue.messages.iter_mut() {
if batch.len() >= limit {
break;
}
if let Some(vis) = msg.visible_at {
if vis > now {
continue;
}
}
msg.visible_at = Some(visible_at);
batch.push(msg.clone());
}
(batch, acct, encrypted)
};
if messages.is_empty() {
batching.window_started_at.remove(&mapping.uuid);
return;
}
let limit = mapping.batch_size.min(10) as usize;
if mapping.max_batching_window_seconds > 0 && messages.len() < limit {
let window_start = batching
.window_started_at
.entry(mapping.uuid.clone())
.or_insert(now);
let elapsed = now.signed_duration_since(*window_start).num_seconds();
if elapsed < mapping.max_batching_window_seconds {
return;
}
}
batching.window_started_at.remove(&mapping.uuid);
let messages = if let (true, Some(hook)) = (queue_encrypted, self.kms_hook.as_ref()) {
messages
.into_iter()
.map(|mut msg| {
if !looks_like_fakecloud_envelope(&msg.body) {
return msg;
}
let mut ctx = HashMap::new();
ctx.insert("aws:sqs:arn".to_string(), mapping.queue_arn.clone());
match hook.decrypt(&account_id, &msg.body, "sqs.amazonaws.com", ctx) {
Ok(bytes) => {
msg.body = String::from_utf8_lossy(&bytes).to_string();
}
Err(err) => {
tracing::warn!(
queue = %mapping.queue_arn,
error = %err,
"SQS->Lambda poller: KMS decrypt failed; delivering opaque body"
);
}
}
msg
})
.collect::<Vec<_>>()
} else {
messages
};
let records: Vec<Value> = messages
.iter()
.map(|msg| {
json!({
"messageId": msg.message_id,
"receiptHandle": msg.receipt_handle,
"body": msg.body,
"attributes": {
"ApproximateReceiveCount": msg.receive_count.to_string(),
"SentTimestamp": msg.sent_timestamp.to_string(),
},
"md5OfBody": msg.md5_of_body,
"eventSource": "aws:sqs",
"eventSourceARN": mapping.queue_arn,
})
})
.collect();
let (matched_records, dropped_ids): (Vec<Value>, Vec<String>) = if mapping.filter.is_empty()
{
(records, Vec::new())
} else {
let mut matched = Vec::new();
let mut dropped = Vec::new();
for (rec, msg) in records.into_iter().zip(messages.iter()) {
if mapping.filter.matches(&rec) {
matched.push(rec);
} else {
dropped.push(msg.message_id.clone());
}
}
(matched, dropped)
};
if !dropped_ids.is_empty() {
let mut sqs_mas = self.sqs_state.write();
let default_acct = sqs_mas.default_account_id().to_string();
let acct = mapping.queue_arn.split(':').nth(4).unwrap_or(&default_acct);
let sqs = sqs_mas.get_or_create(acct);
if let Some(queue) = sqs.queues.values_mut().find(|q| q.arn == mapping.queue_arn) {
queue
.messages
.retain(|m| !dropped_ids.contains(&m.message_id));
}
}
if matched_records.is_empty() {
return;
}
let payload = json!({ "Records": &matched_records }).to_string();
tracing::debug!(
function_arn = %mapping.function_arn,
queue_arn = %mapping.queue_arn,
message_count = matched_records.len(),
dropped_by_filter = dropped_ids.len(),
"SQS->Lambda: delivering messages to function"
);
let invoke_result = if let Some(ref delivery) = self.lambda_delivery {
Some(
delivery
.invoke_lambda(&mapping.function_arn, &payload)
.await,
)
} else {
None
};
let matched_msg_ids: Vec<String> = matched_records
.iter()
.filter_map(|r| {
r.get("messageId")
.and_then(|v| v.as_str())
.map(String::from)
})
.collect();
let (acked_ids, failed_ids) = match &invoke_result {
Some(Ok(body)) if mapping.report_batch_item_failures => {
split_batch_failures(body, &matched_msg_ids)
}
Some(Ok(_)) => (matched_msg_ids.clone(), Vec::new()),
Some(Err(err)) => {
tracing::warn!(
function_arn = %mapping.function_arn,
error = %err,
"SQS->Lambda: function invocation failed; messages will be retried"
);
(Vec::new(), matched_msg_ids.clone())
}
None => (matched_msg_ids.clone(), Vec::new()),
};
if !acked_ids.is_empty() || !failed_ids.is_empty() {
let mut sqs_mas = self.sqs_state.write();
let default_acct = sqs_mas.default_account_id().to_string();
let acct = mapping.queue_arn.split(':').nth(4).unwrap_or(&default_acct);
let sqs = sqs_mas.get_or_create(acct);
if let Some(queue) = sqs.queues.values_mut().find(|q| q.arn == mapping.queue_arn) {
queue
.messages
.retain(|m| !acked_ids.contains(&m.message_id));
for msg in queue.messages.iter_mut() {
if failed_ids.contains(&msg.message_id) {
msg.receive_count = msg.receive_count.saturating_add(1);
msg.visible_at = None;
}
}
}
}
let fn_account = mapping.function_arn.split(':').nth(4).unwrap_or("");
let mut lambda_accounts = self.lambda_state.write();
let lambda = lambda_accounts.get_or_create(fn_account);
lambda.invocations.push(LambdaInvocation {
function_arn: mapping.function_arn.clone(),
payload,
timestamp: now,
source: "aws:sqs".to_string(),
});
}
}
fn looks_like_fakecloud_envelope(body: &str) -> bool {
if body.starts_with("fakecloud-kms:") {
return true;
}
match base64::engine::general_purpose::STANDARD.decode(body) {
Ok(bytes) => bytes.starts_with(&[0x01, 0x02, 0x02, 0x00]),
Err(_) => false,
}
}
fn split_batch_failures(body: &[u8], batch_ids: &[String]) -> (Vec<String>, Vec<String>) {
let parsed: Value = match serde_json::from_slice(body) {
Ok(v) => v,
Err(_) => return (batch_ids.to_vec(), Vec::new()),
};
let raw_failures: Vec<String> = parsed
.get("batchItemFailures")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|f| {
f.get("itemIdentifier")
.and_then(|v| v.as_str())
.map(String::from)
})
.collect()
})
.unwrap_or_default();
let failures: Vec<String> = raw_failures
.into_iter()
.filter(|id| batch_ids.contains(id))
.collect();
if failures.is_empty() {
return (batch_ids.to_vec(), Vec::new());
}
let acked = batch_ids
.iter()
.filter(|id| !failures.contains(id))
.cloned()
.collect();
(acked, failures)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::{BTreeMap, VecDeque};
use std::pin::Pin;
use std::sync::Mutex;
use fakecloud_core::multi_account::MultiAccountState;
use fakecloud_lambda::{EventSourceMapping, LambdaState};
use fakecloud_sqs::{SqsMessage, SqsQueue, SqsState};
use parking_lot::RwLock;
#[test]
fn batch_failures_parse() {
let body = br#"{"batchItemFailures":[{"itemIdentifier":"a"},{"itemIdentifier":"c"}]}"#;
let (acked, failed) =
split_batch_failures(body, &["a".into(), "b".into(), "c".into(), "d".into()]);
assert_eq!(acked, vec!["b".to_string(), "d".to_string()]);
assert_eq!(failed, vec!["a".to_string(), "c".to_string()]);
}
#[test]
fn batch_failures_empty_list_acks_all() {
let body = br#"{"batchItemFailures":[]}"#;
let (acked, failed) = split_batch_failures(body, &["a".into(), "b".into()]);
assert_eq!(acked, vec!["a".to_string(), "b".to_string()]);
assert!(failed.is_empty());
}
#[test]
fn batch_failures_no_field_acks_all() {
let body = br#"{"ok":true}"#;
let (acked, failed) = split_batch_failures(body, &["a".into(), "b".into()]);
assert_eq!(acked, vec!["a".to_string(), "b".to_string()]);
assert!(failed.is_empty());
}
struct RecordingLambda {
invocations: Arc<Mutex<Vec<String>>>,
}
impl LambdaDelivery for RecordingLambda {
fn invoke_lambda(
&self,
_function_arn: &str,
payload: &str,
) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>, String>> + Send>> {
self.invocations.lock().unwrap().push(payload.to_string());
Box::pin(async move { Ok(b"{}".to_vec()) })
}
}
const ACCOUNT: &str = "123456789012";
const REGION: &str = "us-east-1";
fn build_states() -> (
super::SqsLambdaPoller,
super::SharedSqsState,
super::SharedLambdaState,
) {
let queue_arn = format!("arn:aws:sqs:{REGION}:{ACCOUNT}:k2-poll");
let queue_url = format!("http://localhost:4566/{ACCOUNT}/k2-poll");
let mut attrs = BTreeMap::new();
attrs.insert("VisibilityTimeout".to_string(), "30".to_string());
let queue = SqsQueue {
queue_name: "k2-poll".to_string(),
queue_url: queue_url.clone(),
arn: queue_arn.clone(),
created_at: Utc::now(),
messages: VecDeque::from(vec![SqsMessage {
message_id: "msg-1".to_string(),
receipt_handle: None,
md5_of_body: "d41d8cd98f00b204e9800998ecf8427e".to_string(),
body: "hello".to_string(),
sent_timestamp: 0,
attributes: BTreeMap::new(),
message_attributes: BTreeMap::new(),
visible_at: None,
receive_count: 0,
message_group_id: None,
message_dedup_id: None,
created_at: Utc::now(),
sequence_number: None,
}]),
inflight: Vec::new(),
attributes: attrs,
is_fifo: false,
dedup_cache: BTreeMap::new(),
redrive_policy: None,
tags: BTreeMap::new(),
next_sequence_number: 0,
permission_labels: Vec::new(),
receipt_handle_map: BTreeMap::new(),
};
let mut sqs: MultiAccountState<SqsState> =
MultiAccountState::new(ACCOUNT, REGION, "http://localhost:4566");
{
let s = sqs.default_mut();
s.name_to_url.insert("k2-poll".to_string(), queue_url);
s.queues.insert(queue.queue_url.clone(), queue);
}
let sqs_state = Arc::new(RwLock::new(sqs));
let mut lambda: MultiAccountState<LambdaState> =
MultiAccountState::new(ACCOUNT, REGION, "http://localhost:4566");
{
let l = lambda.default_mut();
let mapping = EventSourceMapping {
uuid: "esm-1".to_string(),
function_arn: format!("arn:aws:lambda:{REGION}:{ACCOUNT}:function:k2-fn"),
event_source_arn: queue_arn,
batch_size: 10,
enabled: true,
state: "Enabled".to_string(),
last_modified: Utc::now(),
filter_patterns: Vec::new(),
maximum_batching_window_in_seconds: None,
starting_position: None,
starting_position_timestamp: None,
parallelization_factor: None,
function_response_types: Vec::new(),
kms_key_arn: None,
metrics_config: None,
destination_config: None,
maximum_retry_attempts: None,
maximum_record_age_in_seconds: None,
bisect_batch_on_function_error: None,
tumbling_window_in_seconds: None,
topics: Vec::new(),
queues: Vec::new(),
};
l.event_source_mappings
.insert(mapping.uuid.clone(), mapping);
}
let lambda_state = Arc::new(RwLock::new(lambda));
let poller = super::SqsLambdaPoller::new(sqs_state.clone(), lambda_state.clone());
(poller, sqs_state, lambda_state)
}
#[tokio::test]
async fn poller_marks_picked_messages_invisible_during_invoke() {
let invocations = Arc::new(Mutex::new(Vec::<String>::new()));
let delivery: Arc<dyn LambdaDelivery> = Arc::new(RecordingLambda {
invocations: invocations.clone(),
});
let (poller, sqs_state, _lambda_state) = build_states();
let poller = poller.with_lambda_delivery(delivery);
let mut batching = BatchingState::default();
poller.poll(&mut batching).await;
assert_eq!(
invocations.lock().unwrap().len(),
1,
"first poll should invoke once"
);
{
let mut sqs = sqs_state.write();
let s = sqs.default_mut();
let queue = s.queues.values_mut().next().unwrap();
queue.messages.push_back(SqsMessage {
message_id: "msg-2".to_string(),
receipt_handle: None,
md5_of_body: "d41d8cd98f00b204e9800998ecf8427e".to_string(),
body: "second".to_string(),
sent_timestamp: 0,
attributes: BTreeMap::new(),
message_attributes: BTreeMap::new(),
visible_at: None,
receive_count: 0,
message_group_id: None,
message_dedup_id: None,
created_at: Utc::now(),
sequence_number: None,
});
}
poller.poll(&mut batching).await;
assert_eq!(
invocations.lock().unwrap().len(),
2,
"second poll should invoke once more"
);
let sqs = sqs_state.read();
let queue = sqs.default_ref().queues.values().next().unwrap();
assert!(
queue.messages.is_empty(),
"successful invoke must delete acked messages, found: {:?}",
queue
.messages
.iter()
.map(|m| &m.message_id)
.collect::<Vec<_>>()
);
}
#[tokio::test]
async fn poller_skips_message_during_visibility_window() {
let (poller, sqs_state, _lambda_state) = build_states();
let mapping = Mapping {
uuid: "esm-1".to_string(),
function_arn: format!("arn:aws:lambda:{REGION}:{ACCOUNT}:function:k2-fn"),
queue_arn: format!("arn:aws:sqs:{REGION}:{ACCOUNT}:k2-poll"),
batch_size: 10,
filter: fakecloud_lambda::filter::FilterSet::from_strings(std::iter::empty::<&str>()),
report_batch_item_failures: false,
max_batching_window_seconds: 0,
};
{
let mut sqs = sqs_state.write();
let s = sqs.default_mut();
let queue = s.queues.values_mut().next().unwrap();
queue.messages[0].visible_at = Some(Utc::now() + chrono::Duration::seconds(60));
}
let invocations = Arc::new(Mutex::new(Vec::<String>::new()));
let delivery: Arc<dyn LambdaDelivery> = Arc::new(RecordingLambda {
invocations: invocations.clone(),
});
let poller = poller.with_lambda_delivery(delivery);
let mut batching = BatchingState::default();
poller
.process_mapping(&mapping, &mut batching, Utc::now())
.await;
assert_eq!(
invocations.lock().unwrap().len(),
0,
"message hidden by visibility timeout must not be invoked"
);
let sqs = sqs_state.read();
let queue = sqs.default_ref().queues.values().next().unwrap();
assert_eq!(queue.messages.len(), 1);
}
}