grpc 0.9.0-alpha.1

The official Rust implementation of gRPC: a high performance, open source, universal RPC framework.
Documentation
/*
 *
 * Copyright 2026 gRPC authors.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to
 * deal in the Software without restriction, including without limitation the
 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 * sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 * IN THE SOFTWARE.
 *
 */

//! Interceptors providing client-side access to metadata.

use tokio::sync::oneshot;

use crate::client::CallOptions;
use crate::client::InvokeOnce;
use crate::client::RecvStream;
use crate::client::interceptor::Intercept;
use crate::client::interceptor::InterceptOnce;
use crate::core::RequestHeaders;
use crate::metadata::MetadataMap;

/// An interceptor that attaches metadata to outgoing RPC headers.
pub struct AttachHeadersInterceptor {
    md: MetadataMap,
}

impl AttachHeadersInterceptor {
    /// Creates a new interceptor that will attach `md` to the client's outgoing
    /// headers.
    pub fn new(md: MetadataMap) -> Self {
        Self { md }
    }
}

impl<I: InvokeOnce> Intercept<I> for AttachHeadersInterceptor {
    type SendStream = I::SendStream;
    type RecvStream = I::RecvStream;

    async fn intercept(
        &self,
        mut headers: RequestHeaders,
        options: CallOptions,
        next: I,
    ) -> (Self::SendStream, Self::RecvStream) {
        let incoming_meta = headers.metadata_mut();
        incoming_meta.reserve(self.md.len());
        for kv in self.md.iter() {
            match kv {
                crate::metadata::KeyAndValueRef::Ascii(key, value) => {
                    incoming_meta.append(key, value.clone())
                }
                crate::metadata::KeyAndValueRef::Binary(key, value) => {
                    incoming_meta.append_bin(key, value.clone())
                }
            }
        }
        next.invoke_once(headers, options).await
    }
}

/// An interceptor to read the metadata received in the server's headers.
pub struct CaptureHeadersInterceptor {
    tx: oneshot::Sender<MetadataMap>,
}

impl CaptureHeadersInterceptor {
    /// Creates an interceptor and a paired [`oneshot::Receiver`].  When the
    /// interceptor is attached to a call, the server headers' metadata is sent
    /// when it is available.  If the call completes without receiving headers
    /// (e.g. it times out or is a trailers-only response), the matching
    /// [`oneshot::Sender`] will be dropped and the `Receiver` will see an error
    /// instead.
    pub fn new() -> (Self, oneshot::Receiver<MetadataMap>) {
        let (tx, rx) = oneshot::channel();
        (Self { tx }, rx)
    }
}

impl<I: InvokeOnce> InterceptOnce<I> for CaptureHeadersInterceptor {
    type SendStream = I::SendStream;
    type RecvStream = CaptureHeadersRecvStream<I::RecvStream>;

    async fn intercept_once(
        self,
        headers: RequestHeaders,
        options: CallOptions,
        next: I,
    ) -> (Self::SendStream, Self::RecvStream) {
        let (tx, rx) = next.invoke_once(headers, options).await;
        (tx, CaptureHeadersRecvStream::new(rx, self.tx))
    }
}

/// The [`RecvStream`] portion of a [`CaptureHeadersInterceptor`].
pub struct CaptureHeadersRecvStream<R> {
    rx: R,
    tx: Option<oneshot::Sender<MetadataMap>>,
}

impl<R> CaptureHeadersRecvStream<R> {
    fn new(rx: R, tx: oneshot::Sender<MetadataMap>) -> Self {
        Self { rx, tx: Some(tx) }
    }
}

impl<R: RecvStream> RecvStream for CaptureHeadersRecvStream<R> {
    async fn recv(&mut self, msg: &mut dyn super::RecvMessage) -> super::ResponseStreamItem {
        let res = self.rx.recv(msg).await;
        if let super::ResponseStreamItem::Headers(headers) = &res
            && let Some(tx) = self.tx.take()
        {
            _ = tx.send(headers.metadata().clone());
        }
        res
    }
}

/// An interceptor to read the metadata received in the server's trailers.
pub struct CaptureTrailersInterceptor {
    tx: oneshot::Sender<MetadataMap>,
}

impl CaptureTrailersInterceptor {
    /// Creates an interceptor and a paired [`oneshot::Receiver`].  When the
    /// interceptor is attached to a call, the server trailers' metadata is sent
    /// when it is available.  If the call is terminated before trailers are
    /// received, the matching [`oneshot::Sender`] will be dropped, causing the
    /// `Receiver` to error.
    pub fn new() -> (Self, oneshot::Receiver<MetadataMap>) {
        let (tx, rx) = oneshot::channel();
        (Self { tx }, rx)
    }
}

impl<I: InvokeOnce> InterceptOnce<I> for CaptureTrailersInterceptor {
    type SendStream = I::SendStream;
    type RecvStream = CaptureTrailersRecvStream<I::RecvStream>;

