1use std::{
2 cell::{Cell, RefCell},
3 format,
4 future::Future,
5 rc::Rc,
6 sync::atomic::{AtomicU32, Ordering},
7 time::Duration,
8};
9
10use actix::{prelude::*, Addr};
11use anyhow::Error as AnyError;
12use fastwebsockets::{
13 handshake, CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError,
14 WebSocketWrite,
15};
16use futures_intrusive::sync::LocalManualResetEvent;
17use hyper::{
18 header::{CONNECTION, UPGRADE},
19 upgrade::Upgraded,
20 Body, Request as HyperRequest,
21};
22use maxwell_protocol::{self, HandleError, ProtocolMsg, *};
23use tokio::{
24 io::{split as tokio_split, ReadHalf, WriteHalf},
25 net::TcpStream,
26 task::spawn as tokio_spawn,
27 time::{sleep, timeout},
28};
29
30use super::*;
31use crate::arbiter_pool::ArbiterPool;
32
33static ID_SEED: AtomicU32 = AtomicU32::new(0);
34
35struct SpawnExecutor;
37impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
38where
39 Fut: Future + Send + 'static,
40 Fut::Output: Send + 'static,
41{
42 fn execute(&self, fut: Fut) {
43 tokio_spawn(fut);
44 }
45}
46
47type Sink = WebSocketWrite<WriteHalf<Upgraded>>;
48type Stream = FragmentCollectorRead<ReadHalf<Upgraded>>;
49
50pub trait EventHandler: Send + Sync + Unpin + Sized + 'static {
51 #[inline(always)]
52 fn on_msg(&self, _msg: ProtocolMsg) {}
53 #[inline(always)]
54 fn on_connected(&self, _addr: Addr<CallbackStyleConnection<Self>>) {}
55 #[inline(always)]
56 fn on_disconnected(&self, _addr: Addr<CallbackStyleConnection<Self>>) {}
57 #[inline(always)]
58 fn on_stopped(&self, _addr: Addr<CallbackStyleConnection<Self>>) {}
59}
60
61struct CallbackStyleConnectionInner<EH: EventHandler> {
62 id: u32,
63 addr: RefCell<Option<Addr<CallbackStyleConnection<EH>>>>,
64 endpoint: String,
65 options: ConnectionOptions,
66 sink: RefCell<Option<Sink>>,
67 stream: RefCell<Option<Stream>>,
68 connected_event: LocalManualResetEvent,
69 disconnected_event: LocalManualResetEvent,
70 is_connected: Cell<bool>,
71 msg_ref: Cell<u32>,
72 event_handler: EH,
73 is_stopping: Cell<bool>,
74}
75
76impl<EH: EventHandler> CallbackStyleConnectionInner<EH> {
77 #[inline]
78 pub fn new(endpoint: String, options: ConnectionOptions, event_handler: EH) -> Self {
79 CallbackStyleConnectionInner {
80 id: ID_SEED.fetch_add(1, Ordering::Relaxed),
81 addr: RefCell::new(None),
82 endpoint,
83 options,
84 sink: RefCell::new(None),
85 stream: RefCell::new(None),
86 connected_event: LocalManualResetEvent::new(false),
87 disconnected_event: LocalManualResetEvent::new(true),
88 is_connected: Cell::new(false),
89 msg_ref: Cell::new(1),
90 event_handler,
91 is_stopping: Cell::new(false),
92 }
93 }
94
95 pub async fn connect_repeatedly(self: Rc<Self>) {
96 loop {
97 if self.is_stopping() {
98 break;
99 }
100
101 self.disconnected_event.wait().await;
102
103 self.close_sink().await.unwrap_or_else(|err| {
104 log::error!("Failed to close sink: actor: {}<{}>, err: {}", &self.endpoint, &self.id, err);
105 });
106
107 log::info!("Connecting: actor: {}<{}>", &self.endpoint, &self.id);
108 match self.connect().await {
109 Ok((sink, stream)) => {
110 log::info!("Connected: actor: {}<{}>", &self.endpoint, &self.id);
111 self.set_socket_pair(Some(sink), Some(stream));
112 self.toggle_to_connected();
113 }
114 Err(err) => {
115 log::error!("Failed to connect: actor: {}<{}>, err: {}", &self.endpoint, &self.id, err);
116 self.set_socket_pair(None, None);
117 self.toggle_to_disconnected();
118 sleep(Duration::from_millis(self.options.reconnect_delay as u64)).await;
119 }
120 }
121 }
122 }
123
124 #[inline]
125 pub async fn send(
126 self: Rc<Self>, mut msg: ProtocolMsg,
127 ) -> Result<ProtocolMsg, HandleError<ProtocolMsg>> {
128 let mut msg_ref = maxwell_protocol::get_ref(&msg);
129 if msg_ref == 0 {
130 msg_ref = self.next_msg_ref();
131 maxwell_protocol::set_ref(&mut msg, msg_ref);
132 } else {
133 self.try_set_msg_ref(msg_ref);
134 }
135
136 if !self.is_connected() {
137 for i in 0..3 {
138 if let Err(_) =
139 timeout(Duration::from_millis(i * 500 + 500), self.connected_event.wait()).await
140 {
141 continue;
142 } else {
143 break;
144 }
145 }
146 if !self.is_connected() {
147 let desc = format!("Timeout to send msg: actor: {}<{}>", &self.endpoint, &self.id);
148 log::error!("{:?}", desc);
149 return Err(HandleError::Any { code: 1, desc, msg });
150 }
151 }
152
153 if let Err(err) = self
154 .sink
155 .borrow_mut()
156 .as_mut()
157 .unwrap()
158 .write_frame(Frame::binary(encode(&msg).as_ref().into()))
159 .await
160 {
161 let desc =
162 format!("Failed to send msg: actor: {}<{}>, err: {}", &self.endpoint, &self.id, &err);
163 log::error!("{:?}", desc);
164 log::warn!(
165 "The connection maybe broken, try to reconnect: actor: {}<{}>",
166 &self.endpoint,
167 &self.id
168 );
169 self.toggle_to_disconnected();
170 return Err(HandleError::Any { code: 2, desc, msg });
171 }
172
173 Ok(ProtocolMsg::None)
174 }
175
176 pub async fn receive_repeatedly(self: Rc<Self>) {
177 loop {
178 if self.is_stopping() {
179 break;
180 }
181
182 if !self.is_connected() {
183 self.connected_event.wait().await;
184 }
185
186 match self
187 .stream
188 .borrow_mut()
189 .as_mut()
190 .unwrap()
191 .read_frame(&mut move |_| async { Ok::<_, WebSocketError>(()) })
193 .await
194 {
195 Ok(frame) => match frame.opcode {
196 OpCode::Ping => {}
197 OpCode::Pong => {}
198 OpCode::Binary => {
199 self.event_handler.on_msg(decode_bytes(&frame.payload).unwrap());
200 }
201 OpCode::Close => {
202 log::error!(
203 "Disconnected: actor: {}<{}>, reason: {}",
204 &self.endpoint,
205 &self.id,
206 Self::stringify(&frame.payload)
207 );
208 self.toggle_to_disconnected();
209 }
210 other => {
211 log::warn!("Received unknown msg: {:?}({:?})", &frame.payload, &other);
212 }
213 },
214 Err(err) => {
215 log::error!("Protocol error occured: err: {}", &err);
216 self.toggle_to_disconnected();
217 }
218 }
219 }
220 }
221
222 #[inline]
223 pub fn stop(&self) {
224 self.is_stopping.set(true);
225 self.notify_stopped_event();
226 }
227
228 #[inline]
229 fn notify_connected_event(&self) {
230 let addr = self.addr.borrow();
231 self.event_handler.on_connected(addr.as_ref().unwrap().clone());
232 }
233
234 #[inline]
235 fn notify_disconnected_event(&self) {
236 let addr = self.addr.borrow();
237 self.event_handler.on_disconnected(addr.as_ref().unwrap().clone());
238 }
239
240 #[inline]
241 fn notify_stopped_event(&self) {
242 let addr = self.addr.borrow();
243 self.event_handler.on_stopped(addr.as_ref().unwrap().clone());
244 }
245
246 #[inline]
247 pub fn next_msg_ref(&self) -> u32 {
248 let prev_msg_ref = self.msg_ref.get();
249 if prev_msg_ref < MAX_MSG_REF {
250 let curr_msg_ref = prev_msg_ref + 1;
251 self.msg_ref.set(curr_msg_ref);
252 curr_msg_ref
253 } else {
254 self.msg_ref.set(1);
255 1
256 }
257 }
258
259 #[inline]
260 fn try_set_msg_ref(&self, msg_ref: u32) {
261 if msg_ref > self.msg_ref.get() {
262 self.msg_ref.set(msg_ref);
263 }
264 }
265
266 #[inline]
267 fn stringify(payload: &Payload) -> String {
268 match Self::decode_error_payload(payload) {
269 Ok(code) => format!("{:?}({})", code, u16::from(code)),
270 Err(err) => format!("{:?}", err),
271 }
272 }
273
274 #[inline]
275 fn decode_error_payload(payload: &Payload) -> Result<CloseCode, WebSocketError> {
276 match payload.len() {
277 0 => Ok(CloseCode::Normal),
278 1 => return Err(WebSocketError::InvalidCloseFrame),
279 _ => {
280 let code = CloseCode::from(u16::from_be_bytes(payload[0..2].try_into().unwrap()));
281 if !code.is_allowed() {
282 return Err(WebSocketError::InvalidCloseCode);
283 }
284 Ok(code)
285 }
286 }
287 }
288
289 #[inline]
290 async fn connect(&self) -> Result<(Sink, Stream), AnyError> {
291 let stream = TcpStream::connect(&self.endpoint).await?;
292 let req = HyperRequest::builder()
293 .method("GET")
294 .uri("/$ws")
295 .header("Host", &self.endpoint)
296 .header(UPGRADE, "websocket")
297 .header(CONNECTION, "upgrade")
298 .header("CLIENT-ID", &format!("{}", self.id))
299 .header("Sec-WebSocket-Key", handshake::generate_key())
300 .header("Sec-WebSocket-Version", "13")
301 .body(Body::empty())?;
302
303 let (mut ws, _) = handshake::client(&SpawnExecutor, req, stream).await?;
304 ws.set_auto_close(false);
305 ws.set_auto_pong(false);
306 ws.set_max_message_size(self.options.max_frame_size as usize);
307 let (stream, sink) = ws.split(|s| tokio_split(s));
308 Ok((sink, FragmentCollectorRead::new(stream)))
309 }
310
311 #[inline]
312 fn set_socket_pair(&self, sink: Option<Sink>, stream: Option<Stream>) {
313 *self.sink.borrow_mut() = sink;
314 *self.stream.borrow_mut() = stream;
315 }
316
317 #[inline]
318 fn toggle_to_connected(&self) {
319 self.is_connected.set(true);
320 self.connected_event.set();
321 self.disconnected_event.reset();
322 self.notify_connected_event();
323 }
324
325 #[inline]
326 fn toggle_to_disconnected(&self) {
327 self.is_connected.set(false);
328 self.connected_event.reset();
329 self.disconnected_event.set();
330 self.notify_disconnected_event();
331 }
332
333 #[inline]
334 async fn close_sink(&self) -> Result<(), AnyError> {
335 if let Some(sink) = self.sink.try_borrow_mut()?.as_mut() {
336 Ok(sink.write_frame(Frame::close_raw(vec![].into())).await?)
337 } else {
338 Ok(())
339 }
340 }
341
342 #[inline]
343 fn is_connected(&self) -> bool {
344 self.is_connected.get()
345 }
346
347 #[inline]
348 fn is_stopping(&self) -> bool {
349 self.is_stopping.get()
350 }
351}
352
353pub struct CallbackStyleConnection<EH: EventHandler> {
354 inner: Rc<CallbackStyleConnectionInner<EH>>,
355}
356impl<EH: EventHandler> Connection for CallbackStyleConnection<EH> {}
357
358impl<EH: EventHandler> Actor for CallbackStyleConnection<EH> {
359 type Context = Context<Self>;
360
361 #[inline]
362 fn started(&mut self, ctx: &mut Self::Context) {
363 log::info!("Started: actor: {}<{}>", &self.inner.endpoint, &self.inner.id);
364 *self.inner.addr.borrow_mut() = Some(ctx.address());
365
366 ctx.set_mailbox_capacity(self.inner.options.mailbox_capacity as usize);
367
368 Box::pin(Rc::clone(&self.inner).connect_repeatedly().into_actor(self)).spawn(ctx);
369 Box::pin(Rc::clone(&self.inner).receive_repeatedly().into_actor(self)).spawn(ctx);
370 }
371
372 #[inline]
373 fn stopping(&mut self, _: &mut Self::Context) -> Running {
374 log::info!("Stopping: actor: {}<{}>", &self.inner.endpoint, &self.inner.id);
375 self.inner.stop();
376 Running::Stop
377 }
378
379 #[inline]
380 fn stopped(&mut self, _: &mut Self::Context) {
381 log::info!("Stopped: actor: {}<{}>", &self.inner.endpoint, &self.inner.id);
382 }
383}
384
385impl<EH: EventHandler> Handler<ProtocolMsg> for CallbackStyleConnection<EH> {
386 type Result = ResponseFuture<Result<ProtocolMsg, HandleError<ProtocolMsg>>>;
387
388 #[inline]
389 fn handle(&mut self, msg: ProtocolMsg, _ctx: &mut Context<Self>) -> Self::Result {
390 Box::pin(Rc::clone(&self.inner).send(msg))
391 }
392}
393
394impl<EH: EventHandler> Handler<StopMsg> for CallbackStyleConnection<EH> {
395 type Result = ();
396
397 #[inline]
398 fn handle(&mut self, _msg: StopMsg, ctx: &mut Context<Self>) -> Self::Result {
399 log::info!("Received StopMsg: actor: {}<{}>", &self.inner.endpoint, &self.inner.id);
400 ctx.stop();
401 }
402}
403
404impl<EH: EventHandler> Handler<DumpInfoMsg> for CallbackStyleConnection<EH> {
405 type Result = ();
406
407 #[inline]
408 fn handle(&mut self, _msg: DumpInfoMsg, _ctx: &mut Context<Self>) -> Self::Result {
409 log::info!("Connection info: id: {:?}, endpoint: {:?}", self.inner.id, self.inner.endpoint);
410 }
411}
412
413impl<EH: EventHandler> CallbackStyleConnection<EH> {
414 #[inline]
415 pub fn new(endpoint: String, options: ConnectionOptions, event_handler: EH) -> Self {
416 CallbackStyleConnection {
417 inner: Rc::new(CallbackStyleConnectionInner::new(endpoint, options, event_handler)),
418 }
419 }
420
421 #[inline]
422 pub fn start3(endpoint: String, options: ConnectionOptions, event_handler: EH) -> Addr<Self> {
423 CallbackStyleConnection::start_in_arbiter(
424 &ArbiterPool::singleton().fetch_arbiter(),
425 move |_ctx| CallbackStyleConnection::new(endpoint, options, event_handler),
426 )
427 }
428
429 #[inline]
430 pub fn start4(
431 endpoint: String, options: ConnectionOptions, arbiter: ArbiterHandle, event_handler: EH,
432 ) -> Addr<Self> {
433 CallbackStyleConnection::start_in_arbiter(&arbiter, move |_ctx| {
434 CallbackStyleConnection::new(endpoint, options, event_handler)
435 })
436 }
437
438 #[inline]
439 pub async fn stop(addr: Addr<Self>) -> Result<(), HandleError<StopMsg>> {
440 match addr.send(StopMsg).await {
441 Ok(ok) => Ok(ok),
442 Err(err) => match err {
443 MailboxError::Closed => Err(HandleError::MailboxClosed),
444 MailboxError::Timeout => Err(HandleError::Timeout),
445 },
446 }
447 }
448
449 pub fn dump_info(addr: Addr<Self>) {
450 addr.do_send(DumpInfoMsg);
451 }
452}
453
454#[cfg(test)]
458mod tests {
459 use std::time::Duration;
460
461 use actix::prelude::*;
462 use futures_util::future::join_all;
463 use maxwell_protocol::IntoEnum;
464
465 use crate::connection::*;
466
467 struct EventHandler;
468 impl super::EventHandler for EventHandler {
469 fn on_msg(&self, msg: maxwell_protocol::ProtocolMsg) {
470 println!("Received msg: {:?}", msg);
471 }
472 }
473
474 #[actix::test]
475 async fn test_send_msg() {
476 let conn = CallbackStyleConnection::<EventHandler>::new(
477 String::from("localhost:8081"),
478 ConnectionOptions::default(),
479 EventHandler,
480 )
481 .start();
482 for _ in 1..2 {
483 let msg = maxwell_protocol::PingReq { r#ref: 0 }.into_enum();
484 let res = conn.send(msg).timeout_ext(Duration::from_millis(3000)).await;
485 println!("received result: {:?}", res);
486 }
487 }
488
489 #[actix::test]
490 async fn test_concurrent() {
491 let conn = CallbackStyleConnection::<EventHandler>::new(
492 String::from("localhost:8081"),
493 ConnectionOptions::default(),
494 EventHandler,
495 )
496 .start();
497
498 let threads: Vec<_> = (0..16_u8)
500 .map(|thread_id| {
501 let conn = conn.clone();
502 tokio::spawn(async move {
503 println!("Thread {} started.", thread_id);
504 for _ in 0..10000 {
505 let req = maxwell_protocol::PingReq { r#ref: 0 }.into_enum();
506 let res = conn.send(req).timeout_ext(Duration::from_millis(3000)).await;
507 println!("received result: res: {:?}, thread_id {:?}", res, thread_id);
508 }
509 })
510 })
511 .collect();
512
513 join_all(threads).await;
514 }
515}