1use crate::error::SqlError;
13use base64::Engine;
14use secrecy::{ExposeSecret, SecretString};
15use std::env;
16
17#[derive(Debug, Clone)]
19pub struct ProxyConfig {
20 pub url: String,
22 pub host: String,
24 pub port: u16,
26 pub username: Option<String>,
28 pub password: Option<SecretString>,
30}
31
32impl ProxyConfig {
33 pub fn parse(url: &str) -> Result<Self, SqlError> {
36 let parsed =
37 ::url::Url::parse(url).map_err(|e| SqlError::InvalidUrl(format!("proxy URL: {e}")))?;
38
39 let host = parsed
40 .host_str()
41 .ok_or_else(|| SqlError::InvalidUrl("proxy URL has no host".to_string()))?
42 .to_string();
43 let port = parsed.port().unwrap_or(8080);
44
45 let (username, password) = if let Some(info) = parsed.password() {
46 (
47 Some(parsed.username().to_string()),
48 Some(SecretString::new(info.to_string().into())),
49 )
50 } else {
51 (None, None)
52 };
53
54 Ok(ProxyConfig {
55 url: url.to_string(),
56 host,
57 port,
58 username,
59 password,
60 })
61 }
62}
63
64pub fn is_no_proxy(target_host: &str) -> bool {
72 let no_proxy = match env::var("NO_PROXY") {
73 Ok(v) if !v.is_empty() => v,
74 _ => return false,
75 };
76
77 let target_host = target_host.to_ascii_lowercase();
78
79 for pattern in no_proxy.split(',') {
80 let pattern = pattern.trim().to_ascii_lowercase();
81 if pattern.is_empty() {
82 continue;
83 }
84
85 if pattern == "*" {
87 return true;
88 }
89
90 let pattern_host = pattern.split(':').next().unwrap_or(&pattern);
92
93 if pattern_host.starts_with('.') {
95 if target_host.ends_with(pattern_host) {
96 return true;
97 }
98 }
99 else if target_host == pattern_host || target_host.ends_with(&format!(".{pattern_host}"))
101 {
102 return true;
103 }
104 }
105
106 false
107}
108
109pub fn resolve_proxy_from_env(_target_scheme: &str) -> Option<ProxyConfig> {
117 let try_env = |name: &str| -> Option<ProxyConfig> {
118 env::var(name)
119 .ok()
120 .filter(|s| !s.is_empty())
121 .and_then(|url| ProxyConfig::parse(&url).ok())
122 };
123
124 try_env("ALL_PROXY").or_else(|| {
125 try_env("HTTPS_PROXY").or_else(|| try_env("HTTP_PROXY"))
130 })
131}
132
133pub(crate) async fn http_connect(
136 proxy: &ProxyConfig,
137 target_host: &str,
138 target_port: u16,
139) -> Result<tokio::net::TcpStream, SqlError> {
140 use tokio::io::{AsyncReadExt, AsyncWriteExt};
141
142 let mut stream = tokio::net::TcpStream::connect((proxy.host.as_str(), proxy.port))
143 .await
144 .map_err(|e| {
145 SqlError::ConnectionFailed(format!(
146 "proxy connect to {}:{}: {e}",
147 proxy.host, proxy.port
148 ))
149 })?;
150
151 let mut request = format!(
152 "CONNECT {target_host}:{target_port} HTTP/1.1\r\n\
153 Host: {target_host}:{target_port}\r\n"
154 );
155
156 if let (Some(u), Some(p)) = (&proxy.username, &proxy.password) {
157 let creds = format!("{}:{}", u, p.expose_secret());
158 let encoded = base64::prelude::BASE64_STANDARD.encode(creds);
159 request.push_str(&format!("Proxy-Authorization: Basic {encoded}\r\n"));
160 }
161
162 request.push_str("\r\n");
163
164 stream
165 .write_all(request.as_bytes())
166 .await
167 .map_err(|e| SqlError::ConnectionFailed(format!("proxy write: {e}")))?;
168
169 let mut buf = [0u8; 1024];
172 let n = stream
173 .read(&mut buf)
174 .await
175 .map_err(|e| SqlError::ConnectionFailed(format!("proxy read: {e}")))?;
176
177 let response = std::str::from_utf8(&buf[..n])
178 .map_err(|_| SqlError::ConnectionFailed("proxy returned non-UTF-8 response".to_string()))?;
179
180 let status_line = response.lines().next().unwrap_or("").trim();
181 if !status_line.starts_with("HTTP/1.1 200") && !status_line.starts_with("HTTP/1.0 200") {
182 return Err(SqlError::ConnectionFailed(format!(
183 "proxy error: {status_line}"
184 )));
185 }
186
187 Ok(stream)
188}
189
190pub struct ProxiedConnection {
197 pub(crate) inner: Box<dyn crate::connection::AsyncConnection>,
198 pub(crate) forwarder: Option<tokio::task::JoinHandle<()>>,
199}
200
201#[async_trait::async_trait]
202impl crate::connection::AsyncConnection for ProxiedConnection {
203 async fn execute(&mut self, sql: &str) -> Result<crate::ExecutionSummary, crate::SqlError> {
204 self.inner.execute(sql).await
205 }
206
207 async fn query(&mut self, sql: &str) -> Result<crate::QueryResult, crate::SqlError> {
208 self.inner.query(sql).await
209 }
210
211 async fn query_stream(
214 &mut self,
215 sql: &str,
216 ) -> Result<(Vec<crate::ColumnInfo>, crate::BoxRowStream<'_>), crate::SqlError> {
217 self.inner.query_stream(sql).await
218 }
219
220 async fn execute_multi(
221 &mut self,
222 sql: &str,
223 ) -> Result<Vec<crate::StatementResult>, crate::SqlError> {
224 self.inner.execute_multi(sql).await
225 }
226
227 async fn ping(&mut self) -> Result<(), crate::SqlError> {
228 self.inner.ping().await
229 }
230
231 async fn list_tables(&mut self, schema: Option<&str>) -> Result<Vec<String>, crate::SqlError> {
232 self.inner.list_tables(schema).await
233 }
234
235 async fn list_schemas(
236 &mut self,
237 ) -> Result<Vec<crate::connection::SchemaInfo>, crate::SqlError> {
238 self.inner.list_schemas().await
239 }
240
241 async fn describe_table(
242 &mut self,
243 schema: Option<&str>,
244 table: &str,
245 ) -> Result<crate::QueryResult, crate::SqlError> {
246 self.inner.describe_table(schema, table).await
247 }
248
249 async fn primary_key(
250 &mut self,
251 schema: Option<&str>,
252 table: &str,
253 ) -> Result<Vec<String>, crate::SqlError> {
254 self.inner.primary_key(schema, table).await
255 }
256
257 async fn list_foreign_keys(
258 &mut self,
259 schema: Option<&str>,
260 ) -> Result<Vec<crate::ForeignKey>, crate::SqlError> {
261 self.inner.list_foreign_keys(schema).await
262 }
263
264 async fn bulk_insert_rows(
265 &mut self,
266 target: crate::connection::BulkInsert<'_>,
267 ) -> Result<usize, crate::SqlError> {
268 self.inner.bulk_insert_rows(target).await
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use std::sync::Mutex;
276
277 static ENV_GUARD: Mutex<()> = Mutex::new(());
285
286 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
289 ENV_GUARD.lock().unwrap_or_else(|p| p.into_inner())
290 }
291
292 #[test]
293 fn test_parse_simple() {
294 let cfg = ProxyConfig::parse("http://proxy:8080").unwrap();
295 assert_eq!(cfg.host, "proxy");
296 assert_eq!(cfg.port, 8080);
297 assert_eq!(cfg.username, None);
298 assert!(cfg.password.is_none());
299 }
300
301 #[test]
302 fn test_parse_with_auth() {
303 let cfg = ProxyConfig::parse("http://user:pass@proxy:3128").unwrap();
304 assert_eq!(cfg.host, "proxy");
305 assert_eq!(cfg.port, 3128);
306 assert_eq!(cfg.username, Some("user".to_string()));
307 assert_eq!(cfg.password.as_ref().unwrap().expose_secret(), "pass");
308 }
309
310 #[test]
311 fn test_parse_no_port_uses_8080() {
312 let cfg = ProxyConfig::parse("http://proxy").unwrap();
313 assert_eq!(cfg.port, 8080);
314 }
315
316 #[test]
317 fn test_is_no_proxy_star() {
318 let _guard = env_lock();
319 unsafe {
324 std::env::set_var("NO_PROXY", "*");
325 assert!(is_no_proxy("anything"));
326 std::env::remove_var("NO_PROXY");
327 }
328 }
329
330 #[test]
331 fn test_is_no_proxy_exact() {
332 let _guard = env_lock();
333 unsafe {
336 std::env::set_var("NO_PROXY", "localhost");
337 assert!(is_no_proxy("localhost"));
338 assert!(!is_no_proxy("otherhost"));
339 std::env::remove_var("NO_PROXY");
340 }
341 }
342
343 #[test]
344 fn test_is_no_proxy_suffix() {
345 let _guard = env_lock();
346 unsafe {
349 std::env::set_var("NO_PROXY", ".example.com");
350 assert!(is_no_proxy("db.example.com"));
351 assert!(!is_no_proxy("example.com"));
352 std::env::remove_var("NO_PROXY");
353 }
354 }
355
356 #[test]
357 fn test_resolve_proxy_from_env_empty() {
358 let _guard = env_lock();
359 assert!(resolve_proxy_from_env("postgres").is_none());
361 }
362}