use std::{
env, io,
num::ParseIntError,
pin::Pin,
task::{Context, Poll},
};
use super::Timestamp;
use tokio::{
io::{AsyncWrite, AsyncWriteExt, Stderr, Stdout},
sync::{mpsc, oneshot},
};
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum OutputFormat {
#[default]
Standard,
TelemetryLogFd,
}
pub struct Sink<T: AsyncWrite + Send + Unpin + 'static> {
writer: T,
format: OutputFormat,
buffer_size: usize,
}
impl<T: AsyncWrite + Send + Unpin + 'static> Sink<T> {
pub fn new(writer: T) -> Self {
Sink {
writer,
format: OutputFormat::Standard,
buffer_size: 16,
}
}
pub fn format(mut self, kind: OutputFormat) -> Self {
self.format = kind;
self
}
pub fn buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self
}
pub fn spawn(self) -> SinkHandle {
let (action_tx, mut action_rx) = mpsc::channel(self.buffer_size);
let mut writer = self.writer;
tokio::spawn(async move {
while let Some(action) = action_rx.recv().await {
match action {
Action::Write(bytes) => writer.write_all(&bytes).await.unwrap(),
Action::Flush(ack_tx) => {
writer.flush().await.unwrap();
ack_tx.send(()).unwrap();
}
}
}
});
SinkHandle {
action_tx,
format: self.format,
}
}
}
impl Sink<Stdout> {
pub fn stdout() -> Self {
Sink::new(tokio::io::stdout())
}
}
impl Sink<Stderr> {
pub fn stderr() -> Self {
Sink::new(tokio::io::stderr())
}
}
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)]
pub struct MockSink;
impl AsyncWrite for MockSink {
fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(0))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl Sink<MockSink> {
pub fn mock() -> Self {
Sink::new(MockSink)
}
}
#[cfg(target_os = "linux")]
impl Sink<tokio::fs::File> {
pub fn lambda_telemetry_log_fd() -> Result<Self, Error> {
env::var("_LAMBDA_TELEMETRY_LOG_FD")
.map_err(Error::VarError)
.and_then(|fd| {
let fd = fd.parse().map_err(Error::ParseIntError)?;
Ok(
Sink::new(unsafe { <tokio::fs::File as std::os::fd::FromRawFd>::from_raw_fd(fd) })
.format(OutputFormat::TelemetryLogFd),
)
})
}
}
enum Action {
Write(Vec<u8>),
Flush(oneshot::Sender<()>),
}
#[derive(Clone)]
pub struct SinkHandle {
action_tx: mpsc::Sender<Action>,
format: OutputFormat,
}
impl SinkHandle {
pub async fn write_line(&self, s: String, timestamp: Timestamp) {
let mut line = s.into_bytes();
line.push(b'\n');
let bytes = match self.format {
OutputFormat::Standard => line,
OutputFormat::TelemetryLogFd => {
let mut content = build_telemetry_log_fd_format_header(&line, timestamp.timestamp_micros());
content.append(&mut line);
content
}
};
self.action_tx.send(Action::Write(bytes)).await.unwrap();
}
pub async fn flush(&self) {
let (ack_tx, ack_rx) = oneshot::channel();
self.action_tx.send(Action::Flush(ack_tx)).await.unwrap();
ack_rx.await.unwrap();
}
}
fn build_telemetry_log_fd_format_header(line: &[u8], timestamp: i64) -> Vec<u8> {
let mut buf = vec![0; 16];
buf[0..4].copy_from_slice(&0xa55a0003u32.to_be_bytes());
buf[4..8].copy_from_slice(&(line.len() as u32).to_be_bytes());
buf[8..16].copy_from_slice(×tamp.to_be_bytes());
buf
}
#[derive(Debug)]
pub enum Error {
VarError(env::VarError),
ParseIntError(ParseIntError),
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::DateTime;
use serial_test::serial;
#[test]
fn default_sink_format() {
let builder = Sink::new(Vec::new());
assert_eq!(builder.format, OutputFormat::Standard);
}
#[test]
#[serial]
fn sink_stdout_format() {
let sb = Sink::stdout();
assert_eq!(sb.format, OutputFormat::Standard);
}
#[test]
#[serial]
fn sink_stderr_format() {
let sb = Sink::stderr();
assert_eq!(sb.format, OutputFormat::Standard);
}
#[test]
#[serial]
fn sink_format() {
let sink = Sink::stdout().format(OutputFormat::TelemetryLogFd);
assert_eq!(sink.format, OutputFormat::TelemetryLogFd);
}
#[test]
#[serial]
fn sink_buffer_size() {
let sink = Sink::stdout().buffer_size(32);
assert_eq!(sink.buffer_size, 32);
}
#[cfg(target_os = "linux")]
#[test]
#[serial]
fn sink_lambda_telemetry_log_fd() {
use std::{fs::File, os::fd::IntoRawFd};
let file = File::create("/dev/null").unwrap();
let fd = file.into_raw_fd();
env::set_var("_LAMBDA_TELEMETRY_LOG_FD", fd.to_string());
let sink = Sink::lambda_telemetry_log_fd().unwrap();
assert_eq!(sink.format, OutputFormat::TelemetryLogFd);
env::remove_var("_LAMBDA_TELEMETRY_LOG_FD");
let result = Sink::lambda_telemetry_log_fd();
assert!(matches!(result, Err(Error::VarError(_))));
env::set_var("_LAMBDA_TELEMETRY_LOG_FD", "invalid");
let result = Sink::lambda_telemetry_log_fd();
assert!(matches!(result, Err(Error::ParseIntError(_))));
env::remove_var("_LAMBDA_TELEMETRY_LOG_FD");
}
#[tokio::test]
async fn sink_write_line() {
Sink::new(tokio_test::io::Builder::new().write(b"hello\n").build())
.spawn()
.write_line("hello".to_string(), mock_timestamp())
.await;
assert_eq!(
build_telemetry_log_fd_format_header("hello\n".as_bytes(), 0),
&[0xa5, 0x5a, 0x00, 0x03, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0]
);
Sink::new(
tokio_test::io::Builder::new()
.write(b"\xa5\x5a\x00\x03\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00\x00hello\n")
.build(),
)
.format(OutputFormat::TelemetryLogFd)
.spawn()
.write_line("hello".to_string(), mock_timestamp())
.await;
}
#[tokio::test]
async fn sink_flush() {
let sink = Sink::new(tokio_test::io::Builder::new().write(b"hello\n").build()).spawn();
sink.write_line("hello".to_string(), mock_timestamp()).await;
sink.flush().await;
}
#[tokio::test]
async fn sink_handle_clone_able() {
let sink = Sink::new(
tokio_test::io::Builder::new()
.write(b"hello\nworld\n")
.build(),
)
.spawn();
let sink2 = sink.clone();
sink.write_line("hello".to_string(), mock_timestamp()).await;
sink2
.write_line("world".to_string(), mock_timestamp())
.await;
}
fn mock_timestamp() -> Timestamp {
DateTime::from_timestamp(0, 0).unwrap()
}
}