1#[cfg(unix)]
16use crate::protocols::l4::ext::connect_uds;
17use crate::protocols::l4::ext::{
18 connect_with as tcp_connect, set_dscp, set_recv_buf, set_tcp_fastopen_connect,
19};
20use crate::protocols::l4::socket::SocketAddr;
21use crate::protocols::l4::stream::Stream;
22use crate::protocols::{GetSocketDigest, SocketDigest};
23use crate::upstreams::peer::Peer;
24use async_trait::async_trait;
25use log::debug;
26use pingora_error::{Context, Error, ErrorType::*, OrErr, Result};
27use rand::seq::SliceRandom;
28use std::net::SocketAddr as InetSocketAddr;
29#[cfg(unix)]
30use std::os::unix::io::AsRawFd;
31#[cfg(windows)]
32use std::os::windows::io::AsRawSocket;
33
34#[async_trait]
36pub trait Connect: std::fmt::Debug {
37 async fn connect(&self, addr: &SocketAddr) -> Result<Stream>;
38}
39
40#[derive(Clone, Debug, Default)]
42pub struct BindTo {
43 pub addr: Option<InetSocketAddr>,
45 port_range: Option<(u16, u16)>,
47 fallback: bool,
49}
50
51impl BindTo {
52 pub fn set_port_range(&mut self, range: Option<(u16, u16)>) -> Result<()> {
58 if range.is_none() && self.port_range.is_none() {
59 return Ok(());
61 }
62
63 match range {
64 None | Some((0, 0)) => self.port_range = Some((0, 0)),
66 Some((low, high)) if low > 0 && low < high => {
68 self.port_range = Some((low, high));
69 }
70 _ => return Error::e_explain(SocketError, "invalid port range: {range}"),
71 }
72 Ok(())
73 }
74
75 pub fn set_fallback(&mut self, fallback: bool) {
77 self.fallback = fallback
78 }
79
80 pub fn port_range(&self) -> Option<(u16, u16)> {
82 self.port_range
83 }
84
85 pub fn will_fallback(&self) -> bool {
87 self.fallback && self.port_range.is_some()
88 }
89}
90
91pub(crate) async fn connect<P>(peer: &P, bind_to: Option<BindTo>) -> Result<Stream>
93where
94 P: Peer + Send + Sync,
95{
96 if peer.get_proxy().is_some() {
97 return proxy_connect(peer)
98 .await
99 .err_context(|| format!("Fail to establish CONNECT proxy: {}", peer));
100 }
101 let peer_addr = peer.address();
102 let mut stream: Stream =
103 if let Some(custom_l4) = peer.get_peer_options().and_then(|o| o.custom_l4.as_ref()) {
104 custom_l4.connect(peer_addr).await?
105 } else {
106 match peer_addr {
107 SocketAddr::Inet(addr) => {
108 let connect_future = tcp_connect(addr, bind_to.as_ref(), |socket| {
109 #[cfg(unix)]
110 let raw = socket.as_raw_fd();
111 #[cfg(windows)]
112 let raw = socket.as_raw_socket();
113
114 if peer.tcp_fast_open() {
115 set_tcp_fastopen_connect(raw)?;
116 }
117 if let Some(recv_buf) = peer.tcp_recv_buf() {
118 debug!("Setting recv buf size");
119 set_recv_buf(raw, recv_buf)?;
120 }
121 if let Some(dscp) = peer.dscp() {
122 debug!("Setting dscp");
123 set_dscp(raw, dscp)?;
124 }
125
126 if let Some(tweak_hook) = peer
127 .get_peer_options()
128 .and_then(|o| o.upstream_tcp_sock_tweak_hook.clone())
129 {
130 tweak_hook(socket)?;
131 }
132
133 Ok(())
134 });
135 let conn_res = match peer.connection_timeout() {
136 Some(t) => pingora_timeout::timeout(t, connect_future)
137 .await
138 .explain_err(ConnectTimedout, |_| {
139 format!("timeout {t:?} connecting to server {peer}")
140 })?,
141 None => connect_future.await,
142 };
143 match conn_res {
144 Ok(socket) => {
145 debug!("connected to new server: {}", peer.address());
146 Ok(socket.into())
147 }
148 Err(e) => {
149 let c = format!("Fail to connect to {peer}");
150 match e.etype() {
151 SocketError | BindError => Error::e_because(InternalError, c, e),
152 _ => Err(e.more_context(c)),
153 }
154 }
155 }
156 }
157 #[cfg(unix)]
158 SocketAddr::Unix(addr) => {
159 let connect_future = connect_uds(
160 addr.as_pathname()
161 .expect("non-pathname unix sockets not supported as peer"),
162 );
163 let conn_res = match peer.connection_timeout() {
164 Some(t) => pingora_timeout::timeout(t, connect_future)
165 .await
166 .explain_err(ConnectTimedout, |_| {
167 format!("timeout {t:?} connecting to server {peer}")
168 })?,
169 None => connect_future.await,
170 };
171 match conn_res {
172 Ok(socket) => {
173 debug!("connected to new server: {}", peer.address());
174 Ok(socket.into())
175 }
176 Err(e) => {
177 let c = format!("Fail to connect to {peer}");
178 match e.etype() {
179 SocketError | BindError => Error::e_because(InternalError, c, e),
180 _ => Err(e.more_context(c)),
181 }
182 }
183 }
184 }
185 }?
186 };
187
188 let tracer = peer.get_tracer();
189 if let Some(t) = tracer {
190 t.0.on_connected();
191 stream.tracer = Some(t);
192 }
193
194 if let Some(ka) = peer.tcp_keepalive() {
196 stream.set_keepalive(ka)?;
197 }
198 stream.set_nodelay()?;
199
200 #[cfg(unix)]
201 let digest = SocketDigest::from_raw_fd(stream.as_raw_fd());
202 #[cfg(windows)]
203 let digest = SocketDigest::from_raw_socket(stream.as_raw_socket());
204 digest
205 .peer_addr
206 .set(Some(peer_addr.clone()))
207 .expect("newly created OnceCell must be empty");
208 stream.set_socket_digest(digest);
209
210 Ok(stream)
211}
212
213pub(crate) fn bind_to_random<P: Peer>(
214 peer: &P,
215 v4_list: &[InetSocketAddr],
216 v6_list: &[InetSocketAddr],
217) -> Option<BindTo> {
218 fn bind_to_ips(ips: &[InetSocketAddr]) -> Option<InetSocketAddr> {
220 match ips.len() {
221 0 => None,
222 1 => Some(ips[0]),
223 _ => {
224 ips.choose(&mut rand::thread_rng()).copied()
226 }
227 }
228 }
229
230 let mut bind_to = peer.get_peer_options().and_then(|o| o.bind_to.clone());
231 if bind_to.as_ref().map(|b| b.addr).is_some() {
232 return bind_to;
234 }
235
236 let addr = match peer.address() {
237 SocketAddr::Inet(sockaddr) => match sockaddr {
238 InetSocketAddr::V4(_) => bind_to_ips(v4_list),
239 InetSocketAddr::V6(_) => bind_to_ips(v6_list),
240 },
241 #[cfg(unix)]
242 SocketAddr::Unix(_) => None,
243 };
244
245 if addr.is_some() {
246 if let Some(bind_to) = bind_to.as_mut() {
247 bind_to.addr = addr;
248 } else {
249 bind_to = Some(BindTo {
250 addr,
251 ..Default::default()
252 });
253 }
254 }
255 bind_to
256}
257
258use crate::protocols::raw_connect;
259
260#[cfg(unix)]
261async fn proxy_connect<P: Peer>(peer: &P) -> Result<Stream> {
262 let proxy = peer.get_proxy().unwrap();
264 let options = peer.get_peer_options().unwrap();
265
266 let mut headers = proxy
268 .headers
269 .iter()
270 .chain(options.extra_proxy_headers.iter());
271
272 let stream: Box<Stream> = Box::new(
274 connect_uds(&proxy.next_hop)
275 .await
276 .or_err_with(ConnectError, || {
277 format!("CONNECT proxy connect() error to {:?}", &proxy.next_hop)
278 })?
279 .into(),
280 );
281
282 let req_header = raw_connect::generate_connect_header(&proxy.host, proxy.port, &mut headers)?;
283 let fut = raw_connect::connect(stream, &req_header);
284 let (mut stream, digest) = match peer.connection_timeout() {
285 Some(t) => pingora_timeout::timeout(t, fut)
286 .await
287 .explain_err(ConnectTimedout, |_| "establishing CONNECT proxy")?,
288 None => fut.await,
289 }
290 .map_err(|mut e| {
291 e.retry.decide_reuse(false);
293 e
294 })?;
295 debug!("CONNECT proxy established: {:?}", proxy);
296 stream.set_proxy_digest(digest);
297 let stream = stream.into_any().downcast::<Stream>().unwrap(); Ok(*stream)
299}
300
301#[cfg(windows)]
302async fn proxy_connect<P: Peer>(peer: &P) -> Result<Stream> {
303 panic!("peer proxy not supported on windows")
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::upstreams::peer::{BasicPeer, HttpPeer, Proxy};
310 use pingora_error::ErrorType;
311 use std::collections::BTreeMap;
312 use std::path::PathBuf;
313 use std::sync::atomic::{AtomicBool, Ordering};
314 use std::sync::Arc;
315 use std::time::{Duration, Instant};
316 use tokio::io::AsyncWriteExt;
317 #[cfg(unix)]
318 use tokio::net::UnixListener;
319 use tokio::time::sleep;
320
321 #[cfg(unix)]
327 async fn wait_for_peer<P>(peer: &P)
328 where
329 P: Peer + Send + Sync,
330 {
331 use ErrorType as E;
332 let start = Instant::now();
333 let mut res = connect(peer, None).await;
334 let mut delay = Duration::from_millis(5);
335 let max_delay = Duration::from_secs(10);
336
337 while start.elapsed() < max_delay {
338 match &res {
339 Err(e) if e.etype == E::ConnectRefused => {}
340 _ => break,
341 }
342 sleep(delay).await;
343 delay *= 2;
344 res = connect(peer, None).await;
345 }
346 }
347
348 #[tokio::test]
349 async fn test_conn_error_refused() {
350 let peer = BasicPeer::new("127.0.0.1:79"); let new_session = connect(&peer, None).await;
352 assert_eq!(new_session.unwrap_err().etype(), &ConnectRefused)
353 }
354
355 #[ignore]
357 #[tokio::test]
358 async fn test_conn_error_no_route() {
359 let peer = BasicPeer::new("[::3]:79"); let new_session = connect(&peer, None).await;
361 assert_eq!(new_session.unwrap_err().etype(), &ConnectNoRoute)
362 }
363
364 #[tokio::test]
365 async fn test_conn_error_addr_not_avail() {
366 let peer = HttpPeer::new("127.0.0.1:121".to_string(), false, "".to_string());
367 let addr = "192.0.2.2:0".parse().ok();
368 let bind_to = BindTo {
369 addr,
370 ..Default::default()
371 };
372 let new_session = connect(&peer, Some(bind_to)).await;
373 assert_eq!(new_session.unwrap_err().etype(), &InternalError)
374 }
375
376 #[tokio::test]
377 async fn test_conn_error_other() {
378 let peer = HttpPeer::new("240.0.0.1:80".to_string(), false, "".to_string()); let addr = "127.0.0.1:0".parse().ok();
380 let bind_to = BindTo {
382 addr,
383 ..Default::default()
384 };
385 let new_session = connect(&peer, Some(bind_to)).await;
386 let error = new_session.unwrap_err();
387 assert!(error.etype() == &ConnectError || error.etype() == &ConnectTimedout)
389 }
390
391 #[tokio::test]
392 async fn test_conn_timeout() {
393 let mut peer = BasicPeer::new("192.0.2.1:79");
395 peer.options.connection_timeout = Some(std::time::Duration::from_millis(1)); let new_session = connect(&peer, None).await;
397 assert_eq!(new_session.unwrap_err().etype(), &ConnectTimedout)
398 }
399
400 #[tokio::test]
401 async fn test_tweak_hook() {
402 const INIT_FLAG: bool = false;
403
404 let flag = Arc::new(AtomicBool::new(INIT_FLAG));
405
406 let mut peer = BasicPeer::new("1.1.1.1:80");
407
408 let move_flag = Arc::clone(&flag);
409
410 peer.options.upstream_tcp_sock_tweak_hook = Some(Arc::new(move |_| {
411 move_flag.fetch_xor(true, Ordering::SeqCst);
412 Ok(())
413 }));
414
415 connect(&peer, None).await.unwrap();
416
417 assert_eq!(!INIT_FLAG, flag.load(Ordering::SeqCst));
418 }
419
420 #[tokio::test]
421 async fn test_custom_connect() {
422 #[derive(Debug)]
423 struct MyL4;
424 #[async_trait]
425 impl Connect for MyL4 {
426 async fn connect(&self, _addr: &SocketAddr) -> Result<Stream> {
427 tokio::net::TcpStream::connect("1.1.1.1:80")
428 .await
429 .map(|s| s.into())
430 .or_fail()
431 }
432 }
433 let mut peer = BasicPeer::new("1.1.1.1:79");
435 peer.options.custom_l4 = Some(std::sync::Arc::new(MyL4 {}));
436
437 let new_session = connect(&peer, None).await;
438
439 assert!(new_session.is_ok());
441 }
442
443 #[cfg(unix)]
444 #[tokio::test]
445 async fn test_connect_proxy_fail() {
446 let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
447 let mut path = PathBuf::new();
448 path.push("/tmp/123");
449 peer.proxy = Some(Proxy {
450 next_hop: path.into(),
451 host: "1.1.1.1".into(),
452 port: 80,
453 headers: BTreeMap::new(),
454 });
455 let new_session = connect(&peer, None).await;
456 let e = new_session.unwrap_err();
457 assert_eq!(e.etype(), &ConnectError);
458 assert!(!e.retry());
459 }
460
461 #[cfg(unix)]
462 const MOCK_UDS_PATH: &str = "/tmp/test_unix_connect_proxy.sock";
463
464 #[cfg(unix)]
466 async fn mock_connect_server() {
467 let _ = std::fs::remove_file(MOCK_UDS_PATH);
468 let listener = UnixListener::bind(MOCK_UDS_PATH).unwrap();
469 if let Ok((mut stream, _addr)) = listener.accept().await {
470 stream.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
471 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
473 }
474 let _ = std::fs::remove_file(MOCK_UDS_PATH);
475 }
476
477 #[tokio::test(flavor = "multi_thread")]
478 async fn test_connect_proxy_work() {
479 tokio::spawn(async {
480 mock_connect_server().await;
481 });
482 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
484 let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
485 let mut path = PathBuf::new();
486 path.push(MOCK_UDS_PATH);
487 peer.proxy = Some(Proxy {
488 next_hop: path.into(),
489 host: "1.1.1.1".into(),
490 port: 80,
491 headers: BTreeMap::new(),
492 });
493 let new_session = connect(&peer, None).await;
494 assert!(new_session.is_ok());
495 }
496
497 #[cfg(unix)]
498 const MOCK_BAD_UDS_PATH: &str = "/tmp/test_unix_bad_connect_proxy.sock";
499
500 #[cfg(unix)]
503 async fn mock_connect_bad_server() {
504 let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH);
505 let listener = UnixListener::bind(MOCK_BAD_UDS_PATH).unwrap();
506 if let Ok((mut stream, _addr)) = listener.accept().await {
507 stream.shutdown().await.unwrap();
508 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
509 }
510 let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH);
511 }
512
513 #[cfg(unix)]
514 #[tokio::test(flavor = "multi_thread")]
515 async fn test_connect_proxy_conn_closed() {
516 tokio::spawn(async {
517 mock_connect_bad_server().await;
518 });
519 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
521 let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
522 let mut path = PathBuf::new();
523 path.push(MOCK_BAD_UDS_PATH);
524 peer.proxy = Some(Proxy {
525 next_hop: path.into(),
526 host: "1.1.1.1".into(),
527 port: 80,
528 headers: BTreeMap::new(),
529 });
530 let new_session = connect(&peer, None).await;
531 let err = new_session.unwrap_err();
532 assert_eq!(err.etype(), &ConnectionClosed);
533 assert!(!err.retry());
534 }
535
536 #[cfg(target_os = "linux")]
537 #[tokio::test(flavor = "multi_thread")]
538 async fn test_bind_to_port_range_on_connect() {
539 fn get_ip_local_port_range() -> (u16, u16) {
540 let path = "/proc/sys/net/ipv4/ip_local_port_range";
541 let file = std::fs::read_to_string(path).unwrap();
542 let mut parts = file.split_whitespace();
543 (
544 parts.next().unwrap().parse().unwrap(),
545 parts.next().unwrap().parse().unwrap(),
546 )
547 }
548
549 async fn mock_inet_connect_server() -> u16 {
551 use tokio::net::TcpListener;
552 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
553
554 let port = listener.local_addr().unwrap().port();
555
556 tokio::spawn(async move {
557 if let Ok((mut stream, _addr)) = listener.accept().await {
558 stream.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
559 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
561 }
562 });
563
564 port
565 }
566
567 fn in_port_range(session: Stream, lower: u16, upper: u16) -> bool {
568 let digest = session.get_socket_digest();
569 let local_addr = digest
570 .as_ref()
571 .and_then(|s| s.local_addr())
572 .unwrap()
573 .as_inet()
574 .unwrap();
575
576 local_addr.port() >= lower && local_addr.port() <= upper
578 }
579
580 let port = mock_inet_connect_server().await;
581
582 let (low, _) = get_ip_local_port_range();
585 let high = low + 1;
586
587 let peer = HttpPeer::new(format!("127.0.0.1:{port}"), false, "".to_string());
588 let mut bind_to = BindTo {
589 addr: "127.0.0.1:0".parse().ok(),
590 ..Default::default()
591 };
592
593 wait_for_peer(&peer).await;
595
596 bind_to.set_port_range(Some((low, high))).unwrap();
597
598 let mut success_count = 0;
599 let mut address_unavailable_count = 0;
600
601 for _ in 0..10 {
606 match connect(&peer, Some(bind_to.clone())).await {
607 Ok(session) => {
608 assert!(in_port_range(session, low, high));
609 success_count += 1;
610 }
611 Err(e) if format!("{e:?}").contains("AddrNotAvailable") => {
612 address_unavailable_count += 1;
613 }
614 Err(e) => {
615 panic!("Unexpected error {e:?}")
616 }
617 }
618 }
619
620 assert!(address_unavailable_count > 0);
621 assert!(success_count >= (high - low));
622
623 bind_to.set_fallback(true);
625 let session4 = connect(&peer, Some(bind_to.clone())).await.unwrap();
626 assert!(!in_port_range(session4, low, high));
627
628 let low = low + 2;
630 let high = low + 1;
631 let mut bind_to = BindTo::default();
632 bind_to.set_port_range(Some((low, high))).unwrap();
633 let session5 = connect(&peer, Some(bind_to.clone())).await.unwrap();
634 assert!(in_port_range(session5, low, high));
635 }
636
637 #[test]
638 fn test_bind_to_port_ranges() {
639 let addr = "127.0.0.1:0".parse().ok();
640 let mut bind_to = BindTo {
641 addr,
642 ..Default::default()
643 };
644
645 bind_to.set_port_range(None).unwrap();
647 assert!(bind_to.port_range.is_none());
648
649 bind_to.set_port_range(Some((0, 0))).unwrap();
651 assert_eq!(bind_to.port_range, Some((0, 0)));
652
653 bind_to.set_port_range(None).unwrap();
655 assert_eq!(bind_to.port_range, Some((0, 0)));
656
657 assert!(bind_to.set_port_range(Some((2000, 1000))).is_err());
659
660 bind_to.set_port_range(Some((1000, 2000))).unwrap();
662 assert_eq!(bind_to.port_range, Some((1000, 2000)));
663 }
664}