1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::broadcast;
6use tokio::task::JoinHandle;
7
8use ombrac_macros::{error, info};
9use ombrac_transport::quic::Connection as QuicConnection;
10use ombrac_transport::quic::TransportConfig as QuicTransportConfig;
11use ombrac_transport::quic::client::Client as QuicClient;
12use ombrac_transport::quic::client::Config as QuicConfig;
13use ombrac_transport::quic::error::Error as QuicError;
14
15use crate::client::Client;
16use crate::config::{ServiceConfig, TlsMode};
17
18pub type Result<T> = std::result::Result<T, Error>;
19
20#[derive(thiserror::Error, Debug)]
21pub enum Error {
22 #[error(transparent)]
23 Io(#[from] io::Error),
24
25 #[error("{0}")]
26 Config(String),
27
28 #[error(transparent)]
29 Quic(#[from] QuicError),
30
31 #[error("{0}")]
32 Endpoint(String),
33}
34
35#[cfg(any(
36 feature = "endpoint-default",
37 feature = "endpoint-socks",
38 feature = "endpoint-http",
39 feature = "endpoint-tun"
40))]
41macro_rules! require_config {
42 ($config_opt:expr, $field_name:expr) => {
43 $config_opt.ok_or_else(|| {
44 Error::Config(format!(
45 "'{}' is required but was not provided",
46 $field_name
47 ))
48 })
49 };
50}
51
52pub struct OmbracClient {
81 client: Arc<Client<QuicClient, QuicConnection>>,
82 handles: Vec<JoinHandle<()>>,
83 shutdown_tx: broadcast::Sender<()>,
84}
85
86impl OmbracClient {
87 pub async fn build(config: Arc<ServiceConfig>) -> Result<Self> {
104 let transport = quic_client_from_config(&config).await?;
106
107 info!("binding udp socket to {}", transport.local_addr()?);
108
109 let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
110 let client = Arc::new(
111 Client::new(
112 transport,
113 secret,
114 config.auth_option.clone().map(Into::into),
115 )
116 .await
117 .map_err(Error::Io)?,
118 );
119
120 let mut _handles = Vec::new();
121 let (shutdown_tx, _) = broadcast::channel(1);
122
123 #[cfg(feature = "endpoint-http")]
125 if config.endpoint.http.is_some() {
126 _handles.push(Self::spawn_endpoint(
127 "HTTP",
128 Self::endpoint_http_accept_loop(
129 config.clone(),
130 client.clone(),
131 shutdown_tx.subscribe(),
132 ),
133 ));
134 }
135
136 #[cfg(feature = "endpoint-socks")]
138 if config.endpoint.socks.is_some() {
139 _handles.push(Self::spawn_endpoint(
140 "SOCKS",
141 Self::endpoint_socks_accept_loop(
142 config.clone(),
143 client.clone(),
144 shutdown_tx.subscribe(),
145 ),
146 ));
147 }
148
149 #[cfg(feature = "endpoint-tun")]
151 if let Some(tun_config) = &config.endpoint.tun
152 && (tun_config.tun_ipv4.is_some()
153 || tun_config.tun_ipv6.is_some()
154 || tun_config.tun_fd.is_some())
155 {
156 _handles.push(Self::spawn_endpoint(
157 "TUN",
158 Self::endpoint_tun_accept_loop(
159 config.clone(),
160 client.clone(),
161 shutdown_tx.subscribe(),
162 ),
163 ));
164 }
165
166 if _handles.is_empty() {
167 return Err(Error::Config(
168 "no endpoints were configured or enabled, the service has nothing to do."
169 .to_string(),
170 ));
171 }
172
173 Ok(OmbracClient {
174 client,
175 handles: _handles,
176 shutdown_tx,
177 })
178 }
179
180 pub async fn rebind(&self) -> io::Result<()> {
182 self.client.rebind().await
183 }
184
185 pub fn client(&self) -> &Arc<Client<QuicClient, QuicConnection>> {
186 &self.client
187 }
188
189 pub async fn shutdown(self) {
216 let _ = self.shutdown_tx.send(());
217
218 for handle in self.handles {
219 if let Err(_err) = handle.await {
220 error!("task failed to shut down cleanly: {:?}", _err);
221 }
222 }
223 }
224
225 #[cfg(any(
226 feature = "endpoint-default",
227 feature = "endpoint-socks",
228 feature = "endpoint-http",
229 feature = "endpoint-tun"
230 ))]
231 fn spawn_endpoint(
232 _name: &'static str,
233 task: impl std::future::Future<Output = Result<()>> + Send + 'static,
234 ) -> JoinHandle<()> {
235 tokio::spawn(async move {
236 if let Err(_e) = task.await {
237 error!("the {_name} endpoint shut down due to an error: {_e}");
238 }
239 })
240 }
241
242 #[cfg(feature = "endpoint-http")]
243 async fn endpoint_http_accept_loop(
244 config: Arc<ServiceConfig>,
245 ombrac: Arc<Client<QuicClient, QuicConnection>>,
246 mut shutdown_rx: broadcast::Receiver<()>,
247 ) -> Result<()> {
248 use crate::endpoint::http::Server as HttpServer;
249 let bind_addr = require_config!(config.endpoint.http, "endpoint.http")?;
250 let socket = tokio::net::TcpListener::bind(bind_addr).await?;
251
252 info!("starting http/https endpoint, listening on {bind_addr}");
253
254 HttpServer::new(ombrac)
255 .accept_loop(socket, async {
256 let _ = shutdown_rx.recv().await;
257 })
258 .await
259 .map_err(|e| Error::Endpoint(format!("http server failed to run: {}", e)))
260 }
261
262 #[cfg(feature = "endpoint-socks")]
263 async fn endpoint_socks_accept_loop(
264 config: Arc<ServiceConfig>,
265 ombrac: Arc<Client<QuicClient, QuicConnection>>,
266 mut shutdown_rx: broadcast::Receiver<()>,
267 ) -> Result<()> {
268 use crate::endpoint::socks::CommandHandler;
269 use socks_lib::v5::server::auth::NoAuthentication;
270 use socks_lib::v5::server::{Config as SocksConfig, Server as SocksServer};
271
272 let bind_addr = require_config!(config.endpoint.socks, "endpoint.socks")?;
273 let socket = tokio::net::TcpListener::bind(bind_addr).await?;
274
275 info!("starting socks endpoint, listening on {bind_addr}");
276
277 let socks_config = Arc::new(SocksConfig::new(
278 NoAuthentication,
279 CommandHandler::new(ombrac),
280 ));
281 SocksServer::run(socket, socks_config, async {
282 let _ = shutdown_rx.recv().await;
283 })
284 .await
285 .map_err(|e| Error::Endpoint(format!("socks server failed to run: {}", e)))
286 }
287
288 #[cfg(feature = "endpoint-tun")]
289 async fn endpoint_tun_accept_loop(
290 config: Arc<ServiceConfig>,
291 ombrac: Arc<Client<QuicClient, QuicConnection>>,
292 mut shutdown_rx: broadcast::Receiver<()>,
293 ) -> Result<()> {
294 use crate::endpoint::tun::{AsyncDevice, Tun, TunConfig};
295
296 let config = require_config!(config.endpoint.tun.as_ref(), "endpoint.tun")?;
297
298 let device = match config.tun_fd {
299 Some(fd) => {
300 #[cfg(not(windows))]
301 unsafe {
302 AsyncDevice::from_fd(fd)?
303 }
304
305 #[cfg(windows)]
306 return Err(Error::Config(
307 "'tun_fd' option is not supported on Windows.".to_string(),
308 ));
309 }
310 None => {
311 #[cfg(not(any(target_os = "android", target_os = "ios")))]
312 {
313 let device = {
314 use std::str::FromStr;
315
316 let mut builder = tun_rs::DeviceBuilder::new();
317 builder = builder.mtu(config.tun_mtu.unwrap_or(1500));
318
319 if let Some(ip_str) = &config.tun_ipv4 {
320 let ip = ipnet::Ipv4Net::from_str(ip_str).map_err(|e| {
321 Error::Config(format!("failed to parse ipv4 cidr '{ip_str}': {e}"))
322 })?;
323 builder = builder.ipv4(ip.addr(), ip.netmask(), None);
324 }
325
326 if let Some(ip_str) = &config.tun_ipv6 {
327 let ip = ipnet::Ipv6Net::from_str(ip_str).map_err(|e| {
328 Error::Config(format!("failed to parse ipv6 cidr '{ip_str}': {e}"))
329 })?;
330 builder = builder.ipv6(ip.addr(), ip.netmask());
331 }
332
333 builder.build_async().map_err(|e| {
334 Error::Endpoint(format!("failed to build tun device: {e}"))
335 })?
336 };
337
338 info!(
339 "starting tun endpoint, name: {:?}, mtu: {:?}, ip: {:?}",
340 device.name(),
341 device.mtu(),
342 device.addresses()
343 );
344
345 device
346 }
347
348 #[cfg(any(target_os = "android", target_os = "ios"))]
349 {
350 return Err(Error::Config(
351 "creating a new tun device is not supported on this platform. A pre-configured 'tun_fd' must be provided.".to_string()
352 ));
353 }
354 }
355 };
356
357 let mut tun_config = TunConfig::default();
358 if let Some(value) = &config.fake_dns {
359 tun_config.fakedns_cidr = value.parse().map_err(|e| {
360 Error::Config(format!("failed to parse fake_dns cidr '{value}': {e}"))
361 })?;
362 };
363 if let Some(value) = config.disable_udp_443 {
364 tun_config.disable_udp_443 = value;
365 };
366
367 let tun = Tun::new(tun_config.into(), ombrac);
368 let shutdown_signal = async {
369 let _ = shutdown_rx.recv().await;
370 };
371
372 tun.accept_loop(device, shutdown_signal)
373 .await
374 .map_err(|e| Error::Endpoint(format!("tun device runtime error: {}", e)))
375 }
376}
377
378async fn quic_client_from_config(config: &ServiceConfig) -> io::Result<QuicClient> {
379 let server = &config.server;
380 let transport_cfg = &config.transport;
381
382 let server_name = match &transport_cfg.server_name {
383 Some(value) => value.clone(),
384 None => {
385 let pos = server.rfind(':').ok_or_else(|| {
386 io::Error::new(
387 io::ErrorKind::InvalidInput,
388 format!("invalid server address: {}", server),
389 )
390 })?;
391 server[..pos].to_string()
392 }
393 };
394
395 let addrs: Vec<_> = tokio::net::lookup_host(server).await?.collect();
396 let server_addr = addrs.into_iter().next().ok_or_else(|| {
397 io::Error::new(
398 io::ErrorKind::NotFound,
399 format!("failed to resolve server address: '{}'", server),
400 )
401 })?;
402
403 let mut quic_config = QuicConfig::new(server_addr, server_name);
404
405 quic_config.enable_zero_rtt = transport_cfg.zero_rtt.unwrap_or(false);
406 if let Some(protocols) = &transport_cfg.alpn_protocols {
407 quic_config.alpn_protocols = protocols.iter().map(|p| p.to_vec()).collect();
408 }
409
410 match transport_cfg.tls_mode.unwrap_or(TlsMode::Tls) {
411 TlsMode::Tls => {
412 if let Some(ca) = &transport_cfg.ca_cert {
413 quic_config.root_ca_path = Some(ca.to_path_buf());
414 }
415 }
416 TlsMode::MTls => {
417 quic_config.root_ca_path = Some(transport_cfg.ca_cert.clone().ok_or_else(|| {
418 io::Error::new(
419 io::ErrorKind::InvalidInput,
420 "ca cert is required for mutual tls mode",
421 )
422 })?);
423 let client_cert = transport_cfg.client_cert.clone().ok_or_else(|| {
424 io::Error::new(
425 io::ErrorKind::InvalidInput,
426 "client cert is required for mutual tls mode",
427 )
428 })?;
429 let client_key = transport_cfg.client_key.clone().ok_or_else(|| {
430 io::Error::new(
431 io::ErrorKind::InvalidInput,
432 "client key is required for mutual tls mode",
433 )
434 })?;
435 quic_config.client_cert_key_paths = Some((client_cert, client_key));
436 }
437 TlsMode::Insecure => {
438 quic_config.skip_server_verification = true;
439 }
440 }
441
442 let mut transport_config = QuicTransportConfig::default();
443 if let Some(timeout) = transport_cfg.idle_timeout {
444 transport_config.max_idle_timeout(Duration::from_millis(timeout))?;
445 }
446 if let Some(interval) = transport_cfg.keep_alive {
447 transport_config.keep_alive_period(Duration::from_millis(interval))?;
448 }
449 if let Some(max_streams) = transport_cfg.max_streams {
450 transport_config.max_open_bidirectional_streams(max_streams)?;
451 }
452 if let Some(congestion) = transport_cfg.congestion {
453 transport_config.congestion(congestion, transport_cfg.cwnd_init)?;
454 }
455 quic_config.transport_config(transport_config);
456
457 Ok(QuicClient::new(quic_config)?)
458}