1use http::Request;
2use http_body_util::BodyExt;
3use hyper::{body::Incoming, server::conn::http1, service::service_fn};
4
5use hyper_util::rt::tokio::TokioIo;
6use lambda_runtime_api_client::Client;
7use serde::Deserialize;
8use std::{
9 convert::Infallible,
10 fmt,
11 future::{ready, Future},
12 net::SocketAddr,
13 path::PathBuf,
14 pin::Pin,
15 sync::Arc,
16};
17use tokio::{net::TcpListener, sync::Mutex};
18use tokio_stream::StreamExt;
19use tower::{MakeService, Service, ServiceExt};
20use tracing::trace;
21
22use crate::{
23 logs::*,
24 requests::{self, Api},
25 telemetry_wrapper, Error, ExtensionError, LambdaEvent, LambdaTelemetry, NextEvent,
26};
27
28const DEFAULT_LOG_PORT_NUMBER: u16 = 9002;
29const DEFAULT_TELEMETRY_PORT_NUMBER: u16 = 9003;
30
31pub struct Extension<'a, E, L, T> {
33 extension_name: Option<&'a str>,
34 events: Option<&'a [&'a str]>,
35 events_processor: E,
36 log_types: Option<&'a [&'a str]>,
37 logs_processor: Option<L>,
38 log_buffering: Option<LogBuffering>,
39 log_port_number: u16,
40 telemetry_types: Option<&'a [&'a str]>,
41 telemetry_processor: Option<T>,
42 telemetry_buffering: Option<LogBuffering>,
43 telemetry_port_number: u16,
44}
45
46impl Extension<'_, Identity<LambdaEvent>, MakeIdentity<Vec<LambdaLog>>, MakeIdentity<Vec<LambdaTelemetry>>> {
47 pub fn new() -> Self {
49 Extension {
50 extension_name: None,
51 events: None,
52 events_processor: Identity::new(),
53 log_types: None,
54 log_buffering: None,
55 logs_processor: None,
56 log_port_number: DEFAULT_LOG_PORT_NUMBER,
57 telemetry_types: None,
58 telemetry_buffering: None,
59 telemetry_processor: None,
60 telemetry_port_number: DEFAULT_TELEMETRY_PORT_NUMBER,
61 }
62 }
63}
64
65impl Default
66 for Extension<'_, Identity<LambdaEvent>, MakeIdentity<Vec<LambdaLog>>, MakeIdentity<Vec<LambdaTelemetry>>>
67{
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl<'a, E, L, T> Extension<'a, E, L, T>
74where
75 E: Service<LambdaEvent>,
76 E::Future: Future<Output = Result<(), E::Error>>,
77 E::Error: Into<Error> + fmt::Display + fmt::Debug,
78
79 L: MakeService<(), Vec<LambdaLog>, Response = ()> + Send + Sync + 'static,
81 L::Service: Service<Vec<LambdaLog>, Response = ()> + Send + Sync,
82 <L::Service as Service<Vec<LambdaLog>>>::Future: Send + 'a,
83 L::Error: Into<Error> + fmt::Debug,
84 L::MakeError: Into<Error> + fmt::Debug,
85 L::Future: Send,
86
87 T: MakeService<(), Vec<LambdaTelemetry>, Response = ()> + Send + Sync + 'static,
89 T::Service: Service<Vec<LambdaTelemetry>, Response = ()> + Send + Sync,
90 <T::Service as Service<Vec<LambdaTelemetry>>>::Future: Send + 'a,
91 T::Error: Into<Error> + fmt::Debug,
92 T::MakeError: Into<Error> + fmt::Debug,
93 T::Future: Send,
94{
95 pub fn with_extension_name(self, extension_name: &'a str) -> Self {
97 Extension {
98 extension_name: Some(extension_name),
99 ..self
100 }
101 }
102
103 pub fn with_events(self, events: &'a [&'a str]) -> Self {
106 Extension {
107 events: Some(events),
108 ..self
109 }
110 }
111
112 pub fn with_events_processor<N>(self, ep: N) -> Extension<'a, N, L, T>
114 where
115 N: Service<LambdaEvent>,
116 N::Future: Future<Output = Result<(), N::Error>>,
117 N::Error: Into<Error> + fmt::Display,
118 {
119 Extension {
120 events_processor: ep,
121 extension_name: self.extension_name,
122 events: self.events,
123 log_types: self.log_types,
124 log_buffering: self.log_buffering,
125 logs_processor: self.logs_processor,
126 log_port_number: self.log_port_number,
127 telemetry_types: self.telemetry_types,
128 telemetry_buffering: self.telemetry_buffering,
129 telemetry_processor: self.telemetry_processor,
130 telemetry_port_number: self.telemetry_port_number,
131 }
132 }
133
134 pub fn with_logs_processor<N, NS>(self, lp: N) -> Extension<'a, E, N, T>
136 where
137 N: Service<()>,
138 N::Future: Future<Output = Result<NS, N::Error>>,
139 N::Error: Into<Error> + fmt::Display,
140 {
141 Extension {
142 logs_processor: Some(lp),
143 events_processor: self.events_processor,
144 extension_name: self.extension_name,
145 events: self.events,
146 log_types: self.log_types,
147 log_buffering: self.log_buffering,
148 log_port_number: self.log_port_number,
149 telemetry_types: self.telemetry_types,
150 telemetry_buffering: self.telemetry_buffering,
151 telemetry_processor: self.telemetry_processor,
152 telemetry_port_number: self.telemetry_port_number,
153 }
154 }
155
156 pub fn with_log_types(self, log_types: &'a [&'a str]) -> Self {
159 Extension {
160 log_types: Some(log_types),
161 ..self
162 }
163 }
164
165 pub fn with_log_buffering(self, lb: LogBuffering) -> Self {
167 Extension {
168 log_buffering: Some(lb),
169 ..self
170 }
171 }
172
173 pub fn with_log_port_number(self, port_number: u16) -> Self {
175 Extension {
176 log_port_number: port_number,
177 ..self
178 }
179 }
180
181 pub fn with_telemetry_processor<N, NS>(self, lp: N) -> Extension<'a, E, L, N>
183 where
184 N: Service<()>,
185 N::Future: Future<Output = Result<NS, N::Error>>,
186 N::Error: Into<Error> + fmt::Display,
187 {
188 Extension {
189 telemetry_processor: Some(lp),
190 events_processor: self.events_processor,
191 extension_name: self.extension_name,
192 events: self.events,
193 log_types: self.log_types,
194 log_buffering: self.log_buffering,
195 logs_processor: self.logs_processor,
196 log_port_number: self.log_port_number,
197 telemetry_types: self.telemetry_types,
198 telemetry_buffering: self.telemetry_buffering,
199 telemetry_port_number: self.telemetry_port_number,
200 }
201 }
202
203 pub fn with_telemetry_types(self, telemetry_types: &'a [&'a str]) -> Self {
206 Extension {
207 telemetry_types: Some(telemetry_types),
208 ..self
209 }
210 }
211
212 pub fn with_telemetry_buffering(self, lb: LogBuffering) -> Self {
214 Extension {
215 telemetry_buffering: Some(lb),
216 ..self
217 }
218 }
219
220 pub fn with_telemetry_port_number(self, port_number: u16) -> Self {
222 Extension {
223 telemetry_port_number: port_number,
224 ..self
225 }
226 }
227
228 pub async fn register(self) -> Result<RegisteredExtension<E>, Error> {
236 let client = &Client::builder().build()?;
237
238 let register_res = register(client, self.extension_name, self.events).await?;
239
240 if let Some(mut log_processor) = self.logs_processor {
243 trace!("Log processor found");
244
245 validate_buffering_configuration(self.log_buffering)?;
246
247 let addr = SocketAddr::from(([0, 0, 0, 0], self.log_port_number));
248 let service = log_processor.make_service(());
249 let service = Arc::new(Mutex::new(service.await.unwrap()));
250 tokio::task::spawn(async move {
251 trace!("Creating new logs processor Service");
252
253 loop {
254 let service: Arc<Mutex<_>> = service.clone();
255 let make_service = service_fn(move |req: Request<Incoming>| log_wrapper(service.clone(), req));
256
257 let listener = TcpListener::bind(addr).await.unwrap();
258 let (tcp, _) = listener.accept().await.unwrap();
259 let io = TokioIo::new(tcp);
260 tokio::task::spawn(async move {
261 if let Err(err) = http1::Builder::new().serve_connection(io, make_service).await {
262 println!("Error serving connection: {:?}", err);
263 }
264 });
265 }
266 });
267
268 trace!("Log processor started");
269
270 let req = requests::subscribe_request(
272 Api::LogsApi,
273 ®ister_res.extension_id,
274 self.log_types,
275 self.log_buffering,
276 self.log_port_number,
277 )?;
278 let res = client.call(req).await?;
279 if !res.status().is_success() {
280 let err = format!("unable to initialize the logs api: {}", res.status());
281 return Err(ExtensionError::boxed(err));
282 }
283 trace!("Registered extension with Logs API");
284 }
285
286 if let Some(mut telemetry_processor) = self.telemetry_processor {
289 trace!("Telemetry processor found");
290
291 validate_buffering_configuration(self.telemetry_buffering)?;
292
293 let addr = SocketAddr::from(([0, 0, 0, 0], self.telemetry_port_number));
294 let service = telemetry_processor.make_service(());
295 let service = Arc::new(Mutex::new(service.await.unwrap()));
296 tokio::task::spawn(async move {
297 trace!("Creating new telemetry processor Service");
298
299 loop {
300 let service = service.clone();
301 let make_service = service_fn(move |req| telemetry_wrapper(service.clone(), req));
302
303 let listener = TcpListener::bind(addr).await.unwrap();
304 let (tcp, _) = listener.accept().await.unwrap();
305 let io = TokioIo::new(tcp);
306 tokio::task::spawn(async move {
307 if let Err(err) = http1::Builder::new().serve_connection(io, make_service).await {
308 println!("Error serving connection: {:?}", err);
309 }
310 });
311 }
312 });
313
314 trace!("Telemetry processor started");
315
316 let req = requests::subscribe_request(
318 Api::TelemetryApi,
319 ®ister_res.extension_id,
320 self.telemetry_types,
321 self.telemetry_buffering,
322 self.telemetry_port_number,
323 )?;
324 let res = client.call(req).await?;
325 if !res.status().is_success() {
326 let err = format!("unable to initialize the telemetry api: {}", res.status());
327 return Err(ExtensionError::boxed(err));
328 }
329 trace!("Registered extension with Telemetry API");
330 }
331
332 Ok(RegisteredExtension {
333 extension_id: register_res.extension_id,
334 function_name: register_res.function_name,
335 function_version: register_res.function_version,
336 handler: register_res.handler,
337 account_id: register_res.account_id,
338 events_processor: self.events_processor,
339 })
340 }
341
342 pub async fn run(self) -> Result<(), Error> {
344 self.register().await?.run().await
345 }
346}
347
348pub struct RegisteredExtension<E> {
350 pub extension_id: String,
352 pub account_id: Option<String>,
355 pub function_name: String,
357 pub function_version: String,
359 pub handler: String,
361 events_processor: E,
362}
363
364impl<E> RegisteredExtension<E>
365where
366 E: Service<LambdaEvent>,
367 E::Future: Future<Output = Result<(), E::Error>>,
368 E::Error: Into<Box<dyn std::error::Error + Send + Sync>> + fmt::Display + fmt::Debug,
369{
370 pub async fn run(self) -> Result<(), Error> {
378 let client = &Client::builder().build()?;
379 let mut ep = self.events_processor;
380 let extension_id = &self.extension_id;
381
382 let incoming = async_stream::stream! {
383 loop {
384 trace!("Waiting for next event (incoming loop)");
385 let req = requests::next_event_request(extension_id)?;
386 let res = client.call(req).await;
387 yield res;
388 }
389 };
390
391 tokio::pin!(incoming);
392 while let Some(event) = incoming.next().await {
393 trace!("New event arrived (run loop)");
394 let event = event?;
395 let (_parts, body) = event.into_parts();
396
397 let body = body.collect().await?.to_bytes();
398 trace!("{}", std::str::from_utf8(&body)?); let event: NextEvent = serde_json::from_slice(&body)?;
400 let is_invoke = event.is_invoke();
401
402 let event = LambdaEvent::new(event);
403
404 let ep = match ep.ready().await {
405 Ok(ep) => ep,
406 Err(err) => {
407 println!("Inner service is not ready: {err:?}");
408 let req = if is_invoke {
409 requests::init_error(extension_id, &err.to_string(), None)?
410 } else {
411 requests::exit_error(extension_id, &err.to_string(), None)?
412 };
413
414 client.call(req).await?;
415 return Err(err.into());
416 }
417 };
418
419 let res = ep.call(event).await;
420 if let Err(err) = res {
421 println!("{err:?}");
422 let req = if is_invoke {
423 requests::init_error(extension_id, &err.to_string(), None)?
424 } else {
425 requests::exit_error(extension_id, &err.to_string(), None)?
426 };
427
428 client.call(req).await?;
429 return Err(err.into());
430 }
431 }
432
433 Ok(())
435 }
436}
437
438#[derive(Clone)]
440pub struct Identity<T> {
441 _phantom: std::marker::PhantomData<T>,
442}
443
444impl<T> Identity<T> {
445 fn new() -> Self {
446 Self {
447 _phantom: std::marker::PhantomData,
448 }
449 }
450}
451
452impl<T> Service<T> for Identity<T> {
453 type Error = Infallible;
454 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
455 type Response = ();
456
457 fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
458 core::task::Poll::Ready(Ok(()))
459 }
460
461 fn call(&mut self, _event: T) -> Self::Future {
462 Box::pin(ready(Ok(())))
463 }
464}
465
466#[derive(Clone)]
468pub struct MakeIdentity<T> {
469 _phantom: std::marker::PhantomData<T>,
470}
471
472impl<T> Service<()> for MakeIdentity<T>
473where
474 T: Send + Sync + 'static,
475{
476 type Error = Infallible;
477 type Response = Identity<T>;
478 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
479
480 fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
481 core::task::Poll::Ready(Ok(()))
482 }
483
484 fn call(&mut self, _: ()) -> Self::Future {
485 Box::pin(ready(Ok(Identity::new())))
486 }
487}
488
489#[derive(Debug, Deserialize)]
490#[serde(rename_all = "camelCase")]
491struct RegisterResponseBody {
492 function_name: String,
493 function_version: String,
494 handler: String,
495 account_id: Option<String>,
496}
497
498#[derive(Debug)]
499struct RegisterResponse {
500 extension_id: String,
501 function_name: String,
502 function_version: String,
503 handler: String,
504 account_id: Option<String>,
505}
506
507async fn register<'a>(
509 client: &'a Client,
510 extension_name: Option<&'a str>,
511 events: Option<&'a [&'a str]>,
512) -> Result<RegisterResponse, Error> {
513 let name = match extension_name {
514 Some(name) => name.into(),
515 None => {
516 let args: Vec<String> = std::env::args().collect();
517 PathBuf::from(args[0].clone())
518 .file_name()
519 .expect("unexpected executable name")
520 .to_str()
521 .expect("unexpect executable name")
522 .to_string()
523 }
524 };
525
526 let events = events.unwrap_or(&["INVOKE", "SHUTDOWN"]);
527
528 let req = requests::register_request(&name, events)?;
529 let res = client.call(req).await?;
530 if !res.status().is_success() {
531 let err = format!("unable to register the extension: {}", res.status());
532 return Err(ExtensionError::boxed(err));
533 }
534
535 let header = res
536 .headers()
537 .get(requests::EXTENSION_ID_HEADER)
538 .ok_or_else(|| ExtensionError::boxed("missing extension id header"))
539 .map_err(|e| ExtensionError::boxed(e.to_string()))?;
540 let extension_id = header.to_str()?.to_string();
541
542 let (_, body) = res.into_parts();
543 let body = body.collect().await?.to_bytes();
544 let response: RegisterResponseBody = serde_json::from_slice(&body)?;
545
546 Ok(RegisterResponse {
547 extension_id,
548 function_name: response.function_name,
549 function_version: response.function_version,
550 handler: response.handler,
551 account_id: response.account_id,
552 })
553}