1use crate::error::RealtimeError;
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7pub const MAX_SUBSCRIPTIONS_PER_CLIENT: usize = 100;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct Channel {
13 pub channel_type: ChannelType,
15 pub identifier: String,
17 pub filter: Option<String>,
19}
20
21impl Channel {
22 pub fn parse(s: &str) -> Result<Self, RealtimeError> {
31 let parts: Vec<&str> = s.splitn(2, ':').collect();
32 if parts.len() != 2 {
33 return Err(RealtimeError::InvalidChannel(format!(
34 "missing channel type prefix: {}",
35 s
36 )));
37 }
38
39 let channel_type = match parts[0] {
40 "repo" => ChannelType::Repository,
41 "user" => ChannelType::User,
42 "org" => ChannelType::Organization,
43 _ => {
44 return Err(RealtimeError::InvalidChannel(format!(
45 "unknown channel type: {}",
46 parts[0]
47 )))
48 }
49 };
50
51 let identifier_parts: Vec<&str> = parts[1].splitn(3, '/').collect();
52
53 let (identifier, filter) = match channel_type {
54 ChannelType::Repository => {
55 if identifier_parts.len() < 2 {
56 return Err(RealtimeError::InvalidChannel(format!(
57 "repository channel requires owner/name format: {}",
58 parts[1]
59 )));
60 }
61 let id = format!("{}/{}", identifier_parts[0], identifier_parts[1]);
62 let filter = if identifier_parts.len() > 2 {
63 Some(identifier_parts[2].to_string())
64 } else {
65 None
66 };
67 (id, filter)
68 }
69 ChannelType::User | ChannelType::Organization => {
70 if identifier_parts.is_empty() || identifier_parts[0].is_empty() {
71 return Err(RealtimeError::InvalidChannel(format!(
72 "channel identifier cannot be empty: {}",
73 s
74 )));
75 }
76 (identifier_parts[0].to_string(), None)
77 }
78 };
79
80 Ok(Channel {
81 channel_type,
82 identifier,
83 filter,
84 })
85 }
86
87 pub fn matches(&self, event_channel: &str) -> bool {
89 let event_chan = match Channel::parse(event_channel) {
90 Ok(c) => c,
91 Err(_) => return false,
92 };
93
94 if self.channel_type != event_chan.channel_type {
95 return false;
96 }
97
98 if self.identifier != event_chan.identifier {
99 return false;
100 }
101
102 if self.filter.is_none() {
104 return true;
105 }
106
107 self.filter == event_chan.filter
109 }
110}
111
112impl std::fmt::Display for Channel {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 let prefix = match self.channel_type {
115 ChannelType::Repository => "repo",
116 ChannelType::User => "user",
117 ChannelType::Organization => "org",
118 };
119
120 match &self.filter {
121 Some(filter) => write!(f, "{}:{}/{}", prefix, self.identifier, filter),
122 None => write!(f, "{}:{}", prefix, self.identifier),
123 }
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
129#[serde(rename_all = "lowercase")]
130pub enum ChannelType {
131 Repository,
133 User,
135 Organization,
137}
138
139#[derive(Debug, Default)]
141pub struct ClientSubscriptions {
142 channels: HashSet<Channel>,
144}
145
146impl ClientSubscriptions {
147 pub fn new() -> Self {
149 Self {
150 channels: HashSet::new(),
151 }
152 }
153
154 pub fn subscribe(&mut self, channel: Channel) -> Result<bool, RealtimeError> {
156 if self.channels.len() >= MAX_SUBSCRIPTIONS_PER_CLIENT {
157 return Err(RealtimeError::SubscriptionLimit(
158 MAX_SUBSCRIPTIONS_PER_CLIENT,
159 ));
160 }
161
162 Ok(self.channels.insert(channel))
163 }
164
165 pub fn unsubscribe(&mut self, channel: &Channel) -> bool {
167 self.channels.remove(channel)
168 }
169
170 pub fn is_subscribed(&self, channel: &Channel) -> bool {
172 self.channels.contains(channel)
173 }
174
175 pub fn matches_event(&self, event_channel: &str) -> bool {
177 self.channels.iter().any(|c| c.matches(event_channel))
178 }
179
180 pub fn channels(&self) -> impl Iterator<Item = &Channel> {
182 self.channels.iter()
183 }
184
185 pub fn count(&self) -> usize {
187 self.channels.len()
188 }
189
190 pub fn clear(&mut self) {
192 self.channels.clear();
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_channel_parse_repo() {
202 let channel = Channel::parse("repo:alice/myrepo").unwrap();
203 assert_eq!(channel.channel_type, ChannelType::Repository);
204 assert_eq!(channel.identifier, "alice/myrepo");
205 assert_eq!(channel.filter, None);
206 }
207
208 #[test]
209 fn test_channel_parse_repo_with_filter() {
210 let channel = Channel::parse("repo:alice/myrepo/prs").unwrap();
211 assert_eq!(channel.channel_type, ChannelType::Repository);
212 assert_eq!(channel.identifier, "alice/myrepo");
213 assert_eq!(channel.filter, Some("prs".to_string()));
214 }
215
216 #[test]
217 fn test_channel_parse_user() {
218 let channel = Channel::parse("user:alice").unwrap();
219 assert_eq!(channel.channel_type, ChannelType::User);
220 assert_eq!(channel.identifier, "alice");
221 assert_eq!(channel.filter, None);
222 }
223
224 #[test]
225 fn test_channel_parse_org() {
226 let channel = Channel::parse("org:acme").unwrap();
227 assert_eq!(channel.channel_type, ChannelType::Organization);
228 assert_eq!(channel.identifier, "acme");
229 assert_eq!(channel.filter, None);
230 }
231
232 #[test]
233 fn test_channel_parse_invalid() {
234 assert!(Channel::parse("invalid").is_err());
235 assert!(Channel::parse("unknown:test").is_err());
236 assert!(Channel::parse("repo:").is_err());
237 assert!(Channel::parse("repo:onlyname").is_err());
238 }
239
240 #[test]
241 fn test_channel_to_string() {
242 let channel = Channel {
243 channel_type: ChannelType::Repository,
244 identifier: "alice/myrepo".to_string(),
245 filter: None,
246 };
247 assert_eq!(channel.to_string(), "repo:alice/myrepo");
248
249 let channel_with_filter = Channel {
250 channel_type: ChannelType::Repository,
251 identifier: "alice/myrepo".to_string(),
252 filter: Some("prs".to_string()),
253 };
254 assert_eq!(channel_with_filter.to_string(), "repo:alice/myrepo/prs");
255 }
256
257 #[test]
258 fn test_channel_matches() {
259 let subscription = Channel::parse("repo:alice/myrepo").unwrap();
260
261 assert!(subscription.matches("repo:alice/myrepo"));
263 assert!(subscription.matches("repo:alice/myrepo/prs"));
264 assert!(subscription.matches("repo:alice/myrepo/issues"));
265
266 assert!(!subscription.matches("repo:bob/otherrepo"));
268 assert!(!subscription.matches("user:alice"));
269 }
270
271 #[test]
272 fn test_channel_matches_with_filter() {
273 let subscription = Channel::parse("repo:alice/myrepo/prs").unwrap();
274
275 assert!(subscription.matches("repo:alice/myrepo/prs"));
277
278 assert!(!subscription.matches("repo:alice/myrepo"));
280 assert!(!subscription.matches("repo:alice/myrepo/issues"));
281 }
282
283 #[test]
284 fn test_client_subscriptions() {
285 let mut subs = ClientSubscriptions::new();
286
287 let channel = Channel::parse("repo:alice/myrepo").unwrap();
288 assert!(subs.subscribe(channel.clone()).unwrap());
289 assert!(subs.is_subscribed(&channel));
290 assert_eq!(subs.count(), 1);
291
292 assert!(!subs.subscribe(channel.clone()).unwrap());
294 assert_eq!(subs.count(), 1);
295
296 assert!(subs.unsubscribe(&channel));
297 assert!(!subs.is_subscribed(&channel));
298 assert_eq!(subs.count(), 0);
299 }
300
301 #[test]
302 fn test_client_subscriptions_limit() {
303 let mut subs = ClientSubscriptions::new();
304
305 for i in 0..MAX_SUBSCRIPTIONS_PER_CLIENT {
306 let channel = Channel::parse(&format!("user:user{}", i)).unwrap();
307 subs.subscribe(channel).unwrap();
308 }
309
310 let extra = Channel::parse("user:extra").unwrap();
311 assert!(matches!(
312 subs.subscribe(extra),
313 Err(RealtimeError::SubscriptionLimit(_))
314 ));
315 }
316
317 #[test]
318 fn test_matches_event() {
319 let mut subs = ClientSubscriptions::new();
320 subs.subscribe(Channel::parse("repo:alice/myrepo").unwrap())
321 .unwrap();
322 subs.subscribe(Channel::parse("user:alice").unwrap())
323 .unwrap();
324
325 assert!(subs.matches_event("repo:alice/myrepo"));
326 assert!(subs.matches_event("repo:alice/myrepo/prs"));
327 assert!(subs.matches_event("user:alice"));
328
329 assert!(!subs.matches_event("repo:bob/otherrepo"));
330 assert!(!subs.matches_event("user:bob"));
331 }
332}