Skip to main content

lambda_extension/
extension.rs

1use http::Request;
2use http_body_util::BodyExt;
3use hyper::{body::Incoming, server::conn::http1, service::service_fn};
4
5use hyper_util::rt::tokio::TokioIo;
6use lambda_runtime_api_client::Client;
7use serde::{de::DeserializeOwned, Deserialize};
8use std::{
9    convert::Infallible,
10    fmt,
11    future::{ready, Future},
12    marker::PhantomData,
13    net::SocketAddr,
14    path::PathBuf,
15    pin::Pin,
16    sync::Arc,
17};
18use tokio::{net::TcpListener, sync::Mutex};
19use tokio_stream::StreamExt;
20use tower::{MakeService, Service, ServiceExt};
21use tracing::trace;
22
23use crate::{
24    logs::*,
25    requests::{self, Api},
26    telemetry_wrapper, Error, ExtensionError, LambdaEvent, LambdaTelemetry, NextEvent,
27};
28
29const DEFAULT_LOG_PORT_NUMBER: u16 = 9002;
30const DEFAULT_TELEMETRY_PORT_NUMBER: u16 = 9003;
31
32/// An Extension that runs event, log and telemetry processors
33pub struct Extension<'a, E, L, T, TL = String> {
34    extension_name: Option<&'a str>,
35    events: Option<&'a [&'a str]>,
36    events_processor: E,
37    log_types: Option<&'a [&'a str]>,
38    logs_processor: Option<L>,
39    log_buffering: Option<LogBuffering>,
40    log_port_number: u16,
41    telemetry_types: Option<&'a [&'a str]>,
42    telemetry_processor: Option<T>,
43    telemetry_buffering: Option<LogBuffering>,
44    telemetry_port_number: u16,
45    _telemetry_record_type: PhantomData<fn(TL)>,
46}
47
48impl Extension<'_, Identity<LambdaEvent>, MakeIdentity<Vec<LambdaLog>>, MakeIdentity<Vec<LambdaTelemetry>>> {
49    /// Create a new base [`Extension`] with a no-op events processor
50    pub fn new() -> Self {
51        Extension {
52            extension_name: None,
53            events: None,
54            events_processor: Identity::new(),
55            log_types: None,
56            log_buffering: None,
57            logs_processor: None,
58            log_port_number: DEFAULT_LOG_PORT_NUMBER,
59            telemetry_types: None,
60            telemetry_buffering: None,
61            telemetry_processor: None,
62            telemetry_port_number: DEFAULT_TELEMETRY_PORT_NUMBER,
63            _telemetry_record_type: PhantomData,
64        }
65    }
66}
67
68impl Default
69    for Extension<'_, Identity<LambdaEvent>, MakeIdentity<Vec<LambdaLog>>, MakeIdentity<Vec<LambdaTelemetry>>>
70{
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl<'a, E, L, T, TL> Extension<'a, E, L, T, TL>
77where
78    E: Service<LambdaEvent>,
79    E::Future: Future<Output = Result<(), E::Error>>,
80    E::Error: Into<Error> + fmt::Display + fmt::Debug,
81
82    // Fixme: 'static bound might be too restrictive
83    L: MakeService<(), Vec<LambdaLog>, Response = ()> + Send + Sync + 'static,
84    L::Service: Service<Vec<LambdaLog>, Response = ()> + Send + Sync,
85    <L::Service as Service<Vec<LambdaLog>>>::Future: Send + 'a,
86    L::Error: Into<Error> + fmt::Debug,
87    L::MakeError: Into<Error> + fmt::Debug,
88    L::Future: Send,
89
90    // Fixme: 'static bound might be too restrictive
91    T: MakeService<(), Vec<LambdaTelemetry<TL>>, Response = ()> + Send + Sync + 'static,
92    T::Service: Service<Vec<LambdaTelemetry<TL>>, Response = ()> + Send + Sync,
93    <T::Service as Service<Vec<LambdaTelemetry<TL>>>>::Future: Send + 'a,
94    T::Error: Into<Error> + fmt::Debug,
95    T::MakeError: Into<Error> + fmt::Debug,
96    T::Future: Send,
97    TL: DeserializeOwned + Send + 'static,
98{
99    /// Create a new [`Extension`] with a given extension name
100    pub fn with_extension_name(self, extension_name: &'a str) -> Self {
101        Extension {
102            extension_name: Some(extension_name),
103            ..self
104        }
105    }
106
107    /// Create a new [`Extension`] with a list of given events.
108    /// The only accepted events are `INVOKE` and `SHUTDOWN`.
109    pub fn with_events(self, events: &'a [&'a str]) -> Self {
110        Extension {
111            events: Some(events),
112            ..self
113        }
114    }
115
116    /// Create a new [`Extension`] with a service that receives Lambda events.
117    pub fn with_events_processor<N>(self, ep: N) -> Extension<'a, N, L, T, TL>
118    where
119        N: Service<LambdaEvent>,
120        N::Future: Future<Output = Result<(), N::Error>>,
121        N::Error: Into<Error> + fmt::Display,
122    {
123        Extension {
124            events_processor: ep,
125            extension_name: self.extension_name,
126            events: self.events,
127            log_types: self.log_types,
128            log_buffering: self.log_buffering,
129            logs_processor: self.logs_processor,
130            log_port_number: self.log_port_number,
131            telemetry_types: self.telemetry_types,
132            telemetry_buffering: self.telemetry_buffering,
133            telemetry_processor: self.telemetry_processor,
134            telemetry_port_number: self.telemetry_port_number,
135            _telemetry_record_type: self._telemetry_record_type,
136        }
137    }
138
139    /// Create a new [`Extension`] with a service that receives Lambda logs.
140    pub fn with_logs_processor<N, NS>(self, lp: N) -> Extension<'a, E, N, T, TL>
141    where
142        N: Service<()>,
143        N::Future: Future<Output = Result<NS, N::Error>>,
144        N::Error: Into<Error> + fmt::Display,
145    {
146        Extension {
147            logs_processor: Some(lp),
148            events_processor: self.events_processor,
149            extension_name: self.extension_name,
150            events: self.events,
151            log_types: self.log_types,
152            log_buffering: self.log_buffering,
153            log_port_number: self.log_port_number,
154            telemetry_types: self.telemetry_types,
155            telemetry_buffering: self.telemetry_buffering,
156            telemetry_processor: self.telemetry_processor,
157            telemetry_port_number: self.telemetry_port_number,
158            _telemetry_record_type: self._telemetry_record_type,
159        }
160    }
161
162    /// Create a new [`Extension`] with a list of logs types to subscribe.
163    /// The only accepted log types are `function`, `platform`, and `extension`.
164    pub fn with_log_types(self, log_types: &'a [&'a str]) -> Self {
165        Extension {
166            log_types: Some(log_types),
167            ..self
168        }
169    }
170
171    /// Create a new [`Extension`] with specific configuration to buffer logs.
172    pub fn with_log_buffering(self, lb: LogBuffering) -> Self {
173        Extension {
174            log_buffering: Some(lb),
175            ..self
176        }
177    }
178
179    /// Create a new [`Extension`] with a different port number to listen to logs.
180    pub fn with_log_port_number(self, port_number: u16) -> Self {
181        Extension {
182            log_port_number: port_number,
183            ..self
184        }
185    }
186
187    /// Create a new [`Extension`] with a service that receives Lambda telemetry data.
188    ///
189    /// By default, telemetry log records are deserialized as `String`, but
190    /// it's possible to configure Lambda functions to emit logs in JSON format.
191    /// For more information, refer to [`Self::with_telemetry_record_type`].
192    pub fn with_telemetry_processor<N, NS>(self, lp: N) -> Extension<'a, E, L, N, TL>
193    where
194        N: Service<()>,
195        N::Future: Future<Output = Result<NS, N::Error>>,
196        N::Error: Into<Error> + fmt::Display,
197    {
198        Extension {
199            telemetry_processor: Some(lp),
200            events_processor: self.events_processor,
201            extension_name: self.extension_name,
202            events: self.events,
203            log_types: self.log_types,
204            log_buffering: self.log_buffering,
205            logs_processor: self.logs_processor,
206            log_port_number: self.log_port_number,
207            telemetry_types: self.telemetry_types,
208            telemetry_buffering: self.telemetry_buffering,
209            telemetry_port_number: self.telemetry_port_number,
210            _telemetry_record_type: self._telemetry_record_type,
211        }
212    }
213
214    /// Create a new [`Extension`] with a list of telemetry types to subscribe.
215    /// The only accepted telemetry types are `function`, `platform`, and `extension`.
216    pub fn with_telemetry_types(self, telemetry_types: &'a [&'a str]) -> Self {
217        Extension {
218            telemetry_types: Some(telemetry_types),
219            ..self
220        }
221    }
222
223    /// Create a new [`Extension`] with specific configuration to buffer telemetry.
224    pub fn with_telemetry_buffering(self, lb: LogBuffering) -> Self {
225        Extension {
226            telemetry_buffering: Some(lb),
227            ..self
228        }
229    }
230
231    /// Create a new [`Extension`] with a different port number to listen to telemetry.
232    pub fn with_telemetry_port_number(self, port_number: u16) -> Self {
233        Extension {
234            telemetry_port_number: port_number,
235            ..self
236        }
237    }
238
239    /// Register the extension.
240    ///
241    /// Performs the
242    /// [init phase](https://docs.aws.amazon.com/lambda/latest/dg/lambda-runtime-environment.html#runtimes-lifecycle-ib)
243    /// Lambda lifecycle operations to register the extension. When implementing an internal Lambda
244    /// extension, it is safe to call `lambda_runtime::run` once the future returned by this
245    /// function resolves.
246    pub async fn register(self) -> Result<RegisteredExtension<E>, Error> {
247        let client = &Client::builder().build()?;
248
249        let register_res = register(client, self.extension_name, self.events).await?;
250
251        // Logs API subscriptions must be requested during the Lambda init phase (see
252        // https://docs.aws.amazon.com/lambda/latest/dg/runtimes-logs-api.html#runtimes-logs-api-subscribing).
253        if let Some(mut log_processor) = self.logs_processor {
254            trace!("Log processor found");
255
256            validate_buffering_configuration(self.log_buffering)?;
257
258            let addr = SocketAddr::from(([0, 0, 0, 0], self.log_port_number));
259            let service = log_processor.make_service(());
260            let service = Arc::new(Mutex::new(service.await.unwrap()));
261            tokio::task::spawn(async move {
262                trace!("Creating new logs processor Service");
263
264                loop {
265                    let service: Arc<Mutex<_>> = service.clone();
266                    let make_service = service_fn(move |req: Request<Incoming>| log_wrapper(service.clone(), req));
267
268                    let listener = TcpListener::bind(addr).await.unwrap();
269                    let (tcp, _) = listener.accept().await.unwrap();
270                    let io = TokioIo::new(tcp);
271                    tokio::task::spawn(async move {
272                        if let Err(err) = http1::Builder::new().serve_connection(io, make_service).await {
273                            println!("Error serving connection: {err:?}");
274                        }
275                    });
276                }
277            });
278
279            trace!("Log processor started");
280
281            // Call Logs API to start receiving events
282            let req = requests::subscribe_request(
283                Api::LogsApi,
284                &register_res.extension_id,
285                self.log_types,
286                self.log_buffering,
287                self.log_port_number,
288            )?;
289            let res = client.call(req).await?;
290            if !res.status().is_success() {
291                let err = format!("unable to initialize the logs api: {}", res.status());
292                return Err(ExtensionError::boxed(err));
293            }
294            trace!("Registered extension with Logs API");
295        }
296
297        // Telemetry API subscriptions must be requested during the Lambda init phase (see
298        // https://docs.aws.amazon.com/lambda/latest/dg/telemetry-api.html#telemetry-api-registration
299        if let Some(mut telemetry_processor) = self.telemetry_processor {
300            trace!("Telemetry processor found");
301
302            validate_buffering_configuration(self.telemetry_buffering)?;
303
304            let addr = SocketAddr::from(([0, 0, 0, 0], self.telemetry_port_number));
305            let service = telemetry_processor.make_service(());
306            let service = Arc::new(Mutex::new(service.await.unwrap()));
307            tokio::task::spawn(async move {
308                trace!("Creating new telemetry processor Service");
309
310                loop {
311                    let service = service.clone();
312                    let make_service = service_fn(move |req| telemetry_wrapper(service.clone(), req));
313
314                    let listener = TcpListener::bind(addr).await.unwrap();
315                    let (tcp, _) = listener.accept().await.unwrap();
316                    let io = TokioIo::new(tcp);
317                    tokio::task::spawn(async move {
318                        if let Err(err) = http1::Builder::new().serve_connection(io, make_service).await {
319                            println!("Error serving connection: {err:?}");
320                        }
321                    });
322                }
323            });
324
325            trace!("Telemetry processor started");
326
327            // Call Telemetry API to start receiving events
328            let req = requests::subscribe_request(
329                Api::TelemetryApi,
330                &register_res.extension_id,
331                self.telemetry_types,
332                self.telemetry_buffering,
333                self.telemetry_port_number,
334            )?;
335            let res = client.call(req).await?;
336            if !res.status().is_success() {
337                let err = format!("unable to initialize the telemetry api: {}", res.status());
338                return Err(ExtensionError::boxed(err));
339            }
340            trace!("Registered extension with Telemetry API");
341        }
342
343        Ok(RegisteredExtension {
344            extension_id: register_res.extension_id,
345            function_name: register_res.function_name,
346            function_version: register_res.function_version,
347            handler: register_res.handler,
348            account_id: register_res.account_id,
349            events_processor: self.events_processor,
350        })
351    }
352
353    /// Execute the given extension.
354    pub async fn run(self) -> Result<(), Error> {
355        self.register().await?.run().await
356    }
357}
358
359impl<'a, E, L> Extension<'a, E, L, MakeIdentity<Vec<LambdaTelemetry>>> {
360    /// Set the deserialization type for telemetry log records.
361    ///
362    /// By default, telemetry log records are deserialized as `String`, but
363    /// it's possible to configure Lambda functions to emit logs in JSON format.
364    /// Use this method to deserialize into a different type, such as
365    /// `serde_json::Value`.
366    ///
367    /// Must be called before [`Self::with_telemetry_processor`].
368    ///
369    /// ```
370    /// use lambda_extension::{Extension, LambdaTelemetry, SharedService, service_fn};
371    ///
372    /// async fn handler(events: Vec<LambdaTelemetry<serde_json::Value>>) -> Result<(), lambda_extension::Error> {
373    ///     for event in &events {
374    ///         println!("{event:?}");
375    ///     }
376    ///     Ok(())
377    /// }
378    ///
379    /// let _ext = Extension::new()
380    ///     .with_telemetry_record_type::<serde_json::Value>()
381    ///     .with_telemetry_processor(SharedService::new(service_fn(handler)));
382    /// ```
383    pub fn with_telemetry_record_type<N>(self) -> Extension<'a, E, L, MakeIdentity<Vec<LambdaTelemetry<N>>>, N> {
384        Extension {
385            _telemetry_record_type: PhantomData,
386            telemetry_processor: None,
387            events_processor: self.events_processor,
388            extension_name: self.extension_name,
389            events: self.events,
390            log_types: self.log_types,
391            log_buffering: self.log_buffering,
392            logs_processor: self.logs_processor,
393            log_port_number: self.log_port_number,
394            telemetry_types: self.telemetry_types,
395            telemetry_buffering: self.telemetry_buffering,
396            telemetry_port_number: self.telemetry_port_number,
397        }
398    }
399}
400
401/// An extension registered by calling [`Extension::register`].
402pub struct RegisteredExtension<E> {
403    /// The ID of the registered extension. This ID is unique per extension and remains constant
404    pub extension_id: String,
405    /// The ID of the account the extension was registered to.
406    /// This will be `None` if the register request doesn't send the Lambda-Extension-Accept-Feature header
407    pub account_id: Option<String>,
408    /// The name of the Lambda function that the extension is registered with
409    pub function_name: String,
410    /// The version of the Lambda function that the extension is registered with
411    pub function_version: String,
412    /// The Lambda function handler that AWS Lambda invokes
413    pub handler: String,
414    events_processor: E,
415}
416
417impl<E> RegisteredExtension<E>
418where
419    E: Service<LambdaEvent>,
420    E::Future: Future<Output = Result<(), E::Error>>,
421    E::Error: Into<Box<dyn std::error::Error + Send + Sync>> + fmt::Display + fmt::Debug,
422{
423    /// Execute the extension's run loop.
424    ///
425    /// Performs the
426    /// [invoke](https://docs.aws.amazon.com/lambda/latest/dg/lambda-runtime-environment.html#runtimes-lifecycle-invoke)
427    /// and, for external Lambda extensions registered to receive the `SHUTDOWN` event, the
428    /// [shutdown](https://docs.aws.amazon.com/lambda/latest/dg/lambda-runtime-environment.html#runtimes-lifecycle-shutdown)
429    /// Lambda lifecycle phases.
430    pub async fn run(self) -> Result<(), Error> {
431        let client = &Client::builder().build()?;
432        let mut ep = self.events_processor;
433        let extension_id = &self.extension_id;
434
435        let incoming = async_stream::stream! {
436            loop {
437                trace!("Waiting for next event (incoming loop)");
438                let req = requests::next_event_request(extension_id)?;
439                let res = client.call(req).await;
440                yield res;
441            }
442        };
443
444        tokio::pin!(incoming);
445        while let Some(event) = incoming.next().await {
446            trace!("New event arrived (run loop)");
447            let event = event?;
448            let (_parts, body) = event.into_parts();
449
450            let body = body.collect().await?.to_bytes();
451            trace!("{}", std::str::from_utf8(&body)?); // this may be very verbose
452            let event: NextEvent = serde_json::from_slice(&body)?;
453            let is_invoke = event.is_invoke();
454
455            let event = LambdaEvent::new(event);
456
457            let ep = match ep.ready().await {
458                Ok(ep) => ep,
459                Err(err) => {
460                    println!("Inner service is not ready: {err:?}");
461                    let req = if is_invoke {
462                        requests::init_error(extension_id, &err.to_string(), None)?
463                    } else {
464                        requests::exit_error(extension_id, &err.to_string(), None)?
465                    };
466
467                    client.call(req).await?;
468                    return Err(err.into());
469                }
470            };
471
472            let res = ep.call(event).await;
473            if let Err(err) = res {
474                println!("{err:?}");
475                let req = if is_invoke {
476                    requests::init_error(extension_id, &err.to_string(), None)?
477                } else {
478                    requests::exit_error(extension_id, &err.to_string(), None)?
479                };
480
481                client.call(req).await?;
482                return Err(err.into());
483            }
484        }
485
486        // Unreachable.
487        Ok(())
488    }
489}
490
491/// A no-op generic processor
492#[derive(Clone)]
493pub struct Identity<T> {
494    _phantom: std::marker::PhantomData<T>,
495}
496
497impl<T> Identity<T> {
498    fn new() -> Self {
499        Self {
500            _phantom: std::marker::PhantomData,
501        }
502    }
503}
504
505impl<T> Service<T> for Identity<T> {
506    type Error = Infallible;
507    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
508    type Response = ();
509
510    fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
511        core::task::Poll::Ready(Ok(()))
512    }
513
514    fn call(&mut self, _event: T) -> Self::Future {
515        Box::pin(ready(Ok(())))
516    }
517}
518
519/// Service factory to generate no-op generic processors
520#[derive(Clone)]
521pub struct MakeIdentity<T> {
522    _phantom: std::marker::PhantomData<T>,
523}
524
525impl<T> Service<()> for MakeIdentity<T>
526where
527    T: Send + Sync + 'static,
528{
529    type Error = Infallible;
530    type Response = Identity<T>;
531    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
532
533    fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
534        core::task::Poll::Ready(Ok(()))
535    }
536
537    fn call(&mut self, _: ()) -> Self::Future {
538        Box::pin(ready(Ok(Identity::new())))
539    }
540}
541
542#[derive(Debug, Deserialize)]
543#[serde(rename_all = "camelCase")]
544struct RegisterResponseBody {
545    function_name: String,
546    function_version: String,
547    handler: String,
548    account_id: Option<String>,
549}
550
551#[derive(Debug)]
552struct RegisterResponse {
553    extension_id: String,
554    function_name: String,
555    function_version: String,
556    handler: String,
557    account_id: Option<String>,
558}
559
560/// Initialize and register the extension in the Extensions API
561async fn register<'a>(
562    client: &'a Client,
563    extension_name: Option<&'a str>,
564    events: Option<&'a [&'a str]>,
565) -> Result<RegisterResponse, Error> {
566    let name = match extension_name {
567        Some(name) => name.into(),
568        None => {
569            let args: Vec<String> = std::env::args().collect();
570            PathBuf::from(args[0].clone())
571                .file_name()
572                .expect("unexpected executable name")
573                .to_str()
574                .expect("unexpect executable name")
575                .to_string()
576        }
577    };
578
579    let events = events.unwrap_or(&["INVOKE", "SHUTDOWN"]);
580
581    let req = requests::register_request(&name, events)?;
582    let res = client.call(req).await?;
583    if !res.status().is_success() {
584        let err = format!("unable to register the extension: {}", res.status());
585        return Err(ExtensionError::boxed(err));
586    }
587
588    let header = res
589        .headers()
590        .get(requests::EXTENSION_ID_HEADER)
591        .ok_or_else(|| ExtensionError::boxed("missing extension id header"))
592        .map_err(|e| ExtensionError::boxed(e.to_string()))?;
593    let extension_id = header.to_str()?.to_string();
594
595    let (_, body) = res.into_parts();
596    let body = body.collect().await?.to_bytes();
597    let response: RegisterResponseBody = serde_json::from_slice(&body)?;
598
599    Ok(RegisterResponse {
600        extension_id,
601        function_name: response.function_name,
602        function_version: response.function_version,
603        handler: response.handler,
604        account_id: response.account_id,
605    })
606}