nakago_ws/
controller.rs

1use std::{any::Any, sync::Arc};
2
3use async_trait::async_trait;
4use axum::{
5    extract::{
6        ws::{Message, WebSocket},
7        WebSocketUpgrade,
8    },
9    response::{IntoResponse, Response},
10};
11use biscuit::Empty;
12use derive_new::new;
13use futures::{SinkExt, StreamExt, TryFutureExt};
14use http::StatusCode;
15use mockall::automock;
16use nakago::{provider, Inject, Provider, Tag};
17use nakago_derive::Provider;
18use tokio::sync::mpsc;
19use tokio_stream::wrappers::UnboundedReceiverStream;
20
21use crate::auth::Token;
22
23use super::Connections;
24
25/// A Handler handles Websocket messages
26#[automock]
27#[async_trait]
28pub trait Handler<Session, T = Empty>: Send + Sync + Any
29where
30    Session: Send + Sync + Any,
31    T: Default + Send + Sync + Any,
32{
33    /// Route the given message to the appropriate handler
34    async fn route(&self, conn_id: &str, msg: Message) -> anyhow::Result<()>;
35
36    /// Get the User from the Subject
37    async fn get_session(&self, token: Token<T>) -> anyhow::Result<Session>;
38}
39
40/// WebSocket Controller
41#[derive(Clone, new)]
42pub struct Controller<Session> {
43    connections: Arc<Connections<Session>>,
44    handler: Arc<Box<dyn Handler<Session>>>,
45}
46
47impl<Session: Default + Send + Sync + Clone + Any> Controller<Session> {
48    /// Handle requests for new WebSocket connections
49    pub async fn upgrade(
50        self: Arc<Self>,
51        token: Token,
52        ws: WebSocketUpgrade,
53    ) -> axum::response::Result<impl IntoResponse> {
54        // Retrieve the request Session
55        let session = self.handler.get_session(token).await.map_err(Error)?;
56
57        Ok(ws.on_upgrade(|socket| async move { self.handle(socket, session).await }))
58    }
59
60    /// Handle `WebSocket` connections by setting up a message handler that deserializes them and
61    /// determines how to handle
62    async fn handle(&self, socket: WebSocket, session: Session) {
63        let (mut ws_write, mut ws_read) = socket.split();
64
65        let (tx, rx) = mpsc::unbounded_channel();
66        let mut rx = UnboundedReceiverStream::new(rx);
67
68        tokio::task::spawn(async move {
69            while let Some(message) = rx.next().await {
70                ws_write
71                    .send(message)
72                    .unwrap_or_else(|err| {
73                        eprintln!("websocket send error: {err}");
74                    })
75                    .await;
76            }
77        });
78
79        let conn_id = self.connections.insert(tx, session).await;
80
81        while let Some(result) = ws_read.next().await {
82            let msg = match result {
83                Ok(msg) => msg,
84                Err(err) => {
85                    eprintln!("websocket error(uid={conn_id}): {err}");
86                    break;
87                }
88            };
89
90            if let Err(err) = self.handler.route(&conn_id, msg).await {
91                eprintln!("json error(uid={conn_id}): {err}");
92                break;
93            }
94        }
95
96        eprintln!("good bye user: {}", conn_id);
97
98        self.connections.remove(&conn_id).await;
99    }
100}
101
102struct Error(anyhow::Error);
103
104impl IntoResponse for Error {
105    fn into_response(self) -> Response {
106        (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
107    }
108}
109
110/// Provide a new WebSocket Event Controller
111#[derive(Default, new)]
112pub struct Provide<Session: Any> {
113    connections_tag: Option<&'static Tag<Connections<Session>>>,
114    handler_tag: Option<&'static Tag<Box<dyn Handler<Session>>>>,
115}
116
117impl<Session: Any> Provide<Session> {
118    /// Set a Tag for the Connections instance this Provider requires
119    pub fn with_connections_tag(self, connections_tag: &'static Tag<Connections<Session>>) -> Self {
120        Self {
121            connections_tag: Some(connections_tag),
122            ..self
123        }
124    }
125
126    /// Set a Tag for the Handler instance this Provider requires
127    pub fn with_handler_tag(self, handler_tag: &'static Tag<Box<dyn Handler<Session>>>) -> Self {
128        Self {
129            handler_tag: Some(handler_tag),
130            ..self
131        }
132    }
133}
134
135#[Provider]
136#[async_trait]
137impl<Session> Provider<Controller<Session>> for Provide<Session>
138where
139    Session: Send + Sync + Any,
140{
141    async fn provide(self: Arc<Self>, i: Inject) -> provider::Result<Arc<Controller<Session>>> {
142        let connections = if let Some(tag) = self.connections_tag {
143            i.get_tag(tag).await?
144        } else {
145            i.get::<Connections<Session>>().await?
146        };
147
148        let handler = if let Some(tag) = self.handler_tag {
149            i.get_tag(tag).await?
150        } else {
151            i.get::<Box<dyn Handler<Session>>>().await?
152        };
153
154        Ok(Arc::new(Controller::new(connections, handler)))
155    }
156}