use std::future::Future;
use std::sync::Arc;
use durable_lambda_core::backend::RealBackend;
use durable_lambda_core::context::DurableContext;
use durable_lambda_core::error::DurableError;
use durable_lambda_core::event::parse_invocation;
use durable_lambda_core::response::wrap_handler_result;
use lambda_runtime::{service_fn, LambdaEvent};
use crate::context::BuilderContext;
pub struct DurableHandlerBuilder<F, Fut>
where
F: Fn(serde_json::Value, BuilderContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, DurableError>> + Send,
{
handler: F,
_phantom: std::marker::PhantomData<Fut>,
tracing_subscriber: Option<Box<dyn tracing::Subscriber + Send + Sync + 'static>>,
error_handler: Option<Box<dyn Fn(DurableError) -> DurableError + Send + Sync>>,
}
pub fn handler<F, Fut>(f: F) -> DurableHandlerBuilder<F, Fut>
where
F: Fn(serde_json::Value, BuilderContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, DurableError>> + Send,
{
DurableHandlerBuilder {
handler: f,
_phantom: std::marker::PhantomData,
tracing_subscriber: None,
error_handler: None,
}
}
impl<F, Fut> DurableHandlerBuilder<F, Fut>
where
F: Fn(serde_json::Value, BuilderContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, DurableError>> + Send,
{
pub fn with_tracing(
mut self,
subscriber: impl tracing::Subscriber + Send + Sync + 'static,
) -> Self {
self.tracing_subscriber = Some(Box::new(subscriber));
self
}
pub fn with_error_handler(
mut self,
handler: impl Fn(DurableError) -> DurableError + Send + Sync + 'static,
) -> Self {
self.error_handler = Some(Box::new(handler));
self
}
pub async fn run(self) -> Result<(), lambda_runtime::Error> {
if let Some(subscriber) = self.tracing_subscriber {
tracing::subscriber::set_global_default(subscriber)
.expect("tracing subscriber already set");
}
let error_handler = self.error_handler;
let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
let client = aws_sdk_lambda::Client::new(&config);
let backend = Arc::new(RealBackend::new(client));
lambda_runtime::run(service_fn(|event: LambdaEvent<serde_json::Value>| {
let backend = backend.clone();
let handler = &self.handler;
let error_handler = &error_handler;
async move {
let (payload, _lambda_ctx) = event.into_parts();
let invocation = parse_invocation(&payload)
.map_err(Box::<dyn std::error::Error + Send + Sync>::from)?;
let durable_ctx = DurableContext::new(
backend,
invocation.durable_execution_arn,
invocation.checkpoint_token,
invocation.operations,
invocation.next_marker,
)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let builder_ctx = BuilderContext::new(durable_ctx);
let result = handler(invocation.user_event, builder_ctx).await;
let result = match result {
Ok(v) => Ok(v),
Err(e) => {
let transformed = if let Some(ref h) = error_handler {
h(e)
} else {
e
};
Err(transformed)
}
};
wrap_handler_result(result)
}
}))
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::BuilderContext;
use tracing_subscriber::fmt;
#[test]
fn test_builder_construction_and_type_correctness() {
let _builder = handler(
|_event: serde_json::Value, _ctx: BuilderContext| async move {
Ok(serde_json::json!({"ok": true}))
},
);
}
#[test]
fn test_builder_run_returns_future() {
let builder = handler(
|_event: serde_json::Value, _ctx: BuilderContext| async move {
Ok(serde_json::json!({"ok": true}))
},
);
let _future = builder.run();
}
#[test]
fn test_with_tracing_stores_subscriber() {
let subscriber = fmt().finish();
let _builder = handler(
|_event: serde_json::Value, _ctx: BuilderContext| async move {
Ok(serde_json::json!({"ok": true}))
},
)
.with_tracing(subscriber);
}
#[test]
fn test_with_error_handler_stores_handler() {
let _builder = handler(
|_event: serde_json::Value, _ctx: BuilderContext| async move {
Ok(serde_json::json!({"ok": true}))
},
)
.with_error_handler(|e: DurableError| e);
}
#[test]
fn test_builder_chaining() {
let subscriber = fmt().finish();
let _builder = handler(
|_event: serde_json::Value, _ctx: BuilderContext| async move {
Ok(serde_json::json!({"ok": true}))
},
)
.with_tracing(subscriber)
.with_error_handler(|e: DurableError| e);
}
#[test]
fn test_builder_without_config_backward_compatible() {
let builder = handler(
|_event: serde_json::Value, _ctx: BuilderContext| async move {
Ok(serde_json::json!({"ok": true}))
},
);
let _future = builder.run();
}
}