tonic_rustls/server/
conn.rs

1use std::net::SocketAddr;
2use tokio::net::TcpStream;
3
4#[cfg(feature = "tls")]
5use std::sync::Arc;
6#[cfg(feature = "tls")]
7use tokio_rustls::rustls::pki_types::CertificateDer;
8#[cfg(feature = "tls")]
9use tokio_rustls::server::TlsStream;
10
11/// Trait that connected IO resources implement and use to produce info about the connection.
12///
13/// The goal for this trait is to allow users to implement
14/// custom IO types that can still provide the same connection
15/// metadata.
16///
17/// # Example
18///
19/// The `ConnectInfo` returned will be accessible through [request extensions][ext]:
20///
21/// ```
22/// use tonic::Request;
23/// use tonic_rustls::server::Connected;
24///
25/// // A `Stream` that yields connections
26/// struct MyConnector {}
27///
28/// // Return metadata about the connection as `MyConnectInfo`
29/// impl Connected for MyConnector {
30///     type ConnectInfo = MyConnectInfo;
31///
32///     fn connect_info(&self) -> Self::ConnectInfo {
33///         MyConnectInfo {}
34///     }
35/// }
36///
37/// #[derive(Clone)]
38/// struct MyConnectInfo {
39///     // Metadata about your connection
40/// }
41///
42/// // The connect info can be accessed through request extensions:
43/// # fn foo(request: Request<()>) {
44/// let connect_info: &MyConnectInfo = request
45///     .extensions()
46///     .get::<MyConnectInfo>()
47///     .expect("bug in tonic");
48/// # }
49/// ```
50///
51/// [ext]: crate::Request::extensions
52pub trait Connected {
53    /// The connection info type the IO resources generates.
54    // all these bounds are necessary to set this as a request extension
55    type ConnectInfo: Clone + Send + Sync + 'static;
56
57    /// Create type holding information about the connection.
58    fn connect_info(&self) -> Self::ConnectInfo;
59}
60
61/// Connection info for standard TCP streams.
62///
63/// This type will be accessible through [request extensions][ext] if you're using the default
64/// non-TLS connector.
65///
66/// See [`Connected`] for more details.
67///
68/// [ext]: crate::Request::extensions
69#[derive(Debug, Clone)]
70pub struct TcpConnectInfo {
71    /// Returns the local address of this connection.
72    pub local_addr: Option<SocketAddr>,
73    /// Returns the remote (peer) address of this connection.
74    pub remote_addr: Option<SocketAddr>,
75}
76
77impl TcpConnectInfo {
78    /// Return the local address the IO resource is connected.
79    pub fn local_addr(&self) -> Option<SocketAddr> {
80        self.local_addr
81    }
82
83    /// Return the remote address the IO resource is connected too.
84    pub fn remote_addr(&self) -> Option<SocketAddr> {
85        self.remote_addr
86    }
87}
88
89impl Connected for TcpStream {
90    type ConnectInfo = TcpConnectInfo;
91
92    fn connect_info(&self) -> Self::ConnectInfo {
93        TcpConnectInfo {
94            local_addr: self.local_addr().ok(),
95            remote_addr: self.peer_addr().ok(),
96        }
97    }
98}
99
100impl Connected for tokio::io::DuplexStream {
101    type ConnectInfo = ();
102
103    fn connect_info(&self) -> Self::ConnectInfo {}
104}
105
106#[cfg(feature = "tls")]
107impl<T> Connected for TlsStream<T>
108where
109    T: Connected,
110{
111    type ConnectInfo = TlsConnectInfo<T::ConnectInfo>;
112
113    fn connect_info(&self) -> Self::ConnectInfo {
114        let (inner, session) = self.get_ref();
115        let inner = inner.connect_info();
116
117        let certs = session
118            .peer_certificates()
119            .map(|certs| certs.to_owned().into());
120
121        TlsConnectInfo { inner, certs }
122    }
123}
124
125/// Connection info for TLS streams.
126///
127/// This type will be accessible through [request extensions][ext] if you're using a TLS connector.
128///
129/// See [`Connected`] for more details.
130///
131/// [ext]: crate::Request::extensions
132#[cfg(feature = "tls")]
133#[derive(Debug, Clone)]
134pub struct TlsConnectInfo<T> {
135    inner: T,
136    certs: Option<Arc<Vec<CertificateDer<'static>>>>,
137}
138
139#[cfg(feature = "tls")]
140impl<T> TlsConnectInfo<T> {
141    /// Get a reference to the underlying connection info.
142    pub fn get_ref(&self) -> &T {
143        &self.inner
144    }
145
146    /// Get a mutable reference to the underlying connection info.
147    pub fn get_mut(&mut self) -> &mut T {
148        &mut self.inner
149    }
150
151    /// Return the set of connected peer TLS certificates.
152    pub fn peer_certs(&self) -> Option<Arc<Vec<CertificateDer<'static>>>> {
153        self.certs.clone()
154    }
155}