axum_realtime_kit/ws/
handler.rs

1// axum-realtime-kit/src/ws/handler.rs
2
3//! Defines the `MessageHandler` trait, the core abstraction for implementing
4//! application-specific WebSocket logic.
5
6use crate::ws::types::{ConnectionId, Topic, WsSink};
7use async_trait::async_trait;
8use axum::http::StatusCode;
9use serde::{Serialize, de::DeserializeOwned};
10use serde_json::Value;
11use std::fmt::Debug;
12use std::sync::Arc;
13
14/// A context object passed to message handler methods.
15///
16/// It contains all the relevant state for a single connection, allowing the
17/// handler to access application state, user information, and the client's
18/// WebSocket sink for direct communication.
19#[derive(Debug)]
20pub struct ConnectionContext<S, U>
21where
22    S: Send + Sync + 'static,
23    U: Send + Sync + 'static,
24{
25    /// The unique ID of the connection.
26    pub conn_id: ConnectionId,
27    /// The ID of the authenticated user. Its type is defined by the `MessageHandler`.
28    pub user_id: U,
29    /// The topic this connection is subscribed to.
30    pub topic: Topic,
31    /// A clone of the shared application state.
32    pub app_state: Arc<S>,
33    /// A clone of the WebSocket sink for sending messages directly to the client.
34    pub sink: WsSink,
35}
36
37/// A standard error type for handler methods, consisting of an HTTP status code
38/// and an optional, more specific error message.
39#[derive(Debug)]
40pub enum HandlerError {
41    /// A custom error with a status code and message.
42    Custom(StatusCode, Option<String>),
43    /// An error that occurred during response serialization.
44    Serialization(serde_json::Error),
45}
46
47impl From<(StatusCode, Option<String>)> for HandlerError {
48    fn from(value: (StatusCode, Option<String>)) -> Self {
49        HandlerError::Custom(value.0, value.1)
50    }
51}
52
53// Allow easy conversion from serde_json::Error.
54impl From<serde_json::Error> for HandlerError {
55    fn from(value: serde_json::Error) -> Self {
56        HandlerError::Serialization(value)
57    }
58}
59
60/// The central trait that a user's application must implement.
61///
62/// This trait decouples the generic `WebsocketService` from any specific
63/// business logic, allowing users to "plug in" their own message handling,
64/// event types, and application state.
65#[async_trait]
66pub trait MessageHandler: Send + Sync + 'static {
67    /// The type for messages coming from the client (e.g., a `WsClientMessage` enum).
68    /// Must be deserializable from a string (e.g., JSON).
69    type ClientMessage: DeserializeOwned + Send + Debug;
70
71    /// The type for events broadcast via Redis (e.g., a `ServerEvent` enum).
72    /// Must be serializable to a string (e.g., JSON).
73    type ServerEvent: Serialize + Send + Sync + Debug;
74
75    /// The shared application state (e.g., a struct holding database connection pools).
76    type AppState: Send + Sync + 'static;
77
78    /// The type for the user identifier. This can be `i64`, `Uuid`, `String`, etc.
79    type UserId: Send + Sync + Clone + Debug + 'static;
80
81    /// A lifecycle hook called immediately after a client successfully connects
82    /// and is subscribed.
83    ///
84    /// The default implementation does nothing.
85    async fn on_connect(&self, _context: &ConnectionContext<Self::AppState, Self::UserId>) {
86        // Default is a no-op
87    }
88
89    /// Handles a message that expects a direct response to the sender, not a broadcast.
90    ///
91    /// This is ideal for read-only operations like "list items" or "get item".
92    /// If `Ok(Some(response))` is returned, the `response` is serialized and sent
93    /// directly to the client who sent the message.
94    ///
95    /// If `Ok(None)` is returned, the `WebsocketService` will proceed to call
96    /// `handle_broadcast_message` with the same message. This allows for a single
97    /// message to potentially trigger both a direct response and a broadcast.
98    ///
99    /// The default implementation returns `Ok(None)`, falling back to the broadcast handler.
100    async fn handle_direct_message(
101        &self,
102        _msg: &Self::ClientMessage,
103        _context: &ConnectionContext<Self::AppState, Self::UserId>,
104    ) -> Result<Option<Value>, HandlerError> {
105        Ok(None)
106    }
107
108    /// Handles a message that may result in an event being broadcast to all subscribers.
109    ///
110    /// This is ideal for state-changing operations like "create item" or "delete item".
111    /// If `Ok(Some(event))` is returned, the `event` is serialized and published to
112    //  Redis, which then broadcasts it to all clients on the topic (including the sender).
113    ///
114    /// If `Ok(None)` is returned, the message was handled successfully, but no event
115    /// needs to be broadcast.
116    async fn handle_broadcast_message(
117        &self,
118        msg: Self::ClientMessage,
119        context: &ConnectionContext<Self::AppState, Self::UserId>,
120    ) -> Result<Option<Self::ServerEvent>, HandlerError>;
121
122    /// A lifecycle hook called after a client has disconnected.
123    ///
124    /// Useful for logging or performing cleanup related to the user's session.
125    /// The default implementation does nothing.
126    async fn on_disconnect(&self, _context: &ConnectionContext<Self::AppState, Self::UserId>) {
127        // Default is a no-op
128    }
129}