1use std::{any::Any, sync::Arc};
2
3use async_trait::async_trait;
4use axum::{
5 extract::{
6 ws::{Message, WebSocket},
7 WebSocketUpgrade,
8 },
9 response::{IntoResponse, Response},
10};
11use biscuit::Empty;
12use derive_new::new;
13use futures::{SinkExt, StreamExt, TryFutureExt};
14use http::StatusCode;
15use mockall::automock;
16use nakago::{provider, Inject, Provider, Tag};
17use nakago_derive::Provider;
18use tokio::sync::mpsc;
19use tokio_stream::wrappers::UnboundedReceiverStream;
20
21use crate::auth::Token;
22
23use super::Connections;
24
25#[automock]
27#[async_trait]
28pub trait Handler<Session, T = Empty>: Send + Sync + Any
29where
30 Session: Send + Sync + Any,
31 T: Default + Send + Sync + Any,
32{
33 async fn route(&self, conn_id: &str, msg: Message) -> anyhow::Result<()>;
35
36 async fn get_session(&self, token: Token<T>) -> anyhow::Result<Session>;
38}
39
40#[derive(Clone, new)]
42pub struct Controller<Session> {
43 connections: Arc<Connections<Session>>,
44 handler: Arc<Box<dyn Handler<Session>>>,
45}
46
47impl<Session: Default + Send + Sync + Clone + Any> Controller<Session> {
48 pub async fn upgrade(
50 self: Arc<Self>,
51 token: Token,
52 ws: WebSocketUpgrade,
53 ) -> axum::response::Result<impl IntoResponse> {
54 let session = self.handler.get_session(token).await.map_err(Error)?;
56
57 Ok(ws.on_upgrade(|socket| async move { self.handle(socket, session).await }))
58 }
59
60 async fn handle(&self, socket: WebSocket, session: Session) {
63 let (mut ws_write, mut ws_read) = socket.split();
64
65 let (tx, rx) = mpsc::unbounded_channel();
66 let mut rx = UnboundedReceiverStream::new(rx);
67
68 tokio::task::spawn(async move {
69 while let Some(message) = rx.next().await {
70 ws_write
71 .send(message)
72 .unwrap_or_else(|err| {
73 eprintln!("websocket send error: {err}");
74 })
75 .await;
76 }
77 });
78
79 let conn_id = self.connections.insert(tx, session).await;
80
81 while let Some(result) = ws_read.next().await {
82 let msg = match result {
83 Ok(msg) => msg,
84 Err(err) => {
85 eprintln!("websocket error(uid={conn_id}): {err}");
86 break;
87 }
88 };
89
90 if let Err(err) = self.handler.route(&conn_id, msg).await {
91 eprintln!("json error(uid={conn_id}): {err}");
92 break;
93 }
94 }
95
96 eprintln!("good bye user: {}", conn_id);
97
98 self.connections.remove(&conn_id).await;
99 }
100}
101
102struct Error(anyhow::Error);
103
104impl IntoResponse for Error {
105 fn into_response(self) -> Response {
106 (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
107 }
108}
109
110#[derive(Default, new)]
112pub struct Provide<Session: Any> {
113 connections_tag: Option<&'static Tag<Connections<Session>>>,
114 handler_tag: Option<&'static Tag<Box<dyn Handler<Session>>>>,
115}
116
117impl<Session: Any> Provide<Session> {
118 pub fn with_connections_tag(self, connections_tag: &'static Tag<Connections<Session>>) -> Self {
120 Self {
121 connections_tag: Some(connections_tag),
122 ..self
123 }
124 }
125
126 pub fn with_handler_tag(self, handler_tag: &'static Tag<Box<dyn Handler<Session>>>) -> Self {
128 Self {
129 handler_tag: Some(handler_tag),
130 ..self
131 }
132 }
133}
134
135#[Provider]
136#[async_trait]
137impl<Session> Provider<Controller<Session>> for Provide<Session>
138where
139 Session: Send + Sync + Any,
140{
141 async fn provide(self: Arc<Self>, i: Inject) -> provider::Result<Arc<Controller<Session>>> {
142 let connections = if let Some(tag) = self.connections_tag {
143 i.get_tag(tag).await?
144 } else {
145 i.get::<Connections<Session>>().await?
146 };
147
148 let handler = if let Some(tag) = self.handler_tag {
149 i.get_tag(tag).await?
150 } else {
151 i.get::<Box<dyn Handler<Session>>>().await?
152 };
153
154 Ok(Arc::new(Controller::new(connections, handler)))
155 }
156}