1pub mod address_pool;
2mod connection;
3
4use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::auth::AuthServer;
9use crate::server::connection::QuincyConnection;
10use crate::users_file::UsersFileServerAuthenticator;
11use bytes::Bytes;
12use dashmap::DashMap;
13use futures::stream::FuturesUnordered;
14use futures::StreamExt;
15use quincy::config::ServerConfig;
16use quincy::network::socket::bind_socket;
17use quincy::Result;
18use quinn::{Endpoint, VarInt};
19use tokio::signal;
20use tokio::sync::mpsc::{channel, Receiver, Sender};
21
22use self::address_pool::AddressPool;
23use quincy::constants::{PACKET_BUFFER_SIZE, PACKET_CHANNEL_SIZE, QUINN_RUNTIME};
24use quincy::network::interface::{Interface, InterfaceIO};
25use quincy::network::packet::Packet;
26use quincy::utils::tasks::abort_all;
27use tracing::{debug, info, warn};
28
29type ConnectionQueues = Arc<DashMap<IpAddr, Sender<Bytes>>>;
30
31pub struct QuincyServer {
33 config: ServerConfig,
34 connection_queues: ConnectionQueues,
35 address_pool: Arc<AddressPool>,
36}
37
38impl QuincyServer {
39 pub fn new(config: ServerConfig) -> Result<Self> {
44 let address_pool = AddressPool::new(config.tunnel_network);
45
46 Ok(Self {
47 config,
48 connection_queues: Arc::new(DashMap::new()),
49 address_pool: Arc::new(address_pool),
50 })
51 }
52
53 pub async fn run<I: InterfaceIO>(&self) -> Result<()> {
55 let interface: Interface<I> = Interface::create(
56 self.config.tunnel_network,
57 self.config.connection.mtu,
58 Some(self.config.tunnel_network.network()),
59 None,
60 None,
61 )?;
62 let interface = Arc::new(interface);
63
64 let authenticator = Box::new(UsersFileServerAuthenticator::new(
65 &self.config.authentication,
66 self.address_pool.clone(),
67 )?);
68 let auth_server = AuthServer::new(
69 authenticator,
70 self.config.tunnel_network,
71 Duration::from_secs(self.config.connection.connection_timeout_s),
72 );
73
74 let (sender, receiver) = channel(PACKET_CHANNEL_SIZE);
75
76 let mut tasks = FuturesUnordered::new();
77
78 tasks.extend([
79 tokio::spawn(Self::process_outbound_traffic(
80 interface.clone(),
81 self.connection_queues.clone(),
82 )),
83 tokio::spawn(Self::process_inbound_traffic(
84 self.connection_queues.clone(),
85 interface,
86 receiver,
87 self.config.isolate_clients,
88 )),
89 ]);
90
91 let handler_task = self.handle_connections(auth_server, sender);
92
93 let result = tokio::select! {
94 handler_task_result = handler_task => handler_task_result,
95 Some(task_result) = tasks.next() => task_result?,
96 };
97
98 let _ = abort_all(tasks).await;
99
100 result
101 }
102
103 async fn handle_connections(
109 &self,
110 auth_server: AuthServer,
111 ingress_queue: Sender<Packet>,
112 ) -> Result<()> {
113 let endpoint = self.create_quinn_endpoint()?;
114
115 info!(
116 "Starting connection handler: {}",
117 endpoint.local_addr().expect("Endpoint has a local address")
118 );
119
120 let mut authentication_tasks = FuturesUnordered::new();
121 let mut connection_tasks = FuturesUnordered::new();
122
123 loop {
124 tokio::select! {
125 Some(handshake) = endpoint.accept() => {
127 let client_ip = handshake.remote_address().ip();
128
129 debug!(
130 "Received incoming connection from '{}'",
131 client_ip
132 );
133
134 let quic_connection = match handshake.await {
135 Ok(connection) => connection,
136 Err(e) => {
137 warn!("Connection handshake with client '{client_ip}' failed: {e}");
138 continue;
139 }
140 };
141
142 let connection = QuincyConnection::new(
143 quic_connection,
144 ingress_queue.clone(),
145 );
146
147 authentication_tasks.push(
148 connection.authenticate(&auth_server)
149 );
150 }
151
152 Some(connection) = authentication_tasks.next() => {
154 let connection = match connection {
155 Ok(connection) => connection,
156 Err(e) => {
157 warn!("Failed to authenticate client: {e}");
158 continue;
159 }
160 };
161
162 let client_address = connection.client_address()?.addr();
163 let (connection_sender, connection_receiver) = channel(PACKET_CHANNEL_SIZE);
164
165 connection_tasks.push(tokio::spawn(connection.run(connection_receiver)));
166 self.connection_queues.insert(client_address, connection_sender);
167 }
168
169 Some(connection) = connection_tasks.next() => {
171 let (connection, err) = connection?;
172 let client_address = &connection.client_address()?.addr();
173
174 self.connection_queues.remove(client_address);
175 self.address_pool.release_address(client_address);
176 warn!("Connection with client {client_address} has encountered an error: {err}");
177 }
178
179 _ = signal::ctrl_c() => {
181 info!("Received shutdown signal, shutting down");
182 let _ = abort_all(connection_tasks).await;
183
184 endpoint.close(VarInt::from_u32(0x01), "Server shutdown".as_bytes());
185
186 return Ok(());
187 }
188 }
189 }
190 }
191
192 fn create_quinn_endpoint(&self) -> Result<Endpoint> {
194 let quinn_config = self.config.as_quinn_server_config()?;
195
196 let socket = bind_socket(
197 SocketAddr::new(self.config.bind_address, self.config.bind_port),
198 self.config.connection.send_buffer_size as usize,
199 self.config.connection.recv_buffer_size as usize,
200 self.config.reuse_socket,
201 )?;
202
203 let endpoint_config = self.config.connection.as_endpoint_config()?;
204 let endpoint = Endpoint::new(
205 endpoint_config,
206 Some(quinn_config),
207 socket,
208 QUINN_RUNTIME.clone(),
209 )?;
210
211 Ok(endpoint)
212 }
213
214 async fn process_outbound_traffic(
221 interface: Arc<Interface<impl InterfaceIO>>,
222 connection_queues: ConnectionQueues,
223 ) -> Result<()> {
224 debug!("Started tunnel outbound traffic task (interface -> connection queue)");
225
226 loop {
227 let packet = interface.read_packet().await?;
228 let dest_addr = match packet.destination() {
229 Ok(addr) => addr,
230 Err(e) => {
231 warn!("Received packet with malformed header structure: {e}");
232 continue;
233 }
234 };
235
236 debug!("Destination address for packet: {dest_addr}");
237
238 let connection_queue = match connection_queues.get(&dest_addr) {
239 Some(connection_queue) => connection_queue,
240 None => continue,
241 };
242
243 debug!("Found connection for IP {dest_addr}");
244
245 connection_queue.send(packet.into()).await?;
246 }
247 }
248
249 async fn process_inbound_traffic(
257 connection_queues: ConnectionQueues,
258 interface: Arc<Interface<impl InterfaceIO>>,
259 ingress_queue: Receiver<Packet>,
260 isolate_clients: bool,
261 ) -> Result<()> {
262 debug!("Started tunnel inbound traffic task (tunnel queue -> interface)");
263
264 if isolate_clients {
265 relay_isolated(connection_queues, interface, ingress_queue).await
266 } else {
267 relay_unisolated(connection_queues, interface, ingress_queue).await
268 }
269 }
270}
271
272#[inline]
273async fn relay_isolated(
274 connection_queues: ConnectionQueues,
275 interface: Arc<Interface<impl InterfaceIO>>,
276 mut ingress_queue: Receiver<Packet>,
277) -> Result<()> {
278 loop {
279 let mut packets = Vec::with_capacity(PACKET_BUFFER_SIZE);
280 ingress_queue
281 .recv_many(&mut packets, PACKET_BUFFER_SIZE)
282 .await;
283
284 let filtered_packets = packets
285 .into_iter()
286 .filter(|packet| {
287 let dest_addr = match packet.destination() {
288 Ok(addr) => addr,
289 Err(e) => {
290 warn!("Received packet with malformed header structure: {e}");
291 return false;
292 }
293 };
294 !connection_queues.contains_key(&dest_addr)
295 })
296 .collect::<Vec<_>>();
297
298 interface.write_packets(filtered_packets).await?;
299 }
300}
301
302#[inline]
303async fn relay_unisolated(
304 connection_queues: ConnectionQueues,
305 interface: Arc<Interface<impl InterfaceIO>>,
306 mut ingress_queue: Receiver<Packet>,
307) -> Result<()> {
308 loop {
309 let mut packets = Vec::with_capacity(PACKET_BUFFER_SIZE);
310
311 ingress_queue
312 .recv_many(&mut packets, PACKET_BUFFER_SIZE)
313 .await;
314
315 for packet in packets {
316 let dest_addr = match packet.destination() {
317 Ok(addr) => addr,
318 Err(e) => {
319 warn!("Received packet with malformed header structure: {e}");
320 continue;
321 }
322 };
323
324 match connection_queues.get(&dest_addr) {
325 Some(connection_queue) => connection_queue.send(packet.into()).await?,
327 None => interface.write_packet(packet).await?,
329 }
330 }
331 }
332}