1use std::net::TcpStream;
10use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
11
12use noxu_sync::Mutex;
13
14use crate::error::{RepError, Result};
15
16#[derive(Debug, Clone)]
22pub struct SubscriptionConfig {
23 pub subscriber_name: String,
25 pub group_name: String,
27 pub feeder_host: String,
29 pub feeder_port: u16,
31 pub start_vlsn: u64,
33}
34
35pub trait SubscriptionCallback: Send + Sync {
41 fn on_entry(&self, vlsn: u64, entry_type: u8, data: &[u8]);
48
49 fn on_error(&self, error: &RepError);
51
52 fn on_caught_up(&self, vlsn: u64);
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum SubscriptionState {
59 Idle,
61 Connecting,
63 Active,
65 CaughtUp,
67 Error,
69 Shutdown,
71}
72
73pub struct Subscription {
79 config: SubscriptionConfig,
81 state: Mutex<SubscriptionState>,
83 current_vlsn: Mutex<u64>,
85 entries_received: AtomicU64,
87 shutdown: AtomicBool,
89 connection: Mutex<Option<TcpStream>>,
95}
96
97impl Subscription {
98 pub fn new(config: SubscriptionConfig) -> Self {
100 let start_vlsn = config.start_vlsn;
101 Self {
102 config,
103 state: Mutex::new(SubscriptionState::Idle),
104 current_vlsn: Mutex::new(start_vlsn),
105 entries_received: AtomicU64::new(0),
106 shutdown: AtomicBool::new(false),
107 connection: Mutex::new(None),
108 }
109 }
110
111 pub fn get_state(&self) -> SubscriptionState {
113 *self.state.lock()
114 }
115
116 pub fn get_current_vlsn(&self) -> u64 {
118 *self.current_vlsn.lock()
119 }
120
121 pub fn get_entries_received(&self) -> u64 {
123 self.entries_received.load(Ordering::Relaxed)
124 }
125
126 pub fn get_config(&self) -> &SubscriptionConfig {
128 &self.config
129 }
130
131 pub fn start(&self) -> Result<()> {
141 let mut state = self.state.lock();
142 match *state {
143 SubscriptionState::Idle => {
144 *state = SubscriptionState::Connecting;
145
146 let addr_str = format!(
149 "{}:{}",
150 self.config.feeder_host, self.config.feeder_port
151 );
152 match TcpStream::connect(&addr_str) {
153 Ok(stream) => {
154 *self.connection.lock() = Some(stream);
155 *state = SubscriptionState::Active;
156 Ok(())
157 }
158 Err(e) => {
159 *state = SubscriptionState::Error;
160 Err(RepError::SubscriptionError(format!(
161 "failed to connect to feeder at {}: {}",
162 addr_str, e
163 )))
164 }
165 }
166 }
167 SubscriptionState::Shutdown => Err(RepError::SubscriptionError(
168 "cannot start a shutdown subscription".into(),
169 )),
170 other => Err(RepError::SubscriptionError(format!(
171 "cannot start from state {:?}",
172 other
173 ))),
174 }
175 }
176
177 pub fn get_connection(&self) -> Option<TcpStream> {
182 self.connection.lock().as_ref().and_then(|s| s.try_clone().ok())
183 }
184
185 pub fn process_entry(&self, vlsn: u64, _entry_type: u8, _data: Vec<u8>) {
190 if self.shutdown.load(Ordering::SeqCst) {
191 return;
192 }
193 *self.current_vlsn.lock() = vlsn;
194 self.entries_received.fetch_add(1, Ordering::Relaxed);
195 }
196
197 pub fn mark_caught_up(&self) {
199 let mut state = self.state.lock();
200 if *state == SubscriptionState::Active {
201 *state = SubscriptionState::CaughtUp;
202 }
203 }
204
205 pub fn mark_error(&self) {
207 let mut state = self.state.lock();
208 if *state != SubscriptionState::Shutdown {
209 *state = SubscriptionState::Error;
210 }
211 }
212
213 pub fn shutdown(&self) {
219 self.shutdown.store(true, Ordering::SeqCst);
220 *self.state.lock() = SubscriptionState::Shutdown;
221 if let Some(stream) = self.connection.lock().take() {
223 let _ = stream.shutdown(std::net::Shutdown::Both);
224 }
225 }
226
227 pub fn is_shutdown(&self) -> bool {
229 self.shutdown.load(Ordering::SeqCst)
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use std::net::TcpListener;
237
238 fn test_config_no_connect() -> SubscriptionConfig {
241 SubscriptionConfig {
242 subscriber_name: "sub1".into(),
243 group_name: "group1".into(),
244 feeder_host: "127.0.0.1".into(),
245 feeder_port: 1, start_vlsn: 0,
247 }
248 }
249
250 fn test_config_with_listener() -> (SubscriptionConfig, TcpListener) {
253 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
254 let port = listener.local_addr().unwrap().port();
255 let config = SubscriptionConfig {
256 subscriber_name: "sub1".into(),
257 group_name: "group1".into(),
258 feeder_host: "127.0.0.1".into(),
259 feeder_port: port,
260 start_vlsn: 0,
261 };
262 (config, listener)
263 }
264
265 #[test]
266 fn test_initial_state() {
267 let sub = Subscription::new(test_config_no_connect());
268 assert_eq!(sub.get_state(), SubscriptionState::Idle);
269 assert_eq!(sub.get_current_vlsn(), 0);
270 assert_eq!(sub.get_entries_received(), 0);
271 assert!(!sub.is_shutdown());
272 }
273
274 #[test]
275 fn test_start() {
276 let (config, _listener) = test_config_with_listener();
277 let sub = Subscription::new(config);
278 sub.start().unwrap();
279 assert_eq!(sub.get_state(), SubscriptionState::Active);
280 assert!(sub.get_connection().is_some());
282 }
283
284 #[test]
285 fn test_start_fails_when_no_listener() {
286 let sub = Subscription::new(test_config_no_connect());
289 let result = sub.start();
290 assert!(result.is_err());
291 assert_eq!(sub.get_state(), SubscriptionState::Error);
292 }
293
294 #[test]
295 fn test_start_from_active_fails() {
296 let (config, _listener) = test_config_with_listener();
297 let sub = Subscription::new(config);
298 sub.start().unwrap();
299 let result = sub.start();
300 assert!(result.is_err());
301 }
302
303 #[test]
304 fn test_start_after_shutdown_fails() {
305 let sub = Subscription::new(test_config_no_connect());
306 sub.shutdown();
307 let result = sub.start();
308 assert!(result.is_err());
309 }
310
311 #[test]
312 fn test_process_entries() {
313 let (config, _listener) = test_config_with_listener();
314 let sub = Subscription::new(config);
315 sub.start().unwrap();
316
317 sub.process_entry(1, 1, vec![0x01]);
318 sub.process_entry(2, 1, vec![0x02]);
319 sub.process_entry(3, 2, vec![0x03]);
320
321 assert_eq!(sub.get_current_vlsn(), 3);
322 assert_eq!(sub.get_entries_received(), 3);
323 }
324
325 #[test]
326 fn test_process_entry_after_shutdown_ignored() {
327 let (config, _listener) = test_config_with_listener();
328 let sub = Subscription::new(config);
329 sub.start().unwrap();
330 sub.process_entry(1, 1, vec![0x01]);
331
332 sub.shutdown();
333 sub.process_entry(2, 1, vec![0x02]);
334
335 assert_eq!(sub.get_current_vlsn(), 1);
337 assert_eq!(sub.get_entries_received(), 1);
339 }
340
341 #[test]
342 fn test_mark_caught_up() {
343 let (config, _listener) = test_config_with_listener();
344 let sub = Subscription::new(config);
345 sub.start().unwrap();
346 assert_eq!(sub.get_state(), SubscriptionState::Active);
347
348 sub.mark_caught_up();
349 assert_eq!(sub.get_state(), SubscriptionState::CaughtUp);
350 }
351
352 #[test]
353 fn test_mark_caught_up_from_idle_no_change() {
354 let sub = Subscription::new(test_config_no_connect());
355 sub.mark_caught_up();
356 assert_eq!(sub.get_state(), SubscriptionState::Idle);
358 }
359
360 #[test]
361 fn test_mark_error() {
362 let (config, _listener) = test_config_with_listener();
363 let sub = Subscription::new(config);
364 sub.start().unwrap();
365 sub.mark_error();
366 assert_eq!(sub.get_state(), SubscriptionState::Error);
367 }
368
369 #[test]
370 fn test_mark_error_after_shutdown_no_change() {
371 let sub = Subscription::new(test_config_no_connect());
372 sub.shutdown();
373 sub.mark_error();
374 assert_eq!(sub.get_state(), SubscriptionState::Shutdown);
376 }
377
378 #[test]
379 fn test_shutdown() {
380 let (config, _listener) = test_config_with_listener();
381 let sub = Subscription::new(config);
382 sub.start().unwrap();
383 assert!(!sub.is_shutdown());
384
385 sub.shutdown();
386 assert!(sub.is_shutdown());
387 assert_eq!(sub.get_state(), SubscriptionState::Shutdown);
388 assert!(sub.get_connection().is_none());
390 }
391
392 #[test]
393 fn test_config_accessor() {
394 let config = test_config_no_connect();
395 let sub = Subscription::new(config);
396 assert_eq!(sub.get_config().subscriber_name, "sub1");
397 assert_eq!(sub.get_config().group_name, "group1");
398 assert_eq!(sub.get_config().feeder_host, "127.0.0.1");
399 assert_eq!(sub.get_config().feeder_port, 1);
400 }
401
402 #[test]
403 fn test_start_vlsn_nonzero() {
404 let mut config = test_config_no_connect();
405 config.start_vlsn = 42;
406 let sub = Subscription::new(config);
407 assert_eq!(sub.get_current_vlsn(), 42);
408 }
409
410 #[test]
411 fn test_full_lifecycle() {
412 let (config, _listener) = test_config_with_listener();
413 let sub = Subscription::new(config);
414
415 assert_eq!(sub.get_state(), SubscriptionState::Idle);
417 sub.start().unwrap();
418 assert_eq!(sub.get_state(), SubscriptionState::Active);
419 assert!(sub.get_connection().is_some());
420
421 for i in 1..=10 {
423 sub.process_entry(i, 1, vec![i as u8]);
424 }
425 assert_eq!(sub.get_current_vlsn(), 10);
426 assert_eq!(sub.get_entries_received(), 10);
427
428 sub.mark_caught_up();
430 assert_eq!(sub.get_state(), SubscriptionState::CaughtUp);
431
432 sub.shutdown();
434 assert_eq!(sub.get_state(), SubscriptionState::Shutdown);
435 assert!(sub.is_shutdown());
436 assert!(sub.get_connection().is_none());
437 }
438}