ombrac_client/
service.rs

1use std::io;
2use std::marker::PhantomData;
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::broadcast;
7use tokio::task::JoinHandle;
8
9use ombrac_macros::{error, info, warn};
10#[cfg(feature = "transport-quic")]
11use ombrac_transport::quic::{
12    Connection as QuicConnection, TransportConfig as QuicTransportConfig,
13    client::{Client as QuicClient, Config as QuicConfig},
14};
15use ombrac_transport::{Connection, Initiator};
16
17use crate::client::Client;
18use crate::config::ServiceConfig;
19#[cfg(feature = "transport-quic")]
20use crate::config::TlsMode;
21
22pub type Result<T> = std::result::Result<T, Error>;
23
24#[derive(thiserror::Error, Debug)]
25pub enum Error {
26    #[error("Configuration error: {0}")]
27    Config(String),
28
29    #[error("{0}")]
30    Io(#[from] io::Error),
31
32    #[error("Transport layer error: {0}")]
33    Transport(String),
34
35    #[error("Endpoint failed: {0}")]
36    Endpoint(String),
37}
38
39macro_rules! require_config {
40    ($config_opt:expr, $field_name:expr) => {
41        $config_opt.ok_or_else(|| {
42            Error::Config(format!(
43                "'{}' is required but was not provided",
44                $field_name
45            ))
46        })
47    };
48}
49
50pub trait ServiceBuilder {
51    type Initiator: Initiator<Connection = Self::Connection>;
52    type Connection: Connection;
53
54    fn build(
55        config: &Arc<ServiceConfig>,
56    ) -> impl Future<Output = Result<Arc<Client<Self::Initiator, Self::Connection>>>> + Send;
57}
58
59#[cfg(feature = "transport-quic")]
60pub struct QuicServiceBuilder;
61
62#[cfg(feature = "transport-quic")]
63impl ServiceBuilder for QuicServiceBuilder {
64    type Initiator = QuicClient;
65    type Connection = QuicConnection;
66
67    async fn build(
68        config: &Arc<ServiceConfig>,
69    ) -> Result<Arc<Client<Self::Initiator, Self::Connection>>> {
70        let transport = quic_client_from_config(config).await?;
71        let secret = *blake3::hash(config.secret.as_bytes()).as_bytes();
72        let client = Arc::new(Client::new(transport, secret, None).await?);
73        Ok(client)
74    }
75}
76
77pub struct Service<T, C>
78where
79    T: Initiator<Connection = C>,
80    C: Connection,
81{
82    client: Arc<Client<T, C>>,
83    handles: Vec<JoinHandle<()>>,
84    shutdown_tx: broadcast::Sender<()>,
85    _transport: PhantomData<T>,
86    _connection: PhantomData<C>,
87}
88
89impl<T, C> Service<T, C>
90where
91    T: Initiator<Connection = C>,
92    C: Connection,
93{
94    pub async fn build<Builder>(config: Arc<ServiceConfig>) -> Result<Self>
95    where
96        Builder: ServiceBuilder<Initiator = T, Connection = C>,
97    {
98        let mut handles = Vec::new();
99        let client = Builder::build(&config).await?;
100        let (shutdown_tx, _) = broadcast::channel(1);
101
102        // Start HTTP endpoint if configured
103        #[cfg(feature = "endpoint-http")]
104        if config.endpoint.http.is_some() {
105            handles.push(Self::spawn_endpoint(
106                "HTTP",
107                Self::endpoint_http_accept_loop(
108                    config.clone(),
109                    client.clone(),
110                    shutdown_tx.subscribe(),
111                ),
112            ));
113        }
114
115        // Start SOCKS endpoint if configured
116        #[cfg(feature = "endpoint-socks")]
117        if config.endpoint.socks.is_some() {
118            handles.push(Self::spawn_endpoint(
119                "SOCKS",
120                Self::endpoint_socks_accept_loop(
121                    config.clone(),
122                    client.clone(),
123                    shutdown_tx.subscribe(),
124                ),
125            ));
126        }
127
128        // Start TUN endpoint if configured
129        #[cfg(feature = "endpoint-tun")]
130        if let Some(tun_config) = &config.endpoint.tun
131            && (tun_config.tun_ipv4.is_some()
132                || tun_config.tun_ipv6.is_some()
133                || tun_config.tun_fd.is_some())
134        {
135            handles.push(Self::spawn_endpoint(
136                "TUN",
137                Self::endpoint_tun_accept_loop(
138                    config.clone(),
139                    client.clone(),
140                    shutdown_tx.subscribe(),
141                ),
142            ));
143        }
144
145        if handles.is_empty() {
146            return Err(Error::Config(
147                "No endpoints were configured or enabled. The service has nothing to do."
148                    .to_string(),
149            ));
150        }
151
152        Ok(Service {
153            client,
154            handles,
155            shutdown_tx,
156            _transport: PhantomData,
157            _connection: PhantomData,
158        })
159    }
160
161    pub async fn rebind(&self) -> io::Result<()> {
162        self.client.rebind().await
163    }
164
165    pub async fn shutdown(self) {
166        let _ = self.shutdown_tx.send(());
167
168        for handle in self.handles {
169            if let Err(_err) = handle.await {
170                error!("A task failed to shut down cleanly: {:?}", _err);
171            }
172        }
173        warn!("Service shutdown complete");
174    }
175
176    fn spawn_endpoint(
177        _name: &'static str,
178        task: impl Future<Output = Result<()>> + Send + 'static,
179    ) -> JoinHandle<()> {
180        tokio::spawn(async move {
181            if let Err(_e) = task.await {
182                error!("The {_name} endpoint shut down due to an error: {_e}");
183            }
184        })
185    }
186
187    #[cfg(feature = "endpoint-http")]
188    async fn endpoint_http_accept_loop(
189        config: Arc<ServiceConfig>,
190        ombrac: Arc<Client<T, C>>,
191        mut shutdown_rx: broadcast::Receiver<()>,
192    ) -> Result<()> {
193        use crate::endpoint::http::Server as HttpServer;
194        let bind_addr = require_config!(config.endpoint.http, "endpoint.http")?;
195        let socket = tokio::net::TcpListener::bind(bind_addr).await?;
196
197        info!("Starting HTTP/HTTPS endpoint, listening on {bind_addr}");
198
199        HttpServer::new(ombrac)
200            .accept_loop(socket, async {
201                let _ = shutdown_rx.recv().await;
202            })
203            .await
204            .map_err(|e| Error::Endpoint(format!("HTTP server failed to run: {}", e)))
205    }
206
207    #[cfg(feature = "endpoint-socks")]
208    async fn endpoint_socks_accept_loop(
209        config: Arc<ServiceConfig>,
210        ombrac: Arc<Client<T, C>>,
211        mut shutdown_rx: broadcast::Receiver<()>,
212    ) -> Result<()> {
213        use crate::endpoint::socks::CommandHandler;
214        use socks_lib::v5::server::auth::NoAuthentication;
215        use socks_lib::v5::server::{Config as SocksConfig, Server as SocksServer};
216
217        let bind_addr = require_config!(config.endpoint.socks, "endpoint.socks")?;
218        let socket = tokio::net::TcpListener::bind(bind_addr).await?;
219
220        info!("Starting SOCKS5 endpoint, listening on {bind_addr}");
221
222        let socks_config = Arc::new(SocksConfig::new(
223            NoAuthentication,
224            CommandHandler::new(ombrac),
225        ));
226        SocksServer::run(socket, socks_config, async {
227            let _ = shutdown_rx.recv().await;
228        })
229        .await
230        .map_err(|e| Error::Endpoint(format!("SOCKS server failed to run: {}", e)))
231    }
232
233    #[cfg(feature = "endpoint-tun")]
234    async fn endpoint_tun_accept_loop(
235        config: Arc<ServiceConfig>,
236        ombrac: Arc<Client<T, C>>,
237        mut shutdown_rx: broadcast::Receiver<()>,
238    ) -> Result<()> {
239        use crate::endpoint::tun::{AsyncDevice, Tun, TunConfig};
240
241        let config = require_config!(config.endpoint.tun.as_ref(), "endpoint.tun")?;
242
243        let device = match config.tun_fd {
244            Some(fd) => {
245                #[cfg(not(windows))]
246                unsafe {
247                    AsyncDevice::from_fd(fd)?
248                }
249
250                #[cfg(windows)]
251                return Err(Error::Config(
252                    "'tun_fd' option is not supported on Windows.".to_string(),
253                ));
254            }
255            None => {
256                #[cfg(not(any(target_os = "android", target_os = "ios")))]
257                {
258                    let device = {
259                        use std::str::FromStr;
260
261                        let mut builder = tun_rs::DeviceBuilder::new();
262                        builder = builder.mtu(config.tun_mtu.unwrap_or(1500));
263
264                        if let Some(ip_str) = &config.tun_ipv4 {
265                            let ip = ipnet::Ipv4Net::from_str(ip_str).map_err(|e| {
266                                Error::Config(format!("Failed to parse IPv4 CIDR '{ip_str}': {e}"))
267                            })?;
268                            builder = builder.ipv4(ip.addr(), ip.netmask(), None);
269                        }
270
271                        if let Some(ip_str) = &config.tun_ipv6 {
272                            let ip = ipnet::Ipv6Net::from_str(ip_str).map_err(|e| {
273                                Error::Config(format!("Failed to parse IPv6 CIDR '{ip_str}': {e}"))
274                            })?;
275                            builder = builder.ipv6(ip.addr(), ip.netmask());
276                        }
277
278                        builder.build_async().map_err(|e| {
279                            Error::Endpoint(format!("Failed to build TUN device: {e}"))
280                        })?
281                    };
282
283                    info!(
284                        "Starting TUN endpoint, Name: {:?}, MTU: {:?}, IP: {:?}",
285                        device.name(),
286                        device.mtu(),
287                        device.addresses()
288                    );
289
290                    device
291                }
292
293                #[cfg(any(target_os = "android", target_os = "ios"))]
294                {
295                    return Err(Error::Config(
296                        "Creating a new TUN device is not supported on this platform. A pre-configured 'tun_fd' must be provided.".to_string()
297                    ));
298                }
299            }
300        };
301
302        let mut tun_config = TunConfig::default();
303        if let Some(value) = &config.fake_dns {
304            tun_config.fakedns_cidr = value.parse().map_err(|e| {
305                Error::Config(format!("Failed to parse fake_dns CIDR '{value}': {e}"))
306            })?;
307        };
308
309        let tun = Tun::new(tun_config.into(), ombrac);
310        let shutdown_signal = async {
311            let _ = shutdown_rx.recv().await;
312        };
313
314        tun.accept_loop(device, shutdown_signal)
315            .await
316            .map_err(|e| Error::Endpoint(format!("TUN device runtime error: {}", e)))
317    }
318}
319
320#[cfg(feature = "transport-quic")]
321async fn quic_client_from_config(config: &ServiceConfig) -> io::Result<QuicClient> {
322    let server = &config.server;
323    let transport_cfg = &config.transport;
324
325    let server_name = match &transport_cfg.server_name {
326        Some(value) => value.clone(),
327        None => {
328            let pos = server.rfind(':').ok_or_else(|| {
329                io::Error::new(
330                    io::ErrorKind::InvalidInput,
331                    format!("Invalid server address: {}", server),
332                )
333            })?;
334            server[..pos].to_string()
335        }
336    };
337
338    let addrs: Vec<_> = tokio::net::lookup_host(server).await?.collect();
339    let server_addr = addrs.into_iter().next().ok_or_else(|| {
340        io::Error::new(
341            io::ErrorKind::NotFound,
342            format!("Failed to resolve server address: '{}'", server),
343        )
344    })?;
345
346    let mut quic_config = QuicConfig::new(server_addr, server_name);
347
348    quic_config.enable_zero_rtt = transport_cfg.zero_rtt.unwrap_or(false);
349    if let Some(protocols) = &transport_cfg.alpn_protocols {
350        quic_config.alpn_protocols = protocols.iter().map(|p| p.to_vec()).collect();
351    }
352
353    match transport_cfg.tls_mode.unwrap_or(TlsMode::Tls) {
354        TlsMode::Tls => {
355            if let Some(ca) = &transport_cfg.ca_cert {
356                quic_config.root_ca_path = Some(ca.to_path_buf());
357            }
358        }
359        TlsMode::MTls => {
360            quic_config.root_ca_path = Some(transport_cfg.ca_cert.clone().ok_or_else(|| {
361                io::Error::new(
362                    io::ErrorKind::InvalidInput,
363                    "CA cert is required for mTLS mode",
364                )
365            })?);
366            let client_cert = transport_cfg.client_cert.clone().ok_or_else(|| {
367                io::Error::new(
368                    io::ErrorKind::InvalidInput,
369                    "Client cert is required for mTLS mode",
370                )
371            })?;
372            let client_key = transport_cfg.client_key.clone().ok_or_else(|| {
373                io::Error::new(
374                    io::ErrorKind::InvalidInput,
375                    "Client key is required for mTLS mode",
376                )
377            })?;
378            quic_config.client_cert_key_paths = Some((client_cert, client_key));
379        }
380        TlsMode::Insecure => {
381            quic_config.skip_server_verification = true;
382        }
383    }
384
385    let mut transport_config = QuicTransportConfig::default();
386    if let Some(timeout) = transport_cfg.idle_timeout {
387        transport_config.max_idle_timeout(Duration::from_millis(timeout))?;
388    }
389    if let Some(interval) = transport_cfg.keep_alive {
390        transport_config.keep_alive_period(Duration::from_millis(interval))?;
391    }
392    if let Some(max_streams) = transport_cfg.max_streams {
393        transport_config.max_open_bidirectional_streams(max_streams)?;
394    }
395    if let Some(congestion) = transport_cfg.congestion {
396        transport_config.congestion(congestion, transport_cfg.cwnd_init)?;
397    }
398    quic_config.transport_config(transport_config);
399
400    Ok(QuicClient::new(quic_config)?)
401}