ombrac_client/
service.rs

1use std::io;
2use std::marker::PhantomData;
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::broadcast;
7use tokio::task::JoinHandle;
8
9use ombrac_macros::{error, info, warn};
10#[cfg(feature = "transport-quic")]
11use ombrac_transport::quic::{
12    Connection as QuicConnection, TransportConfig as QuicTransportConfig,
13    client::{Client as QuicClient, Config as QuicConfig},
14};
15use ombrac_transport::{Connection, Initiator};
16
17use crate::client::Client;
18use crate::config::ServiceConfig;
19#[cfg(feature = "transport-quic")]
20use crate::config::TlsMode;
21
22pub type Result<T> = std::result::Result<T, Error>;
23
24#[derive(thiserror::Error, Debug)]
25pub enum Error {
26    #[error("Configuration error: {0}")]
27    Config(String),
28
29    #[error("{0}")]
30    Io(#[from] io::Error),
31
32    #[error("Transport layer error: {0}")]
33    Transport(String),
34
35    #[error("Endpoint failed: {0}")]
36    Endpoint(String),
37}
38
39macro_rules! require_config {
40    ($config_opt:expr, $field_name:expr) => {
41        $config_opt.ok_or_else(|| {
42            Error::Config(format!(
43                "'{}' is required but was not provided",
44                $field_name
45            ))
46        })
47    };
48}
49
50pub trait ServiceBuilder {
51    type Initiator: Initiator<Connection = Self::Connection>;
52    type Connection: Connection;
53
54    fn build(
55        config: &Arc<ServiceConfig>,
56    ) -> impl Future<Output = Result<Arc<Client<Self::Initiator, Self::Connection>>>> + Send;
57}
58
59#[cfg(feature = "transport-quic")]
60pub struct QuicServiceBuilder;
61
62#[cfg(feature = "transport-quic")]
63impl ServiceBuilder for QuicServiceBuilder {
64    type Initiator = QuicClient;
65    type Connection = QuicConnection;
66
67    async fn build(
68        config: &Arc<ServiceConfig>,
69    ) -> Result<Arc<Client<Self::Initiator, Self::Connection>>> {
70        let transport = quic_client_from_config(config).await?;
71        let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
72        let client = Arc::new(
73            Client::new(
74                transport,
75                secret,
76                config.handshake_option.clone().map(Into::into),
77            )
78            .await?,
79        );
80        Ok(client)
81    }
82}
83
84pub struct Service<T, C>
85where
86    T: Initiator<Connection = C>,
87    C: Connection,
88{
89    client: Arc<Client<T, C>>,
90    handles: Vec<JoinHandle<()>>,
91    shutdown_tx: broadcast::Sender<()>,
92    _transport: PhantomData<T>,
93    _connection: PhantomData<C>,
94}
95
96impl<T, C> Service<T, C>
97where
98    T: Initiator<Connection = C>,
99    C: Connection,
100{
101    pub async fn build<Builder>(config: Arc<ServiceConfig>) -> Result<Self>
102    where
103        Builder: ServiceBuilder<Initiator = T, Connection = C>,
104    {
105        let mut handles = Vec::new();
106        let client = Builder::build(&config).await?;
107        let (shutdown_tx, _) = broadcast::channel(1);
108
109        // Start HTTP endpoint if configured
110        #[cfg(feature = "endpoint-http")]
111        if config.endpoint.http.is_some() {
112            handles.push(Self::spawn_endpoint(
113                "HTTP",
114                Self::endpoint_http_accept_loop(
115                    config.clone(),
116                    client.clone(),
117                    shutdown_tx.subscribe(),
118                ),
119            ));
120        }
121
122        // Start SOCKS endpoint if configured
123        #[cfg(feature = "endpoint-socks")]
124        if config.endpoint.socks.is_some() {
125            handles.push(Self::spawn_endpoint(
126                "SOCKS",
127                Self::endpoint_socks_accept_loop(
128                    config.clone(),
129                    client.clone(),
130                    shutdown_tx.subscribe(),
131                ),
132            ));
133        }
134
135        // Start TUN endpoint if configured
136        #[cfg(feature = "endpoint-tun")]
137        if let Some(tun_config) = &config.endpoint.tun
138            && (tun_config.tun_ipv4.is_some()
139                || tun_config.tun_ipv6.is_some()
140                || tun_config.tun_fd.is_some())
141        {
142            handles.push(Self::spawn_endpoint(
143                "TUN",
144                Self::endpoint_tun_accept_loop(
145                    config.clone(),
146                    client.clone(),
147                    shutdown_tx.subscribe(),
148                ),
149            ));
150        }
151
152        if handles.is_empty() {
153            return Err(Error::Config(
154                "No endpoints were configured or enabled. The service has nothing to do."
155                    .to_string(),
156            ));
157        }
158
159        Ok(Service {
160            client,
161            handles,
162            shutdown_tx,
163            _transport: PhantomData,
164            _connection: PhantomData,
165        })
166    }
167
168    pub async fn rebind(&self) -> io::Result<()> {
169        self.client.rebind().await
170    }
171
172    pub async fn shutdown(self) {
173        let _ = self.shutdown_tx.send(());
174
175        for handle in self.handles {
176            if let Err(_err) = handle.await {
177                error!("A task failed to shut down cleanly: {:?}", _err);
178            }
179        }
180        warn!("Service shutdown complete");
181    }
182
183    fn spawn_endpoint(
184        _name: &'static str,
185        task: impl Future<Output = Result<()>> + Send + 'static,
186    ) -> JoinHandle<()> {
187        tokio::spawn(async move {
188            if let Err(_e) = task.await {
189                error!("The {_name} endpoint shut down due to an error: {_e}");
190            }
191        })
192    }
193
194    #[cfg(feature = "endpoint-http")]
195    async fn endpoint_http_accept_loop(
196        config: Arc<ServiceConfig>,
197        ombrac: Arc<Client<T, C>>,
198        mut shutdown_rx: broadcast::Receiver<()>,
199    ) -> Result<()> {
200        use crate::endpoint::http::Server as HttpServer;
201        let bind_addr = require_config!(config.endpoint.http, "endpoint.http")?;
202        let socket = tokio::net::TcpListener::bind(bind_addr).await?;
203
204        info!("Starting HTTP/HTTPS endpoint, listening on {bind_addr}");
205
206        HttpServer::new(ombrac)
207            .accept_loop(socket, async {
208                let _ = shutdown_rx.recv().await;
209            })
210            .await
211            .map_err(|e| Error::Endpoint(format!("HTTP server failed to run: {}", e)))
212    }
213
214    #[cfg(feature = "endpoint-socks")]
215    async fn endpoint_socks_accept_loop(
216        config: Arc<ServiceConfig>,
217        ombrac: Arc<Client<T, C>>,
218        mut shutdown_rx: broadcast::Receiver<()>,
219    ) -> Result<()> {
220        use crate::endpoint::socks::CommandHandler;
221        use socks_lib::v5::server::auth::NoAuthentication;
222        use socks_lib::v5::server::{Config as SocksConfig, Server as SocksServer};
223
224        let bind_addr = require_config!(config.endpoint.socks, "endpoint.socks")?;
225        let socket = tokio::net::TcpListener::bind(bind_addr).await?;
226
227        info!("Starting SOCKS5 endpoint, listening on {bind_addr}");
228
229        let socks_config = Arc::new(SocksConfig::new(
230            NoAuthentication,
231            CommandHandler::new(ombrac),
232        ));
233        SocksServer::run(socket, socks_config, async {
234            let _ = shutdown_rx.recv().await;
235        })
236        .await
237        .map_err(|e| Error::Endpoint(format!("SOCKS server failed to run: {}", e)))
238    }
239
240    #[cfg(feature = "endpoint-tun")]
241    async fn endpoint_tun_accept_loop(
242        config: Arc<ServiceConfig>,
243        ombrac: Arc<Client<T, C>>,
244        mut shutdown_rx: broadcast::Receiver<()>,
245    ) -> Result<()> {
246        use crate::endpoint::tun::{AsyncDevice, Tun, TunConfig};
247
248        let config = require_config!(config.endpoint.tun.as_ref(), "endpoint.tun")?;
249
250        let device = match config.tun_fd {
251            Some(fd) => {
252                #[cfg(not(windows))]
253                unsafe {
254                    AsyncDevice::from_fd(fd)?
255                }
256
257                #[cfg(windows)]
258                return Err(Error::Config(
259                    "'tun_fd' option is not supported on Windows.".to_string(),
260                ));
261            }
262            None => {
263                #[cfg(not(any(target_os = "android", target_os = "ios")))]
264                {
265                    let device = {
266                        use std::str::FromStr;
267
268                        let mut builder = tun_rs::DeviceBuilder::new();
269                        builder = builder.mtu(config.tun_mtu.unwrap_or(1500));
270
271                        if let Some(ip_str) = &config.tun_ipv4 {
272                            let ip = ipnet::Ipv4Net::from_str(ip_str).map_err(|e| {
273                                Error::Config(format!("Failed to parse IPv4 CIDR '{ip_str}': {e}"))
274                            })?;
275                            builder = builder.ipv4(ip.addr(), ip.netmask(), None);
276                        }
277
278                        if let Some(ip_str) = &config.tun_ipv6 {
279                            let ip = ipnet::Ipv6Net::from_str(ip_str).map_err(|e| {
280                                Error::Config(format!("Failed to parse IPv6 CIDR '{ip_str}': {e}"))
281                            })?;
282                            builder = builder.ipv6(ip.addr(), ip.netmask());
283                        }
284
285                        builder.build_async().map_err(|e| {
286                            Error::Endpoint(format!("Failed to build TUN device: {e}"))
287                        })?
288                    };
289
290                    info!(
291                        "Starting TUN endpoint, Name: {:?}, MTU: {:?}, IP: {:?}",
292                        device.name(),
293                        device.mtu(),
294                        device.addresses()
295                    );
296
297                    device
298                }
299
300                #[cfg(any(target_os = "android", target_os = "ios"))]
301                {
302                    return Err(Error::Config(
303                        "Creating a new TUN device is not supported on this platform. A pre-configured 'tun_fd' must be provided.".to_string()
304                    ));
305                }
306            }
307        };
308
309        let mut tun_config = TunConfig::default();
310        if let Some(value) = &config.fake_dns {
311            tun_config.fakedns_cidr = value.parse().map_err(|e| {
312                Error::Config(format!("Failed to parse fake_dns CIDR '{value}': {e}"))
313            })?;
314        };
315
316        let tun = Tun::new(tun_config.into(), ombrac);
317        let shutdown_signal = async {
318            let _ = shutdown_rx.recv().await;
319        };
320
321        tun.accept_loop(device, shutdown_signal)
322            .await
323            .map_err(|e| Error::Endpoint(format!("TUN device runtime error: {}", e)))
324    }
325}
326
327#[cfg(feature = "transport-quic")]
328async fn quic_client_from_config(config: &ServiceConfig) -> io::Result<QuicClient> {
329    let server = &config.server;
330    let transport_cfg = &config.transport;
331
332    let server_name = match &transport_cfg.server_name {
333        Some(value) => value.clone(),
334        None => {
335            let pos = server.rfind(':').ok_or_else(|| {
336                io::Error::new(
337                    io::ErrorKind::InvalidInput,
338                    format!("Invalid server address: {}", server),
339                )
340            })?;
341            server[..pos].to_string()
342        }
343    };
344
345    let addrs: Vec<_> = tokio::net::lookup_host(server).await?.collect();
346    let server_addr = addrs.into_iter().next().ok_or_else(|| {
347        io::Error::new(
348            io::ErrorKind::NotFound,
349            format!("Failed to resolve server address: '{}'", server),
350        )
351    })?;
352
353    let mut quic_config = QuicConfig::new(server_addr, server_name);
354
355    quic_config.enable_zero_rtt = transport_cfg.zero_rtt.unwrap_or(false);
356    if let Some(protocols) = &transport_cfg.alpn_protocols {
357        quic_config.alpn_protocols = protocols.iter().map(|p| p.to_vec()).collect();
358    }
359
360    match transport_cfg.tls_mode.unwrap_or(TlsMode::Tls) {
361        TlsMode::Tls => {
362            if let Some(ca) = &transport_cfg.ca_cert {
363                quic_config.root_ca_path = Some(ca.to_path_buf());
364            }
365        }
366        TlsMode::MTls => {
367            quic_config.root_ca_path = Some(transport_cfg.ca_cert.clone().ok_or_else(|| {
368                io::Error::new(
369                    io::ErrorKind::InvalidInput,
370                    "CA cert is required for mTLS mode",
371                )
372            })?);
373            let client_cert = transport_cfg.client_cert.clone().ok_or_else(|| {
374                io::Error::new(
375                    io::ErrorKind::InvalidInput,
376                    "Client cert is required for mTLS mode",
377                )
378            })?;
379            let client_key = transport_cfg.client_key.clone().ok_or_else(|| {
380                io::Error::new(
381                    io::ErrorKind::InvalidInput,
382                    "Client key is required for mTLS mode",
383                )
384            })?;
385            quic_config.client_cert_key_paths = Some((client_cert, client_key));
386        }
387        TlsMode::Insecure => {
388            quic_config.skip_server_verification = true;
389        }
390    }
391
392    let mut transport_config = QuicTransportConfig::default();
393    if let Some(timeout) = transport_cfg.idle_timeout {
394        transport_config.max_idle_timeout(Duration::from_millis(timeout))?;
395    }
396    if let Some(interval) = transport_cfg.keep_alive {
397        transport_config.keep_alive_period(Duration::from_millis(interval))?;
398    }
399    if let Some(max_streams) = transport_cfg.max_streams {
400        transport_config.max_open_bidirectional_streams(max_streams)?;
401    }
402    if let Some(congestion) = transport_cfg.congestion {
403        transport_config.congestion(congestion, transport_cfg.cwnd_init)?;
404    }
405    quic_config.transport_config(transport_config);
406
407    Ok(QuicClient::new(quic_config)?)
408}