mockforge_core/
ws_proxy.rs1use crate::{Error, Result};
4use axum::extract::ws::{Message as AxumMessage, WebSocket};
5use futures::{SinkExt, StreamExt};
6use serde::{Deserialize, Serialize};
7use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
8use tracing::*;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct WsProxyRule {
13 pub pattern: String,
15 pub upstream_url: String,
17 pub enabled: bool,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct WsProxyConfig {
26 pub upstream_url: String,
28 pub enabled: bool,
30 #[serde(default)]
32 pub rules: Vec<WsProxyRule>,
33 #[serde(default = "default_passthrough")]
35 pub passthrough_by_default: bool,
36}
37
38fn default_passthrough() -> bool {
39 true
40}
41
42impl Default for WsProxyConfig {
43 fn default() -> Self {
44 Self {
45 upstream_url: std::env::var("MOCKFORGE_WS_PROXY_UPSTREAM_URL")
46 .unwrap_or_else(|_| "ws://localhost:9080".to_string()),
47 enabled: false,
48 rules: Vec::new(),
49 passthrough_by_default: true,
50 }
51 }
52}
53
54impl WsProxyConfig {
55 pub fn new(upstream_url: String) -> Self {
57 Self {
58 upstream_url,
59 ..Default::default()
60 }
61 }
62
63 pub fn should_proxy(&self, path: &str) -> bool {
65 if !self.enabled {
66 return false;
67 }
68
69 for rule in &self.rules {
71 if rule.enabled && self.matches_path(&rule.pattern, path) {
72 return true;
73 }
74 }
75
76 self.passthrough_by_default
78 }
79
80 pub fn get_upstream_url(&self, path: &str) -> String {
82 for rule in &self.rules {
84 if rule.enabled && self.matches_path(&rule.pattern, path) {
85 return rule.upstream_url.clone();
86 }
87 }
88
89 self.upstream_url.clone()
91 }
92
93 fn matches_path(&self, pattern: &str, path: &str) -> bool {
95 if pattern == path {
96 return true;
97 }
98
99 if pattern.contains('*') {
101 let pattern_parts: Vec<&str> = pattern.split('/').collect();
102 let path_parts: Vec<&str> = path.split('/').collect();
103
104 if pattern_parts.len() != path_parts.len() {
105 return false;
106 }
107
108 for (pattern_part, path_part) in pattern_parts.iter().zip(path_parts.iter()) {
109 if *pattern_part != "*" && *pattern_part != *path_part {
110 return false;
111 }
112 }
113 return true;
114 }
115
116 false
117 }
118}
119
120fn axum_to_tungstenite(msg: AxumMessage) -> TungsteniteMessage {
122 match msg {
123 AxumMessage::Text(text) => TungsteniteMessage::Text(text.to_string().into()),
124 AxumMessage::Binary(data) => TungsteniteMessage::Binary(data),
125 AxumMessage::Ping(data) => TungsteniteMessage::Ping(data),
126 AxumMessage::Pong(data) => TungsteniteMessage::Pong(data),
127 AxumMessage::Close(frame) => TungsteniteMessage::Close(frame.map(|f| {
128 tokio_tungstenite::tungstenite::protocol::CloseFrame {
129 code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::from(
130 f.code,
131 ),
132 reason: f.reason.to_string().into(),
133 }
134 })),
135 }
136}
137
138fn tungstenite_to_axum(msg: TungsteniteMessage) -> AxumMessage {
140 match msg {
141 TungsteniteMessage::Text(text) => AxumMessage::Text(text.to_string().into()),
142 TungsteniteMessage::Binary(data) => AxumMessage::Binary(data),
143 TungsteniteMessage::Ping(data) => AxumMessage::Ping(data),
144 TungsteniteMessage::Pong(data) => AxumMessage::Pong(data),
145 TungsteniteMessage::Close(frame) => {
146 AxumMessage::Close(frame.map(|f| axum::extract::ws::CloseFrame {
147 code: axum::extract::ws::CloseCode::from(u16::from(f.code)),
148 reason: f.reason.to_string().into(),
149 }))
150 }
151 TungsteniteMessage::Frame(_) => AxumMessage::Text("".to_string().into()), }
153}
154
155#[derive(Clone)]
157pub struct WsProxyHandler {
158 pub config: WsProxyConfig,
160}
161
162impl WsProxyHandler {
163 pub fn new(config: WsProxyConfig) -> Self {
165 Self { config }
166 }
167
168 pub async fn proxy_connection(&self, path: &str, client_socket: WebSocket) -> Result<()> {
170 if !self.config.should_proxy(path) {
171 return Err(Error::generic("WebSocket connection should not be proxied".to_string()));
172 }
173
174 let upstream_url = self.config.get_upstream_url(path);
176
177 let (upstream_socket, _) =
179 tokio_tungstenite::connect_async(&upstream_url).await.map_err(|e| {
180 Error::generic(format!("Failed to connect to upstream WebSocket: {}", e))
181 })?;
182
183 info!("Connected to upstream WebSocket at {}", upstream_url);
184
185 let (mut client_sink, mut client_stream) = client_socket.split();
187 let (mut upstream_sink, mut upstream_stream) = upstream_socket.split();
188
189 let forward_client_to_upstream = tokio::spawn(async move {
191 while let Some(msg) = client_stream.next().await {
192 match msg {
193 Ok(message) => {
194 let tungstenite_msg = axum_to_tungstenite(message);
195 if let Err(e) = upstream_sink.send(tungstenite_msg).await {
196 error!("Failed to send message to upstream: {}", e);
197 break;
198 }
199 }
200 Err(e) => {
201 error!("Error receiving message from client: {}", e);
202 break;
203 }
204 }
205 }
206 });
207
208 let forward_upstream_to_client = tokio::spawn(async move {
210 while let Some(msg) = upstream_stream.next().await {
211 match msg {
212 Ok(message) => {
213 let axum_msg = tungstenite_to_axum(message);
214 if let Err(e) = client_sink.send(axum_msg).await {
215 error!("Failed to send message to client: {}", e);
216 break;
217 }
218 }
219 Err(e) => {
220 error!("Error receiving message from upstream: {}", e);
221 break;
222 }
223 }
224 }
225 });
226
227 tokio::select! {
229 _ = forward_client_to_upstream => {
230 info!("Client to upstream forwarding completed");
231 }
232 _ = forward_upstream_to_client => {
233 info!("Upstream to client forwarding completed");
234 }
235 }
236
237 Ok(())
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_ws_proxy_config() {
247 let mut config = WsProxyConfig::new("ws://default.example.com".to_string());
248 config.enabled = true;
249 config.rules.push(WsProxyRule {
250 pattern: "/ws/users/*".to_string(),
251 upstream_url: "ws://users.example.com".to_string(),
252 enabled: true,
253 });
254 config.rules.push(WsProxyRule {
255 pattern: "/ws/orders/*".to_string(),
256 upstream_url: "ws://orders.example.com".to_string(),
257 enabled: true,
258 });
259
260 assert!(config.should_proxy("/ws/users/123"));
261 assert!(config.should_proxy("/ws/orders/456"));
262
263 assert_eq!(config.get_upstream_url("/ws/users/123"), "ws://users.example.com");
264 assert_eq!(config.get_upstream_url("/ws/orders/456"), "ws://orders.example.com");
265 assert_eq!(config.get_upstream_url("/ws/products"), "ws://default.example.com");
266 }
267
268 #[test]
269 fn test_ws_proxy_config_passthrough() {
270 let mut config = WsProxyConfig::new("ws://default.example.com".to_string());
271 config.passthrough_by_default = true;
272 config.enabled = true;
273
274 assert!(config.should_proxy("/ws/users"));
276 assert!(config.should_proxy("/ws/orders"));
277
278 config.passthrough_by_default = false;
280
281 assert!(!config.should_proxy("/ws/users"));
283 assert!(!config.should_proxy("/ws/orders"));
284 }
285}