guts_realtime/
subscription.rs

1//! Subscription management for real-time channels.
2
3use crate::error::RealtimeError;
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7/// Maximum subscriptions per client.
8pub const MAX_SUBSCRIPTIONS_PER_CLIENT: usize = 100;
9
10/// A channel that clients can subscribe to.
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct Channel {
13    /// Channel type.
14    pub channel_type: ChannelType,
15    /// Channel identifier.
16    pub identifier: String,
17    /// Optional sub-channel filter.
18    pub filter: Option<String>,
19}
20
21impl Channel {
22    /// Parse a channel string into a Channel.
23    ///
24    /// Formats:
25    /// - `repo:owner/name` - All events for a repository
26    /// - `repo:owner/name/prs` - PR events only
27    /// - `repo:owner/name/issues` - Issue events only
28    /// - `user:username` - User notifications
29    /// - `org:orgname` - Organization events
30    pub fn parse(s: &str) -> Result<Self, RealtimeError> {
31        let parts: Vec<&str> = s.splitn(2, ':').collect();
32        if parts.len() != 2 {
33            return Err(RealtimeError::InvalidChannel(format!(
34                "missing channel type prefix: {}",
35                s
36            )));
37        }
38
39        let channel_type = match parts[0] {
40            "repo" => ChannelType::Repository,
41            "user" => ChannelType::User,
42            "org" => ChannelType::Organization,
43            _ => {
44                return Err(RealtimeError::InvalidChannel(format!(
45                    "unknown channel type: {}",
46                    parts[0]
47                )))
48            }
49        };
50
51        let identifier_parts: Vec<&str> = parts[1].splitn(3, '/').collect();
52
53        let (identifier, filter) = match channel_type {
54            ChannelType::Repository => {
55                if identifier_parts.len() < 2 {
56                    return Err(RealtimeError::InvalidChannel(format!(
57                        "repository channel requires owner/name format: {}",
58                        parts[1]
59                    )));
60                }
61                let id = format!("{}/{}", identifier_parts[0], identifier_parts[1]);
62                let filter = if identifier_parts.len() > 2 {
63                    Some(identifier_parts[2].to_string())
64                } else {
65                    None
66                };
67                (id, filter)
68            }
69            ChannelType::User | ChannelType::Organization => {
70                if identifier_parts.is_empty() || identifier_parts[0].is_empty() {
71                    return Err(RealtimeError::InvalidChannel(format!(
72                        "channel identifier cannot be empty: {}",
73                        s
74                    )));
75                }
76                (identifier_parts[0].to_string(), None)
77            }
78        };
79
80        Ok(Channel {
81            channel_type,
82            identifier,
83            filter,
84        })
85    }
86
87    /// Check if an event channel matches this subscription channel.
88    pub fn matches(&self, event_channel: &str) -> bool {
89        let event_chan = match Channel::parse(event_channel) {
90            Ok(c) => c,
91            Err(_) => return false,
92        };
93
94        if self.channel_type != event_chan.channel_type {
95            return false;
96        }
97
98        if self.identifier != event_chan.identifier {
99            return false;
100        }
101
102        // If subscription has no filter, it matches all sub-channels
103        if self.filter.is_none() {
104            return true;
105        }
106
107        // If subscription has a filter, event must match exactly
108        self.filter == event_chan.filter
109    }
110}
111
112impl std::fmt::Display for Channel {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        let prefix = match self.channel_type {
115            ChannelType::Repository => "repo",
116            ChannelType::User => "user",
117            ChannelType::Organization => "org",
118        };
119
120        match &self.filter {
121            Some(filter) => write!(f, "{}:{}/{}", prefix, self.identifier, filter),
122            None => write!(f, "{}:{}", prefix, self.identifier),
123        }
124    }
125}
126
127/// Channel types.
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
129#[serde(rename_all = "lowercase")]
130pub enum ChannelType {
131    /// Repository channel (e.g., repo:owner/name).
132    Repository,
133    /// User notification channel (e.g., user:alice).
134    User,
135    /// Organization channel (e.g., org:acme).
136    Organization,
137}
138
139/// Manages subscriptions for a single client.
140#[derive(Debug, Default)]
141pub struct ClientSubscriptions {
142    /// Set of subscribed channels.
143    channels: HashSet<Channel>,
144}
145
146impl ClientSubscriptions {
147    /// Create a new subscription manager.
148    pub fn new() -> Self {
149        Self {
150            channels: HashSet::new(),
151        }
152    }
153
154    /// Subscribe to a channel.
155    pub fn subscribe(&mut self, channel: Channel) -> Result<bool, RealtimeError> {
156        if self.channels.len() >= MAX_SUBSCRIPTIONS_PER_CLIENT {
157            return Err(RealtimeError::SubscriptionLimit(
158                MAX_SUBSCRIPTIONS_PER_CLIENT,
159            ));
160        }
161
162        Ok(self.channels.insert(channel))
163    }
164
165    /// Unsubscribe from a channel.
166    pub fn unsubscribe(&mut self, channel: &Channel) -> bool {
167        self.channels.remove(channel)
168    }
169
170    /// Check if subscribed to a channel.
171    pub fn is_subscribed(&self, channel: &Channel) -> bool {
172        self.channels.contains(channel)
173    }
174
175    /// Check if any subscription matches the event channel.
176    pub fn matches_event(&self, event_channel: &str) -> bool {
177        self.channels.iter().any(|c| c.matches(event_channel))
178    }
179
180    /// Get all subscribed channels.
181    pub fn channels(&self) -> impl Iterator<Item = &Channel> {
182        self.channels.iter()
183    }
184
185    /// Get subscription count.
186    pub fn count(&self) -> usize {
187        self.channels.len()
188    }
189
190    /// Clear all subscriptions.
191    pub fn clear(&mut self) {
192        self.channels.clear();
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_channel_parse_repo() {
202        let channel = Channel::parse("repo:alice/myrepo").unwrap();
203        assert_eq!(channel.channel_type, ChannelType::Repository);
204        assert_eq!(channel.identifier, "alice/myrepo");
205        assert_eq!(channel.filter, None);
206    }
207
208    #[test]
209    fn test_channel_parse_repo_with_filter() {
210        let channel = Channel::parse("repo:alice/myrepo/prs").unwrap();
211        assert_eq!(channel.channel_type, ChannelType::Repository);
212        assert_eq!(channel.identifier, "alice/myrepo");
213        assert_eq!(channel.filter, Some("prs".to_string()));
214    }
215
216    #[test]
217    fn test_channel_parse_user() {
218        let channel = Channel::parse("user:alice").unwrap();
219        assert_eq!(channel.channel_type, ChannelType::User);
220        assert_eq!(channel.identifier, "alice");
221        assert_eq!(channel.filter, None);
222    }
223
224    #[test]
225    fn test_channel_parse_org() {
226        let channel = Channel::parse("org:acme").unwrap();
227        assert_eq!(channel.channel_type, ChannelType::Organization);
228        assert_eq!(channel.identifier, "acme");
229        assert_eq!(channel.filter, None);
230    }
231
232    #[test]
233    fn test_channel_parse_invalid() {
234        assert!(Channel::parse("invalid").is_err());
235        assert!(Channel::parse("unknown:test").is_err());
236        assert!(Channel::parse("repo:").is_err());
237        assert!(Channel::parse("repo:onlyname").is_err());
238    }
239
240    #[test]
241    fn test_channel_to_string() {
242        let channel = Channel {
243            channel_type: ChannelType::Repository,
244            identifier: "alice/myrepo".to_string(),
245            filter: None,
246        };
247        assert_eq!(channel.to_string(), "repo:alice/myrepo");
248
249        let channel_with_filter = Channel {
250            channel_type: ChannelType::Repository,
251            identifier: "alice/myrepo".to_string(),
252            filter: Some("prs".to_string()),
253        };
254        assert_eq!(channel_with_filter.to_string(), "repo:alice/myrepo/prs");
255    }
256
257    #[test]
258    fn test_channel_matches() {
259        let subscription = Channel::parse("repo:alice/myrepo").unwrap();
260
261        // Should match exact and sub-channels
262        assert!(subscription.matches("repo:alice/myrepo"));
263        assert!(subscription.matches("repo:alice/myrepo/prs"));
264        assert!(subscription.matches("repo:alice/myrepo/issues"));
265
266        // Should not match different repos
267        assert!(!subscription.matches("repo:bob/otherrepo"));
268        assert!(!subscription.matches("user:alice"));
269    }
270
271    #[test]
272    fn test_channel_matches_with_filter() {
273        let subscription = Channel::parse("repo:alice/myrepo/prs").unwrap();
274
275        // Should only match prs channel
276        assert!(subscription.matches("repo:alice/myrepo/prs"));
277
278        // Should not match other channels
279        assert!(!subscription.matches("repo:alice/myrepo"));
280        assert!(!subscription.matches("repo:alice/myrepo/issues"));
281    }
282
283    #[test]
284    fn test_client_subscriptions() {
285        let mut subs = ClientSubscriptions::new();
286
287        let channel = Channel::parse("repo:alice/myrepo").unwrap();
288        assert!(subs.subscribe(channel.clone()).unwrap());
289        assert!(subs.is_subscribed(&channel));
290        assert_eq!(subs.count(), 1);
291
292        // Duplicate subscription returns false
293        assert!(!subs.subscribe(channel.clone()).unwrap());
294        assert_eq!(subs.count(), 1);
295
296        assert!(subs.unsubscribe(&channel));
297        assert!(!subs.is_subscribed(&channel));
298        assert_eq!(subs.count(), 0);
299    }
300
301    #[test]
302    fn test_client_subscriptions_limit() {
303        let mut subs = ClientSubscriptions::new();
304
305        for i in 0..MAX_SUBSCRIPTIONS_PER_CLIENT {
306            let channel = Channel::parse(&format!("user:user{}", i)).unwrap();
307            subs.subscribe(channel).unwrap();
308        }
309
310        let extra = Channel::parse("user:extra").unwrap();
311        assert!(matches!(
312            subs.subscribe(extra),
313            Err(RealtimeError::SubscriptionLimit(_))
314        ));
315    }
316
317    #[test]
318    fn test_matches_event() {
319        let mut subs = ClientSubscriptions::new();
320        subs.subscribe(Channel::parse("repo:alice/myrepo").unwrap())
321            .unwrap();
322        subs.subscribe(Channel::parse("user:alice").unwrap())
323            .unwrap();
324
325        assert!(subs.matches_event("repo:alice/myrepo"));
326        assert!(subs.matches_event("repo:alice/myrepo/prs"));
327        assert!(subs.matches_event("user:alice"));
328
329        assert!(!subs.matches_event("repo:bob/otherrepo"));
330        assert!(!subs.matches_event("user:bob"));
331    }
332}