use bytes::Bytes;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant, SystemTime};
use tonic::body::Body;
use tonic::codegen::http::{Request, Response};
use tonic::metadata::MetadataMap;
use tonic::transport::{Channel, Endpoint};
use tower::{Layer, Service, ServiceExt};
#[derive(Debug, Clone)]
pub struct LogRequest {
pub method: String,
pub metadata: MetadataMap,
pub message: Bytes,
}
#[derive(Debug, Clone)]
pub struct LogResponse {
pub metadata: MetadataMap,
pub message: Bytes,
pub status_code: tonic::Code,
pub status_message: String,
pub duration: Duration,
}
#[derive(Debug, Clone)]
pub struct Log {
pub request: LogRequest,
pub response: LogResponse,
pub started_at: SystemTime,
pub ended_at: SystemTime,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("gRPC transport error: {0}")]
Transport(#[from] tonic::transport::Error),
#[error("gRPC status error: {0}")]
Status(#[from] tonic::Status),
#[error("invalid URI: {0}")]
InvalidUri(String),
}
#[derive(Clone, Default)]
pub struct LoggingLayer;
impl LoggingLayer {
pub fn new() -> Self {
Self
}
}
impl<S> Layer<S> for LoggingLayer {
type Service = LoggingService<S>;
fn layer(&self, inner: S) -> Self::Service {
LoggingService { inner }
}
}
#[derive(Clone)]
pub struct LoggingService<S> {
inner: S,
}
impl<S> Service<Request<Body>> for LoggingService<S>
where
S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
S::Error: std::fmt::Debug + Send,
S::Future: Send,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
let started_at = SystemTime::now();
let start = Instant::now();
let method = req.uri().path().to_string();
let request_metadata = extract_metadata_from_headers(req.headers());
let ready_svc = inner.ready().await.map_err(|e| {
tracing::error!("gRPC service not ready: {:?}", e);
e
})?;
let response = ready_svc.call(req).await?;
let ended_at = SystemTime::now();
let duration = start.elapsed();
let response_metadata = extract_metadata_from_headers(response.headers());
let (status_code, status_message) = extract_grpc_status(response.headers());
let log = Log {
request: LogRequest {
method,
metadata: request_metadata,
message: Bytes::new(), },
response: LogResponse {
metadata: response_metadata,
message: Bytes::new(), status_code,
status_message,
duration,
},
started_at,
ended_at,
};
let _ = crate::runner::publish(crate::runner::EventBody::Call(
crate::runner::CallLog::Grpc(Box::new(log)),
));
Ok(response)
})
}
}
pub type LoggingChannel = LoggingService<Channel>;
fn extract_metadata_from_headers(headers: &http::HeaderMap) -> MetadataMap {
use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue};
let mut metadata = MetadataMap::new();
for (key, value) in headers.iter() {
if !key.as_str().starts_with(':') && !key.as_str().ends_with("-bin") {
if let Ok(value_str) = value.to_str() {
if let Ok(name) = key.as_str().parse::<AsciiMetadataKey>() {
if let Ok(val) = value_str.parse::<AsciiMetadataValue>() {
metadata.insert(name, val);
}
}
}
}
}
metadata
}
fn extract_grpc_status(headers: &http::HeaderMap) -> (tonic::Code, String) {
let code = headers
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<i32>().ok())
.map(tonic::Code::from)
.unwrap_or(tonic::Code::Ok);
let message = headers
.get("grpc-message")
.and_then(|v| v.to_str().ok())
.map(|s| urlencoding::decode(s).unwrap_or_default().into_owned())
.unwrap_or_default();
(code, message)
}
pub trait ChannelExt: Sized {
fn with_tanu_logging(self) -> LoggingChannel;
}
impl ChannelExt for Channel {
fn with_tanu_logging(self) -> LoggingChannel {
LoggingLayer::new().layer(self)
}
}
pub async fn connect(endpoint: impl Into<String>) -> Result<LoggingChannel, Error> {
let endpoint =
Endpoint::from_shared(endpoint.into()).map_err(|e| Error::InvalidUri(e.to_string()))?;
let channel = endpoint.connect().await?;
Ok(channel.with_tanu_logging())
}
pub fn format_message(bytes: &Bytes) -> String {
if bytes.is_empty() {
return "<empty>".to_string();
}
if let Ok(s) = std::str::from_utf8(bytes) {
if s.chars()
.all(|c| c.is_ascii_graphic() || c.is_ascii_whitespace())
{
return s.to_string();
}
}
let hex_lines: Vec<String> = bytes
.chunks(16)
.enumerate()
.map(|(i, chunk)| {
let hex: String = chunk.iter().map(|b| format!("{:02x} ", b)).collect();
let ascii: String = chunk
.iter()
.map(|&b| if b.is_ascii_graphic() { b as char } else { '.' })
.collect();
format!("{:04x} {:48} {}", i * 16, hex, ascii)
})
.collect();
hex_lines.join("\n")
}