use crate::{Error, Result};
use axum::extract::ws::{Message as AxumMessage, WebSocket};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
use tracing::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsProxyRule {
pub pattern: String,
pub upstream_url: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsProxyConfig {
pub upstream_url: String,
pub enabled: bool,
#[serde(default)]
pub rules: Vec<WsProxyRule>,
#[serde(default = "default_passthrough")]
pub passthrough_by_default: bool,
}
fn default_passthrough() -> bool {
true
}
impl Default for WsProxyConfig {
fn default() -> Self {
Self {
upstream_url: std::env::var("MOCKFORGE_WS_PROXY_UPSTREAM_URL")
.unwrap_or_else(|_| "ws://localhost:9080".to_string()),
enabled: false,
rules: Vec::new(),
passthrough_by_default: true,
}
}
}
impl WsProxyConfig {
pub fn new(upstream_url: String) -> Self {
Self {
upstream_url,
..Default::default()
}
}
pub fn should_proxy(&self, path: &str) -> bool {
if !self.enabled {
return false;
}
for rule in &self.rules {
if rule.enabled && self.matches_path(&rule.pattern, path) {
return true;
}
}
self.passthrough_by_default
}
pub fn get_upstream_url(&self, path: &str) -> String {
for rule in &self.rules {
if rule.enabled && self.matches_path(&rule.pattern, path) {
return rule.upstream_url.clone();
}
}
self.upstream_url.clone()
}
fn matches_path(&self, pattern: &str, path: &str) -> bool {
if pattern == path {
return true;
}
if pattern.contains('*') {
let pattern_parts: Vec<&str> = pattern.split('/').collect();
let path_parts: Vec<&str> = path.split('/').collect();
if pattern_parts.len() != path_parts.len() {
return false;
}
for (pattern_part, path_part) in pattern_parts.iter().zip(path_parts.iter()) {
if *pattern_part != "*" && *pattern_part != *path_part {
return false;
}
}
return true;
}
false
}
}
fn axum_to_tungstenite(msg: AxumMessage) -> TungsteniteMessage {
match msg {
AxumMessage::Text(text) => TungsteniteMessage::Text(text.to_string().into()),
AxumMessage::Binary(data) => TungsteniteMessage::Binary(data),
AxumMessage::Ping(data) => TungsteniteMessage::Ping(data),
AxumMessage::Pong(data) => TungsteniteMessage::Pong(data),
AxumMessage::Close(frame) => TungsteniteMessage::Close(frame.map(|f| {
tokio_tungstenite::tungstenite::protocol::CloseFrame {
code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::from(
f.code,
),
reason: f.reason.to_string().into(),
}
})),
}
}
fn tungstenite_to_axum(msg: TungsteniteMessage) -> AxumMessage {
match msg {
TungsteniteMessage::Text(text) => AxumMessage::Text(text.to_string().into()),
TungsteniteMessage::Binary(data) => AxumMessage::Binary(data),
TungsteniteMessage::Ping(data) => AxumMessage::Ping(data),
TungsteniteMessage::Pong(data) => AxumMessage::Pong(data),
TungsteniteMessage::Close(frame) => {
AxumMessage::Close(frame.map(|f| axum::extract::ws::CloseFrame {
code: axum::extract::ws::CloseCode::from(u16::from(f.code)),
reason: f.reason.to_string().into(),
}))
}
TungsteniteMessage::Frame(_) => AxumMessage::Text("".to_string().into()), }
}
#[derive(Clone)]
pub struct WsProxyHandler {
pub config: WsProxyConfig,
}
impl WsProxyHandler {
pub fn new(config: WsProxyConfig) -> Self {
Self { config }
}
pub async fn proxy_connection(&self, path: &str, client_socket: WebSocket) -> Result<()> {
if !self.config.should_proxy(path) {
return Err(Error::internal("WebSocket connection should not be proxied".to_string()));
}
let upstream_url = self.config.get_upstream_url(path);
let (upstream_socket, _) =
tokio_tungstenite::connect_async(&upstream_url).await.map_err(|e| {
Error::internal(format!("Failed to connect to upstream WebSocket: {}", e))
})?;
info!("Connected to upstream WebSocket at {}", upstream_url);
let (mut client_sink, mut client_stream) = client_socket.split();
let (mut upstream_sink, mut upstream_stream) = upstream_socket.split();
let forward_client_to_upstream = tokio::spawn(async move {
while let Some(msg) = client_stream.next().await {
match msg {
Ok(message) => {
let tungstenite_msg = axum_to_tungstenite(message);
if let Err(e) = upstream_sink.send(tungstenite_msg).await {
error!("Failed to send message to upstream: {}", e);
break;
}
}
Err(e) => {
error!("Error receiving message from client: {}", e);
break;
}
}
}
});
let forward_upstream_to_client = tokio::spawn(async move {
while let Some(msg) = upstream_stream.next().await {
match msg {
Ok(message) => {
let axum_msg = tungstenite_to_axum(message);
if let Err(e) = client_sink.send(axum_msg).await {
error!("Failed to send message to client: {}", e);
break;
}
}
Err(e) => {
error!("Error receiving message from upstream: {}", e);
break;
}
}
}
});
tokio::select! {
_ = forward_client_to_upstream => {
info!("Client to upstream forwarding completed");
}
_ = forward_upstream_to_client => {
info!("Upstream to client forwarding completed");
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ws_proxy_config() {
let mut config = WsProxyConfig::new("ws://default.example.com".to_string());
config.enabled = true;
config.rules.push(WsProxyRule {
pattern: "/ws/users/*".to_string(),
upstream_url: "ws://users.example.com".to_string(),
enabled: true,
});
config.rules.push(WsProxyRule {
pattern: "/ws/orders/*".to_string(),
upstream_url: "ws://orders.example.com".to_string(),
enabled: true,
});
assert!(config.should_proxy("/ws/users/123"));
assert!(config.should_proxy("/ws/orders/456"));
assert_eq!(config.get_upstream_url("/ws/users/123"), "ws://users.example.com");
assert_eq!(config.get_upstream_url("/ws/orders/456"), "ws://orders.example.com");
assert_eq!(config.get_upstream_url("/ws/products"), "ws://default.example.com");
}
#[test]
fn test_ws_proxy_config_passthrough() {
let mut config = WsProxyConfig::new("ws://default.example.com".to_string());
config.passthrough_by_default = true;
config.enabled = true;
assert!(config.should_proxy("/ws/users"));
assert!(config.should_proxy("/ws/orders"));
config.passthrough_by_default = false;
assert!(!config.should_proxy("/ws/users"));
assert!(!config.should_proxy("/ws/orders"));
}
}