use crate::metadata::{MetadataMap, MetadataValue};
#[cfg(all(feature = "transport", feature = "tls"))]
use crate::transport::server::TlsConnectInfo;
#[cfg(feature = "transport")]
use crate::transport::{server::TcpConnectInfo, Certificate};
use crate::Extensions;
use futures_core::Stream;
#[cfg(feature = "transport")]
use std::sync::Arc;
use std::{net::SocketAddr, time::Duration};
#[derive(Debug)]
pub struct Request<T> {
metadata: MetadataMap,
message: T,
extensions: Extensions,
}
pub trait IntoRequest<T>: sealed::Sealed {
fn into_request(self) -> Request<T>;
}
pub trait IntoStreamingRequest: sealed::Sealed {
type Stream: Stream<Item = Self::Message> + Send + 'static;
type Message;
fn into_streaming_request(self) -> Request<Self::Stream>;
}
impl<T> Request<T> {
pub fn new(message: T) -> Self {
Request {
metadata: MetadataMap::new(),
message,
extensions: Extensions::new(),
}
}
pub fn get_ref(&self) -> &T {
&self.message
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.message
}
pub fn metadata(&self) -> &MetadataMap {
&self.metadata
}
pub fn metadata_mut(&mut self) -> &mut MetadataMap {
&mut self.metadata
}
pub fn into_inner(self) -> T {
self.message
}
pub fn into_parts(self) -> (MetadataMap, Extensions, T) {
(self.metadata, self.extensions, self.message)
}
pub fn from_parts(metadata: MetadataMap, extensions: Extensions, message: T) -> Self {
Self {
metadata,
extensions,
message,
}
}
pub(crate) fn from_http_parts(parts: http::request::Parts, message: T) -> Self {
Request {
metadata: MetadataMap::from_headers(parts.headers),
message,
extensions: Extensions::from_http(parts.extensions),
}
}
pub fn from_http(http: http::Request<T>) -> Self {
let (parts, message) = http.into_parts();
Request::from_http_parts(parts, message)
}
pub(crate) fn into_http(
self,
uri: http::Uri,
method: http::Method,
version: http::Version,
sanitize_headers: SanitizeHeaders,
) -> http::Request<T> {
let mut request = http::Request::new(self.message);
*request.version_mut() = version;
*request.method_mut() = method;
*request.uri_mut() = uri;
*request.headers_mut() = match sanitize_headers {
SanitizeHeaders::Yes => self.metadata.into_sanitized_headers(),
SanitizeHeaders::No => self.metadata.into_headers(),
};
*request.extensions_mut() = self.extensions.into_http();
request
}
#[doc(hidden)]
pub fn map<F, U>(self, f: F) -> Request<U>
where
F: FnOnce(T) -> U,
{
let message = f(self.message);
Request {
metadata: self.metadata,
message,
extensions: self.extensions,
}
}
pub fn remote_addr(&self) -> Option<SocketAddr> {
#[cfg(feature = "transport")]
{
#[cfg(feature = "tls")]
{
self.extensions()
.get::<TcpConnectInfo>()
.and_then(|i| i.remote_addr())
.or_else(|| {
self.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
.and_then(|i| i.get_ref().remote_addr())
})
}
#[cfg(not(feature = "tls"))]
{
self.extensions()
.get::<TcpConnectInfo>()
.and_then(|i| i.remote_addr())
}
}
#[cfg(not(feature = "transport"))]
{
None
}
}
#[cfg(feature = "transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
#[cfg(feature = "tls")]
{
self.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
.and_then(|i| i.peer_certs())
}
#[cfg(not(feature = "tls"))]
{
None
}
}
pub fn set_timeout(&mut self, deadline: Duration) {
let value: MetadataValue<_> = duration_to_grpc_timeout(deadline).parse().unwrap();
self.metadata_mut()
.insert(crate::metadata::GRPC_TIMEOUT_HEADER, value);
}
pub fn extensions(&self) -> &Extensions {
&self.extensions
}
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
}
impl<T> IntoRequest<T> for T {
fn into_request(self) -> Request<Self> {
Request::new(self)
}
}
impl<T> IntoRequest<T> for Request<T> {
fn into_request(self) -> Request<T> {
self
}
}
impl<T> IntoStreamingRequest for T
where
T: Stream + Send + 'static,
{
type Stream = T;
type Message = T::Item;
fn into_streaming_request(self) -> Request<Self> {
Request::new(self)
}
}
impl<T> IntoStreamingRequest for Request<T>
where
T: Stream + Send + 'static,
{
type Stream = T;
type Message = T::Item;
fn into_streaming_request(self) -> Self {
self
}
}
impl<T> sealed::Sealed for T {}
mod sealed {
pub trait Sealed {}
}
fn duration_to_grpc_timeout(duration: Duration) -> String {
fn try_format<T: Into<u128>>(
duration: Duration,
unit: char,
convert: impl FnOnce(Duration) -> T,
) -> Option<String> {
let max_size: u128 = 99_999_999;
let value = convert(duration).into();
if value > max_size {
None
} else {
Some(format!("{}{}", value, unit))
}
}
try_format(duration, 'n', |d| d.as_nanos())
.or_else(|| try_format(duration, 'u', |d| d.as_micros()))
.or_else(|| try_format(duration, 'm', |d| d.as_millis()))
.or_else(|| try_format(duration, 'S', |d| d.as_secs()))
.or_else(|| try_format(duration, 'M', |d| d.as_secs() / 60))
.or_else(|| {
try_format(duration, 'H', |d| {
let minutes = d.as_secs() / 60;
minutes / 60
})
})
.expect("duration is unrealistically large")
}
pub(crate) enum SanitizeHeaders {
Yes,
No,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::MetadataValue;
use http::Uri;
#[test]
fn reserved_headers_are_excluded() {
let mut r = Request::new(1);
for header in &MetadataMap::GRPC_RESERVED_HEADERS {
r.metadata_mut()
.insert(*header, MetadataValue::from_static("invalid"));
}
let http_request = r.into_http(
Uri::default(),
http::Method::POST,
http::Version::HTTP_2,
SanitizeHeaders::Yes,
);
assert!(http_request.headers().is_empty());
}
#[test]
fn duration_to_grpc_timeout_less_than_second() {
let timeout = Duration::from_millis(500);
let value = duration_to_grpc_timeout(timeout);
assert_eq!(value, format!("{}u", timeout.as_micros()));
}
#[test]
fn duration_to_grpc_timeout_more_than_second() {
let timeout = Duration::from_secs(30);
let value = duration_to_grpc_timeout(timeout);
assert_eq!(value, format!("{}u", timeout.as_micros()));
}
#[test]
fn duration_to_grpc_timeout_a_very_long_time() {
let one_hour = Duration::from_secs(60 * 60);
let value = duration_to_grpc_timeout(one_hour);
assert_eq!(value, format!("{}m", one_hour.as_millis()));
}
}