use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context as TaskContext, Poll},
time::Instant,
};
use http::Request;
use tower::{Layer, Service};
use crate::metrics::ServiceMetrics;
#[derive(Debug, Clone)]
pub struct InstrumentationLayer {
metrics: Arc<ServiceMetrics>,
service_name: String,
}
impl InstrumentationLayer {
pub fn new(metrics: Arc<ServiceMetrics>, service_name: impl Into<String>) -> Self {
Self {
metrics,
service_name: service_name.into(),
}
}
}
impl<S> Layer<S> for InstrumentationLayer {
type Service = InstrumentationService<S>;
fn layer(&self, inner: S) -> Self::Service {
InstrumentationService {
inner,
metrics: Arc::clone(&self.metrics),
service_name: self.service_name.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct InstrumentationService<S> {
inner: S,
metrics: Arc<ServiceMetrics>,
service_name: String,
}
impl<S, B> Service<Request<B>> for InstrumentationService<S>
where
S: Service<Request<B>> + Clone + 'static,
S::Future: Send + 'static,
S::Response: Send + 'static,
S::Error: std::fmt::Display + Send + 'static,
B: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = InstrumentationFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let metrics = Arc::clone(&self.metrics);
let service_name = self.service_name.clone();
let path = req.uri().path();
let (rpc_service, method) = parse_rpc_path(path);
let full_method = format!("{}.{}", rpc_service, method);
metrics
.active_requests
.with_label_values(&[&service_name, &full_method])
.inc();
let start = Instant::now();
let span = tracing::span!(
target: "sunbeam_g2v",
tracing::Level::INFO,
"rpc",
);
span.record("service", &service_name);
span.record("rpc_service", &rpc_service);
span.record("method", &method);
let inner_fut = self.inner.clone().call(req);
InstrumentationFuture {
inner_fut: Box::pin(inner_fut),
metrics,
service_name,
rpc_service,
method,
start,
span,
}
}
}
pub struct InstrumentationFuture<F> {
inner_fut: Pin<Box<F>>,
metrics: Arc<ServiceMetrics>,
service_name: String,
rpc_service: String,
method: String,
start: Instant,
span: tracing::Span,
}
impl<F, T, E> Future for InstrumentationFuture<F>
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: Send + 'static,
E: std::fmt::Display + Send + 'static,
{
type Output = Result<T, E>;
fn poll(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
let this = &mut *self;
match this.inner_fut.as_mut().poll(cx) {
Poll::Ready(result) => {
let duration = this.start.elapsed();
let full_method = format!("{}.{}", this.rpc_service, this.method);
this.metrics
.active_requests
.with_label_values(&[&this.service_name, &full_method])
.dec();
this.metrics
.request_duration
.with_label_values(&[&this.service_name, &full_method])
.observe(duration.as_secs_f64());
this.metrics
.requests_total
.with_label_values(&[&this.service_name, &full_method])
.inc();
match &result {
Ok(_) => {
this.metrics
.success_count
.with_label_values(&[&this.service_name, &full_method])
.inc();
this.span.record("status", "ok");
}
Err(e) => {
let error_str = e.to_string();
let error_code = extract_error_code(&error_str);
this.metrics
.error_count
.with_label_values(&[&this.service_name, &full_method, error_code])
.inc();
this.span.record("status", "error");
this.span.record("error", error_str);
}
}
this.span.record("duration_secs", duration.as_secs_f64());
Poll::Ready(result)
}
Poll::Pending => Poll::Pending,
}
}
}
fn parse_rpc_path(path: &str) -> (String, String) {
let trimmed = path.trim_start_matches('/');
if let Some(pos) = trimmed.rfind('/') {
let service = &trimmed[..pos];
let method = &trimmed[pos + 1..];
(service.to_string(), method.to_string())
} else {
(trimmed.to_string(), "unknown".to_string())
}
}
fn extract_error_code(error_str: &str) -> &'static str {
if error_str.contains("InvalidArgument") {
"InvalidArgument"
} else if error_str.contains("NotFound") {
"NotFound"
} else if error_str.contains("AlreadyExists") {
"AlreadyExists"
} else if error_str.contains("PermissionDenied") {
"PermissionDenied"
} else if error_str.contains("Unauthenticated") {
"Unauthenticated"
} else if error_str.contains("Internal") {
"Internal"
} else if error_str.contains("Unavailable") {
"Unavailable"
} else if error_str.contains("Unimplemented") {
"Unimplemented"
} else if error_str.contains("DeadlineExceeded") {
"DeadlineExceeded"
} else {
"Unknown"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_rpc_path() {
assert_eq!(
parse_rpc_path("/mypackage.MyService/MyMethod"),
("mypackage.MyService".to_string(), "MyMethod".to_string())
);
assert_eq!(
parse_rpc_path("/Service/Method"),
("Service".to_string(), "Method".to_string())
);
assert_eq!(
parse_rpc_path("/Single"),
("Single".to_string(), "unknown".to_string())
);
}
#[test]
fn test_extract_error_code() {
assert_eq!(extract_error_code("InvalidArgument: bad input"), "InvalidArgument");
assert_eq!(extract_error_code("NotFound: resource missing"), "NotFound");
assert_eq!(extract_error_code("Internal: server error"), "Internal");
assert_eq!(extract_error_code("Unknown error"), "Unknown");
}
}