1pub mod address_pool;
2mod connection;
3pub mod session;
4
5#[cfg(feature = "metrics")]
6mod metrics;
7
8use std::collections::HashMap;
9use std::net::{IpAddr, SocketAddr};
10use std::sync::Arc;
11#[cfg(feature = "metrics")]
12use std::time::Duration;
13use std::time::Instant;
14
15use bytes::Bytes;
16use dashmap::DashMap;
17use futures::stream::FuturesUnordered;
18use futures::StreamExt;
19use quinn::{Endpoint, VarInt};
20use tokio::signal;
21use tokio::sync::mpsc::error::TrySendError;
22use tokio::sync::mpsc::{channel, Receiver, Sender};
23use tracing::{debug, info, warn};
24
25use crate::server::address_pool::AddressPoolManager;
26use crate::server::connection::{Assigned, QuincyConnection};
27use crate::server::session::{ConnectionSession, UserSessionRegistry};
28use crate::users::UsersFile;
29use quincy::config::{
30 AddressRange, AllowedNoiseKeys, NoiseKeyExchange, ServerConfig, ServerProtocolConfig,
31};
32use quincy::constants::{PACKET_BUFFER_SIZE, PACKET_CHANNEL_SIZE, QUINN_RUNTIME};
33use quincy::network::interface::{Interface, InterfaceIO};
34use quincy::network::packet::Packet;
35use quincy::network::socket::bind_socket;
36use quincy::utils::tasks::abort_all;
37use quincy::Result;
38
39type ConnectionQueues = Arc<DashMap<IpAddr, Sender<Bytes>>>;
41
42struct AssignmentResult {
44 result: Result<QuincyConnection<Assigned>>,
45 quic_connection: quinn::Connection,
46}
47
48pub struct QuincyServer {
50 config: ServerConfig,
51 connection_queues: ConnectionQueues,
52 address_pool: Arc<AddressPoolManager>,
53 users: Arc<UsersFile>,
54 session_registry: Arc<UserSessionRegistry>,
55}
56
57impl QuincyServer {
58 pub fn new(config: ServerConfig) -> Result<Self> {
65 let users = UsersFile::load(&config.users_file)?;
66
67 let user_pools: HashMap<String, Vec<AddressRange>> = users
68 .users
69 .iter()
70 .filter(|(_, entry)| !entry.address_pool.is_empty())
71 .map(|(name, entry)| (name.clone(), entry.address_pool.clone()))
72 .collect();
73
74 let address_pool = AddressPoolManager::new(config.tunnel_network, user_pools)?;
75
76 Ok(Self {
77 config,
78 connection_queues: Arc::new(DashMap::new()),
79 address_pool: Arc::new(address_pool),
80 users: Arc::new(users),
81 session_registry: Arc::new(UserSessionRegistry::new()),
82 })
83 }
84
85 pub async fn run<I: InterfaceIO>(&self) -> Result<()> {
87 let interface: Interface<I> = Interface::create(
88 self.config.tunnel_network,
89 self.config.connection.mtu,
90 Some(self.config.tunnel_network.network()),
91 self.config.interface_name.clone(),
92 None,
93 None,
94 )?;
95 let interface = Arc::new(interface);
96
97 #[cfg(feature = "metrics")]
98 if self.config.metrics.enabled {
99 use crate::server::metrics::init_metrics;
100
101 init_metrics(&self.config.metrics)?;
102 }
103
104 let (sender, receiver) = channel(PACKET_CHANNEL_SIZE);
105
106 let mut tasks = FuturesUnordered::new();
107
108 tasks.extend([
109 tokio::spawn(Self::process_outbound_traffic(
110 interface.clone(),
111 self.connection_queues.clone(),
112 )),
113 tokio::spawn(Self::process_inbound_traffic(
114 self.connection_queues.clone(),
115 interface,
116 receiver,
117 self.config.isolate_clients,
118 )),
119 ]);
120
121 let handler_task = self.handle_connections(sender);
122
123 let result = tokio::select! {
124 handler_task_result = handler_task => handler_task_result,
125 Some(task_result) = tasks.next() => task_result?,
126 };
127
128 let _ = abort_all(tasks).await;
129
130 result
131 }
132
133 async fn handle_connections(&self, ingress_queue: Sender<Packet>) -> Result<()> {
138 let endpoint = self.create_quinn_endpoint()?;
139
140 info!(
141 "Starting connection handler: {}",
142 endpoint.local_addr().expect("Endpoint has a local address")
143 );
144
145 let protocol = Arc::new(self.config.protocol.clone());
146 let server_address = self.config.tunnel_network;
147 let users = self.users.clone();
148 let address_pool = self.address_pool.clone();
149 let session_registry = self.session_registry.clone();
150
151 let mut assignment_tasks = FuturesUnordered::new();
152 let mut connection_tasks = FuturesUnordered::new();
153
154 loop {
155 tokio::select! {
156 Some(handshake) = endpoint.accept() => {
158 let client_ip = handshake.remote_address().ip();
159
160 debug!(
161 "Received incoming connection from '{}'",
162 client_ip
163 );
164
165 let quic_connection = match handshake.await {
166 Ok(connection) => connection,
167 Err(e) => {
168 warn!("Connection handshake with client '{client_ip}' failed: {e}");
169 continue;
170 }
171 };
172
173 let quic_connection_clone = quic_connection.clone();
174 let connection = QuincyConnection::new(
175 quic_connection,
176 ingress_queue.clone(),
177 );
178
179 let connection = match connection.identify(&protocol, &users) {
181 Ok(conn) => conn,
182 Err(e) => {
183 warn!("Failed to identify client: {e}");
184 quic_connection_clone.close(VarInt::from_u32(0x02), "Session establishment failed".as_bytes());
185 continue;
186 }
187 };
188
189 let address_pool = address_pool.clone();
190 let server_addr = server_address;
191
192 assignment_tasks.push(async move {
193 let result = connection.assign_ip(&address_pool, server_addr).await;
194 AssignmentResult {
195 result,
196 quic_connection: quic_connection_clone,
197 }
198 });
199 }
200
201 Some(assignment) = assignment_tasks.next() => {
203 let connection = match assignment.result {
204 Ok(connection) => connection,
205 Err(e) => {
206 warn!("Failed to assign IP to client: {e}");
207 assignment.quic_connection.close(
208 VarInt::from_u32(0x02),
209 "Session establishment failed".as_bytes(),
210 );
211 continue;
212 }
213 };
214
215 let client_address = connection.client_address();
216 let username = connection.username().to_string();
217
218 let bandwidth_limit = self
221 .users
222 .users
223 .get(&username)
224 .and_then(|entry| entry.bandwidth_limit)
225 .or(self.config.default_bandwidth_limit);
226
227 let rate_limiter = session_registry.add_connection(
229 &username,
230 ConnectionSession {
231 client_address,
232 connected_at: Instant::now(),
233 },
234 bandwidth_limit,
235 );
236
237 let (connection_sender, connection_receiver) = channel(PACKET_CHANNEL_SIZE);
238
239 connection_tasks.push(tokio::spawn(connection.run(
240 connection_receiver,
241 rate_limiter,
242 #[cfg(feature = "metrics")]
243 Duration::from_secs(self.config.metrics.reporting_interval_s),
244 )));
245 self.connection_queues
246 .insert(client_address.addr(), connection_sender);
247 }
248
249 Some(connection) = connection_tasks.next() => {
251 let (connection, err) = connection?;
252 let username = connection.username();
253 let client_address = connection.client_address();
254
255 self.connection_queues.remove(&client_address.addr());
256 self.address_pool.release_address(username, &client_address.addr());
257 session_registry.remove_connection(username, &client_address);
258
259 warn!(
260 "Connection with client {} (user '{username}') has encountered an error: {err}",
261 client_address.addr()
262 );
263 }
264
265 _ = signal::ctrl_c() => {
267 info!("Received shutdown signal, shutting down");
268 let _ = abort_all(connection_tasks).await;
269
270 endpoint.close(VarInt::from_u32(0x01), "Server shutdown".as_bytes());
271
272 return Ok(());
273 }
274 }
275 }
276 }
277
278 fn create_quinn_endpoint(&self) -> Result<Endpoint> {
280 let (allowed_keys, allowed_fingerprints) = match &self.config.protocol {
282 ServerProtocolConfig::Noise(noise) => {
283 let keys = match noise.key_exchange {
284 NoiseKeyExchange::Standard => Some(AllowedNoiseKeys::Standard(
285 self.users.collect_noise_public_keys(),
286 )),
287 NoiseKeyExchange::Hybrid => Some(AllowedNoiseKeys::Hybrid(
288 self.users.collect_noise_pq_public_keys(),
289 )),
290 };
291 (keys, None)
292 }
293 ServerProtocolConfig::Tls(_) => (None, Some(self.users.collect_cert_fingerprints())),
294 };
295
296 let quinn_config = self
297 .config
298 .as_quinn_server_config(allowed_keys, allowed_fingerprints)?;
299
300 let socket = bind_socket(
301 SocketAddr::new(self.config.bind_address, self.config.bind_port),
302 self.config.connection.send_buffer_size as usize,
303 self.config.connection.recv_buffer_size as usize,
304 self.config.reuse_socket,
305 )?;
306
307 let endpoint_config = self
308 .config
309 .connection
310 .as_endpoint_config(self.config.noise_key_exchange())?;
311 let endpoint = Endpoint::new(
312 endpoint_config,
313 Some(quinn_config),
314 socket,
315 QUINN_RUNTIME.clone(),
316 )?;
317
318 Ok(endpoint)
319 }
320
321 async fn process_outbound_traffic(
327 interface: Arc<Interface<impl InterfaceIO>>,
328 connection_queues: ConnectionQueues,
329 ) -> Result<()> {
330 debug!("Started tunnel outbound traffic task (interface -> connection queue)");
331
332 loop {
333 let packet = interface.read_packet().await?;
334 let dest_addr = match packet.destination() {
335 Ok(addr) => addr,
336 Err(e) => {
337 warn!("Received packet with malformed header structure: {e}");
338 continue;
339 }
340 };
341
342 debug!("Destination address for packet: {dest_addr}");
343
344 let connection_queue = match connection_queues.get(&dest_addr) {
345 Some(connection_queue) => connection_queue,
346 None => continue,
347 };
348
349 debug!("Found connection for IP {dest_addr}");
350
351 match connection_queue.try_send(packet.into()) {
352 Ok(()) => {}
353 Err(TrySendError::Full(_)) => {
354 debug!("Dropping outbound packet for {dest_addr}: per-client queue full");
355 }
356 Err(TrySendError::Closed(_)) => {
357 debug!("Dropping outbound packet for {dest_addr}: connection closed");
358 }
359 }
360 }
361 }
362
363 async fn process_inbound_traffic(
371 connection_queues: ConnectionQueues,
372 interface: Arc<Interface<impl InterfaceIO>>,
373 ingress_queue: Receiver<Packet>,
374 isolate_clients: bool,
375 ) -> Result<()> {
376 debug!("Started tunnel inbound traffic task (tunnel queue -> interface)");
377
378 if isolate_clients {
379 relay_isolated(connection_queues, interface, ingress_queue).await
380 } else {
381 relay_unisolated(connection_queues, interface, ingress_queue).await
382 }
383 }
384}
385
386#[inline]
387async fn relay_isolated(
388 connection_queues: ConnectionQueues,
389 interface: Arc<Interface<impl InterfaceIO>>,
390 mut ingress_queue: Receiver<Packet>,
391) -> Result<()> {
392 loop {
393 let mut packets = Vec::with_capacity(PACKET_BUFFER_SIZE);
394 let count = ingress_queue
395 .recv_many(&mut packets, PACKET_BUFFER_SIZE)
396 .await;
397
398 if count == 0 {
400 return Ok(());
401 }
402
403 let filtered_packets = packets
404 .into_iter()
405 .filter(|packet| {
406 let dest_addr = match packet.destination() {
407 Ok(addr) => addr,
408 Err(e) => {
409 warn!("Received packet with malformed header structure: {e}");
410 return false;
411 }
412 };
413 !connection_queues.contains_key(&dest_addr)
414 })
415 .collect::<Vec<_>>();
416
417 interface.write_packets(filtered_packets).await?;
418 }
419}
420
421#[inline]
422async fn relay_unisolated(
423 connection_queues: ConnectionQueues,
424 interface: Arc<Interface<impl InterfaceIO>>,
425 mut ingress_queue: Receiver<Packet>,
426) -> Result<()> {
427 loop {
428 let mut packets = Vec::with_capacity(PACKET_BUFFER_SIZE);
429
430 let count = ingress_queue
431 .recv_many(&mut packets, PACKET_BUFFER_SIZE)
432 .await;
433
434 if count == 0 {
436 return Ok(());
437 }
438
439 for packet in packets {
440 let dest_addr = match packet.destination() {
441 Ok(addr) => addr,
442 Err(e) => {
443 warn!("Received packet with malformed header structure: {e}");
444 continue;
445 }
446 };
447
448 match connection_queues.get(&dest_addr) {
449 Some(connection_queue) => match connection_queue.try_send(packet.into()) {
451 Ok(()) => {}
452 Err(TrySendError::Full(_)) => {
453 debug!("Dropping client-to-client packet for {dest_addr}: queue full");
454 }
455 Err(TrySendError::Closed(_)) => {
456 debug!(
457 "Dropping client-to-client packet for {dest_addr}: connection closed"
458 );
459 }
460 },
461 None => interface.write_packet(packet).await?,
463 }
464 }
465 }
466}