cobalt_aws/
lambda.rs

1//! A wrapper around the [lambda_runtime](https://github.com/awslabs/aws-lambda-rust-runtime) crate.
2//!
3//! This wrapper provides wrappers to make it easier to write Lambda functions which consume messages from
4//! various events, such as an SQS queue configured with an [event source mapping](https://docs.aws.amazon.com/lambda/latest/dg/invocation-eventsourcemapping.html),
5//! or a [step function](https://docs.aws.amazon.com/step-functions/latest/dg/welcome.html).
6//! It also provides mechanisms for running message handlers locally, without the full AWS Lambda environment.
7
8use anyhow::{Context as _, Result};
9use async_trait::async_trait;
10pub use aws_lambda_events::event::s3::S3Event;
11/// Re-export from [aws_lambda_events::event::sqs::SqsEvent], with [RunnableEventType] implemented.
12pub use aws_lambda_events::event::sqs::SqsEvent;
13use clap::Parser;
14use futures::stream::{self, StreamExt, TryStreamExt};
15use futures::FutureExt;
16use lambda_runtime::{service_fn, LambdaEvent};
17use serde::Serialize;
18use std::ffi::OsString;
19use std::future::Future;
20use std::iter::empty;
21use std::sync::Arc;
22use tracing_subscriber::filter::EnvFilter;
23
24/// Re-export of [lambda_runtime::Error](https://docs.rs/lambda_runtime/latest/lambda_runtime/type.Error.html).
25///
26// We provide this re-export so that the user doesn't need to have lambda_runtime as a direct dependency.
27pub use lambda_runtime::Error;
28
29/// This struct is used to attempt to parse the `AWS_LAMBDA_FUNCTION_NAME` environment variable.
30///
31/// We assume that if this variable is present then we're running in a Lambda function.
32/// https://docs.aws.amazon.com/lambda/latest/dg/configuration-envvars.html
33#[derive(Debug, Parser)]
34struct CheckLambda {
35    #[arg(env)]
36    aws_lambda_function_name: Option<String>,
37}
38
39/// Determine whether the code is being executed within an AWS Lambda.
40///
41/// This function can be used to write binaries that are able to run both locally
42/// or as a Lambda function.
43pub fn running_on_lambda() -> Result<bool> {
44    let check_lambda = CheckLambda::try_parse_from(empty::<OsString>())
45        .context("An error occurred while parsing environment variables for message context.")?;
46    Ok(check_lambda.aws_lambda_function_name.is_some())
47}
48
49/// A trait of the `Context` type, required when using [run_message_handler] to execute the message handler.
50///
51/// All `Context` types must implement the implement the [LambdaContext::from_env] method for their corresponding `Env` type
52/// in order to use the `Context` with [run_message_handler].
53///
54/// When implementing `LambdaContext` you must also specify the `EventType` that the Lambda expects to receive, e.g. [SqsEvent].
55#[async_trait]
56pub trait LambdaContext<Env, EventType>: Sized {
57    /// # Example
58    ///
59    /// ```no_run
60    /// use anyhow::Result;
61    /// use cobalt_aws::lambda::{LambdaContext, SqsEvent};
62    ///
63    /// # use async_trait::async_trait;
64    /// # #[derive(Debug)]
65    /// # pub struct Env {
66    /// #     pub greeting: String,
67    /// # }
68    /// #
69    /// /// Shared context we build up before invoking the lambda runner.
70    /// #[derive(Debug)]
71    /// pub struct Context {
72    ///     pub greeting: String,
73    /// }
74    ///
75    /// #[async_trait]
76    /// impl LambdaContext<Env, SqsEvent> for Context {
77    ///     /// Initialise a shared context object from which will be
78    ///     /// passed to all instances of the message handler.
79    ///     async fn from_env(env: &Env) -> Result<Context> {
80    ///         Ok(Context {
81    ///             greeting: env.greeting.clone(),
82    ///         })
83    ///     }
84    /// }
85    /// ```
86    async fn from_env(env: &Env) -> Result<Self>;
87}
88
89/// Environment variables to configure the lambda handler.
90#[derive(Debug, Parser)]
91struct HandlerEnv {
92    /// How many concurrent records should be processed at once.
93    /// This defaults to 1 to set synchronous processing.
94    /// This value should not exceed the value of `BatchSize`
95    /// configured for the [event source mapping](https://docs.aws.amazon.com/lambda/latest/dg/invocation-eventsourcemapping.html#invocation-eventsourcemapping-batching).
96    #[arg(env, default_value_t = 1)]
97    record_concurrency: usize,
98}
99
100/// A trait that allows a type to be used as the `EventType` trait of a [LambdaContext].
101#[async_trait(?Send)]
102pub trait RunnableEventType<Msg, MsgResult, EventResult> {
103    async fn process<F, Fut, Context>(
104        &self,
105        message_handler: F,
106        ctx: Arc<Context>,
107    ) -> Result<EventResult>
108    where
109        F: Fn(Msg, Arc<Context>) -> Fut,
110        Fut: Future<Output = Result<MsgResult>>,
111        Msg: serde::de::DeserializeOwned;
112}
113
114/// The `StepFunctionEvent` trait allows any deserializable struct
115/// to be used as the EventType for a `LambdaContext`.
116///
117/// # Example
118///
119/// ```compile_fail
120/// use async_trait::async_trait;
121/// use cobalt_aws::lambda::StepFunctionEvent;
122/// use serde::{Deserialize, Serialize};
123///
124/// /// The structure of the event we expect to see at the Step Function.
125/// #[derive(Debug, Deserialize, Serialize, Clone)]
126/// pub struct MyEvent {
127///     pub greeting: String,
128/// }
129///
130/// impl StepFunctionEvent for MyEvent {}
131///
132/// #[async_trait]
133/// impl LambdaContext<Env, StatsEvent> for Context {
134///     ...
135/// }
136/// ```
137pub trait StepFunctionEvent: Clone {}
138
139#[async_trait(?Send)]
140impl<T: StepFunctionEvent, MsgResult> RunnableEventType<T, MsgResult, MsgResult> for T {
141    async fn process<F, Fut, Context>(
142        &self,
143        message_handler: F,
144        ctx: Arc<Context>,
145    ) -> Result<MsgResult>
146    where
147        F: Fn(T, Arc<Context>) -> Fut,
148        Fut: Future<Output = Result<MsgResult>>,
149    {
150        message_handler(self.clone(), ctx).await
151    }
152}
153
154#[async_trait(?Send)]
155impl RunnableEventType<S3Event, (), ()> for S3Event {
156    async fn process<F, Fut, Context>(&self, message_handler: F, ctx: Arc<Context>) -> Result<()>
157    where
158        F: Fn(S3Event, Arc<Context>) -> Fut,
159        Fut: Future<Output = Result<()>>,
160    {
161        Ok(message_handler(self.clone(), ctx).await?)
162    }
163}
164
165// When processing an SqsEvent, we expect the message handler to return (), and the
166// lambda itself to also return ().
167#[async_trait(?Send)]
168impl<Msg> RunnableEventType<Msg, (), ()> for SqsEvent {
169    async fn process<F, Fut, Context>(&self, message_handler: F, ctx: Arc<Context>) -> Result<()>
170    where
171        F: Fn(Msg, Arc<Context>) -> Fut,
172        Fut: Future<Output = Result<()>>,
173        Msg: serde::de::DeserializeOwned,
174    {
175        // Process the records in the event batch concurrently, up to `RECORD_CONCURRENCY`.
176        // If any of them fail, return immediately.
177        let handler_env: HandlerEnv = HandlerEnv::try_parse_from(empty::<OsString>())
178            .context("An error occurred while parsing environment variable for handler")?;
179        stream::iter(&self.records)
180            .map(|record| {
181                let body = record.body.clone();
182                let body = body.with_context(|| "No SqsMessage body".to_string())?;
183                let msg = serde_json::from_str::<Msg>(&body)
184                    .with_context(|| format!("Error parsing body into message: {}", body))?;
185                Ok((msg, body))
186            })
187            .try_for_each_concurrent(handler_env.record_concurrency, |(msg, body)| async {
188                message_handler(msg, ctx.clone())
189                    .map(move |r| {
190                        r.with_context(|| format!("Error running message handler {}", body))
191                    })
192                    .await?;
193                Ok(())
194            })
195            .await
196    }
197}
198
199/// Executes a message handler against all the messages received in a batch
200/// from an SQS event source mapping.
201///
202/// The `run_message_handler` function takes care of the following tasks:
203///
204/// * Executes the Lambda runtime, using [lambda_runtime](https://github.com/awslabs/aws-lambda-rust-runtime).
205/// * Sets up tracing to ensure all `tracing::<...>!()` calls are JSON formatted for consumption by CloudWatch.
206/// * Processes environment variables and makes them available to your handler
207/// * Initialises a shared context object, which is passed to your handler.
208/// * Deserialises a batch of messages and passes each one to your handler.
209/// * Processes messages concurrently, based on the env var `RECORD_CONCURRENCY` (default: 1)
210///
211/// ## Writing a message handler
212///
213/// To write a message handler, you need to define four elements:
214///
215/// * The `Message` structure, which defines the structure of the messages which will be sent to the SQS
216///   queue, and then forwarded on to your Lambda.
217/// * The `Env` structure, which defines the expected environment variables your Lambda will receive.
218/// * The `Context` structure, which is provided the `Env` structure, and represents the shared state
219///   that will be passed into your message handler. This structure needs to implement the [LambdaContext] trait.
220/// * The `message_handler` function, which accepts a `Message` and a `Context`, and performs the desired actions.
221///
222/// # Example
223///
224/// ```no_run
225/// use anyhow::Result;
226/// use async_trait::async_trait;
227/// use clap::Parser;
228/// use serde::Deserialize;
229/// use std::fmt::Debug;
230/// use std::sync::Arc;
231///
232/// use cobalt_aws::lambda::{run_message_handler, Error, LambdaContext, SqsEvent};
233///
234/// #[tokio::main]
235/// async fn main() -> Result<(), Error> {
236///     run_message_handler(message_handler).await
237/// }
238///
239/// /// The structure of the messages we expect to see on the queue.
240/// #[derive(Debug, Deserialize)]
241/// pub struct Message {
242///     pub target: String,
243/// }
244///
245/// /// Configuration we receive from environment variables.
246/// ///
247/// /// Note: all fields should be tagged with the `#[arg(env)]` attribute.
248/// #[derive(Debug, Parser)]
249/// pub struct Env {
250///     #[arg(env)]
251///     pub greeting: String,
252/// }
253///
254/// /// Shared context we build up before invoking the lambda runner.
255/// #[derive(Debug)]
256/// pub struct Context {
257///     pub greeting: String,
258/// }
259///
260/// #[async_trait]
261/// impl LambdaContext<Env, SqsEvent> for Context {
262///     /// Initialise a shared context object from which will be
263///     /// passed to all instances of the message handler.
264///     async fn from_env(env: &Env) -> Result<Context> {
265///         Ok(Context {
266///             greeting: env.greeting.clone(),
267///         })
268///     }
269/// }
270///
271/// /// Process a single message from the SQS queue, within the given context.
272/// async fn message_handler(message: Message, context: Arc<Context>) -> Result<()> {
273///     tracing::debug!("Message: {:?}", message);
274///     tracing::debug!("Context: {:?}", context);
275///
276///     // Log a greeting to the target
277///     tracing::info!("{}, {}!", context.greeting, message.target);
278///
279///     Ok(())
280/// }
281/// ```
282/// # Concurrent processing
283///
284/// By default `run_message_handler` will process the messages in an event batch sequentially.
285/// You can configure `run_message_handler` to process messages concurrently by setting
286/// the `RECORD_CONCURRENCY` env var (default: 1). This value should not exceed the value of `BatchSize`
287/// configured for the [event source mapping](https://docs.aws.amazon.com/lambda/latest/dg/invocation-eventsourcemapping.html#invocation-eventsourcemapping-batching)
288/// (default: 10).
289///
290/// # Error handling
291///
292/// If any errors are raised during init, or from the `message_handler` function, then the entire message
293/// batch will be considered to have failed. Error messages will be logged to stdout in a format compatible
294/// with CloudWatch, and the message batch being processed will be returned to the original queue.
295pub async fn run_message_handler<EventType, EventOutput, F, Fut, R, Msg, Context, Env>(
296    message_handler: F,
297) -> Result<(), Error>
298where
299    EventType: for<'de> serde::de::Deserialize<'de> + RunnableEventType<Msg, R, EventOutput>,
300    F: Fn(Msg, Arc<Context>) -> Fut,
301    Fut: Future<Output = Result<R>>,
302    Msg: serde::de::DeserializeOwned,
303    Context: LambdaContext<Env, EventType> + std::fmt::Debug,
304    EventOutput: Serialize,
305    Env: Parser + std::fmt::Debug,
306{
307    // Perform initial setup outside of the runtime to avoid this code being run
308    // on every invocation of the lambda.
309    //
310    // Ideally an error in this code would cause the runtime to return an initialization error:
311    // https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html#runtimes-api-initerror
312    // however this isn't currently supported by `lambda_runtime` (perhaps an area
313    // for future work). To work around this, we capture any errors during this phase
314    // and pass them into the handler function itself, so that it can raise the error
315    // when the function is invoked.
316    let init_result = async {
317        // Setup tracing
318        tracing_subscriber::fmt()
319            .with_env_filter(
320                EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
321            )
322            .json()
323            .init();
324        // We only care about values provided as environment variables, so we pass in an empty
325        // iterator, rather than having clap parse the command line arguments. This avoids an
326        // unfortunate issue in LocalStack where a command line arg of "handler.handler" is provided
327        // as a command line argument to the process when using an Image based lambda. This triggers
328        // a problem, as clap (wrongly, IMHO) tries to assign this command line option to the first
329        //element of the Env struct, and then ignores any actual environment variable provided.
330        let env = Env::try_parse_from(empty::<OsString>()).context(
331            "An error occurred while parsing environment variables for message context.",
332        )?;
333        let ctx = Arc::new(Context::from_env(&env).await?);
334        tracing::info!("Env: {:?}", env);
335        tracing::info!("Context: {:?}", ctx);
336        Ok::<_, anyhow::Error>(ctx)
337    }
338    .await;
339
340    lambda_runtime::run(service_fn(|event: LambdaEvent<EventType>| async {
341        // Check the result of the init phase. If it failed, log the message
342        // and immediately return.
343        let ctx = match &init_result {
344            Ok(x) => x,
345            Err(e) => {
346                tracing::error!("{:?}", e);
347                return Err(Error::from("Failed to initialise lambda."));
348            }
349        };
350
351        // Process the event and capture any errors
352        let (event, _context) = event.into_parts();
353        let result = event.process(&message_handler, ctx.clone()).await;
354
355        // Log out the full error, as the lambda_runtime only logs the first line of the error
356        // message, which can hide crucial information.
357        match result {
358            Ok(x) => Ok(x),
359            Err(e) => {
360                tracing::error!("{:?}", e);
361                Err(Error::from("Failed to process SQS event."))
362            }
363        }
364    }))
365    .await
366}
367
368/// A trait of the `Context` type, required when using [run_local_handler] to execute the message handler.
369///
370/// All `Context` types must implement the implement the [LocalContext::from_local] and [LocalContext::msg] methods for their corresponding `Msg` type
371/// in order to use the `Context` with [run_local_handler].
372///
373/// # Example
374///
375/// ```no_run
376/// use anyhow::Result;
377/// use serde::Deserialize;
378/// use cobalt_aws::lambda::LocalContext;
379///
380/// # use async_trait::async_trait;
381/// # #[derive(Debug)]
382/// # pub struct Env {
383/// #     pub greeting: String,
384/// # }
385/// #
386/// /// Shared context we build up before invoking the local runner.
387/// #[derive(Debug)]
388/// pub struct Context {
389///     pub greeting: String,
390/// }
391///
392/// #[derive(Debug, Deserialize)]
393/// pub struct Message {
394///     pub target: String,
395/// }
396///
397/// #[async_trait]
398/// impl LocalContext<Message> for Context {
399///     /// Initialise a shared context object from which will be
400///     /// passed to all instances of the message handler.
401///     async fn from_local() -> Result<Self> {
402///         Ok(Context {
403///             greeting: "Hello".to_string(),
404///         })
405///     }
406///
407///     /// Construct a message object to be processed by the message handler.
408///     async fn msg(&self) -> Result<Message> {
409///         Ok(Message {
410///             target: "World".to_string(),
411///         })
412///     }
413/// }
414/// ```
415#[async_trait]
416pub trait LocalContext<Msg>: Sized {
417    /// Construct a new local Context object.
418    async fn from_local() -> Result<Self>;
419    /// Construct a message object to be processed by the message handler.
420    async fn msg(&self) -> Result<Msg>;
421}
422
423/// Executes a message handler against the message provided by the [LocalContext].
424///
425/// The `run_local_handler` function takes care of the following tasks:
426///
427/// * Sets up tracing to ensure all `tracing::<...>!()` calls are JSON formatted for consumption by CloudWatch.
428/// * Initialises a shared context object, which is passed to your handler.
429/// * Initialises a single message object, which is passed to your handler.
430///
431/// ## Writing a message handler
432///
433/// To write a message handler, you need to define four elements:
434///
435/// * The `Message` structure, which defines the structure of the messages which will be sent to your message handler.
436/// * The `Context` structure, which represents the shared state
437///   that will be passed into your message handler. This structure needs to implement the [LocalContext] trait.
438/// * The `message_handler` function, which accepts a `Message` and a `Context`, and performs the desired actions.
439///
440/// # Example
441///
442/// ```no_run
443/// use anyhow::Result;
444/// use async_trait::async_trait;
445/// use clap::Parser;
446/// use serde::Deserialize;
447/// use std::fmt::Debug;
448/// use std::sync::Arc;
449///
450/// use cobalt_aws::lambda::{run_local_handler, Error, LocalContext};
451///
452/// #[tokio::main]
453/// async fn main() -> Result<(), Error> {
454///     let result = run_local_handler(message_handler).await?;
455///     Ok(())
456/// }
457///
458/// /// The structure of the messages to be processed by our message handler.
459/// #[derive(Debug, Deserialize)]
460/// pub struct Message {
461///     pub target: String,
462/// }
463///
464///
465/// /// Shared context we build up before invoking the local runner.
466/// #[derive(Debug)]
467/// pub struct Context {
468///     pub greeting: String,
469/// }
470///
471/// #[async_trait]
472/// impl LocalContext<Message> for Context {
473///     /// Initialise a shared context object which will be
474///     /// passed to the message handler.
475///     async fn from_local() -> Result<Self> {
476///         Ok(Context {
477///             greeting: "Hello".to_string(),
478///         })
479///     }
480///
481///     /// Construct a message to be processed.
482///     async fn msg(&self) -> Result<Message> {
483///         Ok(Message {
484///             target: "World".to_string(),
485///         })
486///     }
487/// }
488///
489/// /// Process a single message, within the given context.
490/// async fn message_handler(message: Message, context: Arc<Context>) -> Result<()> {
491///     tracing::debug!("Message: {:?}", message);
492///     tracing::debug!("Context: {:?}", context);
493///
494///     // Log a greeting to the target
495///     tracing::info!("{}, {}!", context.greeting, message.target);
496///
497///     Ok(())
498/// }
499/// ```
500pub async fn run_local_handler<F, Fut, R, Msg, Context>(message_handler: F) -> Result<R, Error>
501where
502    F: Fn(Msg, Arc<Context>) -> Fut,
503    Fut: Future<Output = Result<R>>,
504    Msg: serde::de::DeserializeOwned,
505    Context: LocalContext<Msg> + std::fmt::Debug,
506{
507    tracing_subscriber::fmt()
508        .with_env_filter(
509            EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
510        )
511        .json()
512        .init();
513
514    let ctx = Arc::new(Context::from_local().await?);
515    tracing::info!("Context: {:?}", ctx);
516    Ok(message_handler(ctx.msg().await?, ctx).await?)
517}