lambda_extension/
extension.rs

1use http::Request;
2use http_body_util::BodyExt;
3use hyper::{body::Incoming, server::conn::http1, service::service_fn};
4
5use hyper_util::rt::tokio::TokioIo;
6use lambda_runtime_api_client::Client;
7use serde::Deserialize;
8use std::{
9    convert::Infallible,
10    fmt,
11    future::{ready, Future},
12    net::SocketAddr,
13    path::PathBuf,
14    pin::Pin,
15    sync::Arc,
16};
17use tokio::{net::TcpListener, sync::Mutex};
18use tokio_stream::StreamExt;
19use tower::{MakeService, Service, ServiceExt};
20use tracing::trace;
21
22use crate::{
23    logs::*,
24    requests::{self, Api},
25    telemetry_wrapper, Error, ExtensionError, LambdaEvent, LambdaTelemetry, NextEvent,
26};
27
28const DEFAULT_LOG_PORT_NUMBER: u16 = 9002;
29const DEFAULT_TELEMETRY_PORT_NUMBER: u16 = 9003;
30
31/// An Extension that runs event, log and telemetry processors
32pub struct Extension<'a, E, L, T> {
33    extension_name: Option<&'a str>,
34    events: Option<&'a [&'a str]>,
35    events_processor: E,
36    log_types: Option<&'a [&'a str]>,
37    logs_processor: Option<L>,
38    log_buffering: Option<LogBuffering>,
39    log_port_number: u16,
40    telemetry_types: Option<&'a [&'a str]>,
41    telemetry_processor: Option<T>,
42    telemetry_buffering: Option<LogBuffering>,
43    telemetry_port_number: u16,
44}
45
46impl Extension<'_, Identity<LambdaEvent>, MakeIdentity<Vec<LambdaLog>>, MakeIdentity<Vec<LambdaTelemetry>>> {
47    /// Create a new base [`Extension`] with a no-op events processor
48    pub fn new() -> Self {
49        Extension {
50            extension_name: None,
51            events: None,
52            events_processor: Identity::new(),
53            log_types: None,
54            log_buffering: None,
55            logs_processor: None,
56            log_port_number: DEFAULT_LOG_PORT_NUMBER,
57            telemetry_types: None,
58            telemetry_buffering: None,
59            telemetry_processor: None,
60            telemetry_port_number: DEFAULT_TELEMETRY_PORT_NUMBER,
61        }
62    }
63}
64
65impl Default
66    for Extension<'_, Identity<LambdaEvent>, MakeIdentity<Vec<LambdaLog>>, MakeIdentity<Vec<LambdaTelemetry>>>
67{
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl<'a, E, L, T> Extension<'a, E, L, T>
74where
75    E: Service<LambdaEvent>,
76    E::Future: Future<Output = Result<(), E::Error>>,
77    E::Error: Into<Error> + fmt::Display + fmt::Debug,
78
79    // Fixme: 'static bound might be too restrictive
80    L: MakeService<(), Vec<LambdaLog>, Response = ()> + Send + Sync + 'static,
81    L::Service: Service<Vec<LambdaLog>, Response = ()> + Send + Sync,
82    <L::Service as Service<Vec<LambdaLog>>>::Future: Send + 'a,
83    L::Error: Into<Error> + fmt::Debug,
84    L::MakeError: Into<Error> + fmt::Debug,
85    L::Future: Send,
86
87    // Fixme: 'static bound might be too restrictive
88    T: MakeService<(), Vec<LambdaTelemetry>, Response = ()> + Send + Sync + 'static,
89    T::Service: Service<Vec<LambdaTelemetry>, Response = ()> + Send + Sync,
90    <T::Service as Service<Vec<LambdaTelemetry>>>::Future: Send + 'a,
91    T::Error: Into<Error> + fmt::Debug,
92    T::MakeError: Into<Error> + fmt::Debug,
93    T::Future: Send,
94{
95    /// Create a new [`Extension`] with a given extension name
96    pub fn with_extension_name(self, extension_name: &'a str) -> Self {
97        Extension {
98            extension_name: Some(extension_name),
99            ..self
100        }
101    }
102
103    /// Create a new [`Extension`] with a list of given events.
104    /// The only accepted events are `INVOKE` and `SHUTDOWN`.
105    pub fn with_events(self, events: &'a [&'a str]) -> Self {
106        Extension {
107            events: Some(events),
108            ..self
109        }
110    }
111
112    /// Create a new [`Extension`] with a service that receives Lambda events.
113    pub fn with_events_processor<N>(self, ep: N) -> Extension<'a, N, L, T>
114    where
115        N: Service<LambdaEvent>,
116        N::Future: Future<Output = Result<(), N::Error>>,
117        N::Error: Into<Error> + fmt::Display,
118    {
119        Extension {
120            events_processor: ep,
121            extension_name: self.extension_name,
122            events: self.events,
123            log_types: self.log_types,
124            log_buffering: self.log_buffering,
125            logs_processor: self.logs_processor,
126            log_port_number: self.log_port_number,
127            telemetry_types: self.telemetry_types,
128            telemetry_buffering: self.telemetry_buffering,
129            telemetry_processor: self.telemetry_processor,
130            telemetry_port_number: self.telemetry_port_number,
131        }
132    }
133
134    /// Create a new [`Extension`] with a service that receives Lambda logs.
135    pub fn with_logs_processor<N, NS>(self, lp: N) -> Extension<'a, E, N, T>
136    where
137        N: Service<()>,
138        N::Future: Future<Output = Result<NS, N::Error>>,
139        N::Error: Into<Error> + fmt::Display,
140    {
141        Extension {
142            logs_processor: Some(lp),
143            events_processor: self.events_processor,
144            extension_name: self.extension_name,
145            events: self.events,
146            log_types: self.log_types,
147            log_buffering: self.log_buffering,
148            log_port_number: self.log_port_number,
149            telemetry_types: self.telemetry_types,
150            telemetry_buffering: self.telemetry_buffering,
151            telemetry_processor: self.telemetry_processor,
152            telemetry_port_number: self.telemetry_port_number,
153        }
154    }
155
156    /// Create a new [`Extension`] with a list of logs types to subscribe.
157    /// The only accepted log types are `function`, `platform`, and `extension`.
158    pub fn with_log_types(self, log_types: &'a [&'a str]) -> Self {
159        Extension {
160            log_types: Some(log_types),
161            ..self
162        }
163    }
164
165    /// Create a new [`Extension`] with specific configuration to buffer logs.
166    pub fn with_log_buffering(self, lb: LogBuffering) -> Self {
167        Extension {
168            log_buffering: Some(lb),
169            ..self
170        }
171    }
172
173    /// Create a new [`Extension`] with a different port number to listen to logs.
174    pub fn with_log_port_number(self, port_number: u16) -> Self {
175        Extension {
176            log_port_number: port_number,
177            ..self
178        }
179    }
180
181    /// Create a new [`Extension`] with a service that receives Lambda telemetry data.
182    pub fn with_telemetry_processor<N, NS>(self, lp: N) -> Extension<'a, E, L, N>
183    where
184        N: Service<()>,
185        N::Future: Future<Output = Result<NS, N::Error>>,
186        N::Error: Into<Error> + fmt::Display,
187    {
188        Extension {
189            telemetry_processor: Some(lp),
190            events_processor: self.events_processor,
191            extension_name: self.extension_name,
192            events: self.events,
193            log_types: self.log_types,
194            log_buffering: self.log_buffering,
195            logs_processor: self.logs_processor,
196            log_port_number: self.log_port_number,
197            telemetry_types: self.telemetry_types,
198            telemetry_buffering: self.telemetry_buffering,
199            telemetry_port_number: self.telemetry_port_number,
200        }
201    }
202
203    /// Create a new [`Extension`] with a list of telemetry types to subscribe.
204    /// The only accepted telemetry types are `function`, `platform`, and `extension`.
205    pub fn with_telemetry_types(self, telemetry_types: &'a [&'a str]) -> Self {
206        Extension {
207            telemetry_types: Some(telemetry_types),
208            ..self
209        }
210    }
211
212    /// Create a new [`Extension`] with specific configuration to buffer telemetry.
213    pub fn with_telemetry_buffering(self, lb: LogBuffering) -> Self {
214        Extension {
215            telemetry_buffering: Some(lb),
216            ..self
217        }
218    }
219
220    /// Create a new [`Extension`] with a different port number to listen to telemetry.
221    pub fn with_telemetry_port_number(self, port_number: u16) -> Self {
222        Extension {
223            telemetry_port_number: port_number,
224            ..self
225        }
226    }
227
228    /// Register the extension.
229    ///
230    /// Performs the
231    /// [init phase](https://docs.aws.amazon.com/lambda/latest/dg/lambda-runtime-environment.html#runtimes-lifecycle-ib)
232    /// Lambda lifecycle operations to register the extension. When implementing an internal Lambda
233    /// extension, it is safe to call `lambda_runtime::run` once the future returned by this
234    /// function resolves.
235    pub async fn register(self) -> Result<RegisteredExtension<E>, Error> {
236        let client = &Client::builder().build()?;
237
238        let register_res = register(client, self.extension_name, self.events).await?;
239
240        // Logs API subscriptions must be requested during the Lambda init phase (see
241        // https://docs.aws.amazon.com/lambda/latest/dg/runtimes-logs-api.html#runtimes-logs-api-subscribing).
242        if let Some(mut log_processor) = self.logs_processor {
243            trace!("Log processor found");
244
245            validate_buffering_configuration(self.log_buffering)?;
246
247            let addr = SocketAddr::from(([0, 0, 0, 0], self.log_port_number));
248            let service = log_processor.make_service(());
249            let service = Arc::new(Mutex::new(service.await.unwrap()));
250            tokio::task::spawn(async move {
251                trace!("Creating new logs processor Service");
252
253                loop {
254                    let service: Arc<Mutex<_>> = service.clone();
255                    let make_service = service_fn(move |req: Request<Incoming>| log_wrapper(service.clone(), req));
256
257                    let listener = TcpListener::bind(addr).await.unwrap();
258                    let (tcp, _) = listener.accept().await.unwrap();
259                    let io = TokioIo::new(tcp);
260                    tokio::task::spawn(async move {
261                        if let Err(err) = http1::Builder::new().serve_connection(io, make_service).await {
262                            println!("Error serving connection: {:?}", err);
263                        }
264                    });
265                }
266            });
267
268            trace!("Log processor started");
269
270            // Call Logs API to start receiving events
271            let req = requests::subscribe_request(
272                Api::LogsApi,
273                &register_res.extension_id,
274                self.log_types,
275                self.log_buffering,
276                self.log_port_number,
277            )?;
278            let res = client.call(req).await?;
279            if !res.status().is_success() {
280                let err = format!("unable to initialize the logs api: {}", res.status());
281                return Err(ExtensionError::boxed(err));
282            }
283            trace!("Registered extension with Logs API");
284        }
285
286        // Telemetry API subscriptions must be requested during the Lambda init phase (see
287        // https://docs.aws.amazon.com/lambda/latest/dg/telemetry-api.html#telemetry-api-registration
288        if let Some(mut telemetry_processor) = self.telemetry_processor {
289            trace!("Telemetry processor found");
290
291            validate_buffering_configuration(self.telemetry_buffering)?;
292
293            let addr = SocketAddr::from(([0, 0, 0, 0], self.telemetry_port_number));
294            let service = telemetry_processor.make_service(());
295            let service = Arc::new(Mutex::new(service.await.unwrap()));
296            tokio::task::spawn(async move {
297                trace!("Creating new telemetry processor Service");
298
299                loop {
300                    let service = service.clone();
301                    let make_service = service_fn(move |req| telemetry_wrapper(service.clone(), req));
302
303                    let listener = TcpListener::bind(addr).await.unwrap();
304                    let (tcp, _) = listener.accept().await.unwrap();
305                    let io = TokioIo::new(tcp);
306                    tokio::task::spawn(async move {
307                        if let Err(err) = http1::Builder::new().serve_connection(io, make_service).await {
308                            println!("Error serving connection: {:?}", err);
309                        }
310                    });
311                }
312            });
313
314            trace!("Telemetry processor started");
315
316            // Call Telemetry API to start receiving events
317            let req = requests::subscribe_request(
318                Api::TelemetryApi,
319                &register_res.extension_id,
320                self.telemetry_types,
321                self.telemetry_buffering,
322                self.telemetry_port_number,
323            )?;
324            let res = client.call(req).await?;
325            if !res.status().is_success() {
326                let err = format!("unable to initialize the telemetry api: {}", res.status());
327                return Err(ExtensionError::boxed(err));
328            }
329            trace!("Registered extension with Telemetry API");
330        }
331
332        Ok(RegisteredExtension {
333            extension_id: register_res.extension_id,
334            function_name: register_res.function_name,
335            function_version: register_res.function_version,
336            handler: register_res.handler,
337            account_id: register_res.account_id,
338            events_processor: self.events_processor,
339        })
340    }
341
342    /// Execute the given extension.
343    pub async fn run(self) -> Result<(), Error> {
344        self.register().await?.run().await
345    }
346}
347
348/// An extension registered by calling [`Extension::register`].
349pub struct RegisteredExtension<E> {
350    /// The ID of the registered extension. This ID is unique per extension and remains constant
351    pub extension_id: String,
352    /// The ID of the account the extension was registered to.
353    /// This will be `None` if the register request doesn't send the Lambda-Extension-Accept-Feature header
354    pub account_id: Option<String>,
355    /// The name of the Lambda function that the extension is registered with
356    pub function_name: String,
357    /// The version of the Lambda function that the extension is registered with
358    pub function_version: String,
359    /// The Lambda function handler that AWS Lambda invokes
360    pub handler: String,
361    events_processor: E,
362}
363
364impl<E> RegisteredExtension<E>
365where
366    E: Service<LambdaEvent>,
367    E::Future: Future<Output = Result<(), E::Error>>,
368    E::Error: Into<Box<dyn std::error::Error + Send + Sync>> + fmt::Display + fmt::Debug,
369{
370    /// Execute the extension's run loop.
371    ///
372    /// Performs the
373    /// [invoke](https://docs.aws.amazon.com/lambda/latest/dg/lambda-runtime-environment.html#runtimes-lifecycle-invoke)
374    /// and, for external Lambda extensions registered to receive the `SHUTDOWN` event, the
375    /// [shutdown](https://docs.aws.amazon.com/lambda/latest/dg/lambda-runtime-environment.html#runtimes-lifecycle-shutdown)
376    /// Lambda lifecycle phases.
377    pub async fn run(self) -> Result<(), Error> {
378        let client = &Client::builder().build()?;
379        let mut ep = self.events_processor;
380        let extension_id = &self.extension_id;
381
382        let incoming = async_stream::stream! {
383            loop {
384                trace!("Waiting for next event (incoming loop)");
385                let req = requests::next_event_request(extension_id)?;
386                let res = client.call(req).await;
387                yield res;
388            }
389        };
390
391        tokio::pin!(incoming);
392        while let Some(event) = incoming.next().await {
393            trace!("New event arrived (run loop)");
394            let event = event?;
395            let (_parts, body) = event.into_parts();
396
397            let body = body.collect().await?.to_bytes();
398            trace!("{}", std::str::from_utf8(&body)?); // this may be very verbose
399            let event: NextEvent = serde_json::from_slice(&body)?;
400            let is_invoke = event.is_invoke();
401
402            let event = LambdaEvent::new(event);
403
404            let ep = match ep.ready().await {
405                Ok(ep) => ep,
406                Err(err) => {
407                    println!("Inner service is not ready: {err:?}");
408                    let req = if is_invoke {
409                        requests::init_error(extension_id, &err.to_string(), None)?
410                    } else {
411                        requests::exit_error(extension_id, &err.to_string(), None)?
412                    };
413
414                    client.call(req).await?;
415                    return Err(err.into());
416                }
417            };
418
419            let res = ep.call(event).await;
420            if let Err(err) = res {
421                println!("{err:?}");
422                let req = if is_invoke {
423                    requests::init_error(extension_id, &err.to_string(), None)?
424                } else {
425                    requests::exit_error(extension_id, &err.to_string(), None)?
426                };
427
428                client.call(req).await?;
429                return Err(err.into());
430            }
431        }
432
433        // Unreachable.
434        Ok(())
435    }
436}
437
438/// A no-op generic processor
439#[derive(Clone)]
440pub struct Identity<T> {
441    _phantom: std::marker::PhantomData<T>,
442}
443
444impl<T> Identity<T> {
445    fn new() -> Self {
446        Self {
447            _phantom: std::marker::PhantomData,
448        }
449    }
450}
451
452impl<T> Service<T> for Identity<T> {
453    type Error = Infallible;
454    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
455    type Response = ();
456
457    fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
458        core::task::Poll::Ready(Ok(()))
459    }
460
461    fn call(&mut self, _event: T) -> Self::Future {
462        Box::pin(ready(Ok(())))
463    }
464}
465
466/// Service factory to generate no-op generic processors
467#[derive(Clone)]
468pub struct MakeIdentity<T> {
469    _phantom: std::marker::PhantomData<T>,
470}
471
472impl<T> Service<()> for MakeIdentity<T>
473where
474    T: Send + Sync + 'static,
475{
476    type Error = Infallible;
477    type Response = Identity<T>;
478    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
479
480    fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
481        core::task::Poll::Ready(Ok(()))
482    }
483
484    fn call(&mut self, _: ()) -> Self::Future {
485        Box::pin(ready(Ok(Identity::new())))
486    }
487}
488
489#[derive(Debug, Deserialize)]
490#[serde(rename_all = "camelCase")]
491struct RegisterResponseBody {
492    function_name: String,
493    function_version: String,
494    handler: String,
495    account_id: Option<String>,
496}
497
498#[derive(Debug)]
499struct RegisterResponse {
500    extension_id: String,
501    function_name: String,
502    function_version: String,
503    handler: String,
504    account_id: Option<String>,
505}
506
507/// Initialize and register the extension in the Extensions API
508async fn register<'a>(
509    client: &'a Client,
510    extension_name: Option<&'a str>,
511    events: Option<&'a [&'a str]>,
512) -> Result<RegisterResponse, Error> {
513    let name = match extension_name {
514        Some(name) => name.into(),
515        None => {
516            let args: Vec<String> = std::env::args().collect();
517            PathBuf::from(args[0].clone())
518                .file_name()
519                .expect("unexpected executable name")
520                .to_str()
521                .expect("unexpect executable name")
522                .to_string()
523        }
524    };
525
526    let events = events.unwrap_or(&["INVOKE", "SHUTDOWN"]);
527
528    let req = requests::register_request(&name, events)?;
529    let res = client.call(req).await?;
530    if !res.status().is_success() {
531        let err = format!("unable to register the extension: {}", res.status());
532        return Err(ExtensionError::boxed(err));
533    }
534
535    let header = res
536        .headers()
537        .get(requests::EXTENSION_ID_HEADER)
538        .ok_or_else(|| ExtensionError::boxed("missing extension id header"))
539        .map_err(|e| ExtensionError::boxed(e.to_string()))?;
540    let extension_id = header.to_str()?.to_string();
541
542    let (_, body) = res.into_parts();
543    let body = body.collect().await?.to_bytes();
544    let response: RegisterResponseBody = serde_json::from_slice(&body)?;
545
546    Ok(RegisterResponse {
547        extension_id,
548        function_name: response.function_name,
549        function_version: response.function_version,
550        handler: response.handler,
551        account_id: response.account_id,
552    })
553}