1pub mod client;
2
3use anyhow::anyhow;
4use client::{ClientFormat, ClientInfo, CnctdClient, QueryParams};
5use cnctd_redis::CnctdRedis;
6use futures_util::{SinkExt, StreamExt};
7use local_ip_address::local_ip;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use state::InitCell;
11use warp::filters::ws::Ws;
12use warp::reject::Reject;
13use warp::ws::{Message as WebSocketMessage, WebSocket};
14use warp::Filter;
15use std::collections::HashMap;
16use std::future::Future;
17use std::pin::Pin;
18use tokio::sync::{mpsc, RwLock};
19use std::{sync::Arc, fmt::Debug};
20
21use crate::router::message::Message;
22use crate::router::SocketRouterFunction;
23use crate::server::server_info::ServerInfo;
24
25pub type OnBinaryHandler = Arc<dyn Fn(String, Vec<u8>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
28
29#[derive(Debug)]
30struct NoClientId;
31
32impl Reject for NoClientId {}
33
34#[derive(Clone)]
35pub struct SocketConfig<R> {
36 pub router: R,
37 pub secret: Option<Vec<u8>>,
38 pub redis_url: Option<String>,
39 pub on_disconnect: Option<Arc<dyn Fn(ClientInfo) + Send + Sync>>,
40 pub on_binary: Option<OnBinaryHandler>,
41}
42
43impl<R> SocketConfig<R> {
44 pub fn new(router: R, secret: Option<Vec<u8>>, redis_url: Option<String>, on_disconnect: Option<Arc<dyn Fn(ClientInfo) + Send + Sync>>,) -> Self {
45 Self {
46 router,
47 secret,
48 redis_url,
49 on_disconnect,
50 on_binary: None,
51 }
52 }
53
54 pub fn with_on_binary(mut self, handler: OnBinaryHandler) -> Self {
55 self.on_binary = Some(handler);
56 self
57 }
58}
59
60
61
62pub static CLIENTS: InitCell<Arc<RwLock<HashMap<String, CnctdClient>>>> = InitCell::new();
63
64pub struct CnctdSocket;
65
66impl CnctdSocket {
67 pub fn build_routes<M, Resp, R>(config: SocketConfig<R>) -> warp::filters::BoxedFilter<(impl warp::Reply,)>
68 where
69 M: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static,
70 Resp: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static,
71 R: SocketRouterFunction<M, Resp> + 'static,
72 {
73 CLIENTS.set(Arc::new(RwLock::new(HashMap::new())));
74
75 let redis;
76
77 match config.redis_url {
78 Some(url) => {
79 match cnctd_redis::CnctdRedis::start(&url) {
80 Ok(_) => {
81 println!("Redis started!");
82 tokio::spawn(async {
83 ServerInfo::set_redis_active(true).await;
84 });
85 redis = true
86 },
87 Err(e) => {
88 println!("Error starting Redis pool: {:?}", e);
89 redis = false
90 }
91 }
92 }
93 None => redis = false
94 };
95
96 let websocket_route = warp::path("ws")
97 .and(warp::ws())
98 .and(warp::any().map(move || config.router.clone()))
99 .and(warp::query::<QueryParams>())
100 .and_then(move |ws: Ws, router: R, params: QueryParams| {
101 let on_disconnect = config.on_disconnect.clone();
102 let on_binary = config.on_binary.clone();
103
104 async move {
105 let client_id = match params.client_id {
107 Some(id) => id,
108 None => {
109 if let Some(ref subs_str) = params.subscriptions {
113 let subscriptions: Vec<String> = subs_str
114 .split(',')
115 .map(|s| s.trim().to_string())
116 .filter(|s| !s.is_empty())
117 .collect();
118 let format = ClientFormat::from_str_opt(params.format.as_deref());
119 match CnctdClient::register_client_with_format(
120 subscriptions,
121 None,
122 format,
123 ).await {
124 Ok(id) => {
125 println!("Inline-registered client: {}", id);
126 id
127 }
128 Err(e) => {
129 eprintln!("Inline registration failed: {:?}", e);
130 return Err(warp::reject::custom(NoClientId));
131 }
132 }
133 } else {
134 return Err(warp::reject::custom(NoClientId));
135 }
136 },
137 };
138
139 Ok(ws.on_upgrade(move |socket| {
141 Self::handle_connection(socket, router, client_id, redis, on_disconnect.clone(), on_binary)
142 }))
143 }
144 });
145
146
147 let routes = websocket_route;
148
149 routes.boxed()
150
151 }
152 pub async fn start<M, Resp, R>(port: &str, router: R, secret: Option<Vec<u8>>, redis_url: Option<String>, on_disconnect: Option<Arc<dyn Fn(ClientInfo) + Send + Sync>>,) -> anyhow::Result<()>
153 where
154 M: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static,
155 Resp: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static,
156 R: SocketRouterFunction<M, Resp> + 'static,
157 {
158 CLIENTS.set(Arc::new(RwLock::new(HashMap::new())));
159
160 let my_local_ip = local_ip()?;
161 println!("WebSocket server running at ws://{}:{}", my_local_ip, port);
162 let ip_address: [u8; 4] = [0, 0, 0, 0];
163 let parsed_port = port.parse::<u16>()?;
164 let socket_addr = std::net::SocketAddr::from((ip_address, parsed_port));
165 let config = SocketConfig::new(router, secret, redis_url, on_disconnect);
166 let routes = Self::build_routes(config);
167
168 warp::serve(routes).run(socket_addr).await;
169
170 Ok(())
171
172 }
173
174 pub async fn broadcast_message(msg: &Message) -> anyhow::Result<()> {
175 let clients = CLIENTS.try_get().ok_or_else(|| anyhow!("Clients not initialized"))?.read().await;
176
177 for (client_id, client) in clients.iter() {
178 if client.subscriptions.contains(&msg.channel) {
179 CnctdClient::message_client(&client_id, msg).await?;
180 }
181 }
182
183 Ok(())
184 }
185
186
187
188 async fn handle_connection<M, Resp, R>(
189 websocket: WebSocket,
190 router: R,
191 client_id: String,
192 redis: bool,
193 on_disconnect: Option<Arc<dyn Fn(ClientInfo) + Send + Sync>>,
194 on_binary: Option<OnBinaryHandler>,
195 ) where
196 M: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static,
197 Resp: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static,
198 R: SocketRouterFunction<M, Resp> + 'static,
199 {
200 let (mut ws_tx, mut ws_rx) = websocket.split();
201 let (resp_tx, mut resp_rx) = mpsc::unbounded_channel::<Result<WebSocketMessage, warp::Error>>();
202
203 {
204 let clients = CLIENTS.get();
205 let mut clients_lock = clients.write().await;
206
207 if let Some(client) = clients_lock.get_mut(&client_id.clone()) {
208 client.sender = Some(resp_tx.clone());
210
211 if redis {
212 match Self::push_client_to_redis(&client_id, &client.clone()).await {
213 Ok(_) => println!("pajama party"),
214 Err(e) => eprintln!("Error pushing client to Redis: {:?}", e),
215 }
216 }
217 println!("Updated client sender: {:?}", client);
218 } else {
219 eprintln!("Client with id {} not found.", client_id);
221 return;
222 }
223 }
224
225 let client_id_clone = client_id.clone();
226 let process_incoming = async move {
228 while let Some(result) = ws_rx.next().await {
229 match result {
230 Ok(msg) => {
231 if msg.is_binary() {
232 if let Some(ref handler) = on_binary {
233 let bytes = msg.into_bytes();
234 handler(client_id_clone.clone(), bytes).await;
235 }
236 } else if let Ok(message_str) = msg.to_str() {
237 if let Ok(message) = serde_json::from_str::<M>(message_str) {
238 match router.route(message, client_id_clone.clone()).await {
239 Some(response) => {
240 if let Ok(response_str) = serde_json::to_string(&response) {
241 let _ = resp_tx.send(Ok(WebSocketMessage::text(response_str)));
242 }
243 },
244 None => {}
245 }
246 }
247 }
248 },
249 Err(e) => eprintln!("WebSocket receive error: {:?}", e),
250 }
251 }
252 };
253
254 let send_responses = async move {
256 while let Some(response) = resp_rx.recv().await {
257 if let Ok(msg) = response {
258 if ws_tx.send(msg).await.is_err() {
259 eprintln!("WebSocket send error");
260 break;
261 }
262 }
263 }
264 };
265
266 tokio::select! {
267 _ = process_incoming => {},
268 _ = send_responses => {},
269 };
270
271 if let Some(callback) = on_disconnect {
272 let client_info = CnctdClient::get_client_info(&client_id).await.unwrap();
273 callback(client_info);
274 }
275
276 match Self::remove_client(&client_id).await {
278 Ok(_) => {},
279 Err(e) => eprintln!("Error removing client: {:?}", e),
280 };
281
282 if redis {
283 match Self::remove_client_from_redis(&client_id).await {
284 Ok(_) => {},
285 Err(e) => eprintln!("Error removing client from Redis: {:?}", e),
286 }
287 }
288
289
290
291 }
292
293 pub async fn remove_client(client_id: &str) -> anyhow::Result<()> {
294 let clients = CLIENTS.try_get().ok_or_else(|| anyhow!("Clients not initialized"))?;
295 let mut clients_lock = clients.write().await;
296
297 if let Some(client) = clients_lock.get(client_id) {
298 let should_remove = client.sender.as_ref().map_or(true, |sender| sender.is_closed());
299
300 if should_remove {
301 println!("Removing client: {}", client_id);
302 clients_lock.remove(client_id);
303 } else {
304 println!("Client {} is active; no removal necessary.", client_id);
305 }
306 }
307
308 Ok(())
309 }
310
311
312
313 pub async fn push_client_to_redis(client_id: &str, client: &CnctdClient) -> anyhow::Result<()> {
314 let client_info = client.to_client_info(client_id).await;
315 CnctdRedis::hset("clients", &client_id, client_info)?;
316
317 Ok(())
318 }
319
320 pub async fn remove_client_from_redis(client_id: &str) -> anyhow::Result<()> {
321 CnctdRedis::hset("clients", client_id, ())?;
322
323 Ok(())
324 }
325
326}
327