Skip to main content

ombrac_client/
service.rs

1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::broadcast;
6use tokio::task::JoinHandle;
7
8use ombrac_macros::{error, info};
9use ombrac_transport::quic::Connection as QuicConnection;
10use ombrac_transport::quic::TransportConfig as QuicTransportConfig;
11use ombrac_transport::quic::client::Client as QuicClient;
12use ombrac_transport::quic::client::Config as QuicConfig;
13use ombrac_transport::quic::error::Error as QuicError;
14
15use crate::client::Client;
16use crate::config::{ServiceConfig, TlsMode};
17
18pub type Result<T> = std::result::Result<T, Error>;
19
20#[derive(thiserror::Error, Debug)]
21pub enum Error {
22    #[error(transparent)]
23    Io(#[from] io::Error),
24
25    #[error("{0}")]
26    Config(String),
27
28    #[error(transparent)]
29    Quic(#[from] QuicError),
30
31    #[error("{0}")]
32    Endpoint(String),
33}
34
35#[cfg(any(
36    feature = "endpoint-default",
37    feature = "endpoint-socks",
38    feature = "endpoint-http",
39    feature = "endpoint-tun"
40))]
41macro_rules! require_config {
42    ($config_opt:expr, $field_name:expr) => {
43        $config_opt.ok_or_else(|| {
44            Error::Config(format!(
45                "'{}' is required but was not provided",
46                $field_name
47            ))
48        })
49    };
50}
51
52/// OmbracClient provides a simple, easy-to-use API for starting and managing
53/// the ombrac client using QUIC transport.
54///
55/// This struct hides all transport-specific implementation details and provides
56/// a clean interface for external users.
57///
58/// # Example
59///
60/// ```no_run
61/// use ombrac_client::{OmbracClient, ServiceConfig};
62/// use std::sync::Arc;
63///
64/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
65/// let config = Arc::new(ServiceConfig {
66///     secret: "my-secret".to_string(),
67///     server: "server.example.com:8080".to_string(),
68///     auth_option: None,
69///     endpoint: Default::default(),
70///     transport: Default::default(),
71///     logging: Default::default(),
72/// });
73///
74/// let client = OmbracClient::build(config).await?;
75/// // ... use client ...
76/// client.shutdown().await;
77/// # Ok(())
78/// # }
79/// ```
80pub struct OmbracClient {
81    client: Arc<Client<QuicClient, QuicConnection>>,
82    handles: Vec<JoinHandle<()>>,
83    shutdown_tx: broadcast::Sender<()>,
84}
85
86impl OmbracClient {
87    /// Builds a new client instance from the configuration.
88    ///
89    /// This method:
90    /// 1. Creates a QUIC client from the transport configuration
91    /// 2. Establishes connection with authentication
92    /// 3. Spawns endpoint tasks if configured
93    /// 4. Returns an OmbracClient handle for lifecycle management
94    ///
95    /// # Arguments
96    ///
97    /// * `config` - The service configuration containing transport, endpoint, and secret settings
98    ///
99    /// # Returns
100    ///
101    /// A configured `OmbracClient` instance ready to use, or an error
102    /// if configuration is invalid or client setup fails.
103    pub async fn build(config: Arc<ServiceConfig>) -> Result<Self> {
104        // Build QUIC client from config
105        let transport = quic_client_from_config(&config).await?;
106
107        info!("binding udp socket to {}", transport.local_addr()?);
108
109        let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
110        let client = Arc::new(
111            Client::new(
112                transport,
113                secret,
114                config.auth_option.clone().map(Into::into),
115            )
116            .await
117            .map_err(Error::Io)?,
118        );
119
120        let mut _handles = Vec::new();
121        let (shutdown_tx, _) = broadcast::channel(1);
122
123        // Start HTTP endpoint if configured
124        #[cfg(feature = "endpoint-http")]
125        if config.endpoint.http.is_some() {
126            _handles.push(Self::spawn_endpoint(
127                "HTTP",
128                Self::endpoint_http_accept_loop(
129                    config.clone(),
130                    client.clone(),
131                    shutdown_tx.subscribe(),
132                ),
133            ));
134        }
135
136        // Start SOCKS endpoint if configured
137        #[cfg(feature = "endpoint-socks")]
138        if config.endpoint.socks.is_some() {
139            _handles.push(Self::spawn_endpoint(
140                "SOCKS",
141                Self::endpoint_socks_accept_loop(
142                    config.clone(),
143                    client.clone(),
144                    shutdown_tx.subscribe(),
145                ),
146            ));
147        }
148
149        // Start TUN endpoint if configured
150        #[cfg(feature = "endpoint-tun")]
151        if let Some(tun_config) = &config.endpoint.tun
152            && (tun_config.tun_ipv4.is_some()
153                || tun_config.tun_ipv6.is_some()
154                || tun_config.tun_fd.is_some())
155        {
156            _handles.push(Self::spawn_endpoint(
157                "TUN",
158                Self::endpoint_tun_accept_loop(
159                    config.clone(),
160                    client.clone(),
161                    shutdown_tx.subscribe(),
162                ),
163            ));
164        }
165
166        if _handles.is_empty() {
167            return Err(Error::Config(
168                "no endpoints were configured or enabled, the service has nothing to do."
169                    .to_string(),
170            ));
171        }
172
173        Ok(OmbracClient {
174            client,
175            handles: _handles,
176            shutdown_tx,
177        })
178    }
179
180    /// Rebind the transport to a new socket to ensure a clean state for reconnection.
181    pub async fn rebind(&self) -> io::Result<()> {
182        self.client.rebind().await
183    }
184
185    pub fn client(&self) -> &Arc<Client<QuicClient, QuicConnection>> {
186        &self.client
187    }
188
189    /// Gracefully shuts down the client.
190    ///
191    /// This method will:
192    /// 1. Send a shutdown signal to stop accepting new connections
193    /// 2. Wait for all endpoint tasks to finish gracefully
194    /// 3. Close existing connections
195    ///
196    /// # Example
197    ///
198    /// ```no_run
199    /// # use ombrac_client::OmbracClient;
200    /// # use std::sync::Arc;
201    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
202    /// # let config = Arc::new(ombrac_client::ServiceConfig {
203    /// #     secret: "test".to_string(),
204    /// #     server: "server:8080".to_string(),
205    /// #     auth_option: None,
206    /// #     endpoint: Default::default(),
207    /// #     transport: Default::default(),
208    /// #     logging: Default::default(),
209    /// # });
210    /// # let client = OmbracClient::build(config).await?;
211    /// client.shutdown().await;
212    /// # Ok(())
213    /// # }
214    /// ```
215    pub async fn shutdown(self) {
216        let _ = self.shutdown_tx.send(());
217
218        for handle in self.handles {
219            if let Err(_err) = handle.await {
220                error!("task failed to shut down cleanly: {:?}", _err);
221            }
222        }
223    }
224
225    #[cfg(any(
226        feature = "endpoint-default",
227        feature = "endpoint-socks",
228        feature = "endpoint-http",
229        feature = "endpoint-tun"
230    ))]
231    fn spawn_endpoint(
232        _name: &'static str,
233        task: impl std::future::Future<Output = Result<()>> + Send + 'static,
234    ) -> JoinHandle<()> {
235        tokio::spawn(async move {
236            if let Err(_e) = task.await {
237                error!("the {_name} endpoint shut down due to an error: {_e}");
238            }
239        })
240    }
241
242    #[cfg(feature = "endpoint-http")]
243    async fn endpoint_http_accept_loop(
244        config: Arc<ServiceConfig>,
245        ombrac: Arc<Client<QuicClient, QuicConnection>>,
246        mut shutdown_rx: broadcast::Receiver<()>,
247    ) -> Result<()> {
248        use crate::endpoint::http::Server as HttpServer;
249        let bind_addr = require_config!(config.endpoint.http, "endpoint.http")?;
250        let socket = tokio::net::TcpListener::bind(bind_addr).await?;
251
252        info!("starting http/https endpoint, listening on {bind_addr}");
253
254        HttpServer::new(ombrac)
255            .accept_loop(socket, async {
256                let _ = shutdown_rx.recv().await;
257            })
258            .await
259            .map_err(|e| Error::Endpoint(format!("http server failed to run: {}", e)))
260    }
261
262    #[cfg(feature = "endpoint-socks")]
263    async fn endpoint_socks_accept_loop(
264        config: Arc<ServiceConfig>,
265        ombrac: Arc<Client<QuicClient, QuicConnection>>,
266        mut shutdown_rx: broadcast::Receiver<()>,
267    ) -> Result<()> {
268        use crate::endpoint::socks::CommandHandler;
269        use socks_lib::v5::server::auth::NoAuthentication;
270        use socks_lib::v5::server::{Config as SocksConfig, Server as SocksServer};
271
272        let bind_addr = require_config!(config.endpoint.socks, "endpoint.socks")?;
273        let socket = tokio::net::TcpListener::bind(bind_addr).await?;
274
275        info!("starting socks endpoint, listening on {bind_addr}");
276
277        let socks_config = Arc::new(SocksConfig::new(
278            NoAuthentication,
279            CommandHandler::new(ombrac),
280        ));
281        SocksServer::run(socket, socks_config, async {
282            let _ = shutdown_rx.recv().await;
283        })
284        .await
285        .map_err(|e| Error::Endpoint(format!("socks server failed to run: {}", e)))
286    }
287
288    #[cfg(feature = "endpoint-tun")]
289    async fn endpoint_tun_accept_loop(
290        config: Arc<ServiceConfig>,
291        ombrac: Arc<Client<QuicClient, QuicConnection>>,
292        mut shutdown_rx: broadcast::Receiver<()>,
293    ) -> Result<()> {
294        use crate::endpoint::tun::{AsyncDevice, Tun, TunConfig};
295
296        let config = require_config!(config.endpoint.tun.as_ref(), "endpoint.tun")?;
297
298        let device = match config.tun_fd {
299            Some(fd) => {
300                #[cfg(not(windows))]
301                unsafe {
302                    AsyncDevice::from_fd(fd)?
303                }
304
305                #[cfg(windows)]
306                return Err(Error::Config(
307                    "'tun_fd' option is not supported on Windows.".to_string(),
308                ));
309            }
310            None => {
311                #[cfg(not(any(target_os = "android", target_os = "ios")))]
312                {
313                    let device = {
314                        use std::str::FromStr;
315
316                        let mut builder = tun_rs::DeviceBuilder::new();
317                        builder = builder.mtu(config.tun_mtu.unwrap_or(1500));
318
319                        if let Some(ip_str) = &config.tun_ipv4 {
320                            let ip = ipnet::Ipv4Net::from_str(ip_str).map_err(|e| {
321                                Error::Config(format!("failed to parse ipv4 cidr '{ip_str}': {e}"))
322                            })?;
323                            builder = builder.ipv4(ip.addr(), ip.netmask(), None);
324                        }
325
326                        if let Some(ip_str) = &config.tun_ipv6 {
327                            let ip = ipnet::Ipv6Net::from_str(ip_str).map_err(|e| {
328                                Error::Config(format!("failed to parse ipv6 cidr '{ip_str}': {e}"))
329                            })?;
330                            builder = builder.ipv6(ip.addr(), ip.netmask());
331                        }
332
333                        builder.build_async().map_err(|e| {
334                            Error::Endpoint(format!("failed to build tun device: {e}"))
335                        })?
336                    };
337
338                    info!(
339                        "starting tun endpoint, name: {:?}, mtu: {:?}, ip: {:?}",
340                        device.name(),
341                        device.mtu(),
342                        device.addresses()
343                    );
344
345                    device
346                }
347
348                #[cfg(any(target_os = "android", target_os = "ios"))]
349                {
350                    return Err(Error::Config(
351                        "creating a new tun device is not supported on this platform. A pre-configured 'tun_fd' must be provided.".to_string()
352                    ));
353                }
354            }
355        };
356
357        let mut tun_config = TunConfig::default();
358        if let Some(value) = &config.fake_dns {
359            tun_config.fakedns_cidr = value.parse().map_err(|e| {
360                Error::Config(format!("failed to parse fake_dns cidr '{value}': {e}"))
361            })?;
362        };
363        if let Some(value) = config.disable_udp_443 {
364            tun_config.disable_udp_443 = value;
365        };
366
367        let tun = Tun::new(tun_config.into(), ombrac);
368        let shutdown_signal = async {
369            let _ = shutdown_rx.recv().await;
370        };
371
372        tun.accept_loop(device, shutdown_signal)
373            .await
374            .map_err(|e| Error::Endpoint(format!("tun device runtime error: {}", e)))
375    }
376}
377
378async fn quic_client_from_config(config: &ServiceConfig) -> io::Result<QuicClient> {
379    let server = &config.server;
380    let transport_cfg = &config.transport;
381
382    let server_name = match &transport_cfg.server_name {
383        Some(value) => value.clone(),
384        None => {
385            let pos = server.rfind(':').ok_or_else(|| {
386                io::Error::new(
387                    io::ErrorKind::InvalidInput,
388                    format!("invalid server address: {}", server),
389                )
390            })?;
391            server[..pos].to_string()
392        }
393    };
394
395    let addrs: Vec<_> = tokio::net::lookup_host(server).await?.collect();
396    let server_addr = addrs.into_iter().next().ok_or_else(|| {
397        io::Error::new(
398            io::ErrorKind::NotFound,
399            format!("failed to resolve server address: '{}'", server),
400        )
401    })?;
402
403    let mut quic_config = QuicConfig::new(server_addr, server_name);
404
405    quic_config.enable_zero_rtt = transport_cfg.zero_rtt.unwrap_or(false);
406    if let Some(protocols) = &transport_cfg.alpn_protocols {
407        quic_config.alpn_protocols = protocols.iter().map(|p| p.to_vec()).collect();
408    }
409
410    match transport_cfg.tls_mode.unwrap_or(TlsMode::Tls) {
411        TlsMode::Tls => {
412            if let Some(ca) = &transport_cfg.ca_cert {
413                quic_config.root_ca_path = Some(ca.to_path_buf());
414            }
415        }
416        TlsMode::MTls => {
417            quic_config.root_ca_path = Some(transport_cfg.ca_cert.clone().ok_or_else(|| {
418                io::Error::new(
419                    io::ErrorKind::InvalidInput,
420                    "ca cert is required for mutual tls mode",
421                )
422            })?);
423            let client_cert = transport_cfg.client_cert.clone().ok_or_else(|| {
424                io::Error::new(
425                    io::ErrorKind::InvalidInput,
426                    "client cert is required for mutual tls mode",
427                )
428            })?;
429            let client_key = transport_cfg.client_key.clone().ok_or_else(|| {
430                io::Error::new(
431                    io::ErrorKind::InvalidInput,
432                    "client key is required for mutual tls mode",
433                )
434            })?;
435            quic_config.client_cert_key_paths = Some((client_cert, client_key));
436        }
437        TlsMode::Insecure => {
438            quic_config.skip_server_verification = true;
439        }
440    }
441
442    let mut transport_config = QuicTransportConfig::default();
443    if let Some(timeout) = transport_cfg.idle_timeout {
444        transport_config.max_idle_timeout(Duration::from_millis(timeout))?;
445    }
446    if let Some(interval) = transport_cfg.keep_alive {
447        transport_config.keep_alive_period(Duration::from_millis(interval))?;
448    }
449    if let Some(max_streams) = transport_cfg.max_streams {
450        transport_config.max_open_bidirectional_streams(max_streams)?;
451    }
452    if let Some(congestion) = transport_cfg.congestion {
453        transport_config.congestion(congestion, transport_cfg.cwnd_init)?;
454    }
455    quic_config.transport_config(transport_config);
456
457    Ok(QuicClient::new(quic_config)?)
458}