1use socket2::{SockRef, TcpKeepalive};
2use tokio::io::AsyncReadExt;
3use tokio::io::AsyncWriteExt;
4
5use kube::{api::Api, Client};
6
7use crate::{
8 config::ForwardConfig,
9 config::PodSelector,
10 error::{PortForwardError, Result},
11 metrics::ForwardMetrics,
12 util::ServiceInfo,
13};
14use anyhow;
15use chrono::DateTime;
16use chrono::Utc;
17use k8s_openapi::api::core::v1::Pod;
18use std::sync::Arc;
19use tokio::sync::{broadcast, RwLock};
20
21use tracing::{debug, error, info, warn};
22
23use futures::TryStreamExt;
24
25use std::net::SocketAddr;
26use tokio::{
27 io::{AsyncRead, AsyncWrite},
28 net::TcpListener,
29};
30use tokio_stream::wrappers::TcpListenerStream;
31
32#[derive(Debug)]
33pub struct HealthCheck {
34 pub last_check: Arc<RwLock<Option<DateTime<Utc>>>>,
35 pub failures: Arc<RwLock<u32>>,
36}
37
38impl Default for HealthCheck {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl HealthCheck {
45 pub fn new() -> Self {
46 Self {
47 last_check: Arc::new(RwLock::new(None)),
48 failures: Arc::new(RwLock::new(0)),
49 }
50 }
51
52 pub async fn check_connection(&self, local_port: u16, protocol: &str) -> bool {
53 match protocol.to_uppercase().as_str() {
54 "UDP" => {
55 use tokio::net::UdpSocket;
57
58 match UdpSocket::bind("127.0.0.1:0").await {
60 Ok(test_socket) => {
61 if test_socket
63 .connect(format!("127.0.0.1:{}", local_port))
64 .await
65 .is_ok()
66 {
67 let test_packet = vec![
69 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ];
76
77 match test_socket.send(&test_packet).await {
78 Ok(_) => {
79 match UdpSocket::bind(format!("127.0.0.1:{}", local_port)).await
81 {
82 Ok(_) => {
83 let mut failures = self.failures.write().await;
85 *failures += 1;
86 false
87 }
88 Err(_) => {
89 *self.failures.write().await = 0;
91 *self.last_check.write().await = Some(Utc::now());
92 true
93 }
94 }
95 }
96 Err(_) => {
97 let mut failures = self.failures.write().await;
98 *failures += 1;
99 false
100 }
101 }
102 } else {
103 let mut failures = self.failures.write().await;
104 *failures += 1;
105 false
106 }
107 }
108 Err(_) => {
109 let mut failures = self.failures.write().await;
110 *failures += 1;
111 false
112 }
113 }
114 }
115 _ => {
116 use tokio::net::TcpStream;
118 match TcpStream::connect(format!("127.0.0.1:{}", local_port)).await {
119 Ok(_) => {
120 *self.failures.write().await = 0;
121 *self.last_check.write().await = Some(Utc::now());
122 true
123 }
124 Err(_) => {
125 let mut failures = self.failures.write().await;
126 *failures += 1;
127 false
128 }
129 }
130 }
131 }
132 }
133}
134
135#[derive(Debug, Clone, PartialEq)]
137pub enum ForwardState {
138 Starting,
139 Connected,
140 Disconnected,
141 Failed(String),
142 Stopping,
143}
144
145#[derive(Debug, Clone)]
146pub struct PortForward {
147 pub config: ForwardConfig,
148 pub service_info: ServiceInfo,
149 pub state: Arc<RwLock<ForwardState>>,
150 pub shutdown: broadcast::Sender<()>,
151 pub metrics: ForwardMetrics,
152}
153
154impl PortForward {
155 pub fn new(config: ForwardConfig, service_info: ServiceInfo) -> Self {
156 let (shutdown_tx, _) = broadcast::channel(1);
157 Self {
158 metrics: ForwardMetrics::new(config.name.clone()),
159 config,
160 service_info,
161 state: Arc::new(RwLock::new(ForwardState::Starting)),
162 shutdown: shutdown_tx,
163 }
164 }
165
166 pub async fn start(&self, client: Client) -> Result<()> {
167 let mut retry_count = 0;
168 let mut shutdown_rx = self.shutdown.subscribe();
169
170 loop {
171 if retry_count >= self.config.options.max_retries
172 && !self.config.options.persistent_connection
173 {
174 let err_msg = "Max retry attempts reached".to_string();
175 *self.state.write().await = ForwardState::Failed(err_msg.clone());
176 return Err(PortForwardError::ConnectionError(err_msg));
177 }
178
179 self.metrics.record_connection_attempt();
180
181 match self.establish_forward(&client).await {
182 Ok(()) => {
183 *self.state.write().await = ForwardState::Connected;
184 self.metrics.record_connection_success();
185 self.metrics.set_connection_status(true);
186 info!("Port-forward established for {}", self.config.name);
187
188 tokio::select! {
190 _ = shutdown_rx.recv() => {
191 info!("Received shutdown signal for {}", self.config.name);
192 break;
193 }
194 _ = self.monitor_connection(&client) => {
195 warn!("Connection lost for {}, attempting to reconnect", self.config.name);
196 *self.state.write().await = ForwardState::Disconnected;
197 }
198 }
199 }
200 Err(e) => {
201 warn!(
202 "Failed to establish port-forward for {}: {}",
203 self.config.name, e
204 );
205 self.metrics.record_connection_failure();
206 self.metrics.set_connection_status(false);
207 retry_count += 1;
208 tokio::time::sleep(self.config.options.retry_interval).await;
209 continue;
210 }
211 }
212 }
213
214 Ok(())
215 }
216
217 pub async fn monitor_connection(&self, client: &Client) -> Result<()> {
218 let health_check = HealthCheck::new();
219 let mut interval = tokio::time::interval(self.config.options.health_check_interval);
220 let mut consecutive_failures = 0;
221 let protocol = self
222 .config
223 .ports
224 .protocol
225 .clone()
226 .unwrap_or_else(|| "TCP".to_string());
227 let max_failures = if protocol.to_uppercase() == "UDP" {
228 5
229 } else {
230 3
231 }; tokio::time::sleep(std::time::Duration::from_secs(2)).await;
235
236 let state = self.state.read().await;
238 if !matches!(*state, ForwardState::Connected | ForwardState::Starting) {
239 return Err(PortForwardError::ConnectionError(
240 "Cannot monitor connection: not in Connected or Starting state".to_string(),
241 ));
242 }
243 drop(state);
244
245 let mut initial_attempts = 0;
247 let max_initial_attempts = 3;
248 while initial_attempts < max_initial_attempts {
249 if health_check
250 .check_connection(
251 self.config.ports.local,
252 &self
253 .config
254 .ports
255 .protocol
256 .clone()
257 .expect("Protocol configuration"),
258 )
259 .await
260 {
261 debug!("Initial health check passed for {}", self.config.name);
262 break;
263 }
264 initial_attempts += 1;
265 if initial_attempts < max_initial_attempts {
266 debug!(
267 "Initial health check attempt {}/{} failed for {}, retrying...",
268 initial_attempts, max_initial_attempts, self.config.name
269 );
270 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
271 }
272 }
273
274 if initial_attempts >= max_initial_attempts {
275 return Err(PortForwardError::ConnectionError(
276 "Failed initial health checks".to_string(),
277 ));
278 }
279
280 loop {
281 interval.tick().await;
282
283 let protocol = self
285 .config
286 .ports
287 .protocol
288 .clone()
289 .unwrap_or_else(|| "TCP".to_string());
290 if !health_check
291 .check_connection(self.config.ports.local, &protocol)
292 .await
293 {
294 consecutive_failures += 1;
295 if consecutive_failures > 1 {
296 warn!(
297 "Health check failed for {} ({}/{})",
298 self.config.name, consecutive_failures, max_failures
299 );
300 } else {
301 debug!(
302 "Health check failed for {} ({}/{})",
303 self.config.name, consecutive_failures, max_failures
304 );
305 }
306 continue;
307 }
308 consecutive_failures = 0; if let Ok(pod) = self.get_pod(client).await {
312 if let Some(status) = &pod.status {
313 if let Some(phase) = &status.phase {
314 if phase != "Running" {
315 return Err(PortForwardError::ConnectionError(
316 "Pod is no longer running".to_string(),
317 ));
318 }
319 }
320 }
321 } else {
322 return Err(PortForwardError::ConnectionError(
323 "Pod not found".to_string(),
324 ));
325 }
326 }
327 }
328
329 pub async fn establish_forward(&self, client: &Client) -> Result<()> {
330 let mut current_state = self.state.write().await;
332 match *current_state {
333 ForwardState::Connected => {
334 debug!("Port forward {} is already connected", self.config.name);
335 return Ok(());
336 }
337 ForwardState::Starting => {
338 debug!("Port forward {} is already starting", self.config.name);
339 return Ok(());
340 }
341 _ => {
342 *current_state = ForwardState::Starting;
343 }
344 }
345 drop(current_state); self.metrics.record_connection_attempt();
348 let pod = self.get_pod(client).await?;
350 let pod_name = pod.metadata.name.clone().ok_or_else(|| {
352 self.metrics.record_connection_failure();
353 PortForwardError::ConnectionError("Pod name not found".to_string())
354 })?;
355
356 let _pods: Api<Pod> = Api::namespaced(client.clone(), &self.service_info.namespace);
358
359 let mut retry_count = 0;
361 let max_bind_retries = 3;
362 let bind_retry_delay = std::time::Duration::from_secs(1);
363
364 debug!(
366 "Creating TCP listener for the local port: {}",
367 self.config.ports.local
368 );
369
370 let protocol = self
372 .config
373 .ports
374 .protocol
375 .clone()
376 .unwrap_or_else(|| "TCP".to_string());
377 debug!(
378 "Creating {} listener for the local port: {}",
379 protocol, self.config.ports.local
380 );
381
382 let addr = SocketAddr::from(([127, 0, 0, 1], self.config.ports.local));
383
384 match protocol.to_uppercase().as_str() {
385 "TCP" => {
386 let listener = loop {
387 match TcpListener::bind(addr).await {
388 Ok(listener) => break listener,
389 Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
390 if retry_count >= max_bind_retries {
391 self.metrics.record_connection_failure();
392 return Err(PortForwardError::ConnectionError(format!(
393 "Port {} is already in use. Please choose a different local port",
394 self.config.ports.local
395 )));
396 }
397 if let Err(release_err) = self.try_release_port().await {
399 warn!(
400 "Failed to release port {}: {}",
401 self.config.ports.local, release_err
402 );
403 }
404 retry_count += 1;
405 debug!(
406 "Port {} in use, retrying in {:?}...",
407 self.config.ports.local, bind_retry_delay
408 );
409 tokio::time::sleep(bind_retry_delay).await;
410 continue;
411 }
412 Err(e) => {
413 self.metrics.record_connection_failure();
414 return Err(PortForwardError::ConnectionError(format!(
415 "Failed to bind to port: {}",
416 e
417 )));
418 }
419 }
420 };
421
422 let ka = TcpKeepalive::new().with_time(std::time::Duration::from_secs(30));
424 let sf = SockRef::from(&listener);
425 let _ = sf.set_tcp_keepalive(&ka);
426
427 self.handle_tcp_forward(client, pod_name, listener).await?;
428 }
429 "UDP" => {
430 let socket = tokio::net::UdpSocket::bind(addr).await.map_err(|e| {
431 PortForwardError::ConnectionError(format!("Failed to bind UDP socket: {}", e))
432 })?;
433
434 self.handle_udp_forward(client, pod_name, socket).await?;
435 }
436 _ => {
437 return Err(PortForwardError::ConnectionError(format!(
438 "Unsupported protocol: {}",
439 protocol
440 )));
441 }
442 };
443
444 *self.state.write().await = ForwardState::Connected;
446 self.metrics.record_connection_success();
447 self.metrics.set_connection_status(true);
448
449 Ok(())
450 }
451
452 pub async fn forward_connection(
453 pods: &Api<Pod>,
454 pod_name: String,
455 port: u16,
456 mut client_conn: impl AsyncRead + AsyncWrite + Unpin,
457 ) -> anyhow::Result<()> {
458 debug!("Starting port forward for port {}", port);
459
460 let mut pf = pods
462 .portforward(&pod_name, &[port])
463 .await
464 .map_err(|e| anyhow::anyhow!("Failed to create portforward: {}", e))?;
465
466 let mut upstream_conn = pf
468 .take_stream(port) .ok_or_else(|| {
470 anyhow::anyhow!("Failed to get port forward stream for port {}", port)
471 })?;
472
473 debug!("Port forward stream established for port {}", port);
474
475 match tokio::time::timeout(
477 std::time::Duration::from_secs(30), tokio::io::copy_bidirectional(&mut client_conn, &mut upstream_conn),
479 )
480 .await
481 {
482 Ok(Ok(_)) => {
483 debug!("Connection closed normally for port {}", port);
484 }
485 Ok(Err(e)) => {
486 warn!("Error during data transfer for port {}: {}", port, e);
487 return Err(anyhow::anyhow!("Data transfer error: {}", e));
488 }
489 Err(_) => {
490 warn!("Connection timeout for port {}", port);
491 return Err(anyhow::anyhow!("Connection timeout"));
492 }
493 }
494
495 drop(upstream_conn);
497
498 if let Err(e) = pf.join().await {
500 warn!("Port forwarder join error: {}", e);
501 }
502
503 Ok(())
504 }
505
506 async fn handle_udp_forward(
507 &self,
508 client: &Client,
509 pod_name: String,
510 socket: tokio::net::UdpSocket,
511 ) -> Result<()> {
512 let state = self.state.clone();
513 let name = self.config.name.clone();
514 let remote_port = self.config.ports.remote;
515 let mut shutdown = self.shutdown.subscribe();
516 let metrics = self.metrics.clone();
517 let pods: Api<Pod> = Api::namespaced(client.clone(), &self.service_info.namespace);
518
519 tokio::spawn(async move {
520 let mut buf = vec![0u8; 65535]; let socket = Arc::new(socket);
522
523 loop {
524 tokio::select! {
525 result = socket.recv_from(&mut buf) => {
526 match result {
527 Ok((len, peer)) => {
528 let pods = pods.clone();
529 let pod_name = pod_name.clone();
530 let metrics = metrics.clone();
531 let socket = socket.clone();
532 let data = buf[..len].to_vec();
533
534 tokio::spawn(async move {
535 if let Err(e) = Self::handle_udp_packet(&pods, pod_name, remote_port, socket, data, peer).await {
536 error!("Failed to forward UDP packet: {}", e);
537 metrics.record_connection_failure();
538 } else {
539 metrics.record_connection_success();
540 }
541 });
542 }
543 Err(e) => {
544 error!("UDP receive error: {}", e);
545 metrics.record_connection_failure();
546 break;
547 }
548 }
549 }
550 _ = shutdown.recv() => {
551 info!("Received shutdown signal for UDP forward {}", name);
552 *state.write().await = ForwardState::Disconnected;
553 metrics.set_connection_status(false);
554 break;
555 }
556 }
557 }
558 });
559
560 Ok(())
561 }
562
563 pub async fn handle_udp_packet(
564 pods: &Api<Pod>,
565 pod_name: String,
566 port: u16,
567 socket: Arc<tokio::net::UdpSocket>,
568 data: Vec<u8>,
569 peer: SocketAddr,
570 ) -> anyhow::Result<()> {
571 let mut pf = pods
573 .portforward(&pod_name, &[port])
574 .await
575 .map_err(|e| anyhow::anyhow!("Failed to create UDP portforward: {}", e))?;
576
577 let mut upstream_conn = pf.take_stream(port).ok_or_else(|| {
579 anyhow::anyhow!("Failed to get UDP port forward stream for port {}", port)
580 })?;
581
582 let len_bytes = (data.len() as u16).to_be_bytes();
584 upstream_conn.write_all(&len_bytes).await?;
585
586 upstream_conn.write_all(&data).await?;
588 upstream_conn.flush().await?;
589
590 let mut len_buf = [0u8; 2];
592 match upstream_conn.read_exact(&mut len_buf).await {
593 Ok(_) => {
594 let response_length = u16::from_be_bytes(len_buf) as usize;
595
596 let mut response = vec![0u8; response_length];
598 match upstream_conn.read_exact(&mut response).await {
599 Ok(_) => {
600 socket.send_to(&response, peer).await?;
602 }
603 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
604 debug!("No response data received from upstream");
606 }
607 Err(e) => return Err(e.into()),
608 }
609 }
610 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
611 debug!("Connection closed by upstream after sending data");
613 }
614 Err(e) => return Err(e.into()),
615 }
616
617 Ok(())
618 }
619
620 pub async fn handle_tcp_forward(
621 &self,
622 client: &Client,
623 pod_name: String,
624 listener: TcpListener,
625 ) -> Result<()> {
626 let state = self.state.clone();
627 let name = self.config.name.clone();
628 let remote_port = self.config.ports.remote;
629 let mut shutdown = self.shutdown.subscribe();
630 let metrics = self.metrics.clone();
631 let pods: Api<Pod> = Api::namespaced(client.clone(), &self.service_info.namespace);
632
633 tokio::spawn(async move {
634 let mut listener_stream = TcpListenerStream::new(listener);
635
636 loop {
637 tokio::select! {
638 Ok(Some(client_conn)) = listener_stream.try_next() => {
639 if let Ok(peer_addr) = client_conn.peer_addr() {
640 info!(%peer_addr, "New TCP connection for {}", name);
641 metrics.record_connection_attempt();
642 }
643 let pods = pods.clone();
644 let pod_name = pod_name.clone();
645 let metrics = metrics.clone();
646
647 tokio::spawn(async move {
648 if let Err(e) = Self::forward_connection(&pods, pod_name, remote_port, client_conn).await {
649 error!("Failed to forward TCP connection: {}", e);
650 metrics.record_connection_failure();
651 } else {
652 metrics.record_connection_success();
653 }
654 });
655 }
656 _ = shutdown.recv() => {
657 info!("Received shutdown signal for TCP forward {}", name);
658 *state.write().await = ForwardState::Disconnected;
659 metrics.set_connection_status(false);
660 break;
661 }
662 else => {
663 error!("Port forward {} listener closed", name);
664 *state.write().await = ForwardState::Failed("Listener closed unexpectedly".to_string());
665 metrics.set_connection_status(false);
666 metrics.record_connection_failure();
667 break;
668 }
669 }
670 }
671 });
672
673 Ok(())
674 }
675
676 pub async fn get_pod(&self, client: &Client) -> Result<Pod> {
677 let pods: Api<Pod> = Api::namespaced(client.clone(), &self.service_info.namespace);
678
679 let pod_list = pods
681 .list(&kube::api::ListParams::default())
682 .await
683 .map_err(PortForwardError::KubeError)?;
684
685 for pod in pod_list.items {
686 if self
687 .clone()
688 .matches_pod_selector(&pod, &self.config.pod_selector)
689 {
690 if let Some(status) = &pod.status {
691 if let Some(phase) = &status.phase {
692 if phase == "Running" {
693 return Ok(pod);
694 }
695 }
696 }
697 }
698 }
699
700 Err(PortForwardError::ConnectionError(format!(
701 "No ready pods found matching selector for service {}",
702 self.service_info.name
703 )))
704 }
705
706 pub fn matches_pod_selector(self, pod: &Pod, selector: &PodSelector) -> bool {
707 if selector.label.is_none() && selector.annotation.is_none() {
709 return pod
710 .metadata
711 .labels
712 .as_ref()
713 .is_some_and(|labels| labels.values().any(|v| v == &self.service_info.name));
714 }
715
716 if let Some(label_selector) = &selector.label {
718 let (key, value) = self.clone().parse_selector(label_selector);
719 if pod
720 .metadata
721 .labels
722 .as_ref()
723 .is_none_or(|labels| labels.get(key).is_none_or(|v| v != value))
724 {
725 return false;
726 }
727 }
728
729 if let Some(annotation_selector) = &selector.annotation {
731 let (key, value) = self.clone().parse_selector(annotation_selector);
732 if pod
733 .metadata
734 .annotations
735 .as_ref()
736 .is_none_or(|annotations| annotations.get(key).is_none_or(|v| v != value))
737 {
738 return false;
739 }
740 }
741
742 true
743 }
744
745 pub fn parse_selector(self, selector: &str) -> (&str, &str) {
746 let parts: Vec<&str> = selector.split('=').collect();
747 match parts.as_slice() {
748 [key, value] => (*key, *value),
749 _ => ("", ""), }
751 }
752
753 pub async fn try_release_port(&self) -> std::io::Result<()> {
754 let addr = SocketAddr::from(([127, 0, 0, 1], self.config.ports.local));
755
756 let state = self.state.read().await;
758 match *state {
759 ForwardState::Connected | ForwardState::Starting => {
760 debug!(
761 "Port {} is in use by our own active connection (state: {:?})",
762 self.config.ports.local, state
763 );
764 return Ok(());
765 }
766 _ => drop(state),
767 }
768
769 let socket = tokio::net::TcpSocket::new_v4()?;
771
772 match tokio::net::TcpStream::connect(addr).await {
774 Ok(_) => {
775 let state = self.state.read().await;
777 if matches!(*state, ForwardState::Connected | ForwardState::Starting) {
778 debug!(
779 "Port {} is in use by our active connection (verified)",
780 self.config.ports.local
781 );
782 Ok(())
783 } else {
784 debug!(
785 "Port {} is in use by another process",
786 self.config.ports.local
787 );
788 Err(std::io::Error::new(
789 std::io::ErrorKind::AddrInUse,
790 "Port is actively in use by another process",
791 ))
792 }
793 }
794 Err(_) => {
795 match socket.bind(addr) {
797 Ok(_) => {
798 debug!("Port {} is free", self.config.ports.local);
799 Ok(())
800 }
801 Err(e) => {
802 debug!("Port {} bind error: {}", self.config.ports.local, e);
803 Err(e)
804 }
805 }
806 }
807 }
808 }
809
810 pub async fn stop(&self) {
811 *self.state.write().await = ForwardState::Stopping;
813
814 let _ = self.shutdown.send(());
816
817 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
819
820 if let Err(e) = self.try_release_port().await {
822 warn!(
823 "Failed to release port {} during shutdown: {}",
824 self.config.ports.local, e
825 );
826 }
827
828 *self.state.write().await = ForwardState::Disconnected;
830
831 self.metrics.set_connection_status(false);
833 }
834}
835
836pub struct PortForwardManager {
838 pub forwards: Arc<RwLock<Vec<Arc<PortForward>>>>,
839 pub client: Client,
840}
841
842impl PortForwardManager {
843 pub fn new(client: Client) -> Self {
844 Self {
845 forwards: Arc::new(RwLock::new(Vec::new())),
846 client,
847 }
848 }
849
850 pub async fn add_forward(
851 &self,
852 config: ForwardConfig,
853 service_info: ServiceInfo,
854 ) -> Result<()> {
855 let forward = Arc::new(PortForward::new(config, service_info));
856 self.forwards.write().await.push(forward.clone());
857
858 let client = self.client.clone();
860 tokio::spawn(async move {
861 if let Err(e) = forward.start(client).await {
862 error!("Port-forward failed: {}", e);
863 }
864 });
865
866 Ok(())
867 }
868
869 pub async fn stop_all(&self) {
870 for forward in self.forwards.read().await.iter() {
871 forward.stop().await;
872 }
873 }
874}