atm0s_reverse_proxy_relayer/
lib.rs

1use std::{collections::HashMap, net::SocketAddr, time::Instant};
2
3use ::metrics::{counter, gauge, histogram};
4use agent::{
5    quic::AgentQuicListener,
6    tcp::{AgentTcpListener, TunnelTcpStream},
7    tls::{AgentTlsListener, TunnelTlsStream},
8    AgentListener, AgentListenerEvent, AgentSession,
9};
10use anyhow::anyhow;
11use p2p::{
12    alias_service::{AliasGuard, AliasService, AliasServiceRequester},
13    HandshakeProtocol, P2pNetwork, P2pNetworkConfig, P2pService, P2pServiceEvent, P2pServiceRequester, PeerAddress, PeerId,
14};
15use protocol::{
16    cluster::{write_object, AgentTunnelRequest},
17    key::{ClusterRequest, ClusterValidator},
18    proxy::{AgentId, ProxyDestination},
19};
20use quic::TunnelQuicStream;
21use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
22use serde::de::DeserializeOwned;
23use tokio::{
24    io::{copy_bidirectional, AsyncRead, AsyncWrite},
25    select,
26};
27
28mod agent;
29mod metrics;
30mod proxy;
31mod quic;
32
33pub use agent::AgentSessionId;
34pub use metrics::*;
35pub use p2p;
36pub use proxy::{http::HttpDestinationDetector, rtsp::RtspDestinationDetector, tls::TlsDestinationDetector, ProxyDestinationDetector, ProxyTcpListener};
37
38const ALIAS_SERVICE: u16 = 0;
39const PROXY_TO_AGENT_SERVICE: u16 = 1;
40const TUNNEL_TO_CLUSTER_SERVICE: u16 = 2;
41
42#[derive(Clone)]
43pub struct TunnelServiceCtx {
44    pub service: P2pServiceRequester,
45    pub alias: AliasServiceRequester,
46}
47
48/// This service take care how we process a incoming request from agent
49pub trait TunnelServiceHandle<Ctx> {
50    fn start(&mut self, _ctx: &TunnelServiceCtx);
51    fn on_agent_conn<S: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(&mut self, _ctx: &TunnelServiceCtx, _agent_id: AgentId, ctx: Ctx, _stream: S);
52    fn on_cluster_event(&mut self, _ctx: &TunnelServiceCtx, _event: P2pServiceEvent);
53}
54
55pub struct QuicRelayerConfig<SECURE, TSH> {
56    pub agent_unsecure_listener: SocketAddr,
57    pub agent_secure_listener: SocketAddr,
58    pub proxy_http_listener: SocketAddr,
59    pub proxy_tls_listener: SocketAddr,
60    pub proxy_rtsp_listener: SocketAddr,
61    pub proxy_rtsps_listener: SocketAddr,
62
63    pub agent_key: PrivatePkcs8KeyDer<'static>,
64    pub agent_cert: CertificateDer<'static>,
65
66    pub sdn_peer_id: PeerId,
67    pub sdn_listener: SocketAddr,
68    pub sdn_seeds: Vec<PeerAddress>,
69    pub sdn_key: PrivatePkcs8KeyDer<'static>,
70    pub sdn_cert: CertificateDer<'static>,
71    pub sdn_advertise_address: Option<SocketAddr>,
72    pub sdn_secure: SECURE,
73
74    pub tunnel_service_handle: TSH,
75}
76
77pub enum QuicRelayerEvent {
78    AgentConnected(AgentId, AgentSessionId, String),
79    AgentDisconnected(AgentId, AgentSessionId),
80    Continue,
81}
82
83pub struct QuicRelayer<SECURE, VALIDATE, REQ: ClusterRequest, TSH> {
84    agent_quic: AgentQuicListener<VALIDATE, REQ>,
85    agent_tcp: AgentTcpListener<VALIDATE, REQ>,
86    agent_tls: AgentTlsListener<VALIDATE, REQ>,
87    http_proxy: ProxyTcpListener<HttpDestinationDetector>,
88    tls_proxy: ProxyTcpListener<TlsDestinationDetector>,
89    rtsp_proxy: ProxyTcpListener<RtspDestinationDetector>,
90    rtsps_proxy: ProxyTcpListener<TlsDestinationDetector>,
91
92    sdn: P2pNetwork<SECURE>,
93
94    sdn_alias_requester: AliasServiceRequester,
95    // This service is for proxy from internet to agent
96    sdn_proxy_service: P2pService,
97    // This service is for tunnel from agent to outside
98    sdn_tunnel_service: P2pService,
99    tunnel_service_ctx: TunnelServiceCtx,
100    tunnel_service_handle: TSH,
101
102    agent_quic_sessions: HashMap<AgentId, HashMap<AgentSessionId, (AgentSession<TunnelQuicStream>, AliasGuard)>>,
103    agent_tcp_sessions: HashMap<AgentId, HashMap<AgentSessionId, (AgentSession<TunnelTcpStream>, AliasGuard)>>,
104    agent_tls_sessions: HashMap<AgentId, HashMap<AgentSessionId, (AgentSession<TunnelTlsStream>, AliasGuard)>>,
105}
106
107impl<SECURE, VALIDATE, REQ: ClusterRequest, TSH> QuicRelayer<SECURE, VALIDATE, REQ, TSH>
108where
109    SECURE: HandshakeProtocol,
110    VALIDATE: ClusterValidator<REQ>,
111    REQ: DeserializeOwned + Send + Sync + 'static,
112    TSH: TunnelServiceHandle<REQ::Context> + Send + Sync + 'static,
113{
114    pub async fn new(mut cfg: QuicRelayerConfig<SECURE, TSH>, validate: VALIDATE) -> anyhow::Result<Self> {
115        let mut sdn = P2pNetwork::new(P2pNetworkConfig {
116            peer_id: cfg.sdn_peer_id,
117            listen_addr: cfg.sdn_listener,
118            advertise: cfg.sdn_advertise_address.map(|a| a.into()),
119            priv_key: cfg.sdn_key,
120            cert: cfg.sdn_cert,
121            tick_ms: 1000,
122            seeds: cfg.sdn_seeds,
123            secure: cfg.sdn_secure,
124        })
125        .await?;
126
127        let mut sdn_alias = AliasService::new(sdn.create_service(ALIAS_SERVICE.into()));
128        let sdn_alias_requester = sdn_alias.requester();
129        tokio::spawn(async move { while sdn_alias.run_loop().await.is_ok() {} });
130        let sdn_proxy_service = sdn.create_service(PROXY_TO_AGENT_SERVICE.into());
131        let sdn_tunnel_service = sdn.create_service(TUNNEL_TO_CLUSTER_SERVICE.into());
132        let tunnel_service_ctx = TunnelServiceCtx {
133            service: sdn_tunnel_service.requester(),
134            alias: sdn_alias_requester.clone(),
135        };
136        cfg.tunnel_service_handle.start(&tunnel_service_ctx);
137
138        Ok(Self {
139            agent_quic: AgentQuicListener::new(cfg.agent_secure_listener, cfg.agent_key.clone_key(), cfg.agent_cert.clone(), validate.clone()).await?,
140            agent_tcp: AgentTcpListener::new(cfg.agent_unsecure_listener, validate.clone()).await?,
141            agent_tls: AgentTlsListener::new(cfg.agent_secure_listener, validate, cfg.agent_key, cfg.agent_cert).await?,
142            http_proxy: ProxyTcpListener::new(cfg.proxy_http_listener, Default::default()).await?,
143            tls_proxy: ProxyTcpListener::new(cfg.proxy_tls_listener, Default::default()).await?,
144            rtsp_proxy: ProxyTcpListener::new(cfg.proxy_rtsp_listener, Default::default()).await?,
145            rtsps_proxy: ProxyTcpListener::new(cfg.proxy_rtsps_listener, Default::default()).await?,
146
147            sdn,
148            sdn_alias_requester,
149            sdn_proxy_service,
150            sdn_tunnel_service,
151            tunnel_service_handle: cfg.tunnel_service_handle,
152            tunnel_service_ctx,
153
154            agent_quic_sessions: HashMap::new(),
155            agent_tcp_sessions: HashMap::new(),
156            agent_tls_sessions: HashMap::new(),
157        })
158    }
159
160    fn process_proxy<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(&mut self, proxy: T, dest: ProxyDestination, is_from_cluster: bool) {
161        let agent_id = match dest.agent_id() {
162            Ok(agent_id) => agent_id,
163            Err(e) => {
164                log::warn!("[QuicRelayer] proxy to {dest:?} failed to get agent id: {e}");
165                return;
166            }
167        };
168        if let Some(sessions) = self.agent_tcp_sessions.get(&agent_id) {
169            let session = sessions.values().next().expect("should have session");
170            let job = proxy_local_to_agent(is_from_cluster, proxy, dest, session.0.clone());
171            tokio::spawn(async move {
172                if let Err(e) = job.await {
173                    counter!(METRICS_PROXY_HTTP_ERROR_COUNT).increment(1);
174                    counter!(METRICS_TUNNEL_AGENT_ERROR_COUNT).increment(1);
175                    log::error!("[QuicRelayer {agent_id}] proxy to agent error {:?}", e);
176                };
177            });
178        } else if let Some(sessions) = self.agent_tls_sessions.get(&agent_id) {
179            let session = sessions.values().next().expect("should have session");
180            let job = proxy_local_to_agent(is_from_cluster, proxy, dest, session.0.clone());
181            tokio::spawn(async move {
182                if let Err(e) = job.await {
183                    counter!(METRICS_PROXY_HTTP_ERROR_COUNT).increment(1);
184                    counter!(METRICS_TUNNEL_AGENT_ERROR_COUNT).increment(1);
185                    log::error!("[QuicRelayer {agent_id}] proxy to agent error {:?}", e);
186                };
187            });
188        } else if let Some(sessions) = self.agent_quic_sessions.get(&agent_id) {
189            let session = sessions.values().next().expect("should have session");
190            let job = proxy_local_to_agent(is_from_cluster, proxy, dest, session.0.clone());
191            tokio::spawn(async move {
192                if let Err(e) = job.await {
193                    counter!(METRICS_PROXY_HTTP_ERROR_COUNT).increment(1);
194                    counter!(METRICS_TUNNEL_AGENT_ERROR_COUNT).increment(1);
195                    log::error!("[QuicRelayer {agent_id}] proxy to agent error {:?}", e);
196                };
197            });
198        } else if !is_from_cluster {
199            // we don't allow two times tunnel over cluster
200            let sdn_requester = self.sdn_proxy_service.requester();
201            let job = proxy_to_cluster(proxy, dest, self.sdn_alias_requester.clone(), sdn_requester);
202            tokio::spawn(async move {
203                if let Err(e) = job.await {
204                    counter!(METRICS_PROXY_HTTP_ERROR_COUNT).increment(1);
205                    counter!(METRICS_TUNNEL_CLUSTER_ERROR_COUNT).increment(1);
206                    log::error!("[QuicRelayer {agent_id}] proxy to cluster error {:?}", e);
207                };
208            });
209        } else {
210            log::warn!("[QuicRelayer {agent_id}] proxy to {dest:?} not match any kind");
211            counter!(METRICS_PROXY_CLUSTER_ERROR_COUNT).increment(1);
212            counter!(METRICS_TUNNEL_AGENT_ERROR_COUNT).increment(1);
213        }
214    }
215
216    pub fn p2p(&mut self) -> &mut P2pNetwork<SECURE> {
217        &mut self.sdn
218    }
219
220    pub async fn recv(&mut self) -> anyhow::Result<QuicRelayerEvent> {
221        select! {
222            tunnel = self.http_proxy.recv() => {
223                let (dest, tunnel) = tunnel?;
224                self.process_proxy(tunnel, dest, false);
225                Ok(QuicRelayerEvent::Continue)
226            },
227            tunnel = self.tls_proxy.recv() => {
228                let (dest, tunnel) = tunnel?;
229                self.process_proxy(tunnel, dest, false);
230                Ok(QuicRelayerEvent::Continue)
231            },
232            tunnel = self.rtsp_proxy.recv() => {
233                let (dest, tunnel) = tunnel?;
234                self.process_proxy(tunnel, dest, false);
235                Ok(QuicRelayerEvent::Continue)
236            },
237            tunnel = self.rtsps_proxy.recv() => {
238                let (dest, tunnel) = tunnel?;
239                self.process_proxy(tunnel, dest, false);
240                Ok(QuicRelayerEvent::Continue)
241            },
242            _ = self.sdn.recv() =>  {
243                Ok(QuicRelayerEvent::Continue)
244            },
245            event = self.agent_quic.recv() => process_incoming_event::<_, _, REQ>(event?, &self.sdn_alias_requester, &mut self.agent_quic_sessions, &mut self.tunnel_service_handle, &self.tunnel_service_ctx),
246            event = self.agent_tcp.recv() => process_incoming_event::<_, _, REQ>(event?, &self.sdn_alias_requester, &mut self.agent_tcp_sessions, &mut self.tunnel_service_handle, &self.tunnel_service_ctx),
247            event = self.agent_tls.recv() => process_incoming_event::<_, _, REQ>(event?, &self.sdn_alias_requester, &mut self.agent_tls_sessions, &mut self.tunnel_service_handle, &self.tunnel_service_ctx),
248            event = self.sdn_proxy_service.recv() => match event.expect("sdn channel crash") {
249                P2pServiceEvent::Unicast(from, ..) => {
250                    log::warn!("[QuicRelayer] proxy service don't accept unicast msg from {from}");
251                    Ok(QuicRelayerEvent::Continue)
252                },
253                P2pServiceEvent::Broadcast(from, ..) => {
254                    log::warn!("[QuicRelayer] proxy service don't accept broadcast msg from {from}");
255                    Ok(QuicRelayerEvent::Continue)
256                },
257                P2pServiceEvent::Stream(_from, meta, stream) => {
258                    if let Ok(proxy_dest) = bincode::deserialize::<ProxyDestination>(&meta) {
259                        self.process_proxy(stream, proxy_dest, true);
260                    }
261                    Ok(QuicRelayerEvent::Continue)
262                },
263            },
264            event = self.sdn_tunnel_service.recv() => {
265                self.tunnel_service_handle.on_cluster_event(&self.tunnel_service_ctx, event.expect("sdn channel crash"));
266                Ok(QuicRelayerEvent::Continue)
267            },
268            _ = tokio::signal::ctrl_c() => {
269                log::info!("[QuicRelayer] shutdown inprogress");
270                self.sdn.shutdown();
271                self.agent_quic.shutdown().await;
272
273                log::info!("[QuicRelayer] shutdown done");
274                Ok(QuicRelayerEvent::Continue)
275            }
276        }
277    }
278}
279
280fn process_incoming_event<S, TSH, REQ>(
281    event: AgentListenerEvent<REQ::Context, S>,
282    alias_requester: &AliasServiceRequester,
283    sessions: &mut HashMap<AgentId, HashMap<AgentSessionId, (AgentSession<S>, AliasGuard)>>,
284    tunnel_service_handle: &mut TSH,
285    tunnel_service_ctx: &TunnelServiceCtx,
286) -> anyhow::Result<QuicRelayerEvent>
287where
288    S: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
289    REQ: ClusterRequest,
290    TSH: TunnelServiceHandle<REQ::Context> + Send + Sync + 'static,
291{
292    match event {
293        AgentListenerEvent::Connected(agent_id, agent_session) => {
294            counter!(METRICS_AGENT_COUNT).increment(1);
295            log::info!("[QuicRelayer] agent {agent_id} {} connected", agent_session.session_id());
296            let session_id = agent_session.session_id();
297            let domain = agent_session.domain().to_owned();
298            let alias = alias_requester.register(*agent_id);
299            sessions.entry(agent_id).or_default().insert(agent_session.session_id(), (agent_session, alias));
300            gauge!(METRICS_AGENT_LIVE).increment(1.0);
301            Ok(QuicRelayerEvent::AgentConnected(agent_id, session_id, domain))
302        }
303        AgentListenerEvent::IncomingStream(agent_id, agent_ctx, stream) => {
304            tunnel_service_handle.on_agent_conn(tunnel_service_ctx, agent_id, agent_ctx, stream);
305            Ok(QuicRelayerEvent::Continue)
306        }
307        AgentListenerEvent::Disconnected(agent_id, session_id) => {
308            log::info!("[QuicRelayer] agent {agent_id} {session_id} disconnected");
309            if let Some(child_sessions) = sessions.get_mut(&agent_id) {
310                child_sessions.remove(&session_id);
311                if child_sessions.is_empty() {
312                    log::info!("[QuicRelayer] agent disconnected all connections {agent_id} {session_id}");
313                    sessions.remove(&agent_id);
314                }
315                gauge!(METRICS_AGENT_LIVE).decrement(1.0);
316            }
317            Ok(QuicRelayerEvent::AgentDisconnected(agent_id, session_id))
318        }
319    }
320}
321
322async fn proxy_local_to_agent<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, S: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
323    is_from_cluster: bool,
324    mut proxy: T,
325    dest: ProxyDestination,
326    agent: AgentSession<S>,
327) -> anyhow::Result<()> {
328    let agent_id = agent.agent_id();
329    let started = Instant::now();
330    if is_from_cluster {
331        counter!(METRICS_PROXY_CLUSTER_COUNT).increment(1);
332    } else {
333        counter!(METRICS_PROXY_HTTP_COUNT).increment(1);
334    }
335    counter!(METRICS_TUNNEL_AGENT_COUNT).increment(1);
336    log::info!("[ProxyLocal {agent_id}] creating stream to agent");
337    let mut stream = agent.create_stream().await?;
338
339    histogram!(METRICS_TUNNEL_AGENT_HISTOGRAM).record(started.elapsed().as_millis() as f32 / 1000.0);
340    log::info!("[ProxyLocal {agent_id}] created stream to agent => writing connect request");
341    write_object::<_, _, 500>(
342        &mut stream,
343        &AgentTunnelRequest {
344            service: dest.service,
345            tls: dest.tls,
346            domain: dest.domain,
347        },
348    )
349    .await?;
350
351    log::info!("[ProxyLocal {agent_id}] proxy data with agent ...");
352
353    gauge!(METRICS_TUNNEL_AGENT_LIVE).increment(1.0);
354    if is_from_cluster {
355        gauge!(METRICS_PROXY_CLUSTER_LIVE).increment(1.0);
356    } else {
357        gauge!(METRICS_PROXY_HTTP_LIVE).increment(1.0);
358    }
359    match copy_bidirectional(&mut proxy, &mut stream).await {
360        Ok(res) => {
361            log::info!("[ProxyLocal {agent_id}] proxy data with agent done with res {res:?}");
362        }
363        Err(e) => {
364            log::error!("[ProxyLocal {agent_id}] proxy data with agent error {e}");
365        }
366    };
367
368    if is_from_cluster {
369        gauge!(METRICS_PROXY_CLUSTER_LIVE).decrement(1.0);
370    } else {
371        gauge!(METRICS_PROXY_HTTP_LIVE).decrement(1.0);
372    }
373    gauge!(METRICS_TUNNEL_AGENT_LIVE).decrement(1.0);
374
375    Ok(())
376}
377
378async fn proxy_to_cluster<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
379    mut proxy: T,
380    dest: ProxyDestination,
381    alias_requeser: AliasServiceRequester,
382    sdn_requester: P2pServiceRequester,
383) -> anyhow::Result<()> {
384    let started = Instant::now();
385    counter!(METRICS_PROXY_HTTP_COUNT).increment(1);
386    counter!(METRICS_TUNNEL_CLUSTER_COUNT).increment(1);
387    let agent_id = dest.agent_id()?;
388    log::info!("[ProxyCluster {agent_id}] finding location of agent {agent_id}");
389    let found_location = alias_requeser.find(*agent_id).await.ok_or(anyhow!("ALIAS_NOT_FOUND"))?;
390    let dest_node = match found_location {
391        p2p::alias_service::AliasFoundLocation::Local => return Err(anyhow!("wrong alias context, cluster shouldn't in local")),
392        p2p::alias_service::AliasFoundLocation::Hint(dest) => dest,
393        p2p::alias_service::AliasFoundLocation::Scan(dest) => dest,
394    };
395    log::info!("[ProxyCluster {agent_id}] found location of agent {agent_id}: {found_location:?} => opening cluster connection to {dest_node}");
396
397    let meta = bincode::serialize(&dest).expect("should convert ProxyDestination to bytes");
398
399    let mut stream = sdn_requester.open_stream(dest_node, meta).await?;
400    histogram!(METRICS_TUNNEL_CLUSTER_HISTOGRAM).record(started.elapsed().as_millis() as f32 / 1000.0);
401
402    log::info!("[ProxyCluster {agent_id}] proxy over {dest_node} ...");
403    gauge!(METRICS_TUNNEL_CLUSTER_LIVE).increment(1.0);
404    gauge!(METRICS_PROXY_HTTP_LIVE).increment(1.0);
405
406    match copy_bidirectional(&mut proxy, &mut stream).await {
407        Ok(res) => {
408            log::info!("[ProxyCluster {agent_id}] proxy over {dest_node} done with res {res:?}");
409        }
410        Err(e) => {
411            log::error!("[ProxyCluster {agent_id}] proxy over {dest_node} error {e}");
412        }
413    }
414
415    gauge!(METRICS_PROXY_HTTP_LIVE).decrement(1.0);
416    gauge!(METRICS_TUNNEL_CLUSTER_LIVE).decrement(1.0);
417    Ok(())
418}