1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use crate::net::TcpListener;
7use crate::v5::{Method, Request, Stream};
8
9const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
10
11pub struct Config<A, H> {
12 auth: A,
13 handler: H,
14 timeout: Duration,
15}
16
17impl<A, H> Config<A, H> {
18 pub fn new(auth: A, handler: H) -> Self {
19 Self {
20 auth,
21 handler,
22 timeout: DEFAULT_TIMEOUT,
23 }
24 }
25
26 pub fn with_timeout(mut self, timeout: Duration) -> Self {
27 self.timeout = timeout;
28 self
29 }
30}
31
32pub struct Server;
34
35impl Server {
36 pub async fn run<H, A>(
37 listener: TcpListener,
38 config: Arc<Config<A, H>>,
39 shutdown_signal: impl Future<Output = ()>,
40 ) -> io::Result<()>
41 where
42 H: Handler + 'static,
43 A: Authenticator + 'static,
44 {
45 tokio::pin!(shutdown_signal);
46
47 loop {
48 tokio::select! {
49 biased;
51
52 _ = &mut shutdown_signal => return Ok(()),
53
54 result = listener.accept() => {
55 let (inner, addr) = match result {
56 Ok(res) => res,
57 Err(_err) => {
58 #[cfg(feature = "tracing")]
59 tracing::error!("Failed to accept connection: {}", _err);
60 continue;
61 }
62 };
63
64 let local_addr = match inner.local_addr() {
65 Ok(addr) => addr,
66 Err(_err) => {
67 #[cfg(feature = "tracing")]
68 tracing::error!("Failed to get local address for connection {}: {}", addr, _err);
69 continue;
70 }
71 };
72
73 let config = config.clone();
74 tokio::spawn(async move {
75 let mut stream = Stream::with(inner, addr, local_addr);
76
77 if let Err(_err) = Self::handle_connection(&mut stream, &config).await {
78 #[cfg(feature = "tracing")]
79 tracing::warn!("Connection {} error: {}", addr, _err);
80 }
81 });
82 }
83 }
84 }
85 }
86
87 async fn handle_connection<H, A, S>(
88 stream: &mut Stream<S>,
89 config: &Config<A, H>,
90 ) -> io::Result<()>
91 where
92 H: Handler + 'static,
93 A: Authenticator + 'static,
94 S: AsyncRead + AsyncWrite + Unpin + Send + Sync,
95 {
96 let request = tokio::time::timeout(config.timeout, async {
98 let methods = stream.read_methods().await?;
99 config.auth.auth(stream, methods).await?;
100 stream.read_request().await
101 })
102 .await
103 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Timeout during authentication"))??;
104
105 config.handler.handle(stream, request).await
106 }
107}
108
109pub trait Authenticator: Send + Sync {
111 fn auth<T>(
112 &self,
113 stream: &mut Stream<T>,
114 methods: Vec<Method>,
115 ) -> impl Future<Output = io::Result<()>> + Send
116 where
117 T: AsyncRead + AsyncWrite + Unpin + Send + Sync;
118}
119
120pub trait Handler: Send + Sync {
122 fn handle<T>(
123 &self,
124 stream: &mut Stream<T>,
125 request: Request,
126 ) -> impl Future<Output = io::Result<()>> + Send
127 where
128 T: AsyncRead + AsyncWrite + Unpin + Send + Sync;
129}
130
131pub mod auth {
132 use super::*;
133
134 pub struct NoAuthentication;
135
136 impl Authenticator for NoAuthentication {
137 async fn auth<T>(&self, stream: &mut Stream<T>, _methods: Vec<Method>) -> io::Result<()>
138 where
139 T: AsyncRead + AsyncWrite + Unpin + Send + Sync,
140 {
141 stream.write_auth_method(Method::NoAuthentication).await?;
142 Ok(())
143 }
144 }
145
146 pub struct UserPassword {
147 username: String,
148 password: String,
149 }
150
151 impl UserPassword {
152 pub fn new(username: String, password: String) -> Self {
153 Self { username, password }
154 }
155 }
156
157 impl Authenticator for UserPassword {
158 async fn auth<T>(&self, stream: &mut Stream<T>, methods: Vec<Method>) -> io::Result<()>
159 where
160 T: AsyncRead + AsyncWrite + Unpin + Send + Sync,
161 {
162 if !methods.contains(&Method::UsernamePassword) {
163 return Err(io::Error::new(
164 io::ErrorKind::PermissionDenied,
165 "Username/Password authentication required",
166 ));
167 }
168
169 stream.write_auth_method(Method::UsernamePassword).await?;
170
171 let version = stream.read_u8().await?;
173 if version != 0x01 {
174 return Err(io::Error::new(
175 io::ErrorKind::InvalidData,
176 "Invalid subnegotiation version",
177 ));
178 }
179
180 let ulen = stream.read_u8().await?;
181 let mut username = vec![0; ulen as usize];
182 stream.read_exact(&mut username).await?;
183
184 let plen = stream.read_u8().await?;
185 let mut password = vec![0; plen as usize];
186 stream.read_exact(&mut password).await?;
187
188 if username != self.username.as_bytes() || password != self.password.as_bytes() {
190 stream.write_all(&[0x01, 0x01]).await?;
191 return Err(io::Error::new(
192 io::ErrorKind::PermissionDenied,
193 "Invalid username or password",
194 ));
195 }
196
197 stream.write_all(&[0x01, 0x00]).await?;
198
199 Ok(())
200 }
201 }
202}