actix_ws_proxy/
lib.rs

1use std::error::Error;
2
3use actix::{
4    io::{SinkWrite, WriteHandler},
5    Actor, ActorContext, AsyncContext, StreamHandler,
6};
7use actix_web::{
8    error::{InternalError, PayloadError},
9    http::StatusCode,
10    HttpRequest, HttpResponse,
11};
12use actix_web_actors::ws::{self, handshake, CloseReason, ProtocolError, WebsocketContext};
13use bytes::Bytes;
14use futures::{Sink, Stream, StreamExt};
15
16/// WebsocketProxy proxies an incoming websocket connection to another websocket, connected via awc.
17pub struct WebsocketProxy<S>
18where
19    S: Unpin + Sink<ws::Message>,
20{
21    send: SinkWrite<ws::Message, S>,
22}
23
24impl<S> WebsocketProxy<S>
25where
26    S: Unpin + Sink<ws::Message> + 'static,
27{
28    fn error<E>(&mut self, err: E, ctx: &mut <Self as Actor>::Context)
29    where
30        E: Error,
31    {
32        let reason = Some(CloseReason {
33            code: ws::CloseCode::Error,
34            description: Some(err.to_string()),
35        });
36
37        ctx.close(reason.clone());
38        let _ = self.send.write(ws::Message::Close(reason)); // if we can't send an error message, so it goes
39        self.send.close();
40
41        ctx.stop();
42    }
43}
44
45/// start a websocket proxy
46///
47/// `target` should be a URL of the form `ws://<host>` or `wss://<host>`
48/// see awc::Client::ws for more information
49/// req and stream are exactly like the arguments to actix_web_actors::ws::start
50/// ```
51/// # use actix_web::{get, Error, HttpRequest, HttpResponse, web};
52/// #[get("/proxy/{port}")]
53/// async fn proxy(
54///     req: HttpRequest,
55///     stream: web::Payload,
56///     port: web::Path<u16>,
57/// ) -> Result<HttpResponse, Error> {
58///     actix_ws_proxy::start(&req, format!("ws://127.0.0.1:{}", port), stream).await
59/// }
60/// ```
61pub async fn start<T>(
62    req: &HttpRequest,
63    target: String,
64    stream: T,
65) -> Result<HttpResponse, actix_web::Error>
66where
67    T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
68{
69    let mut res = handshake(req)?;
70
71    let (_, conn) = awc::Client::new()
72        .ws(target)
73        .connect()
74        .await
75        .map_err(|e| InternalError::new(e, StatusCode::BAD_GATEWAY))?;
76
77    let (send, recv) = conn.split();
78
79    let out = WebsocketContext::with_factory(stream, |ctx| {
80        ctx.add_stream(recv);
81        WebsocketProxy {
82            send: SinkWrite::new(send, ctx),
83        }
84    });
85
86    Ok(res.streaming(out))
87}
88
89impl<S> WriteHandler<ProtocolError> for WebsocketProxy<S>
90where
91    S: Unpin + 'static + Sink<ws::Message>,
92{
93    fn error(&mut self, err: ProtocolError, ctx: &mut Self::Context) -> actix::Running {
94        self.error(err, ctx);
95        actix::Running::Stop
96    }
97}
98
99impl<S> Actor for WebsocketProxy<S>
100where
101    S: Unpin + 'static + Sink<ws::Message>,
102{
103    type Context = WebsocketContext<Self>;
104}
105
106// This represents messages from upstream, so we send them downstream
107impl<S> StreamHandler<Result<ws::Frame, ProtocolError>> for WebsocketProxy<S>
108where
109    S: Unpin + Sink<ws::Message> + 'static,
110{
111    fn handle(&mut self, item: Result<ws::Frame, ProtocolError>, ctx: &mut Self::Context) {
112        let frame = match item {
113            Ok(frame) => frame,
114            Err(err) => return self.error(err, ctx),
115        };
116        let msg = match frame {
117            ws::Frame::Text(t) => match t.try_into() {
118                Ok(t) => ws::Message::Text(t),
119                Err(e) => {
120                    self.error(e, ctx);
121                    return;
122                }
123            },
124            ws::Frame::Binary(b) => ws::Message::Binary(b),
125            ws::Frame::Continuation(c) => ws::Message::Continuation(c),
126            ws::Frame::Ping(p) => ws::Message::Ping(p),
127            ws::Frame::Pong(p) => ws::Message::Pong(p),
128            ws::Frame::Close(r) => ws::Message::Close(r),
129        };
130
131        ctx.write_raw(msg)
132    }
133}
134
135// This represents messages from downstream, so they are sent upstream
136impl<S> StreamHandler<Result<ws::Message, ProtocolError>> for WebsocketProxy<S>
137where
138    S: Unpin + Sink<ws::Message> + 'static,
139{
140    fn handle(&mut self, item: Result<ws::Message, ProtocolError>, ctx: &mut Self::Context) {
141        let msg = match item {
142            Ok(msg) => msg,
143            Err(err) => return self.error(err, ctx),
144        };
145
146        // if this fails we're probably shutting down
147        let _ = self.send.write(msg);
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use actix::{Actor, StreamHandler};
154    use actix_web::{
155        get,
156        web::{self, Path},
157        App, Error, HttpRequest, HttpResponse, HttpServer,
158    };
159    use actix_web_actors::ws;
160    use futures::{SinkExt, StreamExt};
161
162    #[derive(Debug)]
163    struct TestActor;
164
165    impl Actor for TestActor {
166        type Context = ws::WebsocketContext<Self>;
167    }
168
169    impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for TestActor {
170        fn handle(
171            &mut self,
172            item: Result<ws::Message, ws::ProtocolError>,
173            ctx: &mut Self::Context,
174        ) {
175            match item.unwrap() {
176                ws::Message::Text(txt) => ctx.text(txt),
177                ws::Message::Binary(bin) => ctx.binary(bin),
178                ws::Message::Ping(arg) => ctx.pong(&arg),
179                ws::Message::Close(reason) => ctx.close(reason),
180                _ => (),
181            }
182        }
183    }
184
185    #[get("/proxy/{port}")]
186    async fn proxy(
187        req: HttpRequest,
188        stream: web::Payload,
189        port: Path<u16>,
190    ) -> Result<HttpResponse, Error> {
191        crate::start(&req, format!("ws://127.0.0.1:{}", port), stream).await
192    }
193
194    #[get("/")]
195    async fn index(req: HttpRequest, stream: web::Payload) -> Result<HttpResponse, Error> {
196        ws::start(TestActor, &req, stream)
197    }
198
199    macro_rules! get_server {
200        ($factory:expr) => {{
201            let port = portpicker::pick_unused_port().expect("No ports free");
202            let server = HttpServer::new(|| App::new().service($factory))
203                .bind(("0.0.0.0", port))
204                .expect("couldn't start server")
205                .run();
206
207            (server, port)
208        }};
209    }
210
211    #[actix::test]
212    async fn with_proxy() {
213        let (srv, port) = get_server!(index);
214        actix::spawn(srv);
215        let (proxysrv, proxyport) = get_server!(proxy);
216        actix::spawn(proxysrv);
217
218        let client = awc::Client::new();
219        let (_resp, mut conn) = client
220            .ws(format!("ws://localhost:{}/proxy/{}", proxyport, port))
221            .connect()
222            .await
223            .unwrap();
224
225        conn.send(ws::Message::Text("echo.into".into()))
226            .await
227            .unwrap();
228
229        let resp = conn.next().await.unwrap().unwrap();
230        assert_eq!(resp, ws::Frame::Text("echo.into".into()));
231
232        let bytes = bytes::Bytes::from_static(&[0x11, 0x22, 0x33, 0x55]);
233
234        conn.send(awc::ws::Message::Binary(bytes.clone()))
235            .await
236            .unwrap();
237
238        let resp = conn.next().await.unwrap().unwrap();
239        assert_eq!(resp, ws::Frame::Binary(bytes.clone()));
240
241        conn.send(ws::Message::Ping(bytes.clone())).await.unwrap();
242
243        let resp = conn.next().await.unwrap().unwrap();
244        assert_eq!(resp, ws::Frame::Pong(bytes.clone()));
245    }
246}