1use std::error::Error;
2
3use actix::{
4 io::{SinkWrite, WriteHandler},
5 Actor, ActorContext, AsyncContext, StreamHandler,
6};
7use actix_web::{
8 error::{InternalError, PayloadError},
9 http::StatusCode,
10 HttpRequest, HttpResponse,
11};
12use actix_web_actors::ws::{self, handshake, CloseReason, ProtocolError, WebsocketContext};
13use bytes::Bytes;
14use futures::{Sink, Stream, StreamExt};
15
16pub struct WebsocketProxy<S>
18where
19 S: Unpin + Sink<ws::Message>,
20{
21 send: SinkWrite<ws::Message, S>,
22}
23
24impl<S> WebsocketProxy<S>
25where
26 S: Unpin + Sink<ws::Message> + 'static,
27{
28 fn error<E>(&mut self, err: E, ctx: &mut <Self as Actor>::Context)
29 where
30 E: Error,
31 {
32 let reason = Some(CloseReason {
33 code: ws::CloseCode::Error,
34 description: Some(err.to_string()),
35 });
36
37 ctx.close(reason.clone());
38 let _ = self.send.write(ws::Message::Close(reason)); self.send.close();
40
41 ctx.stop();
42 }
43}
44
45pub async fn start<T>(
62 req: &HttpRequest,
63 target: String,
64 stream: T,
65) -> Result<HttpResponse, actix_web::Error>
66where
67 T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
68{
69 let mut res = handshake(req)?;
70
71 let (_, conn) = awc::Client::new()
72 .ws(target)
73 .connect()
74 .await
75 .map_err(|e| InternalError::new(e, StatusCode::BAD_GATEWAY))?;
76
77 let (send, recv) = conn.split();
78
79 let out = WebsocketContext::with_factory(stream, |ctx| {
80 ctx.add_stream(recv);
81 WebsocketProxy {
82 send: SinkWrite::new(send, ctx),
83 }
84 });
85
86 Ok(res.streaming(out))
87}
88
89impl<S> WriteHandler<ProtocolError> for WebsocketProxy<S>
90where
91 S: Unpin + 'static + Sink<ws::Message>,
92{
93 fn error(&mut self, err: ProtocolError, ctx: &mut Self::Context) -> actix::Running {
94 self.error(err, ctx);
95 actix::Running::Stop
96 }
97}
98
99impl<S> Actor for WebsocketProxy<S>
100where
101 S: Unpin + 'static + Sink<ws::Message>,
102{
103 type Context = WebsocketContext<Self>;
104}
105
106impl<S> StreamHandler<Result<ws::Frame, ProtocolError>> for WebsocketProxy<S>
108where
109 S: Unpin + Sink<ws::Message> + 'static,
110{
111 fn handle(&mut self, item: Result<ws::Frame, ProtocolError>, ctx: &mut Self::Context) {
112 let frame = match item {
113 Ok(frame) => frame,
114 Err(err) => return self.error(err, ctx),
115 };
116 let msg = match frame {
117 ws::Frame::Text(t) => match t.try_into() {
118 Ok(t) => ws::Message::Text(t),
119 Err(e) => {
120 self.error(e, ctx);
121 return;
122 }
123 },
124 ws::Frame::Binary(b) => ws::Message::Binary(b),
125 ws::Frame::Continuation(c) => ws::Message::Continuation(c),
126 ws::Frame::Ping(p) => ws::Message::Ping(p),
127 ws::Frame::Pong(p) => ws::Message::Pong(p),
128 ws::Frame::Close(r) => ws::Message::Close(r),
129 };
130
131 ctx.write_raw(msg)
132 }
133}
134
135impl<S> StreamHandler<Result<ws::Message, ProtocolError>> for WebsocketProxy<S>
137where
138 S: Unpin + Sink<ws::Message> + 'static,
139{
140 fn handle(&mut self, item: Result<ws::Message, ProtocolError>, ctx: &mut Self::Context) {
141 let msg = match item {
142 Ok(msg) => msg,
143 Err(err) => return self.error(err, ctx),
144 };
145
146 let _ = self.send.write(msg);
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use actix::{Actor, StreamHandler};
154 use actix_web::{
155 get,
156 web::{self, Path},
157 App, Error, HttpRequest, HttpResponse, HttpServer,
158 };
159 use actix_web_actors::ws;
160 use futures::{SinkExt, StreamExt};
161
162 #[derive(Debug)]
163 struct TestActor;
164
165 impl Actor for TestActor {
166 type Context = ws::WebsocketContext<Self>;
167 }
168
169 impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for TestActor {
170 fn handle(
171 &mut self,
172 item: Result<ws::Message, ws::ProtocolError>,
173 ctx: &mut Self::Context,
174 ) {
175 match item.unwrap() {
176 ws::Message::Text(txt) => ctx.text(txt),
177 ws::Message::Binary(bin) => ctx.binary(bin),
178 ws::Message::Ping(arg) => ctx.pong(&arg),
179 ws::Message::Close(reason) => ctx.close(reason),
180 _ => (),
181 }
182 }
183 }
184
185 #[get("/proxy/{port}")]
186 async fn proxy(
187 req: HttpRequest,
188 stream: web::Payload,
189 port: Path<u16>,
190 ) -> Result<HttpResponse, Error> {
191 crate::start(&req, format!("ws://127.0.0.1:{}", port), stream).await
192 }
193
194 #[get("/")]
195 async fn index(req: HttpRequest, stream: web::Payload) -> Result<HttpResponse, Error> {
196 ws::start(TestActor, &req, stream)
197 }
198
199 macro_rules! get_server {
200 ($factory:expr) => {{
201 let port = portpicker::pick_unused_port().expect("No ports free");
202 let server = HttpServer::new(|| App::new().service($factory))
203 .bind(("0.0.0.0", port))
204 .expect("couldn't start server")
205 .run();
206
207 (server, port)
208 }};
209 }
210
211 #[actix::test]
212 async fn with_proxy() {
213 let (srv, port) = get_server!(index);
214 actix::spawn(srv);
215 let (proxysrv, proxyport) = get_server!(proxy);
216 actix::spawn(proxysrv);
217
218 let client = awc::Client::new();
219 let (_resp, mut conn) = client
220 .ws(format!("ws://localhost:{}/proxy/{}", proxyport, port))
221 .connect()
222 .await
223 .unwrap();
224
225 conn.send(ws::Message::Text("echo.into".into()))
226 .await
227 .unwrap();
228
229 let resp = conn.next().await.unwrap().unwrap();
230 assert_eq!(resp, ws::Frame::Text("echo.into".into()));
231
232 let bytes = bytes::Bytes::from_static(&[0x11, 0x22, 0x33, 0x55]);
233
234 conn.send(awc::ws::Message::Binary(bytes.clone()))
235 .await
236 .unwrap();
237
238 let resp = conn.next().await.unwrap().unwrap();
239 assert_eq!(resp, ws::Frame::Binary(bytes.clone()));
240
241 conn.send(ws::Message::Ping(bytes.clone())).await.unwrap();
242
243 let resp = conn.next().await.unwrap().unwrap();
244 assert_eq!(resp, ws::Frame::Pong(bytes.clone()));
245 }
246}