maxwell_utils/connection/
callback_style_connection.rs

1use std::{
2  cell::{Cell, RefCell},
3  format,
4  future::Future,
5  rc::Rc,
6  sync::atomic::{AtomicU32, Ordering},
7  time::Duration,
8};
9
10use actix::{prelude::*, Addr};
11use anyhow::Error as AnyError;
12use fastwebsockets::{
13  handshake, CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError,
14  WebSocketWrite,
15};
16use futures_intrusive::sync::LocalManualResetEvent;
17use hyper::{
18  header::{CONNECTION, UPGRADE},
19  upgrade::Upgraded,
20  Body, Request as HyperRequest,
21};
22use maxwell_protocol::{self, HandleError, ProtocolMsg, *};
23use tokio::{
24  io::{split as tokio_split, ReadHalf, WriteHalf},
25  net::TcpStream,
26  task::spawn as tokio_spawn,
27  time::{sleep, timeout},
28};
29
30use super::*;
31use crate::arbiter_pool::ArbiterPool;
32
33static ID_SEED: AtomicU32 = AtomicU32::new(0);
34
35// Tie hyper's executor to tokio runtime
36struct SpawnExecutor;
37impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
38where
39  Fut: Future + Send + 'static,
40  Fut::Output: Send + 'static,
41{
42  fn execute(&self, fut: Fut) {
43    tokio_spawn(fut);
44  }
45}
46
47type Sink = WebSocketWrite<WriteHalf<Upgraded>>;
48type Stream = FragmentCollectorRead<ReadHalf<Upgraded>>;
49
50pub trait EventHandler: Send + Sync + Unpin + Sized + 'static {
51  #[inline(always)]
52  fn on_msg(&self, _msg: ProtocolMsg) {}
53  #[inline(always)]
54  fn on_connected(&self, _addr: Addr<CallbackStyleConnection<Self>>) {}
55  #[inline(always)]
56  fn on_disconnected(&self, _addr: Addr<CallbackStyleConnection<Self>>) {}
57  #[inline(always)]
58  fn on_stopped(&self, _addr: Addr<CallbackStyleConnection<Self>>) {}
59}
60
61struct CallbackStyleConnectionInner<EH: EventHandler> {
62  id: u32,
63  addr: RefCell<Option<Addr<CallbackStyleConnection<EH>>>>,
64  endpoint: String,
65  options: ConnectionOptions,
66  sink: RefCell<Option<Sink>>,
67  stream: RefCell<Option<Stream>>,
68  connected_event: LocalManualResetEvent,
69  disconnected_event: LocalManualResetEvent,
70  is_connected: Cell<bool>,
71  msg_ref: Cell<u32>,
72  event_handler: EH,
73  is_stopping: Cell<bool>,
74}
75
76impl<EH: EventHandler> CallbackStyleConnectionInner<EH> {
77  #[inline]
78  pub fn new(endpoint: String, options: ConnectionOptions, event_handler: EH) -> Self {
79    CallbackStyleConnectionInner {
80      id: ID_SEED.fetch_add(1, Ordering::Relaxed),
81      addr: RefCell::new(None),
82      endpoint,
83      options,
84      sink: RefCell::new(None),
85      stream: RefCell::new(None),
86      connected_event: LocalManualResetEvent::new(false),
87      disconnected_event: LocalManualResetEvent::new(true),
88      is_connected: Cell::new(false),
89      msg_ref: Cell::new(1),
90      event_handler,
91      is_stopping: Cell::new(false),
92    }
93  }
94
95  pub async fn connect_repeatedly(self: Rc<Self>) {
96    loop {
97      if self.is_stopping() {
98        break;
99      }
100
101      self.disconnected_event.wait().await;
102
103      self.close_sink().await.unwrap_or_else(|err| {
104        log::error!("Failed to close sink: actor: {}<{}>, err: {}", &self.endpoint, &self.id, err);
105      });
106
107      log::info!("Connecting: actor: {}<{}>", &self.endpoint, &self.id);
108      match self.connect().await {
109        Ok((sink, stream)) => {
110          log::info!("Connected: actor: {}<{}>", &self.endpoint, &self.id);
111          self.set_socket_pair(Some(sink), Some(stream));
112          self.toggle_to_connected();
113        }
114        Err(err) => {
115          log::error!("Failed to connect: actor: {}<{}>, err: {}", &self.endpoint, &self.id, err);
116          self.set_socket_pair(None, None);
117          self.toggle_to_disconnected();
118          sleep(Duration::from_millis(self.options.reconnect_delay as u64)).await;
119        }
120      }
121    }
122  }
123
124  #[inline]
125  pub async fn send(
126    self: Rc<Self>, mut msg: ProtocolMsg,
127  ) -> Result<ProtocolMsg, HandleError<ProtocolMsg>> {
128    let mut msg_ref = maxwell_protocol::get_ref(&msg);
129    if msg_ref == 0 {
130      msg_ref = self.next_msg_ref();
131      maxwell_protocol::set_ref(&mut msg, msg_ref);
132    } else {
133      self.try_set_msg_ref(msg_ref);
134    }
135
136    if !self.is_connected() {
137      for i in 0..3 {
138        if let Err(_) =
139          timeout(Duration::from_millis(i * 500 + 500), self.connected_event.wait()).await
140        {
141          continue;
142        } else {
143          break;
144        }
145      }
146      if !self.is_connected() {
147        let desc = format!("Timeout to send msg: actor: {}<{}>", &self.endpoint, &self.id);
148        log::error!("{:?}", desc);
149        return Err(HandleError::Any { code: 1, desc, msg });
150      }
151    }
152
153    if let Err(err) = self
154      .sink
155      .borrow_mut()
156      .as_mut()
157      .unwrap()
158      .write_frame(Frame::binary(encode(&msg).as_ref().into()))
159      .await
160    {
161      let desc =
162        format!("Failed to send msg: actor: {}<{}>, err: {}", &self.endpoint, &self.id, &err);
163      log::error!("{:?}", desc);
164      log::warn!(
165        "The connection maybe broken, try to reconnect: actor: {}<{}>",
166        &self.endpoint,
167        &self.id
168      );
169      self.toggle_to_disconnected();
170      return Err(HandleError::Any { code: 2, desc, msg });
171    }
172
173    Ok(ProtocolMsg::None)
174  }
175
176  pub async fn receive_repeatedly(self: Rc<Self>) {
177    loop {
178      if self.is_stopping() {
179        break;
180      }
181
182      if !self.is_connected() {
183        self.connected_event.wait().await;
184      }
185
186      match self
187        .stream
188        .borrow_mut()
189        .as_mut()
190        .unwrap()
191        // send_fn is empty because we do not create obligated writes here.
192        .read_frame(&mut move |_| async { Ok::<_, WebSocketError>(()) })
193        .await
194      {
195        Ok(frame) => match frame.opcode {
196          OpCode::Ping => {}
197          OpCode::Pong => {}
198          OpCode::Binary => {
199            self.event_handler.on_msg(decode_bytes(&frame.payload).unwrap());
200          }
201          OpCode::Close => {
202            log::error!(
203              "Disconnected: actor: {}<{}>, reason: {}",
204              &self.endpoint,
205              &self.id,
206              Self::stringify(&frame.payload)
207            );
208            self.toggle_to_disconnected();
209          }
210          other => {
211            log::warn!("Received unknown msg: {:?}({:?})", &frame.payload, &other);
212          }
213        },
214        Err(err) => {
215          log::error!("Protocol error occured: err: {}", &err);
216          self.toggle_to_disconnected();
217        }
218      }
219    }
220  }
221
222  #[inline]
223  pub fn stop(&self) {
224    self.is_stopping.set(true);
225    self.notify_stopped_event();
226  }
227
228  #[inline]
229  fn notify_connected_event(&self) {
230    let addr = self.addr.borrow();
231    self.event_handler.on_connected(addr.as_ref().unwrap().clone());
232  }
233
234  #[inline]
235  fn notify_disconnected_event(&self) {
236    let addr = self.addr.borrow();
237    self.event_handler.on_disconnected(addr.as_ref().unwrap().clone());
238  }
239
240  #[inline]
241  fn notify_stopped_event(&self) {
242    let addr = self.addr.borrow();
243    self.event_handler.on_stopped(addr.as_ref().unwrap().clone());
244  }
245
246  #[inline]
247  pub fn next_msg_ref(&self) -> u32 {
248    let prev_msg_ref = self.msg_ref.get();
249    if prev_msg_ref < MAX_MSG_REF {
250      let curr_msg_ref = prev_msg_ref + 1;
251      self.msg_ref.set(curr_msg_ref);
252      curr_msg_ref
253    } else {
254      self.msg_ref.set(1);
255      1
256    }
257  }
258
259  #[inline]
260  fn try_set_msg_ref(&self, msg_ref: u32) {
261    if msg_ref > self.msg_ref.get() {
262      self.msg_ref.set(msg_ref);
263    }
264  }
265
266  #[inline]
267  fn stringify(payload: &Payload) -> String {
268    match Self::decode_error_payload(payload) {
269      Ok(code) => format!("{:?}({})", code, u16::from(code)),
270      Err(err) => format!("{:?}", err),
271    }
272  }
273
274  #[inline]
275  fn decode_error_payload(payload: &Payload) -> Result<CloseCode, WebSocketError> {
276    match payload.len() {
277      0 => Ok(CloseCode::Normal),
278      1 => return Err(WebSocketError::InvalidCloseFrame),
279      _ => {
280        let code = CloseCode::from(u16::from_be_bytes(payload[0..2].try_into().unwrap()));
281        if !code.is_allowed() {
282          return Err(WebSocketError::InvalidCloseCode);
283        }
284        Ok(code)
285      }
286    }
287  }
288
289  #[inline]
290  async fn connect(&self) -> Result<(Sink, Stream), AnyError> {
291    let stream = TcpStream::connect(&self.endpoint).await?;
292    let req = HyperRequest::builder()
293      .method("GET")
294      .uri("/$ws")
295      .header("Host", &self.endpoint)
296      .header(UPGRADE, "websocket")
297      .header(CONNECTION, "upgrade")
298      .header("CLIENT-ID", &format!("{}", self.id))
299      .header("Sec-WebSocket-Key", handshake::generate_key())
300      .header("Sec-WebSocket-Version", "13")
301      .body(Body::empty())?;
302
303    let (mut ws, _) = handshake::client(&SpawnExecutor, req, stream).await?;
304    ws.set_auto_close(false);
305    ws.set_auto_pong(false);
306    ws.set_max_message_size(self.options.max_frame_size as usize);
307    let (stream, sink) = ws.split(|s| tokio_split(s));
308    Ok((sink, FragmentCollectorRead::new(stream)))
309  }
310
311  #[inline]
312  fn set_socket_pair(&self, sink: Option<Sink>, stream: Option<Stream>) {
313    *self.sink.borrow_mut() = sink;
314    *self.stream.borrow_mut() = stream;
315  }
316
317  #[inline]
318  fn toggle_to_connected(&self) {
319    self.is_connected.set(true);
320    self.connected_event.set();
321    self.disconnected_event.reset();
322    self.notify_connected_event();
323  }
324
325  #[inline]
326  fn toggle_to_disconnected(&self) {
327    self.is_connected.set(false);
328    self.connected_event.reset();
329    self.disconnected_event.set();
330    self.notify_disconnected_event();
331  }
332
333  #[inline]
334  async fn close_sink(&self) -> Result<(), AnyError> {
335    if let Some(sink) = self.sink.try_borrow_mut()?.as_mut() {
336      Ok(sink.write_frame(Frame::close_raw(vec![].into())).await?)
337    } else {
338      Ok(())
339    }
340  }
341
342  #[inline]
343  fn is_connected(&self) -> bool {
344    self.is_connected.get()
345  }
346
347  #[inline]
348  fn is_stopping(&self) -> bool {
349    self.is_stopping.get()
350  }
351}
352
353pub struct CallbackStyleConnection<EH: EventHandler> {
354  inner: Rc<CallbackStyleConnectionInner<EH>>,
355}
356impl<EH: EventHandler> Connection for CallbackStyleConnection<EH> {}
357
358impl<EH: EventHandler> Actor for CallbackStyleConnection<EH> {
359  type Context = Context<Self>;
360
361  #[inline]
362  fn started(&mut self, ctx: &mut Self::Context) {
363    log::info!("Started: actor: {}<{}>", &self.inner.endpoint, &self.inner.id);
364    *self.inner.addr.borrow_mut() = Some(ctx.address());
365
366    ctx.set_mailbox_capacity(self.inner.options.mailbox_capacity as usize);
367
368    Box::pin(Rc::clone(&self.inner).connect_repeatedly().into_actor(self)).spawn(ctx);
369    Box::pin(Rc::clone(&self.inner).receive_repeatedly().into_actor(self)).spawn(ctx);
370  }
371
372  #[inline]
373  fn stopping(&mut self, _: &mut Self::Context) -> Running {
374    log::info!("Stopping: actor: {}<{}>", &self.inner.endpoint, &self.inner.id);
375    self.inner.stop();
376    Running::Stop
377  }
378
379  #[inline]
380  fn stopped(&mut self, _: &mut Self::Context) {
381    log::info!("Stopped: actor: {}<{}>", &self.inner.endpoint, &self.inner.id);
382  }
383}
384
385impl<EH: EventHandler> Handler<ProtocolMsg> for CallbackStyleConnection<EH> {
386  type Result = ResponseFuture<Result<ProtocolMsg, HandleError<ProtocolMsg>>>;
387
388  #[inline]
389  fn handle(&mut self, msg: ProtocolMsg, _ctx: &mut Context<Self>) -> Self::Result {
390    Box::pin(Rc::clone(&self.inner).send(msg))
391  }
392}
393
394impl<EH: EventHandler> Handler<StopMsg> for CallbackStyleConnection<EH> {
395  type Result = ();
396
397  #[inline]
398  fn handle(&mut self, _msg: StopMsg, ctx: &mut Context<Self>) -> Self::Result {
399    log::info!("Received StopMsg: actor: {}<{}>", &self.inner.endpoint, &self.inner.id);
400    ctx.stop();
401  }
402}
403
404impl<EH: EventHandler> Handler<DumpInfoMsg> for CallbackStyleConnection<EH> {
405  type Result = ();
406
407  #[inline]
408  fn handle(&mut self, _msg: DumpInfoMsg, _ctx: &mut Context<Self>) -> Self::Result {
409    log::info!("Connection info: id: {:?}, endpoint: {:?}", self.inner.id, self.inner.endpoint);
410  }
411}
412
413impl<EH: EventHandler> CallbackStyleConnection<EH> {
414  #[inline]
415  pub fn new(endpoint: String, options: ConnectionOptions, event_handler: EH) -> Self {
416    CallbackStyleConnection {
417      inner: Rc::new(CallbackStyleConnectionInner::new(endpoint, options, event_handler)),
418    }
419  }
420
421  #[inline]
422  pub fn start3(endpoint: String, options: ConnectionOptions, event_handler: EH) -> Addr<Self> {
423    CallbackStyleConnection::start_in_arbiter(
424      &ArbiterPool::singleton().fetch_arbiter(),
425      move |_ctx| CallbackStyleConnection::new(endpoint, options, event_handler),
426    )
427  }
428
429  #[inline]
430  pub fn start4(
431    endpoint: String, options: ConnectionOptions, arbiter: ArbiterHandle, event_handler: EH,
432  ) -> Addr<Self> {
433    CallbackStyleConnection::start_in_arbiter(&arbiter, move |_ctx| {
434      CallbackStyleConnection::new(endpoint, options, event_handler)
435    })
436  }
437
438  #[inline]
439  pub async fn stop(addr: Addr<Self>) -> Result<(), HandleError<StopMsg>> {
440    match addr.send(StopMsg).await {
441      Ok(ok) => Ok(ok),
442      Err(err) => match err {
443        MailboxError::Closed => Err(HandleError::MailboxClosed),
444        MailboxError::Timeout => Err(HandleError::Timeout),
445      },
446    }
447  }
448
449  pub fn dump_info(addr: Addr<Self>) {
450    addr.do_send(DumpInfoMsg);
451  }
452}
453
454////////////////////////////////////////////////////////////////////////////////
455/// test cases
456////////////////////////////////////////////////////////////////////////////////
457#[cfg(test)]
458mod tests {
459  use std::time::Duration;
460
461  use actix::prelude::*;
462  use futures_util::future::join_all;
463  use maxwell_protocol::IntoEnum;
464
465  use crate::connection::*;
466
467  struct EventHandler;
468  impl super::EventHandler for EventHandler {
469    fn on_msg(&self, msg: maxwell_protocol::ProtocolMsg) {
470      println!("Received msg: {:?}", msg);
471    }
472  }
473
474  #[actix::test]
475  async fn test_send_msg() {
476    let conn = CallbackStyleConnection::<EventHandler>::new(
477      String::from("localhost:8081"),
478      ConnectionOptions::default(),
479      EventHandler,
480    )
481    .start();
482    for _ in 1..2 {
483      let msg = maxwell_protocol::PingReq { r#ref: 0 }.into_enum();
484      let res = conn.send(msg).timeout_ext(Duration::from_millis(3000)).await;
485      println!("received result: {:?}", res);
486    }
487  }
488
489  #[actix::test]
490  async fn test_concurrent() {
491    let conn = CallbackStyleConnection::<EventHandler>::new(
492      String::from("localhost:8081"),
493      ConnectionOptions::default(),
494      EventHandler,
495    )
496    .start();
497
498    // Spawn n threads.
499    let threads: Vec<_> = (0..16_u8)
500      .map(|thread_id| {
501        let conn = conn.clone();
502        tokio::spawn(async move {
503          println!("Thread {} started.", thread_id);
504          for _ in 0..10000 {
505            let req = maxwell_protocol::PingReq { r#ref: 0 }.into_enum();
506            let res = conn.send(req).timeout_ext(Duration::from_millis(3000)).await;
507            println!("received result: res: {:?}, thread_id {:?}", res, thread_id);
508          }
509        })
510      })
511      .collect();
512
513    join_all(threads).await;
514  }
515}