use std::time::Instant;
use tracing::debug;
use crate::{
BoxMiddlewareFuture, ClientCallOutcome, ClientContext, ClientMiddleware, ClientRequest,
MetadataEntry, MetadataFlags, MetadataValue,
};
#[derive(Debug, Clone, Default)]
pub struct ClientLoggingOptions {
pub log_metadata: bool,
}
#[derive(Debug, Clone, Default)]
pub struct ClientLogging {
options: ClientLoggingOptions,
}
impl ClientLogging {
pub fn new(options: ClientLoggingOptions) -> Self {
Self { options }
}
pub fn with_metadata(mut self, log_metadata: bool) -> Self {
self.options.log_metadata = log_metadata;
self
}
}
impl ClientMiddleware for ClientLogging {
fn pre<'a, 'call>(
&'a self,
context: &'a ClientContext<'a>,
request: &'a mut ClientRequest<'call, 'a>,
) -> BoxMiddlewareFuture<'a> {
Box::pin(async move {
context.extensions().insert(RequestStart(Instant::now()));
let method = context.method();
if self.options.log_metadata {
debug!(
target: "vox::client",
service = method.map(|method| method.service_name),
method = method.map(|method| method.method_name),
method_id = %context.method_id(),
metadata = ?RedactedMetadata(request.metadata()),
"rpc request"
);
} else {
debug!(
target: "vox::client",
service = method.map(|method| method.service_name),
method = method.map(|method| method.method_name),
method_id = %context.method_id(),
"rpc request"
);
}
})
}
fn post<'a>(
&'a self,
context: &'a ClientContext<'a>,
outcome: ClientCallOutcome<'a>,
) -> BoxMiddlewareFuture<'a> {
Box::pin(async move {
let method = context.method();
let duration_ms = context
.extensions()
.with::<RequestStart, _>(|start| start.0.elapsed().as_secs_f64() * 1_000.0);
match outcome {
ClientCallOutcome::Response => {
debug!(
target: "vox::client",
service = method.map(|method| method.service_name),
method = method.map(|method| method.method_name),
method_id = %context.method_id(),
duration_ms,
outcome = "response",
"rpc response"
);
}
ClientCallOutcome::Error(error) => {
debug!(
target: "vox::client",
service = method.map(|method| method.service_name),
method = method.map(|method| method.method_name),
method_id = %context.method_id(),
duration_ms,
error = ?error,
outcome = "error",
"rpc response"
);
}
}
})
}
}
#[derive(Debug)]
struct RequestStart(Instant);
#[derive(Clone, Copy)]
struct RedactedMetadata<'a>(&'a [MetadataEntry<'a>]);
impl std::fmt::Debug for RedactedMetadata<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut entries = f.debug_list();
for entry in self.0 {
entries.entry(&MetadataEntryDebug(entry));
}
entries.finish()
}
}
struct MetadataEntryDebug<'a>(&'a MetadataEntry<'a>);
impl std::fmt::Debug for MetadataEntryDebug<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let entry = self.0;
let mut debug = f.debug_struct("MetadataEntry");
debug.field("key", &entry.key);
if entry.flags.contains(MetadataFlags::SENSITIVE) {
debug.field("value", &"[REDACTED]");
} else {
debug.field("value", &MetadataValueDebug(&entry.value));
}
debug.field("flags", &entry.flags);
debug.finish()
}
}
struct MetadataValueDebug<'a>(&'a MetadataValue<'a>);
impl std::fmt::Debug for MetadataValueDebug<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0 {
MetadataValue::String(value) => value.fmt(f),
MetadataValue::Bytes(bytes) => write!(f, "<{} bytes>", bytes.len()),
MetadataValue::U64(value) => value.fmt(f),
}
}
}
#[cfg(test)]
mod tests {
use std::{
io,
sync::{Arc, Mutex},
};
use tracing_subscriber::{
Layer,
filter::LevelFilter,
fmt::{self, MakeWriter},
layer::SubscriberExt,
};
use super::{ClientLogging, ClientLoggingOptions, RedactedMetadata};
use crate::{
Caller, MetadataEntry, MetadataFlags, MetadataValue, MethodDescriptor, MethodId,
MiddlewareCaller, Payload, RequestCall, ServiceDescriptor, VoxError,
};
use vox_types::CallResult;
#[test]
fn metadata_debug_redacts_sensitive_values() {
let metadata = vec![
MetadataEntry {
key: "authorization",
value: MetadataValue::String("Bearer secret"),
flags: MetadataFlags::SENSITIVE,
},
MetadataEntry {
key: "blob",
value: MetadataValue::Bytes(&[1, 2, 3]),
flags: MetadataFlags::NONE,
},
];
assert_eq!(
format!("{:?}", RedactedMetadata(&metadata)),
"[MetadataEntry { key: \"authorization\", value: \"[REDACTED]\", flags: MetadataFlags(1) }, MetadataEntry { key: \"blob\", value: <3 bytes>, flags: MetadataFlags(0) }]"
);
}
#[tokio::test]
async fn client_logging_emits_redacted_request_and_response_logs() {
let writer = SharedWriter::default();
let subscriber = tracing_subscriber::registry().with(
fmt::layer()
.without_time()
.with_ansi(false)
.with_writer(writer.clone())
.with_filter(LevelFilter::DEBUG),
);
let _guard = tracing::subscriber::set_default(subscriber);
static METHOD: MethodDescriptor = MethodDescriptor {
id: MethodId(7),
service_name: "Audit",
method_name: "record",
args_shape: <() as facet::Facet<'static>>::SHAPE,
args: &[],
return_shape: <() as facet::Facet<'static>>::SHAPE,
retry: crate::RetryPolicy::VOLATILE,
doc: None,
};
static SERVICE: ServiceDescriptor = ServiceDescriptor {
service_name: "Audit",
methods: &[&METHOD],
doc: None,
};
let logging = ClientLogging::new(ClientLoggingOptions { log_metadata: true });
let caller =
MiddlewareCaller::new(AlwaysCancelledCaller, &SERVICE).with_middleware(logging);
let _ = caller
.call(RequestCall {
method_id: MethodId(7),
metadata: vec![
MetadataEntry {
key: "authorization",
value: MetadataValue::String("Bearer secret"),
flags: MetadataFlags::SENSITIVE,
},
MetadataEntry {
key: "attempt",
value: MetadataValue::U64(2),
flags: MetadataFlags::NONE,
},
],
args: Payload::PostcardBytes(&[]),
schemas: Default::default(),
})
.await;
let output = writer.output();
assert!(output.contains("rpc request"));
assert!(output.contains("rpc response"));
assert!(output.contains("authorization"));
assert!(output.contains("[REDACTED]"));
assert!(!output.contains("Bearer secret"));
assert!(output.contains("attempt"));
assert!(output.contains("Cancelled"));
}
#[derive(Clone)]
struct AlwaysCancelledCaller;
impl Caller for AlwaysCancelledCaller {
async fn call<'a>(&'a self, _call: RequestCall<'a>) -> CallResult {
Err(VoxError::Cancelled)
}
}
#[derive(Clone, Default)]
struct SharedWriter {
output: Arc<Mutex<Vec<u8>>>,
}
impl SharedWriter {
fn output(&self) -> String {
let bytes = self.output.lock().expect("shared writer mutex poisoned");
String::from_utf8(bytes.clone()).expect("log output should be utf-8")
}
}
impl<'a> MakeWriter<'a> for SharedWriter {
type Writer = SharedWriterGuard;
fn make_writer(&'a self) -> Self::Writer {
SharedWriterGuard {
output: Arc::clone(&self.output),
}
}
}
struct SharedWriterGuard {
output: Arc<Mutex<Vec<u8>>>,
}
impl io::Write for SharedWriterGuard {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.output
.lock()
.expect("shared writer mutex poisoned")
.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
}