1use http::Request;
2use http_body_util::BodyExt;
3use hyper::{body::Incoming, server::conn::http1, service::service_fn};
4
5use hyper_util::rt::tokio::TokioIo;
6use lambda_runtime_api_client::Client;
7use serde::{de::DeserializeOwned, Deserialize};
8use std::{
9 convert::Infallible,
10 fmt,
11 future::{ready, Future},
12 marker::PhantomData,
13 net::SocketAddr,
14 path::PathBuf,
15 pin::Pin,
16 sync::Arc,
17};
18use tokio::{net::TcpListener, sync::Mutex};
19use tokio_stream::StreamExt;
20use tower::{MakeService, Service, ServiceExt};
21use tracing::trace;
22
23use crate::{
24 logs::*,
25 requests::{self, Api},
26 telemetry_wrapper, Error, ExtensionError, LambdaEvent, LambdaTelemetry, NextEvent,
27};
28
29const DEFAULT_LOG_PORT_NUMBER: u16 = 9002;
30const DEFAULT_TELEMETRY_PORT_NUMBER: u16 = 9003;
31
32pub struct Extension<'a, E, L, T, TL = String> {
34 extension_name: Option<&'a str>,
35 events: Option<&'a [&'a str]>,
36 events_processor: E,
37 log_types: Option<&'a [&'a str]>,
38 logs_processor: Option<L>,
39 log_buffering: Option<LogBuffering>,
40 log_port_number: u16,
41 telemetry_types: Option<&'a [&'a str]>,
42 telemetry_processor: Option<T>,
43 telemetry_buffering: Option<LogBuffering>,
44 telemetry_port_number: u16,
45 _telemetry_record_type: PhantomData<fn(TL)>,
46}
47
48impl Extension<'_, Identity<LambdaEvent>, MakeIdentity<Vec<LambdaLog>>, MakeIdentity<Vec<LambdaTelemetry>>> {
49 pub fn new() -> Self {
51 Extension {
52 extension_name: None,
53 events: None,
54 events_processor: Identity::new(),
55 log_types: None,
56 log_buffering: None,
57 logs_processor: None,
58 log_port_number: DEFAULT_LOG_PORT_NUMBER,
59 telemetry_types: None,
60 telemetry_buffering: None,
61 telemetry_processor: None,
62 telemetry_port_number: DEFAULT_TELEMETRY_PORT_NUMBER,
63 _telemetry_record_type: PhantomData,
64 }
65 }
66}
67
68impl Default
69 for Extension<'_, Identity<LambdaEvent>, MakeIdentity<Vec<LambdaLog>>, MakeIdentity<Vec<LambdaTelemetry>>>
70{
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl<'a, E, L, T, TL> Extension<'a, E, L, T, TL>
77where
78 E: Service<LambdaEvent>,
79 E::Future: Future<Output = Result<(), E::Error>>,
80 E::Error: Into<Error> + fmt::Display + fmt::Debug,
81
82 L: MakeService<(), Vec<LambdaLog>, Response = ()> + Send + Sync + 'static,
84 L::Service: Service<Vec<LambdaLog>, Response = ()> + Send + Sync,
85 <L::Service as Service<Vec<LambdaLog>>>::Future: Send + 'a,
86 L::Error: Into<Error> + fmt::Debug,
87 L::MakeError: Into<Error> + fmt::Debug,
88 L::Future: Send,
89
90 T: MakeService<(), Vec<LambdaTelemetry<TL>>, Response = ()> + Send + Sync + 'static,
92 T::Service: Service<Vec<LambdaTelemetry<TL>>, Response = ()> + Send + Sync,
93 <T::Service as Service<Vec<LambdaTelemetry<TL>>>>::Future: Send + 'a,
94 T::Error: Into<Error> + fmt::Debug,
95 T::MakeError: Into<Error> + fmt::Debug,
96 T::Future: Send,
97 TL: DeserializeOwned + Send + 'static,
98{
99 pub fn with_extension_name(self, extension_name: &'a str) -> Self {
101 Extension {
102 extension_name: Some(extension_name),
103 ..self
104 }
105 }
106
107 pub fn with_events(self, events: &'a [&'a str]) -> Self {
110 Extension {
111 events: Some(events),
112 ..self
113 }
114 }
115
116 pub fn with_events_processor<N>(self, ep: N) -> Extension<'a, N, L, T, TL>
118 where
119 N: Service<LambdaEvent>,
120 N::Future: Future<Output = Result<(), N::Error>>,
121 N::Error: Into<Error> + fmt::Display,
122 {
123 Extension {
124 events_processor: ep,
125 extension_name: self.extension_name,
126 events: self.events,
127 log_types: self.log_types,
128 log_buffering: self.log_buffering,
129 logs_processor: self.logs_processor,
130 log_port_number: self.log_port_number,
131 telemetry_types: self.telemetry_types,
132 telemetry_buffering: self.telemetry_buffering,
133 telemetry_processor: self.telemetry_processor,
134 telemetry_port_number: self.telemetry_port_number,
135 _telemetry_record_type: self._telemetry_record_type,
136 }
137 }
138
139 pub fn with_logs_processor<N, NS>(self, lp: N) -> Extension<'a, E, N, T, TL>
141 where
142 N: Service<()>,
143 N::Future: Future<Output = Result<NS, N::Error>>,
144 N::Error: Into<Error> + fmt::Display,
145 {
146 Extension {
147 logs_processor: Some(lp),
148 events_processor: self.events_processor,
149 extension_name: self.extension_name,
150 events: self.events,
151 log_types: self.log_types,
152 log_buffering: self.log_buffering,
153 log_port_number: self.log_port_number,
154 telemetry_types: self.telemetry_types,
155 telemetry_buffering: self.telemetry_buffering,
156 telemetry_processor: self.telemetry_processor,
157 telemetry_port_number: self.telemetry_port_number,
158 _telemetry_record_type: self._telemetry_record_type,
159 }
160 }
161
162 pub fn with_log_types(self, log_types: &'a [&'a str]) -> Self {
165 Extension {
166 log_types: Some(log_types),
167 ..self
168 }
169 }
170
171 pub fn with_log_buffering(self, lb: LogBuffering) -> Self {
173 Extension {
174 log_buffering: Some(lb),
175 ..self
176 }
177 }
178
179 pub fn with_log_port_number(self, port_number: u16) -> Self {
181 Extension {
182 log_port_number: port_number,
183 ..self
184 }
185 }
186
187 pub fn with_telemetry_processor<N, NS>(self, lp: N) -> Extension<'a, E, L, N, TL>
193 where
194 N: Service<()>,
195 N::Future: Future<Output = Result<NS, N::Error>>,
196 N::Error: Into<Error> + fmt::Display,
197 {
198 Extension {
199 telemetry_processor: Some(lp),
200 events_processor: self.events_processor,
201 extension_name: self.extension_name,
202 events: self.events,
203 log_types: self.log_types,
204 log_buffering: self.log_buffering,
205 logs_processor: self.logs_processor,
206 log_port_number: self.log_port_number,
207 telemetry_types: self.telemetry_types,
208 telemetry_buffering: self.telemetry_buffering,
209 telemetry_port_number: self.telemetry_port_number,
210 _telemetry_record_type: self._telemetry_record_type,
211 }
212 }
213
214 pub fn with_telemetry_types(self, telemetry_types: &'a [&'a str]) -> Self {
217 Extension {
218 telemetry_types: Some(telemetry_types),
219 ..self
220 }
221 }
222
223 pub fn with_telemetry_buffering(self, lb: LogBuffering) -> Self {
225 Extension {
226 telemetry_buffering: Some(lb),
227 ..self
228 }
229 }
230
231 pub fn with_telemetry_port_number(self, port_number: u16) -> Self {
233 Extension {
234 telemetry_port_number: port_number,
235 ..self
236 }
237 }
238
239 pub async fn register(self) -> Result<RegisteredExtension<E>, Error> {
247 let client = &Client::builder().build()?;
248
249 let register_res = register(client, self.extension_name, self.events).await?;
250
251 if let Some(mut log_processor) = self.logs_processor {
254 trace!("Log processor found");
255
256 validate_buffering_configuration(self.log_buffering)?;
257
258 let addr = SocketAddr::from(([0, 0, 0, 0], self.log_port_number));
259 let service = log_processor.make_service(());
260 let service = Arc::new(Mutex::new(service.await.unwrap()));
261 tokio::task::spawn(async move {
262 trace!("Creating new logs processor Service");
263
264 loop {
265 let service: Arc<Mutex<_>> = service.clone();
266 let make_service = service_fn(move |req: Request<Incoming>| log_wrapper(service.clone(), req));
267
268 let listener = TcpListener::bind(addr).await.unwrap();
269 let (tcp, _) = listener.accept().await.unwrap();
270 let io = TokioIo::new(tcp);
271 tokio::task::spawn(async move {
272 if let Err(err) = http1::Builder::new().serve_connection(io, make_service).await {
273 println!("Error serving connection: {err:?}");
274 }
275 });
276 }
277 });
278
279 trace!("Log processor started");
280
281 let req = requests::subscribe_request(
283 Api::LogsApi,
284 ®ister_res.extension_id,
285 self.log_types,
286 self.log_buffering,
287 self.log_port_number,
288 )?;
289 let res = client.call(req).await?;
290 if !res.status().is_success() {
291 let err = format!("unable to initialize the logs api: {}", res.status());
292 return Err(ExtensionError::boxed(err));
293 }
294 trace!("Registered extension with Logs API");
295 }
296
297 if let Some(mut telemetry_processor) = self.telemetry_processor {
300 trace!("Telemetry processor found");
301
302 validate_buffering_configuration(self.telemetry_buffering)?;
303
304 let addr = SocketAddr::from(([0, 0, 0, 0], self.telemetry_port_number));
305 let service = telemetry_processor.make_service(());
306 let service = Arc::new(Mutex::new(service.await.unwrap()));
307 tokio::task::spawn(async move {
308 trace!("Creating new telemetry processor Service");
309
310 loop {
311 let service = service.clone();
312 let make_service = service_fn(move |req| telemetry_wrapper(service.clone(), req));
313
314 let listener = TcpListener::bind(addr).await.unwrap();
315 let (tcp, _) = listener.accept().await.unwrap();
316 let io = TokioIo::new(tcp);
317 tokio::task::spawn(async move {
318 if let Err(err) = http1::Builder::new().serve_connection(io, make_service).await {
319 println!("Error serving connection: {err:?}");
320 }
321 });
322 }
323 });
324
325 trace!("Telemetry processor started");
326
327 let req = requests::subscribe_request(
329 Api::TelemetryApi,
330 ®ister_res.extension_id,
331 self.telemetry_types,
332 self.telemetry_buffering,
333 self.telemetry_port_number,
334 )?;
335 let res = client.call(req).await?;
336 if !res.status().is_success() {
337 let err = format!("unable to initialize the telemetry api: {}", res.status());
338 return Err(ExtensionError::boxed(err));
339 }
340 trace!("Registered extension with Telemetry API");
341 }
342
343 Ok(RegisteredExtension {
344 extension_id: register_res.extension_id,
345 function_name: register_res.function_name,
346 function_version: register_res.function_version,
347 handler: register_res.handler,
348 account_id: register_res.account_id,
349 events_processor: self.events_processor,
350 })
351 }
352
353 pub async fn run(self) -> Result<(), Error> {
355 self.register().await?.run().await
356 }
357}
358
359impl<'a, E, L> Extension<'a, E, L, MakeIdentity<Vec<LambdaTelemetry>>> {
360 pub fn with_telemetry_record_type<N>(self) -> Extension<'a, E, L, MakeIdentity<Vec<LambdaTelemetry<N>>>, N> {
384 Extension {
385 _telemetry_record_type: PhantomData,
386 telemetry_processor: None,
387 events_processor: self.events_processor,
388 extension_name: self.extension_name,
389 events: self.events,
390 log_types: self.log_types,
391 log_buffering: self.log_buffering,
392 logs_processor: self.logs_processor,
393 log_port_number: self.log_port_number,
394 telemetry_types: self.telemetry_types,
395 telemetry_buffering: self.telemetry_buffering,
396 telemetry_port_number: self.telemetry_port_number,
397 }
398 }
399}
400
401pub struct RegisteredExtension<E> {
403 pub extension_id: String,
405 pub account_id: Option<String>,
408 pub function_name: String,
410 pub function_version: String,
412 pub handler: String,
414 events_processor: E,
415}
416
417impl<E> RegisteredExtension<E>
418where
419 E: Service<LambdaEvent>,
420 E::Future: Future<Output = Result<(), E::Error>>,
421 E::Error: Into<Box<dyn std::error::Error + Send + Sync>> + fmt::Display + fmt::Debug,
422{
423 pub async fn run(self) -> Result<(), Error> {
431 let client = &Client::builder().build()?;
432 let mut ep = self.events_processor;
433 let extension_id = &self.extension_id;
434
435 let incoming = async_stream::stream! {
436 loop {
437 trace!("Waiting for next event (incoming loop)");
438 let req = requests::next_event_request(extension_id)?;
439 let res = client.call(req).await;
440 yield res;
441 }
442 };
443
444 tokio::pin!(incoming);
445 while let Some(event) = incoming.next().await {
446 trace!("New event arrived (run loop)");
447 let event = event?;
448 let (_parts, body) = event.into_parts();
449
450 let body = body.collect().await?.to_bytes();
451 trace!("{}", std::str::from_utf8(&body)?); let event: NextEvent = serde_json::from_slice(&body)?;
453 let is_invoke = event.is_invoke();
454
455 let event = LambdaEvent::new(event);
456
457 let ep = match ep.ready().await {
458 Ok(ep) => ep,
459 Err(err) => {
460 println!("Inner service is not ready: {err:?}");
461 let req = if is_invoke {
462 requests::init_error(extension_id, &err.to_string(), None)?
463 } else {
464 requests::exit_error(extension_id, &err.to_string(), None)?
465 };
466
467 client.call(req).await?;
468 return Err(err.into());
469 }
470 };
471
472 let res = ep.call(event).await;
473 if let Err(err) = res {
474 println!("{err:?}");
475 let req = if is_invoke {
476 requests::init_error(extension_id, &err.to_string(), None)?
477 } else {
478 requests::exit_error(extension_id, &err.to_string(), None)?
479 };
480
481 client.call(req).await?;
482 return Err(err.into());
483 }
484 }
485
486 Ok(())
488 }
489}
490
491#[derive(Clone)]
493pub struct Identity<T> {
494 _phantom: std::marker::PhantomData<T>,
495}
496
497impl<T> Identity<T> {
498 fn new() -> Self {
499 Self {
500 _phantom: std::marker::PhantomData,
501 }
502 }
503}
504
505impl<T> Service<T> for Identity<T> {
506 type Error = Infallible;
507 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
508 type Response = ();
509
510 fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
511 core::task::Poll::Ready(Ok(()))
512 }
513
514 fn call(&mut self, _event: T) -> Self::Future {
515 Box::pin(ready(Ok(())))
516 }
517}
518
519#[derive(Clone)]
521pub struct MakeIdentity<T> {
522 _phantom: std::marker::PhantomData<T>,
523}
524
525impl<T> Service<()> for MakeIdentity<T>
526where
527 T: Send + Sync + 'static,
528{
529 type Error = Infallible;
530 type Response = Identity<T>;
531 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
532
533 fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
534 core::task::Poll::Ready(Ok(()))
535 }
536
537 fn call(&mut self, _: ()) -> Self::Future {
538 Box::pin(ready(Ok(Identity::new())))
539 }
540}
541
542#[derive(Debug, Deserialize)]
543#[serde(rename_all = "camelCase")]
544struct RegisterResponseBody {
545 function_name: String,
546 function_version: String,
547 handler: String,
548 account_id: Option<String>,
549}
550
551#[derive(Debug)]
552struct RegisterResponse {
553 extension_id: String,
554 function_name: String,
555 function_version: String,
556 handler: String,
557 account_id: Option<String>,
558}
559
560async fn register<'a>(
562 client: &'a Client,
563 extension_name: Option<&'a str>,
564 events: Option<&'a [&'a str]>,
565) -> Result<RegisterResponse, Error> {
566 let name = match extension_name {
567 Some(name) => name.into(),
568 None => {
569 let args: Vec<String> = std::env::args().collect();
570 PathBuf::from(args[0].clone())
571 .file_name()
572 .expect("unexpected executable name")
573 .to_str()
574 .expect("unexpect executable name")
575 .to_string()
576 }
577 };
578
579 let events = events.unwrap_or(&["INVOKE", "SHUTDOWN"]);
580
581 let req = requests::register_request(&name, events)?;
582 let res = client.call(req).await?;
583 if !res.status().is_success() {
584 let err = format!("unable to register the extension: {}", res.status());
585 return Err(ExtensionError::boxed(err));
586 }
587
588 let header = res
589 .headers()
590 .get(requests::EXTENSION_ID_HEADER)
591 .ok_or_else(|| ExtensionError::boxed("missing extension id header"))
592 .map_err(|e| ExtensionError::boxed(e.to_string()))?;
593 let extension_id = header.to_str()?.to_string();
594
595 let (_, body) = res.into_parts();
596 let body = body.collect().await?.to_bytes();
597 let response: RegisterResponseBody = serde_json::from_slice(&body)?;
598
599 Ok(RegisterResponse {
600 extension_id,
601 function_name: response.function_name,
602 function_version: response.function_version,
603 handler: response.handler,
604 account_id: response.account_id,
605 })
606}