aerosocket_server/
tls_transport.rs1#[cfg(feature = "tls-transport")]
7use aerosocket_core::{Error, Result, Transport};
8#[cfg(feature = "tls-transport")]
9use aerosocket_core::transport::TransportStream;
10#[cfg(feature = "tls-transport")]
11use async_trait::async_trait;
12#[cfg(feature = "tls-transport")]
13use std::net::SocketAddr;
14#[cfg(feature = "tls-transport")]
15use std::sync::Arc;
16
17#[cfg(feature = "tls-transport")]
18use tokio::net::{TcpListener, TcpStream as TokioTcpStream};
19#[cfg(feature = "tls-transport")]
20use tokio_rustls::{TlsAcceptor, server::TlsStream, rustls::ServerConfig as RustlsServerConfig};
21
22#[cfg(feature = "tls-transport")]
23pub struct TlsTransport {
25 listener: TcpListener,
27 acceptor: TlsAcceptor,
29 local_addr: SocketAddr,
31}
32
33#[cfg(feature = "tls-transport")]
34pub struct TlsStreamWrapper {
36 inner: TlsStream<TokioTcpStream>,
37}
38
39#[cfg(feature = "tls-transport")]
40#[async_trait]
41impl Transport for TlsTransport {
42 type Stream = TlsStreamWrapper;
43
44 async fn accept(&self) -> Result<Self::Stream> {
45 let tcp_stream = self.listener.accept().await
46 .map_err(|e| Error::Io(e))?.0;
47
48 let tls_stream = self.acceptor.accept(tcp_stream).await
49 .map_err(|e| Error::Other(format!("Failed to accept TLS connection: {}", e)))?;
50
51 Ok(TlsStreamWrapper { inner: tls_stream })
52 }
53
54 fn local_addr(&self) -> Result<SocketAddr> {
55 Ok(self.local_addr)
56 }
57
58 async fn close(self) -> Result<()> {
59 Ok(())
61 }
62}
63
64#[cfg(feature = "tls-transport")]
65#[async_trait]
66impl TransportStream for TlsStreamWrapper {
67 async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
68 use tokio::io::AsyncReadExt;
69 self.inner.read(buf).await
70 .map_err(|e| Error::Io(e))
71 }
72
73 async fn write(&mut self, buf: &[u8]) -> Result<usize> {
74 use tokio::io::AsyncWriteExt;
75 self.inner.write(buf).await
76 .map_err(|e| Error::Io(e))
77 }
78
79 async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
80 use tokio::io::AsyncWriteExt;
81 self.inner.write_all(buf).await
82 .map_err(|e| Error::Io(e))
83 }
84
85 async fn flush(&mut self) -> Result<()> {
86 use tokio::io::AsyncWriteExt;
87 self.inner.flush().await
88 .map_err(|e| Error::Io(e))
89 }
90
91 async fn close(&mut self) -> Result<()> {
92 use tokio::io::AsyncWriteExt;
93 self.inner.shutdown().await
94 .map_err(|e| Error::Io(e))
95 }
96
97 fn remote_addr(&self) -> Result<SocketAddr> {
98 self.inner.get_ref().0.peer_addr()
99 .map_err(|e| Error::Io(e))
100 }
101
102 fn local_addr(&self) -> Result<SocketAddr> {
103 self.inner.get_ref().0.local_addr()
104 .map_err(|e| Error::Io(e))
105 }
106}
107
108#[cfg(feature = "tls-transport")]
109impl TlsTransport {
110 pub async fn bind(addr: SocketAddr, tls_config: RustlsServerConfig) -> Result<Self> {
112 let listener = TcpListener::bind(addr).await
113 .map_err(|e| Error::Io(e))?;
114
115 let local_addr = listener.local_addr()
116 .map_err(|e| Error::Io(e))?;
117
118 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
119
120 Ok(Self {
121 listener,
122 acceptor,
123 local_addr,
124 })
125 }
126
127 pub async fn bind_with_default_config(addr: SocketAddr) -> Result<Self> {
129 let config = create_default_tls_config()?;
130 Self::bind(addr, config).await
131 }
132}
133
134#[cfg(feature = "tls-transport")]
135pub fn create_default_tls_config() -> Result<RustlsServerConfig> {
137 Err(Error::Other("TLS configuration not available in this release. Please implement your own TLS config.".to_string()))
139}
140
141#[cfg(not(feature = "tls-transport"))]
142pub struct TlsTransport;
144
145#[cfg(not(feature = "tls-transport"))]
146impl TlsTransport {
147 pub async fn bind(_addr: std::net::SocketAddr, _config: ()) -> aerosocket_core::Result<Self> {
148 Err(aerosocket_core::Error::Other("TLS transport requires the 'tls-transport' feature to be enabled".to_string()))
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[tokio::test]
157 async fn test_tls_transport_creation() {
158 let addr = "127.0.0.1:0".parse().unwrap();
161
162 if let Ok(config) = create_default_tls_config() {
164 let result = TlsTransport::bind(addr, config).await;
165 assert!(result.is_ok());
166 }
167 }
168}