Skip to main content

rat_quickdns/
upstream_handler.rs

1//! 上游服务器处理器模块
2//! 
3//! 基于handler模式的上游服务器管理,避免强制类型转换,提供最优性能
4
5use crate::{
6    transport::{Transport, TransportConfig, HttpsConfig, TlsConfig},
7    utils::{parse_server_address, parse_url_components, get_user_agent},
8    Result, DnsError,
9    dns_info, dns_debug,
10};
11use std::{
12    collections::HashMap,
13    time::Duration,
14};
15use async_trait::async_trait;
16
17/// 上游服务器类型
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
19pub enum UpstreamType {
20    /// UDP传输
21    Udp,
22    /// TCP传输
23    Tcp,
24    /// DNS over TLS
25    DoT,
26    /// DNS over HTTPS
27    DoH,
28}
29
30/// 上游服务器配置(字符串存储)
31#[derive(Debug, Clone)]
32pub struct UpstreamSpec {
33    /// 服务器名称
34    pub name: String,
35    /// 传输类型
36    pub transport_type: UpstreamType,
37    /// 服务器地址(统一字段:IP/域名/URL)
38    pub server: String,
39    /// 预解析的IP地址(可选,避免运行时解析)
40    pub resolved_ip: Option<String>,
41    /// 权重
42    pub weight: u32,
43    /// 期望区域
44    pub region: Option<String>,
45}
46
47/// 上游处理器trait
48#[async_trait]
49pub trait UpstreamHandler: Send + Sync + std::fmt::Debug {
50    /// 处理器类型
51    fn handler_type(&self) -> UpstreamType;
52    
53    /// 从规格创建传输实例
54    async fn create_transport(&self, spec: &UpstreamSpec) -> Result<Box<dyn Transport>>;
55    
56    /// 验证规格是否有效
57    fn validate_spec(&self, spec: &UpstreamSpec) -> Result<()>;
58    
59    /// 获取默认端口
60    fn default_port(&self) -> u16;
61}
62
63/// UDP处理器
64#[derive(Debug, Default)]
65pub struct UdpHandler;
66
67#[async_trait]
68impl UpstreamHandler for UdpHandler {
69    fn handler_type(&self) -> UpstreamType {
70        UpstreamType::Udp
71    }
72    
73    async fn create_transport(&self, spec: &UpstreamSpec) -> Result<Box<dyn Transport>> {
74        let (server, port) = parse_server_address(&spec.server, self.default_port())?;
75        
76        // 优先使用预解析的IP,避免运行时DNS查询
77        let actual_server = spec.resolved_ip.as_ref().unwrap_or(&server);
78        
79        let config = TransportConfig {
80            server: actual_server.clone(),
81            port,
82            timeout: Duration::from_secs(5),
83            tcp_fast_open: false,
84            tcp_nodelay: true,
85            pool_size: 10,
86        };
87        
88        Ok(Box::new(crate::transport::UdpTransport::new(config)))
89    }
90    
91    fn validate_spec(&self, spec: &UpstreamSpec) -> Result<()> {
92        if spec.server.is_empty() {
93            return Err(DnsError::InvalidConfig("UDP server cannot be empty".to_string()));
94        }
95        Ok(())
96    }
97    
98    fn default_port(&self) -> u16 {
99        53
100    }
101}
102
103/// TCP处理器
104#[derive(Debug, Default)]
105pub struct TcpHandler;
106
107#[async_trait]
108impl UpstreamHandler for TcpHandler {
109    fn handler_type(&self) -> UpstreamType {
110        UpstreamType::Tcp
111    }
112    
113    async fn create_transport(&self, spec: &UpstreamSpec) -> Result<Box<dyn Transport>> {
114        let (server, port) = parse_server_address(&spec.server, self.default_port())?;
115        
116        // 优先使用预解析的IP,避免运行时DNS查询
117        let actual_server = spec.resolved_ip.as_ref().unwrap_or(&server);
118        
119        let config = TransportConfig {
120            server: actual_server.clone(),
121            port,
122            timeout: Duration::from_secs(5),
123            tcp_fast_open: false,
124            tcp_nodelay: true,
125            pool_size: 10,
126        };
127        
128        Ok(Box::new(crate::transport::TcpTransport::new(config)))
129    }
130    
131    fn validate_spec(&self, spec: &UpstreamSpec) -> Result<()> {
132        if spec.server.is_empty() {
133            return Err(DnsError::InvalidConfig("TCP server cannot be empty".to_string()));
134        }
135        Ok(())
136    }
137    
138    fn default_port(&self) -> u16 {
139        53
140    }
141}
142
143/// DoT处理器
144#[derive(Debug, Default)]
145pub struct DoTHandler;
146
147#[async_trait]
148impl UpstreamHandler for DoTHandler {
149    fn handler_type(&self) -> UpstreamType {
150        UpstreamType::DoT
151    }
152    
153    async fn create_transport(&self, spec: &UpstreamSpec) -> Result<Box<dyn Transport>> {
154        let (server, port) = parse_server_address(&spec.server, self.default_port())?;
155        
156        // 对于DoT,连接地址优先使用预解析IP,但SNI必须使用原始域名
157        let connection_server = spec.resolved_ip.as_ref().unwrap_or(&server);
158        let sni_name = server.clone(); // SNI使用原始域名,确保证书验证正确
159            
160        let config = TlsConfig {
161            base: TransportConfig {
162                server: connection_server.clone(),
163                port,
164                timeout: Duration::from_secs(10),
165                tcp_fast_open: false,
166                tcp_nodelay: true,
167                pool_size: 5,
168            },
169            server_name: sni_name,
170            verify_cert: true,
171        };
172        
173        Ok(Box::new(crate::transport::TlsTransport::new(config)?))
174    }
175    
176    fn validate_spec(&self, spec: &UpstreamSpec) -> Result<()> {
177        if spec.server.is_empty() {
178            return Err(DnsError::InvalidConfig("DoT server cannot be empty".to_string()));
179        }
180        Ok(())
181    }
182    
183    fn default_port(&self) -> u16 {
184        853
185    }
186}
187
188/// DoH处理器
189#[derive(Debug, Default)]
190pub struct DoHHandler;
191
192#[async_trait]
193impl UpstreamHandler for DoHHandler {
194    fn handler_type(&self) -> UpstreamType {
195        UpstreamType::DoH
196    }
197    
198    async fn create_transport(&self, spec: &UpstreamSpec) -> Result<Box<dyn Transport>> {
199        // 对于DoH,server字段应该是完整的HTTPS URL
200        let url = &spec.server;
201        
202        // 从URL中提取主机名和端口
203        let (hostname, port) = parse_url_components(url)?;
204        
205        // 连接地址优先使用预解析IP,但SNI必须使用原始域名
206        let connection_server = spec.resolved_ip.as_ref().unwrap_or(&hostname);
207        
208        let config = HttpsConfig {
209            base: TransportConfig {
210                server: connection_server.clone(),
211                port,
212                timeout: Duration::from_secs(10),
213                tcp_fast_open: false,
214                tcp_nodelay: true,
215                pool_size: 5,
216            },
217            url: url.clone(),
218            method: crate::transport::HttpMethod::POST,
219            user_agent: get_user_agent(),
220        };
221        
222        Ok(Box::new(crate::transport::HttpsTransport::new(config)?))
223    }
224    
225    fn validate_spec(&self, spec: &UpstreamSpec) -> Result<()> {
226        if spec.server.is_empty() {
227            return Err(DnsError::InvalidConfig("DoH server cannot be empty".to_string()));
228        }
229        
230        // 验证URL格式
231        if !spec.server.starts_with("https://") {
232            return Err(DnsError::InvalidConfig("DoH URL must use HTTPS".to_string()));
233        }
234        
235        Ok(())
236    }
237    
238    fn default_port(&self) -> u16 {
239        443
240    }
241}
242
243/// 上游管理器
244#[derive(Debug)]
245pub struct UpstreamManager {
246    handlers: HashMap<UpstreamType, Box<dyn UpstreamHandler>>,
247    specs: Vec<UpstreamSpec>,
248}
249
250impl Clone for UpstreamManager {
251    fn clone(&self) -> Self {
252        let mut new_manager = Self::default();
253        new_manager.specs = self.specs.clone();
254        new_manager
255    }
256}
257
258impl Default for UpstreamManager {
259    fn default() -> Self {
260        let mut handlers: HashMap<UpstreamType, Box<dyn UpstreamHandler>> = HashMap::new();
261        handlers.insert(UpstreamType::Udp, Box::new(UdpHandler));
262        handlers.insert(UpstreamType::Tcp, Box::new(TcpHandler));
263        handlers.insert(UpstreamType::DoT, Box::new(DoTHandler));
264        handlers.insert(UpstreamType::DoH, Box::new(DoHHandler));
265        
266        Self {
267            handlers,
268            specs: Vec::new(),
269        }
270    }
271}
272
273impl UpstreamManager {
274    /// 创建新的管理器
275    pub fn new() -> Self {
276        Self::default()
277    }
278    
279    /// 添加上游服务器
280    pub fn add_upstream(&mut self, spec: UpstreamSpec) -> Result<()> {
281        dns_info!("Adding upstream server: {} ({:?}) -> {}", spec.name, spec.transport_type, spec.server);
282        
283        // 验证规格
284        if let Some(handler) = self.handlers.get(&spec.transport_type) {
285            handler.validate_spec(&spec)?;
286        } else {
287            return Err(DnsError::InvalidConfig(
288                format!("Unsupported transport type: {:?}", spec.transport_type)
289            ));
290        }
291        
292        self.specs.push(spec);
293        dns_debug!("Successfully added upstream server, total count: {}", self.specs.len());
294        Ok(())
295    }
296    
297    /// 创建传输实例
298    pub async fn create_transport(&self, spec: &UpstreamSpec) -> Result<Box<dyn Transport>> {
299        if let Some(handler) = self.handlers.get(&spec.transport_type) {
300            handler.create_transport(spec).await
301        } else {
302            Err(DnsError::InvalidConfig(
303                format!("No handler for transport type: {:?}", spec.transport_type)
304            ))
305        }
306    }
307    
308    /// 获取所有上游规格
309    pub fn get_specs(&self) -> &[UpstreamSpec] {
310        &self.specs
311    }
312    
313    /// 按类型筛选上游
314    pub fn filter_by_type(&self, transport_type: UpstreamType) -> Vec<&UpstreamSpec> {
315        self.specs.iter()
316            .filter(|spec| spec.transport_type == transport_type)
317            .collect()
318    }
319}
320
321// 解析函数已移至 crate::utils 模块,避免代码重复
322
323/// 构建器辅助函数
324impl UpstreamSpec {
325    /// 创建UDP上游配置
326    pub fn udp(name: String, server: String) -> Self {
327        Self {
328            name,
329            transport_type: UpstreamType::Udp,
330            server,
331            resolved_ip: None,
332            weight: 1,
333            region: None,
334        }
335    }
336    
337    /// 创建TCP上游配置
338    pub fn tcp(name: String, server: String) -> Self {
339        Self {
340            name,
341            transport_type: UpstreamType::Tcp,
342            server,
343            resolved_ip: None,
344            weight: 1,
345            region: None,
346        }
347    }
348    
349    /// 创建DoT上游配置
350    pub fn dot(name: String, server: String) -> Self {
351        Self {
352            name,
353            transport_type: UpstreamType::DoT,
354            server,
355            resolved_ip: None,
356            weight: 1,
357            region: None,
358        }
359    }
360    
361    /// 创建DoH上游配置
362    pub fn doh(name: String, url: String) -> Self {
363        Self {
364            name,
365            transport_type: UpstreamType::DoH,
366            server: url,
367            resolved_ip: None,
368            weight: 1,
369            region: None,
370        }
371    }
372    
373    /// 设置预解析的IP地址
374    pub fn with_resolved_ip(mut self, ip: String) -> Self {
375        self.resolved_ip = Some(ip);
376        self
377    }
378    
379    /// 设置权重
380    pub fn with_weight(mut self, weight: u32) -> Self {
381        self.weight = weight;
382        self
383    }
384    
385    /// 设置区域
386    pub fn with_region(mut self, region: String) -> Self {
387        self.region = Some(region);
388        self
389    }
390}