mockforge_core/
ws_proxy.rs1use crate::{Error, Result};
4use axum::extract::ws::{Message as AxumMessage, WebSocket};
5use futures::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
8use tracing::*;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct WsProxyRule {
13    pub pattern: String,
15    pub upstream_url: String,
17    pub enabled: bool,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct WsProxyConfig {
26    pub upstream_url: String,
28    pub enabled: bool,
30    #[serde(default)]
32    pub rules: Vec<WsProxyRule>,
33    #[serde(default = "default_passthrough")]
35    pub passthrough_by_default: bool,
36}
37
38fn default_passthrough() -> bool {
39    true
40}
41
42impl Default for WsProxyConfig {
43    fn default() -> Self {
44        Self {
45            upstream_url: std::env::var("MOCKFORGE_WS_PROXY_UPSTREAM_URL")
46                .unwrap_or_else(|_| "ws://localhost:9080".to_string()),
47            enabled: false,
48            rules: Vec::new(),
49            passthrough_by_default: true,
50        }
51    }
52}
53
54impl WsProxyConfig {
55    pub fn new(upstream_url: String) -> Self {
57        Self {
58            upstream_url,
59            ..Default::default()
60        }
61    }
62
63    pub fn should_proxy(&self, path: &str) -> bool {
65        if !self.enabled {
66            return false;
67        }
68
69        for rule in &self.rules {
71            if rule.enabled && self.matches_path(&rule.pattern, path) {
72                return true;
73            }
74        }
75
76        self.passthrough_by_default
78    }
79
80    pub fn get_upstream_url(&self, path: &str) -> String {
82        for rule in &self.rules {
84            if rule.enabled && self.matches_path(&rule.pattern, path) {
85                return rule.upstream_url.clone();
86            }
87        }
88
89        self.upstream_url.clone()
91    }
92
93    fn matches_path(&self, pattern: &str, path: &str) -> bool {
95        if pattern == path {
96            return true;
97        }
98
99        if pattern.contains('*') {
101            let pattern_parts: Vec<&str> = pattern.split('/').collect();
102            let path_parts: Vec<&str> = path.split('/').collect();
103
104            if pattern_parts.len() != path_parts.len() {
105                return false;
106            }
107
108            for (pattern_part, path_part) in pattern_parts.iter().zip(path_parts.iter()) {
109                if *pattern_part != "*" && *pattern_part != *path_part {
110                    return false;
111                }
112            }
113            return true;
114        }
115
116        false
117    }
118}
119
120fn axum_to_tungstenite(msg: AxumMessage) -> TungsteniteMessage {
122    match msg {
123        AxumMessage::Text(text) => TungsteniteMessage::Text(text.to_string().into()),
124        AxumMessage::Binary(data) => TungsteniteMessage::Binary(data),
125        AxumMessage::Ping(data) => TungsteniteMessage::Ping(data),
126        AxumMessage::Pong(data) => TungsteniteMessage::Pong(data),
127        AxumMessage::Close(frame) => TungsteniteMessage::Close(frame.map(|f| {
128            tokio_tungstenite::tungstenite::protocol::CloseFrame {
129                code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::from(
130                    f.code,
131                ),
132                reason: f.reason.to_string().into(),
133            }
134        })),
135    }
136}
137
138fn tungstenite_to_axum(msg: TungsteniteMessage) -> AxumMessage {
140    match msg {
141        TungsteniteMessage::Text(text) => AxumMessage::Text(text.to_string().into()),
142        TungsteniteMessage::Binary(data) => AxumMessage::Binary(data),
143        TungsteniteMessage::Ping(data) => AxumMessage::Ping(data),
144        TungsteniteMessage::Pong(data) => AxumMessage::Pong(data),
145        TungsteniteMessage::Close(frame) => {
146            AxumMessage::Close(frame.map(|f| axum::extract::ws::CloseFrame {
147                code: axum::extract::ws::CloseCode::from(u16::from(f.code)),
148                reason: f.reason.to_string().into(),
149            }))
150        }
151        TungsteniteMessage::Frame(_) => AxumMessage::Text("".to_string().into()), }
153}
154
155#[derive(Clone)]
157pub struct WsProxyHandler {
158    pub config: WsProxyConfig,
159}
160
161impl WsProxyHandler {
162    pub fn new(config: WsProxyConfig) -> Self {
164        Self { config }
165    }
166
167    pub async fn proxy_connection(&self, path: &str, client_socket: WebSocket) -> Result<()> {
169        if !self.config.should_proxy(path) {
170            return Err(Error::generic("WebSocket connection should not be proxied".to_string()));
171        }
172
173        let upstream_url = self.config.get_upstream_url(path);
175
176        let (upstream_socket, _) =
178            tokio_tungstenite::connect_async(&upstream_url).await.map_err(|e| {
179                Error::generic(format!("Failed to connect to upstream WebSocket: {}", e))
180            })?;
181
182        info!("Connected to upstream WebSocket at {}", upstream_url);
183
184        let (mut client_sink, mut client_stream) = client_socket.split();
186        let (mut upstream_sink, mut upstream_stream) = upstream_socket.split();
187
188        let forward_client_to_upstream = tokio::spawn(async move {
190            while let Some(msg) = client_stream.next().await {
191                match msg {
192                    Ok(message) => {
193                        let tungstenite_msg = axum_to_tungstenite(message);
194                        if let Err(e) = upstream_sink.send(tungstenite_msg).await {
195                            error!("Failed to send message to upstream: {}", e);
196                            break;
197                        }
198                    }
199                    Err(e) => {
200                        error!("Error receiving message from client: {}", e);
201                        break;
202                    }
203                }
204            }
205        });
206
207        let forward_upstream_to_client = tokio::spawn(async move {
209            while let Some(msg) = upstream_stream.next().await {
210                match msg {
211                    Ok(message) => {
212                        let axum_msg = tungstenite_to_axum(message);
213                        if let Err(e) = client_sink.send(axum_msg).await {
214                            error!("Failed to send message to client: {}", e);
215                            break;
216                        }
217                    }
218                    Err(e) => {
219                        error!("Error receiving message from upstream: {}", e);
220                        break;
221                    }
222                }
223            }
224        });
225
226        tokio::select! {
228            _ = forward_client_to_upstream => {
229                info!("Client to upstream forwarding completed");
230            }
231            _ = forward_upstream_to_client => {
232                info!("Upstream to client forwarding completed");
233            }
234        }
235
236        Ok(())
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_ws_proxy_config() {
246        let mut config = WsProxyConfig::new("ws://default.example.com".to_string());
247        config.enabled = true;
248        config.rules.push(WsProxyRule {
249            pattern: "/ws/users/*".to_string(),
250            upstream_url: "ws://users.example.com".to_string(),
251            enabled: true,
252        });
253        config.rules.push(WsProxyRule {
254            pattern: "/ws/orders/*".to_string(),
255            upstream_url: "ws://orders.example.com".to_string(),
256            enabled: true,
257        });
258
259        assert!(config.should_proxy("/ws/users/123"));
260        assert!(config.should_proxy("/ws/orders/456"));
261
262        assert_eq!(config.get_upstream_url("/ws/users/123"), "ws://users.example.com");
263        assert_eq!(config.get_upstream_url("/ws/orders/456"), "ws://orders.example.com");
264        assert_eq!(config.get_upstream_url("/ws/products"), "ws://default.example.com");
265    }
266
267    #[test]
268    fn test_ws_proxy_config_passthrough() {
269        let mut config = WsProxyConfig::new("ws://default.example.com".to_string());
270        config.passthrough_by_default = true;
271        config.enabled = true;
272
273        assert!(config.should_proxy("/ws/users"));
275        assert!(config.should_proxy("/ws/orders"));
276
277        config.passthrough_by_default = false;
279
280        assert!(!config.should_proxy("/ws/users"));
282        assert!(!config.should_proxy("/ws/orders"));
283    }
284}