cobalt_aws/lambda.rs
1//! A wrapper around the [lambda_runtime](https://github.com/awslabs/aws-lambda-rust-runtime) crate.
2//!
3//! This wrapper provides wrappers to make it easier to write Lambda functions which consume messages from
4//! various events, such as an SQS queue configured with an [event source mapping](https://docs.aws.amazon.com/lambda/latest/dg/invocation-eventsourcemapping.html),
5//! or a [step function](https://docs.aws.amazon.com/step-functions/latest/dg/welcome.html).
6//! It also provides mechanisms for running message handlers locally, without the full AWS Lambda environment.
7
8use anyhow::{Context as _, Result};
9use async_trait::async_trait;
10pub use aws_lambda_events::event::s3::S3Event;
11/// Re-export from [aws_lambda_events::event::sqs::SqsEvent], with [RunnableEventType] implemented.
12pub use aws_lambda_events::event::sqs::SqsEvent;
13use clap::Parser;
14use futures::stream::{self, StreamExt, TryStreamExt};
15use futures::FutureExt;
16use lambda_runtime::{service_fn, LambdaEvent};
17use serde::Serialize;
18use std::ffi::OsString;
19use std::future::Future;
20use std::iter::empty;
21use std::sync::Arc;
22use tracing_subscriber::filter::EnvFilter;
23
24/// Re-export of [lambda_runtime::Error](https://docs.rs/lambda_runtime/latest/lambda_runtime/type.Error.html).
25///
26// We provide this re-export so that the user doesn't need to have lambda_runtime as a direct dependency.
27pub use lambda_runtime::Error;
28
29/// This struct is used to attempt to parse the `AWS_LAMBDA_FUNCTION_NAME` environment variable.
30///
31/// We assume that if this variable is present then we're running in a Lambda function.
32/// https://docs.aws.amazon.com/lambda/latest/dg/configuration-envvars.html
33#[derive(Debug, Parser)]
34struct CheckLambda {
35 #[arg(env)]
36 aws_lambda_function_name: Option<String>,
37}
38
39/// Determine whether the code is being executed within an AWS Lambda.
40///
41/// This function can be used to write binaries that are able to run both locally
42/// or as a Lambda function.
43pub fn running_on_lambda() -> Result<bool> {
44 let check_lambda = CheckLambda::try_parse_from(empty::<OsString>())
45 .context("An error occurred while parsing environment variables for message context.")?;
46 Ok(check_lambda.aws_lambda_function_name.is_some())
47}
48
49/// A trait of the `Context` type, required when using [run_message_handler] to execute the message handler.
50///
51/// All `Context` types must implement the implement the [LambdaContext::from_env] method for their corresponding `Env` type
52/// in order to use the `Context` with [run_message_handler].
53///
54/// When implementing `LambdaContext` you must also specify the `EventType` that the Lambda expects to receive, e.g. [SqsEvent].
55#[async_trait]
56pub trait LambdaContext<Env, EventType>: Sized {
57 /// # Example
58 ///
59 /// ```no_run
60 /// use anyhow::Result;
61 /// use cobalt_aws::lambda::{LambdaContext, SqsEvent};
62 ///
63 /// # use async_trait::async_trait;
64 /// # #[derive(Debug)]
65 /// # pub struct Env {
66 /// # pub greeting: String,
67 /// # }
68 /// #
69 /// /// Shared context we build up before invoking the lambda runner.
70 /// #[derive(Debug)]
71 /// pub struct Context {
72 /// pub greeting: String,
73 /// }
74 ///
75 /// #[async_trait]
76 /// impl LambdaContext<Env, SqsEvent> for Context {
77 /// /// Initialise a shared context object from which will be
78 /// /// passed to all instances of the message handler.
79 /// async fn from_env(env: &Env) -> Result<Context> {
80 /// Ok(Context {
81 /// greeting: env.greeting.clone(),
82 /// })
83 /// }
84 /// }
85 /// ```
86 async fn from_env(env: &Env) -> Result<Self>;
87}
88
89/// Environment variables to configure the lambda handler.
90#[derive(Debug, Parser)]
91struct HandlerEnv {
92 /// How many concurrent records should be processed at once.
93 /// This defaults to 1 to set synchronous processing.
94 /// This value should not exceed the value of `BatchSize`
95 /// configured for the [event source mapping](https://docs.aws.amazon.com/lambda/latest/dg/invocation-eventsourcemapping.html#invocation-eventsourcemapping-batching).
96 #[arg(env, default_value_t = 1)]
97 record_concurrency: usize,
98}
99
100/// A trait that allows a type to be used as the `EventType` trait of a [LambdaContext].
101#[async_trait(?Send)]
102pub trait RunnableEventType<Msg, MsgResult, EventResult> {
103 async fn process<F, Fut, Context>(
104 &self,
105 message_handler: F,
106 ctx: Arc<Context>,
107 ) -> Result<EventResult>
108 where
109 F: Fn(Msg, Arc<Context>) -> Fut,
110 Fut: Future<Output = Result<MsgResult>>,
111 Msg: serde::de::DeserializeOwned;
112}
113
114/// The `StepFunctionEvent` trait allows any deserializable struct
115/// to be used as the EventType for a `LambdaContext`.
116///
117/// # Example
118///
119/// ```compile_fail
120/// use async_trait::async_trait;
121/// use cobalt_aws::lambda::StepFunctionEvent;
122/// use serde::{Deserialize, Serialize};
123///
124/// /// The structure of the event we expect to see at the Step Function.
125/// #[derive(Debug, Deserialize, Serialize, Clone)]
126/// pub struct MyEvent {
127/// pub greeting: String,
128/// }
129///
130/// impl StepFunctionEvent for MyEvent {}
131///
132/// #[async_trait]
133/// impl LambdaContext<Env, StatsEvent> for Context {
134/// ...
135/// }
136/// ```
137pub trait StepFunctionEvent: Clone {}
138
139#[async_trait(?Send)]
140impl<T: StepFunctionEvent, MsgResult> RunnableEventType<T, MsgResult, MsgResult> for T {
141 async fn process<F, Fut, Context>(
142 &self,
143 message_handler: F,
144 ctx: Arc<Context>,
145 ) -> Result<MsgResult>
146 where
147 F: Fn(T, Arc<Context>) -> Fut,
148 Fut: Future<Output = Result<MsgResult>>,
149 {
150 message_handler(self.clone(), ctx).await
151 }
152}
153
154#[async_trait(?Send)]
155impl RunnableEventType<S3Event, (), ()> for S3Event {
156 async fn process<F, Fut, Context>(&self, message_handler: F, ctx: Arc<Context>) -> Result<()>
157 where
158 F: Fn(S3Event, Arc<Context>) -> Fut,
159 Fut: Future<Output = Result<()>>,
160 {
161 Ok(message_handler(self.clone(), ctx).await?)
162 }
163}
164
165// When processing an SqsEvent, we expect the message handler to return (), and the
166// lambda itself to also return ().
167#[async_trait(?Send)]
168impl<Msg> RunnableEventType<Msg, (), ()> for SqsEvent {
169 async fn process<F, Fut, Context>(&self, message_handler: F, ctx: Arc<Context>) -> Result<()>
170 where
171 F: Fn(Msg, Arc<Context>) -> Fut,
172 Fut: Future<Output = Result<()>>,
173 Msg: serde::de::DeserializeOwned,
174 {
175 // Process the records in the event batch concurrently, up to `RECORD_CONCURRENCY`.
176 // If any of them fail, return immediately.
177 let handler_env: HandlerEnv = HandlerEnv::try_parse_from(empty::<OsString>())
178 .context("An error occurred while parsing environment variable for handler")?;
179 stream::iter(&self.records)
180 .map(|record| {
181 let body = record.body.clone();
182 let body = body.with_context(|| "No SqsMessage body".to_string())?;
183 let msg = serde_json::from_str::<Msg>(&body)
184 .with_context(|| format!("Error parsing body into message: {}", body))?;
185 Ok((msg, body))
186 })
187 .try_for_each_concurrent(handler_env.record_concurrency, |(msg, body)| async {
188 message_handler(msg, ctx.clone())
189 .map(move |r| {
190 r.with_context(|| format!("Error running message handler {}", body))
191 })
192 .await?;
193 Ok(())
194 })
195 .await
196 }
197}
198
199/// Executes a message handler against all the messages received in a batch
200/// from an SQS event source mapping.
201///
202/// The `run_message_handler` function takes care of the following tasks:
203///
204/// * Executes the Lambda runtime, using [lambda_runtime](https://github.com/awslabs/aws-lambda-rust-runtime).
205/// * Sets up tracing to ensure all `tracing::<...>!()` calls are JSON formatted for consumption by CloudWatch.
206/// * Processes environment variables and makes them available to your handler
207/// * Initialises a shared context object, which is passed to your handler.
208/// * Deserialises a batch of messages and passes each one to your handler.
209/// * Processes messages concurrently, based on the env var `RECORD_CONCURRENCY` (default: 1)
210///
211/// ## Writing a message handler
212///
213/// To write a message handler, you need to define four elements:
214///
215/// * The `Message` structure, which defines the structure of the messages which will be sent to the SQS
216/// queue, and then forwarded on to your Lambda.
217/// * The `Env` structure, which defines the expected environment variables your Lambda will receive.
218/// * The `Context` structure, which is provided the `Env` structure, and represents the shared state
219/// that will be passed into your message handler. This structure needs to implement the [LambdaContext] trait.
220/// * The `message_handler` function, which accepts a `Message` and a `Context`, and performs the desired actions.
221///
222/// # Example
223///
224/// ```no_run
225/// use anyhow::Result;
226/// use async_trait::async_trait;
227/// use clap::Parser;
228/// use serde::Deserialize;
229/// use std::fmt::Debug;
230/// use std::sync::Arc;
231///
232/// use cobalt_aws::lambda::{run_message_handler, Error, LambdaContext, SqsEvent};
233///
234/// #[tokio::main]
235/// async fn main() -> Result<(), Error> {
236/// run_message_handler(message_handler).await
237/// }
238///
239/// /// The structure of the messages we expect to see on the queue.
240/// #[derive(Debug, Deserialize)]
241/// pub struct Message {
242/// pub target: String,
243/// }
244///
245/// /// Configuration we receive from environment variables.
246/// ///
247/// /// Note: all fields should be tagged with the `#[arg(env)]` attribute.
248/// #[derive(Debug, Parser)]
249/// pub struct Env {
250/// #[arg(env)]
251/// pub greeting: String,
252/// }
253///
254/// /// Shared context we build up before invoking the lambda runner.
255/// #[derive(Debug)]
256/// pub struct Context {
257/// pub greeting: String,
258/// }
259///
260/// #[async_trait]
261/// impl LambdaContext<Env, SqsEvent> for Context {
262/// /// Initialise a shared context object from which will be
263/// /// passed to all instances of the message handler.
264/// async fn from_env(env: &Env) -> Result<Context> {
265/// Ok(Context {
266/// greeting: env.greeting.clone(),
267/// })
268/// }
269/// }
270///
271/// /// Process a single message from the SQS queue, within the given context.
272/// async fn message_handler(message: Message, context: Arc<Context>) -> Result<()> {
273/// tracing::debug!("Message: {:?}", message);
274/// tracing::debug!("Context: {:?}", context);
275///
276/// // Log a greeting to the target
277/// tracing::info!("{}, {}!", context.greeting, message.target);
278///
279/// Ok(())
280/// }
281/// ```
282/// # Concurrent processing
283///
284/// By default `run_message_handler` will process the messages in an event batch sequentially.
285/// You can configure `run_message_handler` to process messages concurrently by setting
286/// the `RECORD_CONCURRENCY` env var (default: 1). This value should not exceed the value of `BatchSize`
287/// configured for the [event source mapping](https://docs.aws.amazon.com/lambda/latest/dg/invocation-eventsourcemapping.html#invocation-eventsourcemapping-batching)
288/// (default: 10).
289///
290/// # Error handling
291///
292/// If any errors are raised during init, or from the `message_handler` function, then the entire message
293/// batch will be considered to have failed. Error messages will be logged to stdout in a format compatible
294/// with CloudWatch, and the message batch being processed will be returned to the original queue.
295pub async fn run_message_handler<EventType, EventOutput, F, Fut, R, Msg, Context, Env>(
296 message_handler: F,
297) -> Result<(), Error>
298where
299 EventType: for<'de> serde::de::Deserialize<'de> + RunnableEventType<Msg, R, EventOutput>,
300 F: Fn(Msg, Arc<Context>) -> Fut,
301 Fut: Future<Output = Result<R>>,
302 Msg: serde::de::DeserializeOwned,
303 Context: LambdaContext<Env, EventType> + std::fmt::Debug,
304 EventOutput: Serialize,
305 Env: Parser + std::fmt::Debug,
306{
307 // Perform initial setup outside of the runtime to avoid this code being run
308 // on every invocation of the lambda.
309 //
310 // Ideally an error in this code would cause the runtime to return an initialization error:
311 // https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html#runtimes-api-initerror
312 // however this isn't currently supported by `lambda_runtime` (perhaps an area
313 // for future work). To work around this, we capture any errors during this phase
314 // and pass them into the handler function itself, so that it can raise the error
315 // when the function is invoked.
316 let init_result = async {
317 // Setup tracing
318 tracing_subscriber::fmt()
319 .with_env_filter(
320 EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
321 )
322 .json()
323 .init();
324 // We only care about values provided as environment variables, so we pass in an empty
325 // iterator, rather than having clap parse the command line arguments. This avoids an
326 // unfortunate issue in LocalStack where a command line arg of "handler.handler" is provided
327 // as a command line argument to the process when using an Image based lambda. This triggers
328 // a problem, as clap (wrongly, IMHO) tries to assign this command line option to the first
329 //element of the Env struct, and then ignores any actual environment variable provided.
330 let env = Env::try_parse_from(empty::<OsString>()).context(
331 "An error occurred while parsing environment variables for message context.",
332 )?;
333 let ctx = Arc::new(Context::from_env(&env).await?);
334 tracing::info!("Env: {:?}", env);
335 tracing::info!("Context: {:?}", ctx);
336 Ok::<_, anyhow::Error>(ctx)
337 }
338 .await;
339
340 lambda_runtime::run(service_fn(|event: LambdaEvent<EventType>| async {
341 // Check the result of the init phase. If it failed, log the message
342 // and immediately return.
343 let ctx = match &init_result {
344 Ok(x) => x,
345 Err(e) => {
346 tracing::error!("{:?}", e);
347 return Err(Error::from("Failed to initialise lambda."));
348 }
349 };
350
351 // Process the event and capture any errors
352 let (event, _context) = event.into_parts();
353 let result = event.process(&message_handler, ctx.clone()).await;
354
355 // Log out the full error, as the lambda_runtime only logs the first line of the error
356 // message, which can hide crucial information.
357 match result {
358 Ok(x) => Ok(x),
359 Err(e) => {
360 tracing::error!("{:?}", e);
361 Err(Error::from("Failed to process SQS event."))
362 }
363 }
364 }))
365 .await
366}
367
368/// A trait of the `Context` type, required when using [run_local_handler] to execute the message handler.
369///
370/// All `Context` types must implement the implement the [LocalContext::from_local] and [LocalContext::msg] methods for their corresponding `Msg` type
371/// in order to use the `Context` with [run_local_handler].
372///
373/// # Example
374///
375/// ```no_run
376/// use anyhow::Result;
377/// use serde::Deserialize;
378/// use cobalt_aws::lambda::LocalContext;
379///
380/// # use async_trait::async_trait;
381/// # #[derive(Debug)]
382/// # pub struct Env {
383/// # pub greeting: String,
384/// # }
385/// #
386/// /// Shared context we build up before invoking the local runner.
387/// #[derive(Debug)]
388/// pub struct Context {
389/// pub greeting: String,
390/// }
391///
392/// #[derive(Debug, Deserialize)]
393/// pub struct Message {
394/// pub target: String,
395/// }
396///
397/// #[async_trait]
398/// impl LocalContext<Message> for Context {
399/// /// Initialise a shared context object from which will be
400/// /// passed to all instances of the message handler.
401/// async fn from_local() -> Result<Self> {
402/// Ok(Context {
403/// greeting: "Hello".to_string(),
404/// })
405/// }
406///
407/// /// Construct a message object to be processed by the message handler.
408/// async fn msg(&self) -> Result<Message> {
409/// Ok(Message {
410/// target: "World".to_string(),
411/// })
412/// }
413/// }
414/// ```
415#[async_trait]
416pub trait LocalContext<Msg>: Sized {
417 /// Construct a new local Context object.
418 async fn from_local() -> Result<Self>;
419 /// Construct a message object to be processed by the message handler.
420 async fn msg(&self) -> Result<Msg>;
421}
422
423/// Executes a message handler against the message provided by the [LocalContext].
424///
425/// The `run_local_handler` function takes care of the following tasks:
426///
427/// * Sets up tracing to ensure all `tracing::<...>!()` calls are JSON formatted for consumption by CloudWatch.
428/// * Initialises a shared context object, which is passed to your handler.
429/// * Initialises a single message object, which is passed to your handler.
430///
431/// ## Writing a message handler
432///
433/// To write a message handler, you need to define four elements:
434///
435/// * The `Message` structure, which defines the structure of the messages which will be sent to your message handler.
436/// * The `Context` structure, which represents the shared state
437/// that will be passed into your message handler. This structure needs to implement the [LocalContext] trait.
438/// * The `message_handler` function, which accepts a `Message` and a `Context`, and performs the desired actions.
439///
440/// # Example
441///
442/// ```no_run
443/// use anyhow::Result;
444/// use async_trait::async_trait;
445/// use clap::Parser;
446/// use serde::Deserialize;
447/// use std::fmt::Debug;
448/// use std::sync::Arc;
449///
450/// use cobalt_aws::lambda::{run_local_handler, Error, LocalContext};
451///
452/// #[tokio::main]
453/// async fn main() -> Result<(), Error> {
454/// let result = run_local_handler(message_handler).await?;
455/// Ok(())
456/// }
457///
458/// /// The structure of the messages to be processed by our message handler.
459/// #[derive(Debug, Deserialize)]
460/// pub struct Message {
461/// pub target: String,
462/// }
463///
464///
465/// /// Shared context we build up before invoking the local runner.
466/// #[derive(Debug)]
467/// pub struct Context {
468/// pub greeting: String,
469/// }
470///
471/// #[async_trait]
472/// impl LocalContext<Message> for Context {
473/// /// Initialise a shared context object which will be
474/// /// passed to the message handler.
475/// async fn from_local() -> Result<Self> {
476/// Ok(Context {
477/// greeting: "Hello".to_string(),
478/// })
479/// }
480///
481/// /// Construct a message to be processed.
482/// async fn msg(&self) -> Result<Message> {
483/// Ok(Message {
484/// target: "World".to_string(),
485/// })
486/// }
487/// }
488///
489/// /// Process a single message, within the given context.
490/// async fn message_handler(message: Message, context: Arc<Context>) -> Result<()> {
491/// tracing::debug!("Message: {:?}", message);
492/// tracing::debug!("Context: {:?}", context);
493///
494/// // Log a greeting to the target
495/// tracing::info!("{}, {}!", context.greeting, message.target);
496///
497/// Ok(())
498/// }
499/// ```
500pub async fn run_local_handler<F, Fut, R, Msg, Context>(message_handler: F) -> Result<R, Error>
501where
502 F: Fn(Msg, Arc<Context>) -> Fut,
503 Fut: Future<Output = Result<R>>,
504 Msg: serde::de::DeserializeOwned,
505 Context: LocalContext<Msg> + std::fmt::Debug,
506{
507 tracing_subscriber::fmt()
508 .with_env_filter(
509 EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
510 )
511 .json()
512 .init();
513
514 let ctx = Arc::new(Context::from_local().await?);
515 tracing::info!("Context: {:?}", ctx);
516 Ok(message_handler(ctx.msg().await?, ctx).await?)
517}