supabase_client_realtime/
channel.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde_json::Value;
6use tokio::sync::RwLock;
7
8use crate::callback::{Binding, CallbackRegistry};
9use crate::error::RealtimeError;
10use crate::types::{
11 BroadcastConfig, ChannelState, JoinConfig, JoinPayload, PostgresChangesEvent,
12 PostgresChangesFilter, PresenceConfig, PresenceState, SubscriptionStatus,
13};
14
15pub struct ChannelBuilder {
21 pub(crate) name: String,
22 pub(crate) topic: String,
23 pub(crate) broadcast_config: BroadcastConfig,
24 pub(crate) presence_key: String,
25 pub(crate) presence_enabled: bool,
26 pub(crate) postgres_changes: Vec<PostgresChangesFilter>,
27 pub(crate) bindings: Vec<Binding>,
28 pub(crate) is_private: bool,
29 pub(crate) subscribe_timeout: Duration,
30 pub(crate) access_token: Option<String>,
31 pub(crate) client_sender: crate::client::ClientSender,
33}
34
35impl ChannelBuilder {
36 pub fn on_postgres_changes<F>(
38 mut self,
39 event: PostgresChangesEvent,
40 filter: PostgresChangesFilter,
41 callback: F,
42 ) -> Self
43 where
44 F: Fn(crate::types::PostgresChangePayload) + Send + Sync + 'static,
45 {
46 let filter_index = self.postgres_changes.len();
47 let filter = filter.event(event);
49 self.postgres_changes.push(filter);
50 self.bindings.push(Binding::PostgresChanges {
51 filter_index,
52 event,
53 callback: Arc::new(callback),
54 });
55 self
56 }
57
58 pub fn on_broadcast<F>(mut self, event: &str, callback: F) -> Self
60 where
61 F: Fn(Value) + Send + Sync + 'static,
62 {
63 self.bindings.push(Binding::Broadcast {
64 event: event.to_string(),
65 callback: Arc::new(callback),
66 });
67 self
68 }
69
70 pub fn on_presence_sync<F>(mut self, callback: F) -> Self
72 where
73 F: Fn(&PresenceState) + Send + Sync + 'static,
74 {
75 self.presence_enabled = true;
76 self.bindings.push(Binding::PresenceSync(Arc::new(callback)));
77 self
78 }
79
80 pub fn on_presence_join<F>(mut self, callback: F) -> Self
82 where
83 F: Fn(String, Vec<crate::types::PresenceMeta>) + Send + Sync + 'static,
84 {
85 self.presence_enabled = true;
86 self.bindings
87 .push(Binding::PresenceJoin(Arc::new(callback)));
88 self
89 }
90
91 pub fn on_presence_leave<F>(mut self, callback: F) -> Self
93 where
94 F: Fn(String, Vec<crate::types::PresenceMeta>) + Send + Sync + 'static,
95 {
96 self.presence_enabled = true;
97 self.bindings
98 .push(Binding::PresenceLeave(Arc::new(callback)));
99 self
100 }
101
102 pub fn broadcast_ack(mut self, ack: bool) -> Self {
104 self.broadcast_config.ack = ack;
105 self
106 }
107
108 pub fn broadcast_self(mut self, self_send: bool) -> Self {
110 self.broadcast_config.self_send = self_send;
111 self
112 }
113
114 pub fn presence_key(mut self, key: &str) -> Self {
116 self.presence_enabled = true;
117 self.presence_key = key.to_string();
118 self
119 }
120
121 pub fn private(mut self) -> Self {
123 self.is_private = true;
124 self
125 }
126
127 pub fn timeout(mut self, timeout: Duration) -> Self {
129 self.subscribe_timeout = timeout;
130 self
131 }
132
133 pub async fn subscribe<F>(
137 self,
138 status_callback: F,
139 ) -> Result<RealtimeChannel, RealtimeError>
140 where
141 F: Fn(SubscriptionStatus, Option<RealtimeError>) + Send + Sync + 'static,
142 {
143 let join_payload = JoinPayload {
144 config: JoinConfig {
145 broadcast: self.broadcast_config.clone(),
146 presence: PresenceConfig {
147 key: self.presence_key.clone(),
148 },
149 postgres_changes: self.postgres_changes.clone(),
150 },
151 access_token: self.access_token.clone(),
152 };
153
154 let registry = CallbackRegistry::new();
155 {
156 let mut bindings = registry.bindings.write().await;
157 for binding in self.bindings {
158 bindings.push(binding);
159 }
160 }
161 {
162 let mut status_cb = registry.status_callback.write().await;
163 *status_cb = Some(Arc::new(status_callback));
164 }
165
166 let inner = Arc::new(ChannelInner {
167 name: self.name.clone(),
168 topic: self.topic.clone(),
169 state: RwLock::new(ChannelState::Joining),
170 join_ref: RwLock::new(None),
171 join_payload: RwLock::new(join_payload.clone()),
172 registry,
173 presence_state: RwLock::new(PresenceState::new()),
174 pg_change_id_map: RwLock::new(HashMap::new()),
175 client_sender: self.client_sender.clone(),
176 });
177
178 let channel = RealtimeChannel {
179 inner: inner.clone(),
180 };
181
182 self.client_sender
184 .subscribe_channel(channel.clone(), join_payload, self.subscribe_timeout)
185 .await?;
186
187 Ok(channel)
188 }
189}
190
191#[derive(Clone)]
197pub struct RealtimeChannel {
198 pub(crate) inner: Arc<ChannelInner>,
199}
200
201pub(crate) struct ChannelInner {
202 pub(crate) name: String,
203 pub(crate) topic: String,
204 pub(crate) state: RwLock<ChannelState>,
205 pub(crate) join_ref: RwLock<Option<String>>,
206 pub(crate) join_payload: RwLock<JoinPayload>,
207 pub(crate) registry: CallbackRegistry,
208 pub(crate) presence_state: RwLock<PresenceState>,
209 pub(crate) pg_change_id_map: RwLock<HashMap<u64, usize>>,
211 pub(crate) client_sender: crate::client::ClientSender,
212}
213
214impl RealtimeChannel {
215 pub fn topic(&self) -> &str {
217 &self.inner.topic
218 }
219
220 pub fn name(&self) -> &str {
222 &self.inner.name
223 }
224
225 pub async fn state(&self) -> ChannelState {
227 *self.inner.state.read().await
228 }
229
230 pub async fn send_broadcast(
232 &self,
233 event: &str,
234 payload: Value,
235 ) -> Result<(), RealtimeError> {
236 let state = *self.inner.state.read().await;
237 if state != ChannelState::Joined {
238 return Err(RealtimeError::InvalidChannelState {
239 expected: ChannelState::Joined,
240 actual: state,
241 });
242 }
243 let join_ref = self.inner.join_ref.read().await;
244 let join_ref = join_ref
245 .as_deref()
246 .ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
247 self.inner
248 .client_sender
249 .send_broadcast(&self.inner.topic, event, payload, join_ref)
250 .await
251 }
252
253 pub async fn track(&self, payload: Value) -> Result<(), RealtimeError> {
255 let state = *self.inner.state.read().await;
256 if state != ChannelState::Joined {
257 return Err(RealtimeError::InvalidChannelState {
258 expected: ChannelState::Joined,
259 actual: state,
260 });
261 }
262 let join_ref = self.inner.join_ref.read().await;
263 let join_ref = join_ref
264 .as_deref()
265 .ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
266 self.inner
267 .client_sender
268 .send_presence_track(&self.inner.topic, payload, join_ref)
269 .await
270 }
271
272 pub async fn untrack(&self) -> Result<(), RealtimeError> {
274 let state = *self.inner.state.read().await;
275 if state != ChannelState::Joined {
276 return Err(RealtimeError::InvalidChannelState {
277 expected: ChannelState::Joined,
278 actual: state,
279 });
280 }
281 let join_ref = self.inner.join_ref.read().await;
282 let join_ref = join_ref
283 .as_deref()
284 .ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
285 self.inner
286 .client_sender
287 .send_presence_untrack(&self.inner.topic, join_ref)
288 .await
289 }
290
291 pub async fn presence_state(&self) -> PresenceState {
293 self.inner.presence_state.read().await.clone()
294 }
295
296 pub async fn unsubscribe(&self) -> Result<(), RealtimeError> {
298 let state = *self.inner.state.read().await;
299 if state == ChannelState::Closed || state == ChannelState::Leaving {
300 return Ok(());
301 }
302 let join_ref = self.inner.join_ref.read().await;
303 let join_ref = join_ref
304 .as_deref()
305 .ok_or_else(|| RealtimeError::Internal("No join_ref for leave".to_string()))?;
306 self.inner
307 .client_sender
308 .send_leave(&self.inner.topic, join_ref)
309 .await?;
310 *self.inner.state.write().await = ChannelState::Leaving;
311 Ok(())
312 }
313
314 pub async fn update_access_token(&self, token: &str) -> Result<(), RealtimeError> {
316 let state = *self.inner.state.read().await;
317 if state != ChannelState::Joined {
318 return Err(RealtimeError::InvalidChannelState {
319 expected: ChannelState::Joined,
320 actual: state,
321 });
322 }
323 {
325 let mut jp = self.inner.join_payload.write().await;
326 jp.access_token = Some(token.to_string());
327 }
328 let join_ref = self.inner.join_ref.read().await;
329 let join_ref = join_ref
330 .as_deref()
331 .ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
332 self.inner
333 .client_sender
334 .send_access_token(&self.inner.topic, token, join_ref)
335 .await
336 }
337}