mockforge_core/
ws_proxy.rs1use crate::{Error, Result};
4use axum::extract::ws::{Message as AxumMessage, WebSocket};
5use futures::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
8use tracing::*;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct WsProxyRule {
13 pub pattern: String,
15 pub upstream_url: String,
17 pub enabled: bool,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct WsProxyConfig {
26 pub upstream_url: String,
28 pub enabled: bool,
30 #[serde(default)]
32 pub rules: Vec<WsProxyRule>,
33 #[serde(default = "default_passthrough")]
35 pub passthrough_by_default: bool,
36}
37
38fn default_passthrough() -> bool {
39 true
40}
41
42impl Default for WsProxyConfig {
43 fn default() -> Self {
44 Self {
45 upstream_url: std::env::var("MOCKFORGE_WS_PROXY_UPSTREAM_URL")
46 .unwrap_or_else(|_| "ws://localhost:9080".to_string()),
47 enabled: false,
48 rules: Vec::new(),
49 passthrough_by_default: true,
50 }
51 }
52}
53
54impl WsProxyConfig {
55 pub fn new(upstream_url: String) -> Self {
57 Self {
58 upstream_url,
59 ..Default::default()
60 }
61 }
62
63 pub fn should_proxy(&self, path: &str) -> bool {
65 if !self.enabled {
66 return false;
67 }
68
69 for rule in &self.rules {
71 if rule.enabled && self.matches_path(&rule.pattern, path) {
72 return true;
73 }
74 }
75
76 self.passthrough_by_default
78 }
79
80 pub fn get_upstream_url(&self, path: &str) -> String {
82 for rule in &self.rules {
84 if rule.enabled && self.matches_path(&rule.pattern, path) {
85 return rule.upstream_url.clone();
86 }
87 }
88
89 self.upstream_url.clone()
91 }
92
93 fn matches_path(&self, pattern: &str, path: &str) -> bool {
95 if pattern == path {
96 return true;
97 }
98
99 if pattern.contains('*') {
101 let pattern_parts: Vec<&str> = pattern.split('/').collect();
102 let path_parts: Vec<&str> = path.split('/').collect();
103
104 if pattern_parts.len() != path_parts.len() {
105 return false;
106 }
107
108 for (pattern_part, path_part) in pattern_parts.iter().zip(path_parts.iter()) {
109 if *pattern_part != "*" && *pattern_part != *path_part {
110 return false;
111 }
112 }
113 return true;
114 }
115
116 false
117 }
118}
119
120fn axum_to_tungstenite(msg: AxumMessage) -> TungsteniteMessage {
122 match msg {
123 AxumMessage::Text(text) => TungsteniteMessage::Text(text.to_string().into()),
124 AxumMessage::Binary(data) => TungsteniteMessage::Binary(data),
125 AxumMessage::Ping(data) => TungsteniteMessage::Ping(data),
126 AxumMessage::Pong(data) => TungsteniteMessage::Pong(data),
127 AxumMessage::Close(frame) => TungsteniteMessage::Close(frame.map(|f| {
128 tokio_tungstenite::tungstenite::protocol::CloseFrame {
129 code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::from(
130 f.code,
131 ),
132 reason: f.reason.to_string().into(),
133 }
134 })),
135 }
136}
137
138fn tungstenite_to_axum(msg: TungsteniteMessage) -> AxumMessage {
140 match msg {
141 TungsteniteMessage::Text(text) => AxumMessage::Text(text.to_string().into()),
142 TungsteniteMessage::Binary(data) => AxumMessage::Binary(data),
143 TungsteniteMessage::Ping(data) => AxumMessage::Ping(data),
144 TungsteniteMessage::Pong(data) => AxumMessage::Pong(data),
145 TungsteniteMessage::Close(frame) => {
146 AxumMessage::Close(frame.map(|f| axum::extract::ws::CloseFrame {
147 code: axum::extract::ws::CloseCode::from(u16::from(f.code)),
148 reason: f.reason.to_string().into(),
149 }))
150 }
151 TungsteniteMessage::Frame(_) => AxumMessage::Text("".to_string().into()), }
153}
154
155#[derive(Clone)]
157pub struct WsProxyHandler {
158 pub config: WsProxyConfig,
159}
160
161impl WsProxyHandler {
162 pub fn new(config: WsProxyConfig) -> Self {
164 Self { config }
165 }
166
167 pub async fn proxy_connection(&self, path: &str, client_socket: WebSocket) -> Result<()> {
169 if !self.config.should_proxy(path) {
170 return Err(Error::generic("WebSocket connection should not be proxied".to_string()));
171 }
172
173 let upstream_url = self.config.get_upstream_url(path);
175
176 let (upstream_socket, _) =
178 tokio_tungstenite::connect_async(&upstream_url).await.map_err(|e| {
179 Error::generic(format!("Failed to connect to upstream WebSocket: {}", e))
180 })?;
181
182 info!("Connected to upstream WebSocket at {}", upstream_url);
183
184 let (mut client_sink, mut client_stream) = client_socket.split();
186 let (mut upstream_sink, mut upstream_stream) = upstream_socket.split();
187
188 let forward_client_to_upstream = tokio::spawn(async move {
190 while let Some(msg) = client_stream.next().await {
191 match msg {
192 Ok(message) => {
193 let tungstenite_msg = axum_to_tungstenite(message);
194 if let Err(e) = upstream_sink.send(tungstenite_msg).await {
195 error!("Failed to send message to upstream: {}", e);
196 break;
197 }
198 }
199 Err(e) => {
200 error!("Error receiving message from client: {}", e);
201 break;
202 }
203 }
204 }
205 });
206
207 let forward_upstream_to_client = tokio::spawn(async move {
209 while let Some(msg) = upstream_stream.next().await {
210 match msg {
211 Ok(message) => {
212 let axum_msg = tungstenite_to_axum(message);
213 if let Err(e) = client_sink.send(axum_msg).await {
214 error!("Failed to send message to client: {}", e);
215 break;
216 }
217 }
218 Err(e) => {
219 error!("Error receiving message from upstream: {}", e);
220 break;
221 }
222 }
223 }
224 });
225
226 tokio::select! {
228 _ = forward_client_to_upstream => {
229 info!("Client to upstream forwarding completed");
230 }
231 _ = forward_upstream_to_client => {
232 info!("Upstream to client forwarding completed");
233 }
234 }
235
236 Ok(())
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_ws_proxy_config() {
246 let mut config = WsProxyConfig::new("ws://default.example.com".to_string());
247 config.enabled = true;
248 config.rules.push(WsProxyRule {
249 pattern: "/ws/users/*".to_string(),
250 upstream_url: "ws://users.example.com".to_string(),
251 enabled: true,
252 });
253 config.rules.push(WsProxyRule {
254 pattern: "/ws/orders/*".to_string(),
255 upstream_url: "ws://orders.example.com".to_string(),
256 enabled: true,
257 });
258
259 assert!(config.should_proxy("/ws/users/123"));
260 assert!(config.should_proxy("/ws/orders/456"));
261
262 assert_eq!(config.get_upstream_url("/ws/users/123"), "ws://users.example.com");
263 assert_eq!(config.get_upstream_url("/ws/orders/456"), "ws://orders.example.com");
264 assert_eq!(config.get_upstream_url("/ws/products"), "ws://default.example.com");
265 }
266
267 #[test]
268 fn test_ws_proxy_config_passthrough() {
269 let mut config = WsProxyConfig::new("ws://default.example.com".to_string());
270 config.passthrough_by_default = true;
271 config.enabled = true;
272
273 assert!(config.should_proxy("/ws/users"));
275 assert!(config.should_proxy("/ws/orders"));
276
277 config.passthrough_by_default = false;
279
280 assert!(!config.should_proxy("/ws/users"));
282 assert!(!config.should_proxy("/ws/orders"));
283 }
284}