    async fn intercept_once(
        self,
        headers: RequestHeaders,
        options: CallOptions,
        next: I,
    ) -> (Self::SendStream, Self::RecvStream) {
        let (tx, rx) = next.invoke_once(headers, options).await;
        (tx, CaptureTrailersRecvStream::new(rx, self.tx))
    }
}

/// The [`RecvStream`] portion of a [`CaptureTrailersInterceptor`].
pub struct CaptureTrailersRecvStream<R> {
    rx: R,
    tx: Option<oneshot::Sender<MetadataMap>>,
}

impl<R> CaptureTrailersRecvStream<R> {
    fn new(rx: R, tx: oneshot::Sender<MetadataMap>) -> Self {
        Self { rx, tx: Some(tx) }
    }
}

impl<R: RecvStream> RecvStream for CaptureTrailersRecvStream<R> {
    async fn recv(&mut self, msg: &mut dyn super::RecvMessage) -> super::ResponseStreamItem {
        let res = self.rx.recv(msg).await;
        if let super::ResponseStreamItem::Trailers(trailers) = &res
            && let Some(tx) = self.tx.take()
        {
            _ = tx.send(trailers.metadata().clone());
        }
        res
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::client::ResponseStreamItem;
    use crate::client::test_util::MockInvoker;
    use crate::client::test_util::NopRecvMessage;
    use crate::core::ResponseHeaders;
    use crate::core::Trailers;
    use crate::metadata::BinaryMetadataValue;

    #[tokio::test]
    async fn test_attach_headers_interceptor() {
        // Create test interceptor with metadata to attach.
        let mut md = MetadataMap::new();
        md.insert("x-test-header", "test-value".parse().unwrap());
        md.insert_bin(
            "x-test-header-bin",
            BinaryMetadataValue::from_bytes(b"test-bin"),
        );
        let interceptor = AttachHeadersInterceptor::new(md);

        // Call the interceptor with additional headers in place.
        let (invoker, _) = MockInvoker::new();
        let mut initial_headers = RequestHeaders::default();
        initial_headers
            .metadata_mut()
            .insert("x-initial-header", "initial".parse().unwrap());
        let _ = interceptor
            .intercept(initial_headers, CallOptions::default(), &invoker)
            .await;

        // Verify the received headers include all values.
        let final_headers = invoker.req_headers.lock().unwrap().take().unwrap();
        assert_eq!(
            final_headers.metadata().get("x-test-header").unwrap(),
            "test-value"
        );
        assert_eq!(
            final_headers
                .metadata()
                .get_bin("x-test-header-bin")
                .unwrap(),
            b"test-bin".as_slice()
        );
        assert_eq!(
            final_headers.metadata().get("x-initial-header").unwrap(),
            "initial"
        );
    }

    #[tokio::test]
    async fn test_capture_headers_interceptor() {
        // Create test interceptor.
        let (interceptor, rx) = CaptureHeadersInterceptor::new();

        // Start a call through the interceptor.
        let (invoker, mut controller) = MockInvoker::new();
        let (_, mut recv_stream) = interceptor
            .intercept_once(RequestHeaders::default(), CallOptions::default(), &invoker)
            .await;

        // Send a Headers response on the call.
        let mut resp_md = MetadataMap::new();
        resp_md.insert("x-resp-header", "resp-value".parse().unwrap());
        let mut headers = ResponseHeaders::default();
        *headers.metadata_mut() = resp_md;
        controller
            .send_resp(ResponseStreamItem::Headers(headers))
            .await;

        // Receive the sent Headers response.
        let res = recv_stream.recv(&mut NopRecvMessage).await;
        assert!(matches!(res, ResponseStreamItem::Headers(_)));

        // Verify the received headers are correct.
        let captured_md = rx.await.unwrap();
        assert_eq!(captured_md.get("x-resp-header").unwrap(), "resp-value");
    }

    #[tokio::test]
    async fn test_capture_trailers_interceptor() {
        // Create test interceptor.
        let (interceptor, rx) = CaptureTrailersInterceptor::new();

        // Start a call through the interceptor.
        let (invoker, mut controller) = MockInvoker::new();
        let (_, mut recv_stream) = interceptor
            .intercept_once(RequestHeaders::default(), CallOptions::default(), &invoker)
            .await;

        // Send a Trailers response on the call.
        let mut trailers_md = MetadataMap::new();
        trailers_md.insert("x-trailer", "trailer-value".parse().unwrap());
        let mut trailers = Trailers::new(Ok(()));
        *trailers.metadata_mut() = trailers_md;
        controller
            .send_resp(ResponseStreamItem::Trailers(trailers))
            .await;

        // Receive the sent Trailers response.
        let res = recv_stream.recv(&mut NopRecvMessage).await;
        assert!(matches!(res, ResponseStreamItem::Trailers(_)));

        // Verify the received trailers are correct.
        let captured_md = rx.await.unwrap();
        assert_eq!(captured_md.get("x-trailer").unwrap(), "trailer-value");
    }
}