use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::future::BoxFuture;
use http_body::Frame;
use tonic::body::Body;
use tower::Service;
#[derive(Clone)]
pub struct ResolveStatusMiddleware<S> {
inner: S,
rebuild: Arc<dyn Fn() + Send + Sync>,
header_name: http::HeaderName,
last_seen: Arc<Mutex<Option<String>>>,
}
impl<S> ResolveStatusMiddleware<S> {
pub fn new<F>(inner: S, header_name: http::HeaderName, rebuild: F) -> Self
where
F: Fn() + Send + Sync + 'static,
{
Self {
inner,
rebuild: Arc::new(rebuild),
header_name,
last_seen: Arc::new(Mutex::new(None)),
}
}
}
#[derive(Debug, PartialEq, Eq)]
enum DedupAction {
None,
StoreAndRebuild(String),
RebuildKeepLast,
Reset,
}
fn classify(observed: Option<&str>, last_seen: Option<&str>) -> DedupAction {
match observed {
None => {
if last_seen.is_none() {
DedupAction::None
} else {
DedupAction::Reset
}
}
Some("") => DedupAction::RebuildKeepLast,
Some(v) => match last_seen {
Some(prev) if prev == v => DedupAction::None,
_ => DedupAction::StoreAndRebuild(v.to_string()),
},
}
}
fn apply_dedup(
observed: Option<&str>,
last_seen: &Mutex<Option<String>>,
rebuild: &(dyn Fn() + Send + Sync),
) {
let (action, prev, should_rebuild) = {
let mut guard = last_seen.lock().expect("middleware mutex poisoned");
let prev: Option<String> = guard.clone();
let action = classify(observed, prev.as_deref());
let should_rebuild = match &action {
DedupAction::None => false,
DedupAction::Reset => {
*guard = None;
false
}
DedupAction::RebuildKeepLast => true,
DedupAction::StoreAndRebuild(v) => {
*guard = Some(v.clone());
true
}
};
(action, prev, should_rebuild)
};
if should_rebuild {
tracing::info!(
observed = observed.unwrap_or("<absent>"),
previous = prev.as_deref().unwrap_or("<none>"),
decision = ?action,
"ResolveStatusMiddleware firing channel rebuild",
);
rebuild();
}
}
impl<S> Service<http::Request<Body>> for ResolveStatusMiddleware<S>
where
S: Service<
http::Request<Body>,
Response = http::Response<Body>,
Error = tonic::transport::Error,
> + Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = http::Response<Body>;
type Error = tonic::transport::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<Body>) -> Self::Future {
let mut inner = self.inner.clone();
std::mem::swap(&mut inner, &mut self.inner);
let header = self.header_name.clone();
let rebuild = self.rebuild.clone();
let last_seen = self.last_seen.clone();
Box::pin(async move {
let resp = inner.call(req).await?;
let (parts, body) = resp.into_parts();
let wrapped = match parts.headers.get(&header).and_then(|v| v.to_str().ok()) {
Some(v) => {
apply_dedup(Some(v), &last_seen, &*rebuild);
body
}
None => Body::new(TrailerObserver::new(body, header, last_seen, rebuild)),
};
Ok(http::Response::from_parts(parts, wrapped))
})
}
}
struct TrailerObserver {
inner: Body,
header: http::HeaderName,
last_seen: Arc<Mutex<Option<String>>>,
rebuild: Arc<dyn Fn() + Send + Sync>,
saw_trailers: bool,
fired: bool,
}
impl TrailerObserver {
fn new(
inner: Body,
header: http::HeaderName,
last_seen: Arc<Mutex<Option<String>>>,
rebuild: Arc<dyn Fn() + Send + Sync>,
) -> Self {
Self {
inner,
header,
last_seen,
rebuild,
saw_trailers: false,
fired: false,
}
}
fn fire(&mut self, observed: Option<&str>) {
if self.fired {
return;
}
self.fired = true;
apply_dedup(observed, &self.last_seen, &*self.rebuild);
}
}
impl http_body::Body for TrailerObserver {
type Data = Bytes;
type Error = tonic::Status;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.as_mut().get_mut();
let polled = Pin::new(&mut this.inner).poll_frame(cx);
match &polled {
Poll::Ready(Some(Ok(frame))) if frame.is_trailers() => {
if let Some(map) = frame.trailers_ref() {
let observed = map.get(&this.header).and_then(|v| v.to_str().ok());
this.saw_trailers = true;
this.fire(observed);
}
}
Poll::Ready(None) if !this.saw_trailers => {
this.fire(None);
}
_ => {}
}
polled
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dedup_none_to_value_rebuilds() {
assert_eq!(
classify(Some("not-primary"), None),
DedupAction::StoreAndRebuild("not-primary".to_string()),
);
}
#[test]
fn dedup_same_value_is_noop() {
assert_eq!(classify(Some("v"), Some("v")), DedupAction::None);
}
#[test]
fn dedup_different_value_rebuilds() {
assert_eq!(
classify(Some("w"), Some("v")),
DedupAction::StoreAndRebuild("w".to_string()),
);
}
#[test]
fn dedup_no_trailer_resets_when_seen() {
assert_eq!(classify(None, Some("v")), DedupAction::Reset);
}
#[test]
fn dedup_no_trailer_steady_state_is_noop() {
assert_eq!(classify(None, None), DedupAction::None);
}
#[test]
fn dedup_empty_always_rebuilds_without_state_change() {
assert_eq!(classify(Some(""), None), DedupAction::RebuildKeepLast);
assert_eq!(classify(Some(""), Some("v")), DedupAction::RebuildKeepLast);
}
}