1use std::io;
2use std::marker::PhantomData;
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::broadcast;
7use tokio::task::JoinHandle;
8
9use ombrac_macros::{error, info, warn};
10#[cfg(feature = "transport-quic")]
11use ombrac_transport::quic::{
12 Connection as QuicConnection, TransportConfig as QuicTransportConfig,
13 client::{Client as QuicClient, Config as QuicConfig},
14};
15use ombrac_transport::{Connection, Initiator};
16
17use crate::client::Client;
18use crate::config::ServiceConfig;
19#[cfg(feature = "transport-quic")]
20use crate::config::TlsMode;
21
22pub type Result<T> = std::result::Result<T, Error>;
23
24#[derive(thiserror::Error, Debug)]
25pub enum Error {
26 #[error("Configuration error: {0}")]
27 Config(String),
28
29 #[error("{0}")]
30 Io(#[from] io::Error),
31
32 #[error("Transport layer error: {0}")]
33 Transport(String),
34
35 #[error("Endpoint failed: {0}")]
36 Endpoint(String),
37}
38
39macro_rules! require_config {
40 ($config_opt:expr, $field_name:expr) => {
41 $config_opt.ok_or_else(|| {
42 Error::Config(format!(
43 "'{}' is required but was not provided",
44 $field_name
45 ))
46 })
47 };
48}
49
50pub trait ServiceBuilder {
51 type Initiator: Initiator<Connection = Self::Connection>;
52 type Connection: Connection;
53
54 fn build(
55 config: &Arc<ServiceConfig>,
56 ) -> impl Future<Output = Result<Arc<Client<Self::Initiator, Self::Connection>>>> + Send;
57}
58
59#[cfg(feature = "transport-quic")]
60pub struct QuicServiceBuilder;
61
62#[cfg(feature = "transport-quic")]
63impl ServiceBuilder for QuicServiceBuilder {
64 type Initiator = QuicClient;
65 type Connection = QuicConnection;
66
67 async fn build(
68 config: &Arc<ServiceConfig>,
69 ) -> Result<Arc<Client<Self::Initiator, Self::Connection>>> {
70 let transport = quic_client_from_config(config).await?;
71 let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
72 let client = Arc::new(
73 Client::new(
74 transport,
75 secret,
76 config.handshake_option.clone().map(Into::into),
77 )
78 .await?,
79 );
80 Ok(client)
81 }
82}
83
84pub struct Service<T, C>
85where
86 T: Initiator<Connection = C>,
87 C: Connection,
88{
89 client: Arc<Client<T, C>>,
90 handles: Vec<JoinHandle<()>>,
91 shutdown_tx: broadcast::Sender<()>,
92 _transport: PhantomData<T>,
93 _connection: PhantomData<C>,
94}
95
96impl<T, C> Service<T, C>
97where
98 T: Initiator<Connection = C>,
99 C: Connection,
100{
101 pub async fn build<Builder>(config: Arc<ServiceConfig>) -> Result<Self>
102 where
103 Builder: ServiceBuilder<Initiator = T, Connection = C>,
104 {
105 let mut handles = Vec::new();
106 let client = Builder::build(&config).await?;
107 let (shutdown_tx, _) = broadcast::channel(1);
108
109 #[cfg(feature = "endpoint-http")]
111 if config.endpoint.http.is_some() {
112 handles.push(Self::spawn_endpoint(
113 "HTTP",
114 Self::endpoint_http_accept_loop(
115 config.clone(),
116 client.clone(),
117 shutdown_tx.subscribe(),
118 ),
119 ));
120 }
121
122 #[cfg(feature = "endpoint-socks")]
124 if config.endpoint.socks.is_some() {
125 handles.push(Self::spawn_endpoint(
126 "SOCKS",
127 Self::endpoint_socks_accept_loop(
128 config.clone(),
129 client.clone(),
130 shutdown_tx.subscribe(),
131 ),
132 ));
133 }
134
135 #[cfg(feature = "endpoint-tun")]
137 if let Some(tun_config) = &config.endpoint.tun
138 && (tun_config.tun_ipv4.is_some()
139 || tun_config.tun_ipv6.is_some()
140 || tun_config.tun_fd.is_some())
141 {
142 handles.push(Self::spawn_endpoint(
143 "TUN",
144 Self::endpoint_tun_accept_loop(
145 config.clone(),
146 client.clone(),
147 shutdown_tx.subscribe(),
148 ),
149 ));
150 }
151
152 if handles.is_empty() {
153 return Err(Error::Config(
154 "No endpoints were configured or enabled. The service has nothing to do."
155 .to_string(),
156 ));
157 }
158
159 Ok(Service {
160 client,
161 handles,
162 shutdown_tx,
163 _transport: PhantomData,
164 _connection: PhantomData,
165 })
166 }
167
168 pub async fn rebind(&self) -> io::Result<()> {
169 self.client.rebind().await
170 }
171
172 pub async fn shutdown(self) {
173 let _ = self.shutdown_tx.send(());
174
175 for handle in self.handles {
176 if let Err(_err) = handle.await {
177 error!("A task failed to shut down cleanly: {:?}", _err);
178 }
179 }
180 warn!("Service shutdown complete");
181 }
182
183 fn spawn_endpoint(
184 _name: &'static str,
185 task: impl Future<Output = Result<()>> + Send + 'static,
186 ) -> JoinHandle<()> {
187 tokio::spawn(async move {
188 if let Err(_e) = task.await {
189 error!("The {_name} endpoint shut down due to an error: {_e}");
190 }
191 })
192 }
193
194 #[cfg(feature = "endpoint-http")]
195 async fn endpoint_http_accept_loop(
196 config: Arc<ServiceConfig>,
197 ombrac: Arc<Client<T, C>>,
198 mut shutdown_rx: broadcast::Receiver<()>,
199 ) -> Result<()> {
200 use crate::endpoint::http::Server as HttpServer;
201 let bind_addr = require_config!(config.endpoint.http, "endpoint.http")?;
202 let socket = tokio::net::TcpListener::bind(bind_addr).await?;
203
204 info!("Starting HTTP/HTTPS endpoint, listening on {bind_addr}");
205
206 HttpServer::new(ombrac)
207 .accept_loop(socket, async {
208 let _ = shutdown_rx.recv().await;
209 })
210 .await
211 .map_err(|e| Error::Endpoint(format!("HTTP server failed to run: {}", e)))
212 }
213
214 #[cfg(feature = "endpoint-socks")]
215 async fn endpoint_socks_accept_loop(
216 config: Arc<ServiceConfig>,
217 ombrac: Arc<Client<T, C>>,
218 mut shutdown_rx: broadcast::Receiver<()>,
219 ) -> Result<()> {
220 use crate::endpoint::socks::CommandHandler;
221 use socks_lib::v5::server::auth::NoAuthentication;
222 use socks_lib::v5::server::{Config as SocksConfig, Server as SocksServer};
223
224 let bind_addr = require_config!(config.endpoint.socks, "endpoint.socks")?;
225 let socket = tokio::net::TcpListener::bind(bind_addr).await?;
226
227 info!("Starting SOCKS5 endpoint, listening on {bind_addr}");
228
229 let socks_config = Arc::new(SocksConfig::new(
230 NoAuthentication,
231 CommandHandler::new(ombrac),
232 ));
233 SocksServer::run(socket, socks_config, async {
234 let _ = shutdown_rx.recv().await;
235 })
236 .await
237 .map_err(|e| Error::Endpoint(format!("SOCKS server failed to run: {}", e)))
238 }
239
240 #[cfg(feature = "endpoint-tun")]
241 async fn endpoint_tun_accept_loop(
242 config: Arc<ServiceConfig>,
243 ombrac: Arc<Client<T, C>>,
244 mut shutdown_rx: broadcast::Receiver<()>,
245 ) -> Result<()> {
246 use crate::endpoint::tun::{AsyncDevice, Tun, TunConfig};
247
248 let config = require_config!(config.endpoint.tun.as_ref(), "endpoint.tun")?;
249
250 let device = match config.tun_fd {
251 Some(fd) => {
252 #[cfg(not(windows))]
253 unsafe {
254 AsyncDevice::from_fd(fd)?
255 }
256
257 #[cfg(windows)]
258 return Err(Error::Config(
259 "'tun_fd' option is not supported on Windows.".to_string(),
260 ));
261 }
262 None => {
263 #[cfg(not(any(target_os = "android", target_os = "ios")))]
264 {
265 let device = {
266 use std::str::FromStr;
267
268 let mut builder = tun_rs::DeviceBuilder::new();
269 builder = builder.mtu(config.tun_mtu.unwrap_or(1500));
270
271 if let Some(ip_str) = &config.tun_ipv4 {
272 let ip = ipnet::Ipv4Net::from_str(ip_str).map_err(|e| {
273 Error::Config(format!("Failed to parse IPv4 CIDR '{ip_str}': {e}"))
274 })?;
275 builder = builder.ipv4(ip.addr(), ip.netmask(), None);
276 }
277
278 if let Some(ip_str) = &config.tun_ipv6 {
279 let ip = ipnet::Ipv6Net::from_str(ip_str).map_err(|e| {
280 Error::Config(format!("Failed to parse IPv6 CIDR '{ip_str}': {e}"))
281 })?;
282 builder = builder.ipv6(ip.addr(), ip.netmask());
283 }
284
285 builder.build_async().map_err(|e| {
286 Error::Endpoint(format!("Failed to build TUN device: {e}"))
287 })?
288 };
289
290 info!(
291 "Starting TUN endpoint, Name: {:?}, MTU: {:?}, IP: {:?}",
292 device.name(),
293 device.mtu(),
294 device.addresses()
295 );
296
297 device
298 }
299
300 #[cfg(any(target_os = "android", target_os = "ios"))]
301 {
302 return Err(Error::Config(
303 "Creating a new TUN device is not supported on this platform. A pre-configured 'tun_fd' must be provided.".to_string()
304 ));
305 }
306 }
307 };
308
309 let mut tun_config = TunConfig::default();
310 if let Some(value) = &config.fake_dns {
311 tun_config.fakedns_cidr = value.parse().map_err(|e| {
312 Error::Config(format!("Failed to parse fake_dns CIDR '{value}': {e}"))
313 })?;
314 };
315
316 let tun = Tun::new(tun_config.into(), ombrac);
317 let shutdown_signal = async {
318 let _ = shutdown_rx.recv().await;
319 };
320
321 tun.accept_loop(device, shutdown_signal)
322 .await
323 .map_err(|e| Error::Endpoint(format!("TUN device runtime error: {}", e)))
324 }
325}
326
327#[cfg(feature = "transport-quic")]
328async fn quic_client_from_config(config: &ServiceConfig) -> io::Result<QuicClient> {
329 let server = &config.server;
330 let transport_cfg = &config.transport;
331
332 let server_name = match &transport_cfg.server_name {
333 Some(value) => value.clone(),
334 None => {
335 let pos = server.rfind(':').ok_or_else(|| {
336 io::Error::new(
337 io::ErrorKind::InvalidInput,
338 format!("Invalid server address: {}", server),
339 )
340 })?;
341 server[..pos].to_string()
342 }
343 };
344
345 let addrs: Vec<_> = tokio::net::lookup_host(server).await?.collect();
346 let server_addr = addrs.into_iter().next().ok_or_else(|| {
347 io::Error::new(
348 io::ErrorKind::NotFound,
349 format!("Failed to resolve server address: '{}'", server),
350 )
351 })?;
352
353 let mut quic_config = QuicConfig::new(server_addr, server_name);
354
355 quic_config.enable_zero_rtt = transport_cfg.zero_rtt.unwrap_or(false);
356 if let Some(protocols) = &transport_cfg.alpn_protocols {
357 quic_config.alpn_protocols = protocols.iter().map(|p| p.to_vec()).collect();
358 }
359
360 match transport_cfg.tls_mode.unwrap_or(TlsMode::Tls) {
361 TlsMode::Tls => {
362 if let Some(ca) = &transport_cfg.ca_cert {
363 quic_config.root_ca_path = Some(ca.to_path_buf());
364 }
365 }
366 TlsMode::MTls => {
367 quic_config.root_ca_path = Some(transport_cfg.ca_cert.clone().ok_or_else(|| {
368 io::Error::new(
369 io::ErrorKind::InvalidInput,
370 "CA cert is required for mTLS mode",
371 )
372 })?);
373 let client_cert = transport_cfg.client_cert.clone().ok_or_else(|| {
374 io::Error::new(
375 io::ErrorKind::InvalidInput,
376 "Client cert is required for mTLS mode",
377 )
378 })?;
379 let client_key = transport_cfg.client_key.clone().ok_or_else(|| {
380 io::Error::new(
381 io::ErrorKind::InvalidInput,
382 "Client key is required for mTLS mode",
383 )
384 })?;
385 quic_config.client_cert_key_paths = Some((client_cert, client_key));
386 }
387 TlsMode::Insecure => {
388 quic_config.skip_server_verification = true;
389 }
390 }
391
392 let mut transport_config = QuicTransportConfig::default();
393 if let Some(timeout) = transport_cfg.idle_timeout {
394 transport_config.max_idle_timeout(Duration::from_millis(timeout))?;
395 }
396 if let Some(interval) = transport_cfg.keep_alive {
397 transport_config.keep_alive_period(Duration::from_millis(interval))?;
398 }
399 if let Some(max_streams) = transport_cfg.max_streams {
400 transport_config.max_open_bidirectional_streams(max_streams)?;
401 }
402 if let Some(congestion) = transport_cfg.congestion {
403 transport_config.congestion(congestion, transport_cfg.cwnd_init)?;
404 }
405 quic_config.transport_config(transport_config);
406
407 Ok(QuicClient::new(quic_config)?)
408}