mockforge_core/
ws_proxy.rs

1//! WebSocket proxy functionality for tunneling connections to upstream services
2
3use crate::{Error, Result};
4use axum::extract::ws::{Message as AxumMessage, WebSocket};
5use futures::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
8use tracing::*;
9
10/// WebSocket proxy rule
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct WsProxyRule {
13    /// Path pattern (supports wildcards)
14    pub pattern: String,
15    /// Upstream WebSocket URL for this path
16    pub upstream_url: String,
17    /// Whether this rule is enabled
18    pub enabled: bool,
19}
20
21/// WebSocket proxy configuration
22/// Environment variables:
23/// - MOCKFORGE_WS_PROXY_UPSTREAM_URL: Default upstream WebSocket URL for proxy (default: ws://localhost:9080)
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct WsProxyConfig {
26    /// Default upstream WebSocket URL
27    pub upstream_url: String,
28    /// Whether to enable proxy mode
29    pub enabled: bool,
30    /// Per-path proxy rules
31    #[serde(default)]
32    pub rules: Vec<WsProxyRule>,
33    /// Passthrough by default unless an override applies
34    #[serde(default = "default_passthrough")]
35    pub passthrough_by_default: bool,
36}
37
38fn default_passthrough() -> bool {
39    true
40}
41
42impl Default for WsProxyConfig {
43    fn default() -> Self {
44        Self {
45            upstream_url: std::env::var("MOCKFORGE_WS_PROXY_UPSTREAM_URL")
46                .unwrap_or_else(|_| "ws://localhost:9080".to_string()),
47            enabled: false,
48            rules: Vec::new(),
49            passthrough_by_default: true,
50        }
51    }
52}
53
54impl WsProxyConfig {
55    /// Create a new WebSocket proxy configuration
56    pub fn new(upstream_url: String) -> Self {
57        Self {
58            upstream_url,
59            ..Default::default()
60        }
61    }
62
63    /// Check if a WebSocket connection should be proxied
64    pub fn should_proxy(&self, path: &str) -> bool {
65        if !self.enabled {
66            return false;
67        }
68
69        // Check per-path rules first
70        for rule in &self.rules {
71            if rule.enabled && self.matches_path(&rule.pattern, path) {
72                return true;
73            }
74        }
75
76        // If no specific rule matches, use passthrough behavior
77        self.passthrough_by_default
78    }
79
80    /// Get the upstream URL for a specific path
81    pub fn get_upstream_url(&self, path: &str) -> String {
82        // Check per-path rules first
83        for rule in &self.rules {
84            if rule.enabled && self.matches_path(&rule.pattern, path) {
85                return rule.upstream_url.clone();
86            }
87        }
88
89        // Return default upstream URL
90        self.upstream_url.clone()
91    }
92
93    /// Check if a path matches a pattern
94    fn matches_path(&self, pattern: &str, path: &str) -> bool {
95        if pattern == path {
96            return true;
97        }
98
99        // Simple wildcard matching (* matches any segment)
100        if pattern.contains('*') {
101            let pattern_parts: Vec<&str> = pattern.split('/').collect();
102            let path_parts: Vec<&str> = path.split('/').collect();
103
104            if pattern_parts.len() != path_parts.len() {
105                return false;
106            }
107
108            for (pattern_part, path_part) in pattern_parts.iter().zip(path_parts.iter()) {
109                if *pattern_part != "*" && *pattern_part != *path_part {
110                    return false;
111                }
112            }
113            return true;
114        }
115
116        false
117    }
118}
119
120/// Convert Axum WebSocket message to Tungstenite message
121fn axum_to_tungstenite(msg: AxumMessage) -> TungsteniteMessage {
122    match msg {
123        AxumMessage::Text(text) => TungsteniteMessage::Text(text.to_string().into()),
124        AxumMessage::Binary(data) => TungsteniteMessage::Binary(data),
125        AxumMessage::Ping(data) => TungsteniteMessage::Ping(data),
126        AxumMessage::Pong(data) => TungsteniteMessage::Pong(data),
127        AxumMessage::Close(frame) => TungsteniteMessage::Close(frame.map(|f| {
128            tokio_tungstenite::tungstenite::protocol::CloseFrame {
129                code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::from(
130                    f.code,
131                ),
132                reason: f.reason.to_string().into(),
133            }
134        })),
135    }
136}
137
138/// Convert Tungstenite WebSocket message to Axum message
139fn tungstenite_to_axum(msg: TungsteniteMessage) -> AxumMessage {
140    match msg {
141        TungsteniteMessage::Text(text) => AxumMessage::Text(text.to_string().into()),
142        TungsteniteMessage::Binary(data) => AxumMessage::Binary(data),
143        TungsteniteMessage::Ping(data) => AxumMessage::Ping(data),
144        TungsteniteMessage::Pong(data) => AxumMessage::Pong(data),
145        TungsteniteMessage::Close(frame) => {
146            AxumMessage::Close(frame.map(|f| axum::extract::ws::CloseFrame {
147                code: axum::extract::ws::CloseCode::from(u16::from(f.code)),
148                reason: f.reason.to_string().into(),
149            }))
150        }
151        TungsteniteMessage::Frame(_) => AxumMessage::Text("".to_string().into()), // Should not happen in normal operation
152    }
153}
154
155/// WebSocket proxy handler for tunneling connections to upstream services
156#[derive(Clone)]
157pub struct WsProxyHandler {
158    /// WebSocket proxy configuration
159    pub config: WsProxyConfig,
160}
161
162impl WsProxyHandler {
163    /// Create a new WebSocket proxy handler
164    pub fn new(config: WsProxyConfig) -> Self {
165        Self { config }
166    }
167
168    /// Proxy a WebSocket connection to the upstream service
169    pub async fn proxy_connection(&self, path: &str, client_socket: WebSocket) -> Result<()> {
170        if !self.config.should_proxy(path) {
171            return Err(Error::generic("WebSocket connection should not be proxied".to_string()));
172        }
173
174        // Get the upstream URL for this path
175        let upstream_url = self.config.get_upstream_url(path);
176
177        // Connect to upstream WebSocket server
178        let (upstream_socket, _) =
179            tokio_tungstenite::connect_async(&upstream_url).await.map_err(|e| {
180                Error::generic(format!("Failed to connect to upstream WebSocket: {}", e))
181            })?;
182
183        info!("Connected to upstream WebSocket at {}", upstream_url);
184
185        // Use a simpler approach without shared mutexes
186        let (mut client_sink, mut client_stream) = client_socket.split();
187        let (mut upstream_sink, mut upstream_stream) = upstream_socket.split();
188
189        // Forward messages from client to upstream
190        let forward_client_to_upstream = tokio::spawn(async move {
191            while let Some(msg) = client_stream.next().await {
192                match msg {
193                    Ok(message) => {
194                        let tungstenite_msg = axum_to_tungstenite(message);
195                        if let Err(e) = upstream_sink.send(tungstenite_msg).await {
196                            error!("Failed to send message to upstream: {}", e);
197                            break;
198                        }
199                    }
200                    Err(e) => {
201                        error!("Error receiving message from client: {}", e);
202                        break;
203                    }
204                }
205            }
206        });
207
208        // Forward messages from upstream to client
209        let forward_upstream_to_client = tokio::spawn(async move {
210            while let Some(msg) = upstream_stream.next().await {
211                match msg {
212                    Ok(message) => {
213                        let axum_msg = tungstenite_to_axum(message);
214                        if let Err(e) = client_sink.send(axum_msg).await {
215                            error!("Failed to send message to client: {}", e);
216                            break;
217                        }
218                    }
219                    Err(e) => {
220                        error!("Error receiving message from upstream: {}", e);
221                        break;
222                    }
223                }
224            }
225        });
226
227        // Wait for either task to complete
228        tokio::select! {
229            _ = forward_client_to_upstream => {
230                info!("Client to upstream forwarding completed");
231            }
232            _ = forward_upstream_to_client => {
233                info!("Upstream to client forwarding completed");
234            }
235        }
236
237        Ok(())
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_ws_proxy_config() {
247        let mut config = WsProxyConfig::new("ws://default.example.com".to_string());
248        config.enabled = true;
249        config.rules.push(WsProxyRule {
250            pattern: "/ws/users/*".to_string(),
251            upstream_url: "ws://users.example.com".to_string(),
252            enabled: true,
253        });
254        config.rules.push(WsProxyRule {
255            pattern: "/ws/orders/*".to_string(),
256            upstream_url: "ws://orders.example.com".to_string(),
257            enabled: true,
258        });
259
260        assert!(config.should_proxy("/ws/users/123"));
261        assert!(config.should_proxy("/ws/orders/456"));
262
263        assert_eq!(config.get_upstream_url("/ws/users/123"), "ws://users.example.com");
264        assert_eq!(config.get_upstream_url("/ws/orders/456"), "ws://orders.example.com");
265        assert_eq!(config.get_upstream_url("/ws/products"), "ws://default.example.com");
266    }
267
268    #[test]
269    fn test_ws_proxy_config_passthrough() {
270        let mut config = WsProxyConfig::new("ws://default.example.com".to_string());
271        config.passthrough_by_default = true;
272        config.enabled = true;
273
274        // With passthrough enabled, all connections should be proxied
275        assert!(config.should_proxy("/ws/users"));
276        assert!(config.should_proxy("/ws/orders"));
277
278        // Disable passthrough
279        config.passthrough_by_default = false;
280
281        // Now only connections with matching rules should be proxied
282        assert!(!config.should_proxy("/ws/users"));
283        assert!(!config.should_proxy("/ws/orders"));
284    }
285}