overtls/
config.rs

1use crate::error::{Error, Result};
2use serde::{Deserialize, Serialize};
3use std::{
4    net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
5    path::PathBuf,
6};
7
8pub(crate) const TEST_TIMEOUT_SECS: u64 = 10;
9pub(crate) const DEFAULT_POOL_MAX_SIZE: usize = 50;
10
11#[derive(Clone, Serialize, Deserialize, Debug, Default, PartialEq, Eq)]
12pub struct Config {
13    #[serde(
14        rename(deserialize = "server_settings", serialize = "server_settings"),
15        skip_serializing_if = "Option::is_none"
16    )]
17    pub server: Option<Server>,
18    #[serde(
19        rename(deserialize = "client_settings", serialize = "client_settings"),
20        skip_serializing_if = "Option::is_none"
21    )]
22    pub client: Option<Client>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub remarks: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub method: Option<String>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub password: Option<String>,
29    pub tunnel_path: TunnelPath,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub test_timeout_secs: Option<u64>,
32    #[serde(skip)]
33    pub is_server: bool,
34}
35
36#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
37#[serde(untagged)]
38pub enum TunnelPath {
39    Single(String),
40    Multiple(Vec<String>),
41}
42
43impl std::fmt::Display for TunnelPath {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            TunnelPath::Single(s) => write!(f, "{s}"),
47            TunnelPath::Multiple(v) => {
48                let mut s = String::new();
49                for (i, item) in v.iter().enumerate() {
50                    if i > 0 {
51                        s.push(',');
52                    }
53                    s.push_str(item);
54                }
55                write!(f, "{s}")
56            }
57        }
58    }
59}
60
61impl Default for TunnelPath {
62    fn default() -> Self {
63        TunnelPath::Single("/tunnel/".to_string())
64    }
65}
66
67impl TunnelPath {
68    pub fn is_empty(&self) -> bool {
69        match self {
70            TunnelPath::Single(s) => s.is_empty(),
71            TunnelPath::Multiple(v) => v.is_empty(),
72        }
73    }
74
75    pub fn standardize(&mut self) {
76        if self.is_empty() {
77            *self = TunnelPath::default();
78        }
79        match self {
80            TunnelPath::Single(s) => {
81                *s = format!("/{}/", s.trim().trim_matches('/'));
82            }
83            TunnelPath::Multiple(v) => {
84                v.iter_mut().for_each(|s| {
85                    *s = s.trim().trim_matches('/').to_string();
86                    if !s.is_empty() {
87                        *s = format!("/{s}/");
88                    }
89                });
90                v.retain(|s| !s.is_empty());
91            }
92        }
93    }
94
95    pub fn extract(&self) -> Vec<&str> {
96        match self {
97            TunnelPath::Single(s) => vec![s],
98            TunnelPath::Multiple(v) => v.iter().map(|s| s.as_str()).collect(),
99        }
100    }
101}
102
103#[derive(Clone, Serialize, Deserialize, Debug, Default, PartialEq, Eq)]
104pub struct Server {
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub disable_tls: Option<bool>,
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub manage_clients: Option<ManageClients>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub certfile: Option<PathBuf>,
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub keyfile: Option<PathBuf>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub forward_addr: Option<String>,
115    pub listen_host: String,
116    pub listen_port: u16,
117}
118
119#[derive(Clone, Serialize, Deserialize, Debug, Default, PartialEq, Eq)]
120pub struct ManageClients {
121    pub enable: Option<bool>,
122    pub webapi_url: Option<String>,
123    pub webapi_token: Option<String>,
124    pub node_id: Option<usize>,
125    #[serde(rename(deserialize = "api_update_time", serialize = "api_update_time"))]
126    pub api_update_interval_secs: Option<u64>,
127}
128
129#[derive(Clone, Serialize, Deserialize, Debug, Default, Eq)]
130pub struct Client {
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub disable_tls: Option<bool>,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub client_id: Option<String>,
135    pub server_host: String,
136    pub server_port: u16,
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub server_domain: Option<String>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub cafile: Option<String>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub dangerous_mode: Option<bool>,
143    pub listen_host: String,
144    pub listen_port: u16,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub listen_user: Option<String>,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub listen_password: Option<String>,
149    #[serde(skip)]
150    pub cache_dns: bool,
151    #[serde(skip)]
152    server_ip_addr: Option<SocketAddr>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub pool_max_size: Option<usize>,
155}
156
157impl PartialEq for Client {
158    fn eq(&self, other: &Self) -> bool {
159        let dangerous_mode = self.dangerous_mode.unwrap_or(false);
160        let other_dangerous_mode = other.dangerous_mode.unwrap_or(false);
161        let cert_matches = if !dangerous_mode && !other_dangerous_mode {
162            self.certificate_content() == other.certificate_content()
163        } else {
164            true
165        };
166        self.server_host == other.server_host
167            && self.server_port == other.server_port
168            && self.server_domain == other.server_domain
169            && cert_matches // self.cafile == other.cafile
170            && self.dangerous_mode == other.dangerous_mode
171            && self.disable_tls == other.disable_tls
172            && self.client_id == other.client_id
173    }
174}
175
176impl Client {
177    pub fn certificate_content(&self) -> Option<String> {
178        self.cafile.as_ref().and_then(|cert| Self::_certificate_content(cert))
179    }
180
181    fn _certificate_content(cert: &str) -> Option<String> {
182        const BEGIN_CERT: &str = "-----BEGIN CERTIFICATE-----";
183        let checker = |s: &str| !s.is_empty() && s.starts_with(BEGIN_CERT) && s.len() > 100;
184        if PathBuf::from(cert).exists() {
185            std::fs::read_to_string(cert).ok().filter(|s| checker(s))
186        } else if checker(cert) {
187            Some(cert.to_string())
188        } else {
189            None
190        }
191    }
192
193    pub fn export_certificate<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
194        match self.certificate_content() {
195            Some(cert) => std::fs::write(path, cert).map_err(|e| e.into()),
196            None => Err(Error::from("certificate not exists")),
197        }
198    }
199
200    pub fn server_ip_addr(&self) -> Option<SocketAddr> {
201        self.server_ip_addr
202    }
203}
204
205impl Config {
206    pub fn certificate_content(&self) -> Option<String> {
207        self.client.as_ref().and_then(|c| c.certificate_content())
208    }
209
210    pub fn export_certificate<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
211        self.client.as_ref().ok_or(Error::from("no client"))?.export_certificate(path)
212    }
213
214    pub fn manage_clients(&self) -> bool {
215        let f = |s: &Server| {
216            let f2 = |c: &ManageClients| c.enable.unwrap_or(false);
217            s.manage_clients.as_ref().map(f2).unwrap_or(false)
218        };
219        self.server.as_ref().map(f).unwrap_or(false)
220    }
221
222    pub fn webapi_url(&self) -> Option<String> {
223        let f = |s: &Server| s.manage_clients.as_ref().map(|c| c.webapi_url.clone()).unwrap_or(None);
224        self.server.as_ref().map(f).unwrap_or(None)
225    }
226
227    pub fn webapi_token(&self) -> Option<String> {
228        let f = |s: &Server| {
229            let f2 = |c: &ManageClients| c.webapi_token.clone();
230            s.manage_clients.as_ref().map(f2).unwrap_or(None)
231        };
232        self.server.as_ref().map(f).unwrap_or(None)
233    }
234
235    pub fn node_id(&self) -> Option<usize> {
236        let f = |s: &Server| s.manage_clients.as_ref().map(|c| c.node_id).unwrap_or(None);
237        self.server.as_ref().map(f).unwrap_or(None)
238    }
239
240    pub fn api_update_interval_secs(&self) -> Option<u64> {
241        let f = |s: &Server| {
242            let f2 = |c: &ManageClients| c.api_update_interval_secs;
243            s.manage_clients.as_ref().map(f2).unwrap_or(None)
244        };
245        self.server.as_ref().map(f).unwrap_or(None)
246    }
247
248    pub fn exist_server(&self) -> bool {
249        self.server.is_some()
250    }
251
252    pub fn exist_client(&self) -> bool {
253        self.client.is_some()
254    }
255
256    pub fn forward_addr(&self) -> Option<String> {
257        if self.is_server {
258            let f = |s: &Server| s.forward_addr.clone();
259            let default = Some(format!("http://{}:80", Ipv4Addr::LOCALHOST));
260            self.server.as_ref().map(f).unwrap_or(default)
261        } else {
262            None
263        }
264    }
265
266    pub fn listen_addr(&self) -> Result<SocketAddr> {
267        let unspec = std::net::IpAddr::from(Ipv4Addr::UNSPECIFIED);
268        if self.is_server {
269            let f = |s: &Server| SocketAddr::new(s.listen_host.parse().unwrap_or(unspec), s.listen_port);
270            self.server.as_ref().map(f).ok_or_else(|| "Server listen address is not set".into())
271        } else {
272            let f = |c: &Client| SocketAddr::new(c.listen_host.parse().unwrap_or(unspec), c.listen_port);
273            self.client.as_ref().map(f).ok_or_else(|| "Client listen address is not set".into())
274        }
275    }
276
277    pub fn set_listen_addr(&mut self, addr: std::net::SocketAddr) {
278        if self.is_server {
279            if let Some(s) = &mut self.server {
280                s.listen_host = addr.ip().to_string();
281                s.listen_port = addr.port();
282            }
283        } else if let Some(c) = &mut self.client {
284            c.listen_host = addr.ip().to_string();
285            c.listen_port = addr.port();
286        }
287    }
288
289    pub fn disable_tls(&self) -> bool {
290        if self.is_server {
291            if let Some(s) = &self.server {
292                return s.disable_tls.unwrap_or(false);
293            }
294        } else if let Some(c) = &self.client {
295            return c.disable_tls.unwrap_or(false);
296        }
297        false
298    }
299
300    pub fn dangerous_mode(&self) -> bool {
301        if let Some(c) = &self.client {
302            return c.dangerous_mode.unwrap_or(false);
303        }
304        false
305    }
306
307    pub fn set_dangerous_mode(&mut self, dangerous_mode: bool) {
308        if let Some(c) = &mut self.client {
309            c.dangerous_mode = Some(dangerous_mode);
310        }
311    }
312
313    pub fn cache_dns(&self) -> bool {
314        self.client.as_ref().is_some_and(|c| c.cache_dns)
315    }
316
317    pub fn set_cache_dns(&mut self, cache_dns: bool) {
318        if let Some(c) = &mut self.client {
319            c.cache_dns = cache_dns;
320        }
321    }
322
323    pub fn check_correctness(&mut self, is_server: bool) -> Result<()> {
324        self.is_server = is_server;
325        if self.is_server {
326            if self.server.is_none() {
327                return Err("Configuration needs server settings".into());
328            }
329            self.client = None;
330        } else {
331            if self.client.is_none() {
332                return Err("Configuration needs client settings".into());
333            }
334            self.server = None;
335        }
336
337        if self.tunnel_path.is_empty() {
338            self.tunnel_path = TunnelPath::default();
339        } else {
340            self.tunnel_path.standardize();
341        }
342
343        if let Some(server) = &mut self.server {
344            if server.listen_host.is_empty() {
345                server.listen_host = Ipv4Addr::UNSPECIFIED.to_string();
346            }
347            if server.listen_port == 0 {
348                server.listen_port = 443;
349            }
350        }
351        if let Some(client) = &mut self.client {
352            let server_host = client.server_host.clone();
353            let server_host = match (server_host.is_empty(), client.server_domain.clone()) {
354                (true, Some(domain)) => match domain.is_empty() {
355                    true => return Err(Error::from("We need server host in client settings")),
356                    false => domain,
357                },
358                (true, None) => return Err(Error::from("We need server host in client settings")),
359                (false, _) => server_host,
360            };
361            if client.server_host.is_empty() {
362                client.server_host.clone_from(&server_host);
363            }
364            if client.server_domain.is_none() || client.server_domain.as_ref().unwrap_or(&"".to_string()).is_empty() {
365                client.server_domain = Some(server_host.clone());
366            }
367
368            if client.server_port == 0 {
369                client.server_port = 443;
370            }
371
372            if !self.is_server {
373                let mut addr = (server_host, client.server_port).to_socket_addrs()?;
374                let addr = addr.next().ok_or("address not available")?;
375                {
376                    let timeout = std::time::Duration::from_secs(self.test_timeout_secs.unwrap_or(TEST_TIMEOUT_SECS));
377                    crate::tcp_stream::std_create(addr, Some(timeout))?;
378                }
379                if client.listen_host.is_empty() {
380                    client.listen_host = if addr.is_ipv4() {
381                        Ipv4Addr::LOCALHOST.to_string()
382                    } else {
383                        Ipv6Addr::LOCALHOST.to_string()
384                    };
385                }
386                client.server_ip_addr = Some(addr);
387            }
388        }
389        Ok(())
390    }
391
392    /// load from overtls config file
393    pub fn from_config_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
394        let f = std::fs::File::open(&path)?;
395        let config: Config = serde_json::from_reader(f)?;
396        Ok(config)
397    }
398
399    /// load from JSON string
400    pub fn from_json_str(json: &str) -> Result<Self> {
401        let config: Config = serde_json::from_str(json)?;
402        Ok(config)
403    }
404
405    /// load from `ssr://...` style url
406    pub fn from_ssr_url(url: &str) -> Result<Self> {
407        let engine = base64easy::EngineKind::UrlSafeNoPad;
408        let url = url.trim_start_matches("ssr://");
409        let url = base64easy::decode(url, engine)?;
410        let url = String::from_utf8(url)?;
411        // split string by `/?`
412        let mut parts = url.split("/?");
413
414        // split first part by `:` and collect to vector
415        let mut parts0 = parts.next().ok_or("url is invalid")?.split(':').collect::<Vec<&str>>();
416        // check if parts length is less than 6
417        if parts0.len() < 6 {
418            return Err("url is invalid".into());
419        }
420        let host = parts0.remove(0);
421        let port = parts0.remove(0);
422        let protocol = parts0.remove(0);
423        let method = parts0.remove(0); // none is default
424        let obfs = parts0.remove(0);
425        let password = String::from_utf8(base64easy::decode(parts0.remove(0), engine)?)?;
426
427        if method != "none" {
428            return Err("method is not none".into());
429        }
430        if obfs != "plain" {
431            return Err("obfs is not plain".into());
432        }
433        if protocol != "origin" {
434            return Err("protocol is not origin".into());
435        }
436        let port = port.parse::<u16>()?;
437
438        // split second part by `&` and collect to vector
439        let parts1 = parts.next().ok_or("url is invalid")?.split('&').collect::<Vec<&str>>();
440        // for each element in parts1, split by `=` and collect to a hashmap
441        let mut map = std::collections::HashMap::new();
442        for part in parts1 {
443            let mut kv = part.split('=');
444            let k = kv.next().ok_or("url is invalid")?;
445            let v = kv.next().ok_or("url is invalid")?;
446            map.insert(k, v);
447        }
448
449        let ot_enable = map.get("ot_enable").map_or("0".to_string(), |r| r.to_string());
450        if ot_enable != "1" {
451            return Err("ot_enable is not 1".into());
452        }
453        let remarks = map.get("remarks").and_then(|r| match base64easy::decode(r, engine) {
454            Ok(decoded) => String::from_utf8(decoded).ok(),
455            Err(_) => None,
456        });
457        let ot_domain = map.get("ot_domain").and_then(|r| match base64easy::decode(r, engine) {
458            Ok(decoded) => String::from_utf8(decoded).ok(),
459            Err(_) => None,
460        });
461        let ot_path = map.get("ot_path").ok_or("ot_path is not set")?;
462        let ot_path = String::from_utf8(base64easy::decode(ot_path, engine)?)?;
463
464        let ot_cert = map
465            .get("ot_cert")
466            .and_then(|r| base64easy::decode(r, engine).ok())
467            .and_then(|decoded| String::from_utf8(decoded).ok())
468            .filter(|s| !s.is_empty());
469
470        let dangerous_mode = map.get("dangerous_mode").and_then(|r| r.parse::<bool>().ok());
471
472        let client = Client {
473            server_host: host.to_string(),
474            server_port: port,
475            server_domain: ot_domain,
476            cafile: ot_cert,
477            dangerous_mode,
478            ..Client::default()
479        };
480
481        let config = Config {
482            password: Some(password),
483            method: Some(method.to_string()),
484            remarks,
485            tunnel_path: TunnelPath::Single(ot_path),
486            client: Some(client),
487            ..Config::default()
488        };
489
490        Ok(config)
491    }
492
493    pub fn generate_ssr_url(&self) -> Result<String> {
494        let client = self.client.as_ref().ok_or(Error::from("client is not set"))?;
495        let engine = base64easy::EngineKind::UrlSafeNoPad;
496        let method = self.method.as_ref().map_or("none".to_string(), |m| m.clone());
497        let password = self.password.as_ref().map_or("password".to_string(), |p| p.clone());
498        let password = base64easy::encode(password.as_bytes(), engine);
499        let remarks = self.remarks.as_ref().map_or("remarks".to_string(), |r| r.clone());
500        let remarks = base64easy::encode(remarks.as_bytes(), engine);
501        let domain = client.server_domain.as_ref().map_or("".to_string(), |d| d.clone());
502        let domain = base64easy::encode(domain.as_bytes(), engine);
503        let err = "tunnel_path is not set";
504        let tunnel_path = base64easy::encode(self.tunnel_path.extract().first().ok_or(err)?.as_bytes(), engine);
505        let host = &client.server_host;
506        let port = client.server_port;
507
508        let mut url = format!("{host}:{port}:origin:{method}:plain:{password}/?remarks={remarks}&ot_enable=1");
509        url.push_str(&format!("&ot_domain={domain}&ot_path={tunnel_path}"));
510
511        let dangerous_mode = client.dangerous_mode.unwrap_or(false);
512        if !dangerous_mode && let Some(ref ca) = client.certificate_content() {
513            let ca = base64easy::encode(ca.as_bytes(), engine);
514            url.push_str(&format!("&ot_cert={ca}"));
515        }
516
517        if let Some(dangerous_mode) = client.dangerous_mode {
518            url.push_str(&format!("&dangerous_mode={dangerous_mode}"));
519        }
520
521        Ok(format!("ssr://{}", base64easy::encode(url.as_bytes(), engine)))
522    }
523}
524
525pub(crate) fn generate_ssr_url<P>(path: P) -> Result<String>
526where
527    P: AsRef<std::path::Path>,
528{
529    let config = Config::from_config_file(path)?;
530    if config.certificate_content().is_some() {
531        log::warn!("Certificate content exists!");
532    }
533    config.generate_ssr_url()
534}
535
536#[test]
537fn test_config() {
538    let client = Client {
539        server_host: "www.gov.cn".to_string(),
540        server_port: 443,
541        listen_host: "127.0.0.1".to_string(),
542        listen_port: 0,
543        // server_domain: Some("www.gov.cn".to_string()),
544        dangerous_mode: Some(false),
545        ..Client::default()
546    };
547
548    let mut config = Config {
549        remarks: Some("test".to_string()),
550        method: Some("none".to_string()),
551        password: Some("password".to_string()),
552        tunnel_path: TunnelPath::Single("/tunnel/".to_string()),
553        client: Some(client),
554        ..Config::default()
555    };
556
557    config.check_correctness(false).unwrap();
558
559    let qrcode = config.generate_ssr_url().unwrap();
560    println!("{qrcode:?}");
561
562    let config2 = Config::from_ssr_url(&qrcode).unwrap();
563    println!("{config2:?}");
564
565    assert_eq!(config, config2);
566}