use tokio::net::TcpStream;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use crate::overlay::{DynOverlayResolver, OverlayReachability, RoutingMode};
use crate::{Result, TunnelError};
pub struct OverlayAwareConnector {
direct_url: String,
overlay_url: Option<String>,
routing_mode: RoutingMode,
resolver: Option<DynOverlayResolver>,
}
impl OverlayAwareConnector {
#[must_use]
pub fn new(
direct_url: &str,
overlay_url: Option<&str>,
routing_mode: RoutingMode,
resolver: Option<DynOverlayResolver>,
) -> Self {
Self {
direct_url: direct_url.to_string(),
overlay_url: overlay_url.map(ToString::to_string),
routing_mode,
resolver,
}
}
pub async fn connect(
&self,
) -> Result<(
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::http::Response<Option<Vec<u8>>>,
)> {
match self.routing_mode {
RoutingMode::DirectOnly => self.connect_direct().await,
RoutingMode::OverlayOnly => self.connect_overlay_only().await,
RoutingMode::PreferOverlay => self.connect_prefer_overlay().await,
}
}
async fn connect_direct(
&self,
) -> Result<(
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::http::Response<Option<Vec<u8>>>,
)> {
tracing::debug!(url = %self.direct_url, "connecting directly");
connect_async(&self.direct_url)
.await
.map_err(TunnelError::connection)
}
async fn connect_overlay_only(
&self,
) -> Result<(
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::http::Response<Option<Vec<u8>>>,
)> {
let overlay_url = self.resolve_overlay_url().ok_or_else(|| {
TunnelError::connection_msg("overlay routing required but no overlay URL available")
})?;
tracing::debug!(url = %overlay_url, "connecting via overlay (overlay-only mode)");
connect_async(&overlay_url)
.await
.map_err(TunnelError::connection)
}
async fn connect_prefer_overlay(
&self,
) -> Result<(
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::http::Response<Option<Vec<u8>>>,
)> {
if let Some(overlay_url) = self.resolve_overlay_url() {
tracing::debug!(url = %overlay_url, "attempting overlay connection");
match connect_async(&overlay_url).await {
Ok(result) => {
tracing::debug!("connected via overlay");
return Ok(result);
}
Err(e) => {
tracing::warn!(
error = %e,
"overlay connection failed, falling back to direct"
);
}
}
} else {
tracing::debug!("no overlay URL available, using direct connection");
}
self.connect_direct().await
}
fn resolve_overlay_url(&self) -> Option<String> {
if let Some(ref url) = self.overlay_url {
return Some(url.clone());
}
let resolver = self.resolver.as_ref()?;
if !resolver.overlay_active() {
return None;
}
let url_without_scheme = self
.direct_url
.strip_prefix("wss://")
.or_else(|| self.direct_url.strip_prefix("ws://"))?;
let host = url_without_scheme.split(':').next()?;
match resolver.resolve_overlay_ip(host) {
OverlayReachability::Reachable(ip) => {
let scheme = if self.direct_url.starts_with("wss://") {
"wss"
} else {
"ws"
};
let after_host = url_without_scheme.strip_prefix(host)?;
Some(format!("{scheme}://{ip}{after_host}"))
}
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::overlay::OverlayReachability;
use std::net::Ipv4Addr;
struct MockResolver {
overlay_ip: Option<Ipv4Addr>,
active: bool,
}
impl MockResolver {
fn active_with_ip(ip: Ipv4Addr) -> Self {
Self {
overlay_ip: Some(ip),
active: true,
}
}
fn inactive() -> Self {
Self {
overlay_ip: None,
active: false,
}
}
}
impl crate::overlay::OverlayResolver for MockResolver {
fn resolve_overlay_ip(&self, _node_name: &str) -> OverlayReachability {
match self.overlay_ip {
Some(ip) => OverlayReachability::Reachable(ip),
None => OverlayReachability::Unavailable,
}
}
fn resolve_direct_endpoint(&self, _node_name: &str) -> Option<String> {
None
}
fn local_overlay_ip(&self) -> Option<Ipv4Addr> {
self.overlay_ip
}
fn overlay_active(&self) -> bool {
self.active
}
}
#[test]
fn test_connector_new() {
let connector = OverlayAwareConnector::new(
"ws://localhost:3669/tunnel/v1",
None,
RoutingMode::PreferOverlay,
None,
);
assert_eq!(connector.direct_url, "ws://localhost:3669/tunnel/v1");
assert!(connector.overlay_url.is_none());
assert_eq!(connector.routing_mode, RoutingMode::PreferOverlay);
assert!(connector.resolver.is_none());
}
#[test]
fn test_connector_with_overlay_url() {
let connector = OverlayAwareConnector::new(
"ws://public-host:3669/tunnel/v1",
Some("ws://10.0.0.5:3669/tunnel/v1"),
RoutingMode::PreferOverlay,
None,
);
assert_eq!(
connector.overlay_url.as_deref(),
Some("ws://10.0.0.5:3669/tunnel/v1")
);
}
#[test]
fn test_resolve_overlay_url_precomputed() {
let connector = OverlayAwareConnector::new(
"ws://public-host:3669/tunnel/v1",
Some("ws://10.0.0.5:3669/tunnel/v1"),
RoutingMode::PreferOverlay,
None,
);
let resolved = connector.resolve_overlay_url();
assert_eq!(resolved.as_deref(), Some("ws://10.0.0.5:3669/tunnel/v1"));
}
#[test]
fn test_resolve_overlay_url_from_resolver() {
let mock = std::sync::Arc::new(MockResolver::active_with_ip(Ipv4Addr::new(10, 0, 0, 42)));
let connector = OverlayAwareConnector::new(
"ws://remote-node:3669/tunnel/v1",
None,
RoutingMode::PreferOverlay,
Some(mock),
);
let overlay_url = connector.resolve_overlay_url();
assert_eq!(
overlay_url.as_deref(),
Some("ws://10.0.0.42:3669/tunnel/v1")
);
}
#[test]
fn test_resolve_overlay_url_inactive_resolver() {
let mock = std::sync::Arc::new(MockResolver::inactive());
let connector = OverlayAwareConnector::new(
"ws://remote-node:3669/tunnel/v1",
None,
RoutingMode::PreferOverlay,
Some(mock),
);
let overlay_url = connector.resolve_overlay_url();
assert!(overlay_url.is_none());
}
#[test]
fn test_resolve_overlay_url_no_resolver_no_precomputed() {
let connector = OverlayAwareConnector::new(
"ws://remote-node:3669/tunnel/v1",
None,
RoutingMode::PreferOverlay,
None,
);
let overlay_url = connector.resolve_overlay_url();
assert!(overlay_url.is_none());
}
#[test]
fn test_resolve_overlay_url_preserves_wss_scheme() {
let mock = std::sync::Arc::new(MockResolver::active_with_ip(Ipv4Addr::new(10, 0, 0, 42)));
let connector = OverlayAwareConnector::new(
"wss://remote-node:3669/tunnel/v1",
None,
RoutingMode::PreferOverlay,
Some(mock),
);
let overlay_url = connector.resolve_overlay_url();
assert_eq!(
overlay_url.as_deref(),
Some("wss://10.0.0.42:3669/tunnel/v1")
);
}
#[test]
fn test_direct_only_does_not_resolve_overlay() {
let connector = OverlayAwareConnector::new(
"ws://127.0.0.1:3669/tunnel/v1",
Some("ws://10.0.0.5:3669/tunnel/v1"),
RoutingMode::DirectOnly,
None,
);
assert_eq!(connector.routing_mode, RoutingMode::DirectOnly);
assert_eq!(connector.direct_url, "ws://127.0.0.1:3669/tunnel/v1");
}
#[tokio::test]
async fn test_overlay_only_no_overlay_url_returns_error() {
let connector = OverlayAwareConnector::new(
"ws://127.0.0.1:3669/tunnel/v1",
None,
RoutingMode::OverlayOnly,
None,
);
let result = connector.connect().await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("overlay"),
"Expected overlay-related error, got: {err}"
);
}
#[test]
fn test_prefer_overlay_resolves_url_correctly() {
let mock = std::sync::Arc::new(MockResolver::active_with_ip(Ipv4Addr::new(10, 0, 0, 1)));
let connector = OverlayAwareConnector::new(
"ws://node-b:3669/tunnel/v1",
None,
RoutingMode::PreferOverlay,
Some(mock),
);
let overlay_url = connector.resolve_overlay_url();
assert!(overlay_url.is_some());
assert_eq!(overlay_url.as_deref(), Some("ws://10.0.0.1:3669/tunnel/v1"));
}
#[test]
fn test_prefer_overlay_falls_back_when_no_overlay() {
let connector = OverlayAwareConnector::new(
"ws://node-b:3669/tunnel/v1",
None,
RoutingMode::PreferOverlay,
None,
);
let overlay_url = connector.resolve_overlay_url();
assert!(overlay_url.is_none());
assert_eq!(connector.direct_url, "ws://node-b:3669/tunnel/v1");
}
}