1use std::io;
2use std::marker::PhantomData;
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::broadcast;
7use tokio::task::JoinHandle;
8
9use ombrac_macros::{error, info, warn};
10#[cfg(feature = "transport-quic")]
11use ombrac_transport::quic::{
12 Connection as QuicConnection, TransportConfig as QuicTransportConfig,
13 client::{Client as QuicClient, Config as QuicConfig},
14};
15use ombrac_transport::{Connection, Initiator};
16
17use crate::client::Client;
18use crate::config::ServiceConfig;
19#[cfg(feature = "transport-quic")]
20use crate::config::TlsMode;
21
22pub type Result<T> = std::result::Result<T, Error>;
23
24#[derive(thiserror::Error, Debug)]
25pub enum Error {
26 #[error("Configuration error: {0}")]
27 Config(String),
28
29 #[error("{0}")]
30 Io(#[from] io::Error),
31
32 #[error("Transport layer error: {0}")]
33 Transport(String),
34
35 #[error("Endpoint failed: {0}")]
36 Endpoint(String),
37}
38
39macro_rules! require_config {
40 ($config_opt:expr, $field_name:expr) => {
41 $config_opt.ok_or_else(|| {
42 Error::Config(format!(
43 "'{}' is required but was not provided",
44 $field_name
45 ))
46 })
47 };
48}
49
50pub trait ServiceBuilder {
51 type Initiator: Initiator<Connection = Self::Connection>;
52 type Connection: Connection;
53
54 fn build(
55 config: &Arc<ServiceConfig>,
56 ) -> impl Future<Output = Result<Arc<Client<Self::Initiator, Self::Connection>>>> + Send;
57}
58
59#[cfg(feature = "transport-quic")]
60pub struct QuicServiceBuilder;
61
62#[cfg(feature = "transport-quic")]
63impl ServiceBuilder for QuicServiceBuilder {
64 type Initiator = QuicClient;
65 type Connection = QuicConnection;
66
67 async fn build(
68 config: &Arc<ServiceConfig>,
69 ) -> Result<Arc<Client<Self::Initiator, Self::Connection>>> {
70 let transport = quic_client_from_config(config).await?;
71 let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
72 let client = Arc::new(Client::new(transport, secret, None).await?);
73 Ok(client)
74 }
75}
76
77pub struct Service<T, C>
78where
79 T: Initiator<Connection = C>,
80 C: Connection,
81{
82 client: Arc<Client<T, C>>,
83 handles: Vec<JoinHandle<()>>,
84 shutdown_tx: broadcast::Sender<()>,
85 _transport: PhantomData<T>,
86 _connection: PhantomData<C>,
87}
88
89impl<T, C> Service<T, C>
90where
91 T: Initiator<Connection = C>,
92 C: Connection,
93{
94 pub async fn build<Builder>(config: Arc<ServiceConfig>) -> Result<Self>
95 where
96 Builder: ServiceBuilder<Initiator = T, Connection = C>,
97 {
98 let mut handles = Vec::new();
99 let client = Builder::build(&config).await?;
100 let (shutdown_tx, _) = broadcast::channel(1);
101
102 #[cfg(feature = "endpoint-http")]
104 if config.endpoint.http.is_some() {
105 handles.push(Self::spawn_endpoint(
106 "HTTP",
107 Self::endpoint_http_accept_loop(
108 config.clone(),
109 client.clone(),
110 shutdown_tx.subscribe(),
111 ),
112 ));
113 }
114
115 #[cfg(feature = "endpoint-socks")]
117 if config.endpoint.socks.is_some() {
118 handles.push(Self::spawn_endpoint(
119 "SOCKS",
120 Self::endpoint_socks_accept_loop(
121 config.clone(),
122 client.clone(),
123 shutdown_tx.subscribe(),
124 ),
125 ));
126 }
127
128 #[cfg(feature = "endpoint-tun")]
130 if let Some(tun_config) = &config.endpoint.tun
131 && (tun_config.tun_ipv4.is_some()
132 || tun_config.tun_ipv6.is_some()
133 || tun_config.tun_fd.is_some())
134 {
135 handles.push(Self::spawn_endpoint(
136 "TUN",
137 Self::endpoint_tun_accept_loop(
138 config.clone(),
139 client.clone(),
140 shutdown_tx.subscribe(),
141 ),
142 ));
143 }
144
145 if handles.is_empty() {
146 return Err(Error::Config(
147 "No endpoints were configured or enabled. The service has nothing to do."
148 .to_string(),
149 ));
150 }
151
152 Ok(Service {
153 client,
154 handles,
155 shutdown_tx,
156 _transport: PhantomData,
157 _connection: PhantomData,
158 })
159 }
160
161 pub async fn rebind(&self) -> io::Result<()> {
162 self.client.rebind().await
163 }
164
165 pub async fn shutdown(self) {
166 let _ = self.shutdown_tx.send(());
167
168 for handle in self.handles {
169 if let Err(_err) = handle.await {
170 error!("A task failed to shut down cleanly: {:?}", _err);
171 }
172 }
173 warn!("Service shutdown complete");
174 }
175
176 fn spawn_endpoint(
177 _name: &'static str,
178 task: impl Future<Output = Result<()>> + Send + 'static,
179 ) -> JoinHandle<()> {
180 tokio::spawn(async move {
181 if let Err(_e) = task.await {
182 error!("The {_name} endpoint shut down due to an error: {_e}");
183 }
184 })
185 }
186
187 #[cfg(feature = "endpoint-http")]
188 async fn endpoint_http_accept_loop(
189 config: Arc<ServiceConfig>,
190 ombrac: Arc<Client<T, C>>,
191 mut shutdown_rx: broadcast::Receiver<()>,
192 ) -> Result<()> {
193 use crate::endpoint::http::Server as HttpServer;
194 let bind_addr = require_config!(config.endpoint.http, "endpoint.http")?;
195 let socket = tokio::net::TcpListener::bind(bind_addr).await?;
196
197 info!("Starting HTTP/HTTPS endpoint, listening on {bind_addr}");
198
199 HttpServer::new(ombrac)
200 .accept_loop(socket, async {
201 let _ = shutdown_rx.recv().await;
202 })
203 .await
204 .map_err(|e| Error::Endpoint(format!("HTTP server failed to run: {}", e)))
205 }
206
207 #[cfg(feature = "endpoint-socks")]
208 async fn endpoint_socks_accept_loop(
209 config: Arc<ServiceConfig>,
210 ombrac: Arc<Client<T, C>>,
211 mut shutdown_rx: broadcast::Receiver<()>,
212 ) -> Result<()> {
213 use crate::endpoint::socks::CommandHandler;
214 use socks_lib::v5::server::auth::NoAuthentication;
215 use socks_lib::v5::server::{Config as SocksConfig, Server as SocksServer};
216
217 let bind_addr = require_config!(config.endpoint.socks, "endpoint.socks")?;
218 let socket = tokio::net::TcpListener::bind(bind_addr).await?;
219
220 info!("Starting SOCKS5 endpoint, listening on {bind_addr}");
221
222 let socks_config = Arc::new(SocksConfig::new(
223 NoAuthentication,
224 CommandHandler::new(ombrac),
225 ));
226 SocksServer::run(socket, socks_config, async {
227 let _ = shutdown_rx.recv().await;
228 })
229 .await
230 .map_err(|e| Error::Endpoint(format!("SOCKS server failed to run: {}", e)))
231 }
232
233 #[cfg(feature = "endpoint-tun")]
234 async fn endpoint_tun_accept_loop(
235 config: Arc<ServiceConfig>,
236 ombrac: Arc<Client<T, C>>,
237 mut shutdown_rx: broadcast::Receiver<()>,
238 ) -> Result<()> {
239 use crate::endpoint::tun::{AsyncDevice, Tun, TunConfig};
240
241 let config = require_config!(config.endpoint.tun.as_ref(), "endpoint.tun")?;
242
243 let device = match config.tun_fd {
244 Some(fd) => {
245 #[cfg(not(windows))]
246 unsafe {
247 AsyncDevice::from_fd(fd)?
248 }
249
250 #[cfg(windows)]
251 return Err(Error::Config(
252 "'tun_fd' option is not supported on Windows.".to_string(),
253 ));
254 }
255 None => {
256 #[cfg(not(any(target_os = "android", target_os = "ios")))]
257 {
258 let device = {
259 use std::str::FromStr;
260
261 let mut builder = tun_rs::DeviceBuilder::new();
262 builder = builder.mtu(config.tun_mtu.unwrap_or(1500));
263
264 if let Some(ip_str) = &config.tun_ipv4 {
265 let ip = ipnet::Ipv4Net::from_str(ip_str).map_err(|e| {
266 Error::Config(format!("Failed to parse IPv4 CIDR '{ip_str}': {e}"))
267 })?;
268 builder = builder.ipv4(ip.addr(), ip.netmask(), None);
269 }
270
271 if let Some(ip_str) = &config.tun_ipv6 {
272 let ip = ipnet::Ipv6Net::from_str(ip_str).map_err(|e| {
273 Error::Config(format!("Failed to parse IPv6 CIDR '{ip_str}': {e}"))
274 })?;
275 builder = builder.ipv6(ip.addr(), ip.netmask());
276 }
277
278 builder.build_async().map_err(|e| {
279 Error::Endpoint(format!("Failed to build TUN device: {e}"))
280 })?
281 };
282
283 info!(
284 "Starting TUN endpoint, Name: {:?}, MTU: {:?}, IP: {:?}",
285 device.name(),
286 device.mtu(),
287 device.addresses()
288 );
289
290 device
291 }
292
293 #[cfg(any(target_os = "android", target_os = "ios"))]
294 {
295 return Err(Error::Config(
296 "Creating a new TUN device is not supported on this platform. A pre-configured 'tun_fd' must be provided.".to_string()
297 ));
298 }
299 }
300 };
301
302 let mut tun_config = TunConfig::default();
303 if let Some(value) = &config.fake_dns {
304 tun_config.fakedns_cidr = value.parse().map_err(|e| {
305 Error::Config(format!("Failed to parse fake_dns CIDR '{value}': {e}"))
306 })?;
307 };
308
309 let tun = Tun::new(tun_config.into(), ombrac);
310 let shutdown_signal = async {
311 let _ = shutdown_rx.recv().await;
312 };
313
314 tun.accept_loop(device, shutdown_signal)
315 .await
316 .map_err(|e| Error::Endpoint(format!("TUN device runtime error: {}", e)))
317 }
318}
319
320#[cfg(feature = "transport-quic")]
321async fn quic_client_from_config(config: &ServiceConfig) -> io::Result<QuicClient> {
322 let server = &config.server;
323 let transport_cfg = &config.transport;
324
325 let server_name = match &transport_cfg.server_name {
326 Some(value) => value.clone(),
327 None => {
328 let pos = server.rfind(':').ok_or_else(|| {
329 io::Error::new(
330 io::ErrorKind::InvalidInput,
331 format!("Invalid server address: {}", server),
332 )
333 })?;
334 server[..pos].to_string()
335 }
336 };
337
338 let addrs: Vec<_> = tokio::net::lookup_host(server).await?.collect();
339 let server_addr = addrs.into_iter().next().ok_or_else(|| {
340 io::Error::new(
341 io::ErrorKind::NotFound,
342 format!("Failed to resolve server address: '{}'", server),
343 )
344 })?;
345
346 let mut quic_config = QuicConfig::new(server_addr, server_name);
347
348 quic_config.enable_zero_rtt = transport_cfg.zero_rtt.unwrap_or(false);
349 if let Some(protocols) = &transport_cfg.alpn_protocols {
350 quic_config.alpn_protocols = protocols.iter().map(|p| p.to_vec()).collect();
351 }
352
353 match transport_cfg.tls_mode.unwrap_or(TlsMode::Tls) {
354 TlsMode::Tls => {
355 if let Some(ca) = &transport_cfg.ca_cert {
356 quic_config.root_ca_path = Some(ca.to_path_buf());
357 }
358 }
359 TlsMode::MTls => {
360 quic_config.root_ca_path = Some(transport_cfg.ca_cert.clone().ok_or_else(|| {
361 io::Error::new(
362 io::ErrorKind::InvalidInput,
363 "CA cert is required for mTLS mode",
364 )
365 })?);
366 let client_cert = transport_cfg.client_cert.clone().ok_or_else(|| {
367 io::Error::new(
368 io::ErrorKind::InvalidInput,
369 "Client cert is required for mTLS mode",
370 )
371 })?;
372 let client_key = transport_cfg.client_key.clone().ok_or_else(|| {
373 io::Error::new(
374 io::ErrorKind::InvalidInput,
375 "Client key is required for mTLS mode",
376 )
377 })?;
378 quic_config.client_cert_key_paths = Some((client_cert, client_key));
379 }
380 TlsMode::Insecure => {
381 quic_config.skip_server_verification = true;
382 }
383 }
384
385 let mut transport_config = QuicTransportConfig::default();
386 if let Some(timeout) = transport_cfg.idle_timeout {
387 transport_config.max_idle_timeout(Duration::from_millis(timeout))?;
388 }
389 if let Some(interval) = transport_cfg.keep_alive {
390 transport_config.keep_alive_period(Duration::from_millis(interval))?;
391 }
392 if let Some(max_streams) = transport_cfg.max_streams {
393 transport_config.max_open_bidirectional_streams(max_streams)?;
394 }
395 if let Some(congestion) = transport_cfg.congestion {
396 transport_config.congestion(congestion, transport_cfg.cwnd_init)?;
397 }
398 quic_config.transport_config(transport_config);
399
400 Ok(QuicClient::new(quic_config)?)
401}