iridium_stomp/connection.rs
1use futures::{SinkExt, StreamExt, future};
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6use thiserror::Error;
7use tokio::net::TcpStream;
8use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
9use tokio_util::codec::Framed;
10
11use crate::codec::{StompCodec, StompItem};
12use crate::frame::Frame;
13
14/// Configuration for STOMP heartbeat intervals.
15///
16/// Provides a type-safe way to configure heartbeat values instead of using
17/// raw strings. The `Display` implementation formats the value as required
18/// by the STOMP protocol ("send_ms,receive_ms").
19///
20/// # Example
21///
22/// ```
23/// use iridium_stomp::Heartbeat;
24///
25/// // Create a custom heartbeat configuration
26/// let hb = Heartbeat::new(5000, 10000);
27/// assert_eq!(hb.to_string(), "5000,10000");
28///
29/// // Use predefined configurations
30/// assert_eq!(Heartbeat::disabled().to_string(), "0,0");
31/// assert_eq!(Heartbeat::default().to_string(), "10000,10000");
32/// ```
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct Heartbeat {
35 /// Minimum interval (in milliseconds) between heartbeats the client can send.
36 /// A value of 0 means the client cannot send heartbeats.
37 pub send_ms: u32,
38
39 /// Minimum interval (in milliseconds) between heartbeats the client wants to receive.
40 /// A value of 0 means the client does not want to receive heartbeats.
41 pub receive_ms: u32,
42}
43
44impl Heartbeat {
45 /// Create a new heartbeat configuration with the specified intervals.
46 ///
47 /// # Arguments
48 ///
49 /// * `send_ms` - Minimum interval in milliseconds between heartbeats the client can send.
50 /// * `receive_ms` - Minimum interval in milliseconds between heartbeats the client wants to receive.
51 ///
52 /// # Example
53 ///
54 /// ```
55 /// use iridium_stomp::Heartbeat;
56 ///
57 /// let hb = Heartbeat::new(5000, 10000);
58 /// assert_eq!(hb.send_ms, 5000);
59 /// assert_eq!(hb.receive_ms, 10000);
60 /// ```
61 pub fn new(send_ms: u32, receive_ms: u32) -> Self {
62 Self {
63 send_ms,
64 receive_ms,
65 }
66 }
67
68 /// Create a heartbeat configuration that disables heartbeats entirely.
69 ///
70 /// This is equivalent to `Heartbeat::new(0, 0)`.
71 ///
72 /// # Example
73 ///
74 /// ```
75 /// use iridium_stomp::Heartbeat;
76 ///
77 /// let hb = Heartbeat::disabled();
78 /// assert_eq!(hb.send_ms, 0);
79 /// assert_eq!(hb.receive_ms, 0);
80 /// assert_eq!(hb.to_string(), "0,0");
81 /// ```
82 pub fn disabled() -> Self {
83 Self::new(0, 0)
84 }
85
86 /// Create a heartbeat configuration from a Duration for symmetric heartbeats.
87 ///
88 /// Both send and receive intervals will be set to the same value.
89 ///
90 /// The maximum supported Duration is approximately 49.7 days (u32::MAX milliseconds,
91 /// or 4,294,967,295 ms). If a larger Duration is provided, it will be clamped to
92 /// u32::MAX milliseconds to prevent overflow.
93 ///
94 /// # Example
95 ///
96 /// ```
97 /// use iridium_stomp::Heartbeat;
98 /// use std::time::Duration;
99 ///
100 /// let hb = Heartbeat::from_duration(Duration::from_secs(15));
101 /// assert_eq!(hb.send_ms, 15000);
102 /// assert_eq!(hb.receive_ms, 15000);
103 /// ```
104 pub fn from_duration(interval: Duration) -> Self {
105 let ms = interval.as_millis().min(u32::MAX as u128) as u32;
106 Self::new(ms, ms)
107 }
108}
109
110impl Default for Heartbeat {
111 /// Returns the default heartbeat configuration: 10 seconds for both send and receive.
112 fn default() -> Self {
113 Self::new(10000, 10000)
114 }
115}
116
117impl std::fmt::Display for Heartbeat {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 write!(f, "{},{}", self.send_ms, self.receive_ms)
120 }
121}
122
123/// Internal subscription entry stored for each destination.
124#[derive(Clone)]
125pub(crate) struct SubscriptionEntry {
126 pub(crate) id: String,
127 pub(crate) sender: mpsc::Sender<Frame>,
128 pub(crate) ack: String,
129 pub(crate) headers: Vec<(String, String)>,
130}
131
132/// Alias for the subscription dispatch map: destination -> list of
133/// `SubscriptionEntry`.
134pub(crate) type Subscriptions = HashMap<String, Vec<SubscriptionEntry>>;
135
136/// Alias for the pending map: subscription_id -> queue of (message-id, Frame).
137pub(crate) type PendingMap = HashMap<String, VecDeque<(String, Frame)>>;
138
139/// Internal type for resubscribe snapshot entries: (destination, id, ack, headers)
140pub(crate) type ResubEntry = (String, String, String, Vec<(String, String)>);
141
142/// Alias for pending receipt map: receipt-id -> oneshot sender to notify when received.
143pub(crate) type PendingReceipts = HashMap<String, oneshot::Sender<()>>;
144
145/// Errors returned by `Connection` operations.
146#[derive(Error, Debug)]
147pub enum ConnError {
148 /// I/O-level error
149 #[error("io error: {0}")]
150 Io(#[from] std::io::Error),
151 /// Protocol-level error
152 #[error("protocol error: {0}")]
153 Protocol(String),
154 /// Receipt timeout error
155 #[error("receipt timeout: no RECEIPT received for '{0}' within timeout")]
156 ReceiptTimeout(String),
157 /// Server rejected the connection (e.g., authentication failure)
158 ///
159 /// This error is returned when the server sends an ERROR frame in response
160 /// to the CONNECT frame. Common causes include invalid credentials,
161 /// unauthorized access, or broker configuration issues.
162 #[error("server rejected connection: {0}")]
163 ServerRejected(ServerError),
164}
165
166/// Represents an ERROR frame received from the STOMP server.
167///
168/// STOMP servers send ERROR frames to indicate protocol violations, authentication
169/// failures, or other server-side errors. After sending an ERROR frame, the server
170/// typically closes the connection.
171///
172/// # Example
173///
174/// ```ignore
175/// use iridium_stomp::ReceivedFrame;
176///
177/// while let Some(received) = conn.next_frame().await {
178/// match received {
179/// ReceivedFrame::Frame(frame) => {
180/// // Normal message processing
181/// }
182/// ReceivedFrame::Error(err) => {
183/// eprintln!("Server error: {}", err.message);
184/// if let Some(body) = &err.body {
185/// eprintln!("Details: {}", body);
186/// }
187/// break;
188/// }
189/// }
190/// }
191/// ```
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub struct ServerError {
194 /// The error message from the `message` header.
195 pub message: String,
196
197 /// The error body, if present. Contains additional error details.
198 pub body: Option<String>,
199
200 /// The receipt-id if this error is in response to a specific frame.
201 pub receipt_id: Option<String>,
202
203 /// The original ERROR frame for access to additional headers.
204 pub frame: Frame,
205}
206
207impl ServerError {
208 /// Create a `ServerError` from an ERROR frame.
209 ///
210 /// This is primarily used internally but is public for testing and
211 /// advanced use cases where you need to construct a `ServerError` manually.
212 pub fn from_frame(frame: Frame) -> Self {
213 let message = frame
214 .get_header("message")
215 .unwrap_or("unknown error")
216 .to_string();
217
218 let body = if frame.body.is_empty() {
219 None
220 } else {
221 String::from_utf8(frame.body.clone()).ok()
222 };
223
224 let receipt_id = frame.get_header("receipt-id").map(|s| s.to_string());
225
226 Self {
227 message,
228 body,
229 receipt_id,
230 frame,
231 }
232 }
233}
234
235impl std::fmt::Display for ServerError {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 write!(f, "STOMP server error: {}", self.message)?;
238 if let Some(body) = &self.body {
239 write!(f, " - {}", body)?;
240 }
241 Ok(())
242 }
243}
244
245impl std::error::Error for ServerError {}
246
247/// The result of receiving a frame from the server.
248///
249/// STOMP servers can send either normal frames (MESSAGE, RECEIPT, etc.) or
250/// ERROR frames indicating a problem. This enum allows callers to handle
251/// both cases with pattern matching.
252///
253/// # Example
254///
255/// ```ignore
256/// use iridium_stomp::ReceivedFrame;
257///
258/// match conn.next_frame().await {
259/// Some(ReceivedFrame::Frame(frame)) => {
260/// println!("Got frame: {}", frame.command);
261/// }
262/// Some(ReceivedFrame::Error(err)) => {
263/// eprintln!("Server error: {}", err);
264/// }
265/// None => {
266/// println!("Connection closed");
267/// }
268/// }
269/// ```
270#[derive(Debug, Clone, PartialEq, Eq)]
271pub enum ReceivedFrame {
272 /// A normal STOMP frame (MESSAGE, RECEIPT, etc.)
273 Frame(Frame),
274 /// An ERROR frame from the server
275 Error(ServerError),
276}
277
278impl ReceivedFrame {
279 /// Returns `true` if this is an error frame.
280 pub fn is_error(&self) -> bool {
281 matches!(self, ReceivedFrame::Error(_))
282 }
283
284 /// Returns `true` if this is a normal frame.
285 pub fn is_frame(&self) -> bool {
286 matches!(self, ReceivedFrame::Frame(_))
287 }
288
289 /// Returns the frame if this is a normal frame, or `None` if it's an error.
290 pub fn into_frame(self) -> Option<Frame> {
291 match self {
292 ReceivedFrame::Frame(f) => Some(f),
293 ReceivedFrame::Error(_) => None,
294 }
295 }
296
297 /// Returns the error if this is an error frame, or `None` if it's a normal frame.
298 pub fn into_error(self) -> Option<ServerError> {
299 match self {
300 ReceivedFrame::Frame(_) => None,
301 ReceivedFrame::Error(e) => Some(e),
302 }
303 }
304}
305
306/// Subscription acknowledgement modes as defined by STOMP 1.2.
307#[derive(Debug, Clone, Copy, PartialEq, Eq)]
308pub enum AckMode {
309 Auto,
310 Client,
311 ClientIndividual,
312}
313
314impl AckMode {
315 fn as_str(&self) -> &'static str {
316 match self {
317 AckMode::Auto => "auto",
318 AckMode::Client => "client",
319 AckMode::ClientIndividual => "client-individual",
320 }
321 }
322}
323
324/// Options for customizing the STOMP CONNECT frame.
325///
326/// Use this struct with `Connection::connect_with_options()` to set custom
327/// headers, specify supported STOMP versions, or configure broker-specific
328/// options like `client-id` for durable subscriptions.
329///
330/// # Validation
331///
332/// This struct performs minimal validation. Values are passed to the broker
333/// as-is, and invalid configurations will be rejected by the broker at
334/// connection time. Empty strings are technically accepted but may cause
335/// broker-specific errors.
336///
337/// # Custom Headers
338///
339/// Custom headers added via `header()` cannot override critical STOMP headers
340/// (`accept-version`, `host`, `login`, `passcode`, `heart-beat`, `client-id`).
341/// Such headers are silently ignored. Use the dedicated builder methods to
342/// set these values.
343///
344/// # Example
345///
346/// ```ignore
347/// use iridium_stomp::{Connection, ConnectOptions};
348///
349/// let options = ConnectOptions::default()
350/// .client_id("my-durable-client")
351/// .host("my-vhost")
352/// .header("custom-header", "value");
353///
354/// let conn = Connection::connect_with_options(
355/// "localhost:61613",
356/// "guest",
357/// "guest",
358/// "10000,10000",
359/// options,
360/// ).await?;
361/// ```
362#[derive(Debug, Clone, Default)]
363pub struct ConnectOptions {
364 /// STOMP version(s) to accept (e.g., "1.2" or "1.0,1.1,1.2").
365 /// Defaults to "1.2" if not set.
366 pub accept_version: Option<String>,
367
368 /// Client ID for durable subscriptions (required by ActiveMQ, etc.).
369 pub client_id: Option<String>,
370
371 /// Virtual host header value. Defaults to "/" if not set.
372 pub host: Option<String>,
373
374 /// Additional custom headers to include in the CONNECT frame.
375 /// Note: Headers that would override critical STOMP headers are ignored.
376 pub headers: Vec<(String, String)>,
377}
378
379impl ConnectOptions {
380 /// Create a new `ConnectOptions` with default values.
381 pub fn new() -> Self {
382 Self::default()
383 }
384
385 /// Set the STOMP version(s) to accept (builder style).
386 ///
387 /// Examples: "1.2", "1.1,1.2", "1.0,1.1,1.2"
388 pub fn accept_version(mut self, version: impl Into<String>) -> Self {
389 self.accept_version = Some(version.into());
390 self
391 }
392
393 /// Set the client ID for durable subscriptions (builder style).
394 ///
395 /// Required by some brokers (e.g., ActiveMQ) for durable topic subscriptions.
396 pub fn client_id(mut self, id: impl Into<String>) -> Self {
397 self.client_id = Some(id.into());
398 self
399 }
400
401 /// Set the virtual host (builder style).
402 ///
403 /// Defaults to "/" if not set.
404 pub fn host(mut self, host: impl Into<String>) -> Self {
405 self.host = Some(host.into());
406 self
407 }
408
409 /// Add a custom header to the CONNECT frame (builder style).
410 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
411 self.headers.push((key.into(), value.into()));
412 self
413 }
414}
415
416/// Parse the STOMP `heart-beat` header value (format: "cx,cy").
417///
418/// Parameters
419/// - `header`: header string from the server or client (for example
420/// "10000,10000"). The values represent milliseconds.
421///
422/// Returns a tuple `(cx, cy)` where each value is the heartbeat interval in
423/// milliseconds. Missing or invalid fields default to `0`.
424pub fn parse_heartbeat_header(header: &str) -> (u64, u64) {
425 let mut parts = header.split(',');
426 let cx = parts
427 .next()
428 .and_then(|s| s.trim().parse::<u64>().ok())
429 .unwrap_or(0);
430 let cy = parts
431 .next()
432 .and_then(|s| s.trim().parse::<u64>().ok())
433 .unwrap_or(0);
434 (cx, cy)
435}
436
437/// Negotiate heartbeat intervals between client and server.
438///
439/// Parameters
440/// - `client_out`: client's desired outgoing heartbeat interval in
441/// milliseconds (how often the client will send heartbeats).
442/// - `client_in`: client's desired incoming heartbeat interval in
443/// milliseconds (how often the client expects to receive heartbeats).
444/// - `server_out`: server's advertised outgoing interval in milliseconds.
445/// - `server_in`: server's advertised incoming interval in milliseconds.
446///
447/// Returns `(outgoing, incoming)` where each element is `Some(Duration)` if
448/// heartbeats are enabled in that direction, or `None` if disabled. The
449/// negotiated interval uses the STOMP rule of taking the maximum of the
450/// corresponding client and server values.
451pub fn negotiate_heartbeats(
452 client_out: u64,
453 client_in: u64,
454 server_out: u64,
455 server_in: u64,
456) -> (Option<Duration>, Option<Duration>) {
457 let negotiated_out_ms = std::cmp::max(client_out, server_in);
458 let negotiated_in_ms = std::cmp::max(client_in, server_out);
459
460 let outgoing = if negotiated_out_ms == 0 {
461 None
462 } else {
463 Some(Duration::from_millis(negotiated_out_ms))
464 };
465 let incoming = if negotiated_in_ms == 0 {
466 None
467 } else {
468 Some(Duration::from_millis(negotiated_in_ms))
469 };
470 (outgoing, incoming)
471}
472
473/// High-level connection object that manages a single TCP/STOMP connection.
474///
475/// The `Connection` spawns a background task that maintains the TCP transport,
476/// sends/receives STOMP frames using `StompCodec`, negotiates heartbeats, and
477/// performs simple reconnect logic with exponential backoff.
478#[derive(Clone)]
479pub struct Connection {
480 outbound_tx: mpsc::Sender<StompItem>,
481 /// The inbound receiver is shared behind a mutex so the `Connection`
482 /// handle may be cloned and callers can call `next_frame` concurrently.
483 inbound_rx: Arc<Mutex<mpsc::Receiver<Frame>>>,
484 shutdown_tx: broadcast::Sender<()>,
485 /// Map of destination -> list of (subscription id, sender) for dispatching
486 /// inbound MESSAGE frames to subscribers.
487 subscriptions: Arc<Mutex<Subscriptions>>,
488 /// Monotonic counter used to allocate subscription ids.
489 sub_id_counter: Arc<AtomicU64>,
490 /// Pending messages awaiting ACK/NACK from the application.
491 ///
492 /// Organized by subscription id. For `client` ack mode the ACK is
493 /// cumulative: acknowledging message `M` for subscription `S` acknowledges
494 /// all messages previously delivered for `S` up to and including `M`.
495 /// For `client-individual` the ACK/NACK applies only to the single
496 /// message.
497 pending: Arc<Mutex<PendingMap>>,
498 /// Pending receipt confirmations.
499 ///
500 /// When a frame is sent with a `receipt` header, the receipt-id is stored
501 /// here with a oneshot sender. When the server responds with a RECEIPT
502 /// frame, the sender is notified.
503 pending_receipts: Arc<Mutex<PendingReceipts>>,
504}
505
506impl Connection {
507 /// Heartbeat value that disables heartbeats entirely.
508 ///
509 /// Use this when you don't want the client or server to send heartbeats.
510 /// Note that some brokers may still require heartbeats for long-lived connections.
511 ///
512 /// # Example
513 ///
514 /// ```ignore
515 /// let conn = Connection::connect(
516 /// "localhost:61613",
517 /// "guest",
518 /// "guest",
519 /// Connection::NO_HEARTBEAT,
520 /// ).await?;
521 /// ```
522 pub const NO_HEARTBEAT: &'static str = "0,0";
523
524 /// Default heartbeat value: 10 seconds for both send and receive.
525 ///
526 /// This is a reasonable default for most applications. The actual heartbeat
527 /// interval will be negotiated with the server (taking the maximum of client
528 /// and server preferences).
529 ///
530 /// # Example
531 ///
532 /// ```ignore
533 /// let conn = Connection::connect(
534 /// "localhost:61613",
535 /// "guest",
536 /// "guest",
537 /// Connection::DEFAULT_HEARTBEAT,
538 /// ).await?;
539 /// ```
540 pub const DEFAULT_HEARTBEAT: &'static str = "10000,10000";
541
542 /// Establish a connection to the STOMP server at `addr` with the given
543 /// credentials and heartbeat header string (e.g. "10000,10000").
544 ///
545 /// This is a convenience wrapper around `connect_with_options()` that uses
546 /// default options (STOMP 1.2, host="/", no client-id).
547 ///
548 /// Parameters
549 /// - `addr`: TCP address (host:port) of the STOMP server.
550 /// - `login`: login username for STOMP `CONNECT`.
551 /// - `passcode`: passcode for STOMP `CONNECT`.
552 /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
553 /// milliseconds) that will be sent in the `CONNECT` frame.
554 ///
555 /// Returns a `Connection` which provides `send_frame`, `next_frame`, and
556 /// `close` helpers. The detailed connection handling (I/O, heartbeats,
557 /// reconnects) runs on a background task spawned by this method.
558 pub async fn connect(
559 addr: &str,
560 login: &str,
561 passcode: &str,
562 client_hb: &str,
563 ) -> Result<Self, ConnError> {
564 Self::connect_with_options(addr, login, passcode, client_hb, ConnectOptions::default())
565 .await
566 }
567
568 /// Establish a connection to the STOMP server with custom options.
569 ///
570 /// Use this method when you need to set a custom `client-id` (for durable
571 /// subscriptions), specify a virtual host, negotiate different STOMP
572 /// versions, or add custom CONNECT headers.
573 ///
574 /// Parameters
575 /// - `addr`: TCP address (host:port) of the STOMP server.
576 /// - `login`: login username for STOMP `CONNECT`.
577 /// - `passcode`: passcode for STOMP `CONNECT`.
578 /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
579 /// milliseconds) that will be sent in the `CONNECT` frame.
580 /// - `options`: custom connection options (version, host, client-id, etc.).
581 ///
582 /// # Errors
583 ///
584 /// Returns an error if:
585 /// - The TCP connection cannot be established (`ConnError::Io`)
586 /// - The server rejects the connection, e.g., due to invalid credentials
587 /// (`ConnError::ServerRejected`)
588 /// - The server closes the connection without responding (`ConnError::Protocol`)
589 ///
590 /// # Example
591 ///
592 /// ```ignore
593 /// use iridium_stomp::{Connection, ConnectOptions};
594 ///
595 /// // Connect with a client-id for durable subscriptions
596 /// let options = ConnectOptions::default()
597 /// .client_id("my-app-instance-1");
598 ///
599 /// let conn = Connection::connect_with_options(
600 /// "localhost:61613",
601 /// "guest",
602 /// "guest",
603 /// "10000,10000",
604 /// options,
605 /// ).await?;
606 /// ```
607 pub async fn connect_with_options(
608 addr: &str,
609 login: &str,
610 passcode: &str,
611 client_hb: &str,
612 options: ConnectOptions,
613 ) -> Result<Self, ConnError> {
614 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(32);
615 let (in_tx, in_rx) = mpsc::channel::<Frame>(32);
616 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
617 let sub_id_counter = Arc::new(AtomicU64::new(1));
618 let (shutdown_tx, _) = broadcast::channel::<()>(1);
619 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
620 let pending_clone = pending.clone();
621 let pending_receipts: Arc<Mutex<PendingReceipts>> = Arc::new(Mutex::new(HashMap::new()));
622 let pending_receipts_clone = pending_receipts.clone();
623
624 let addr = addr.to_string();
625 let login = login.to_string();
626 let passcode = passcode.to_string();
627 let client_hb = client_hb.to_string();
628
629 // Extract options into owned values for the spawned task
630 let accept_version = options.accept_version.unwrap_or_else(|| "1.2".to_string());
631 let host = options.host.unwrap_or_else(|| "/".to_string());
632 let client_id = options.client_id;
633 let custom_headers = options.headers;
634
635 // Perform initial connection and STOMP handshake before spawning background task.
636 // This ensures authentication errors are returned to the caller immediately.
637 let stream = TcpStream::connect(&addr).await?;
638 let mut framed = Framed::new(stream, StompCodec::new());
639
640 // Build and send CONNECT frame
641 let connect = Self::build_connect_frame(
642 &accept_version,
643 &host,
644 &login,
645 &passcode,
646 &client_hb,
647 &client_id,
648 &custom_headers,
649 );
650
651 framed
652 .send(StompItem::Frame(connect))
653 .await
654 .map_err(|e| ConnError::Io(std::io::Error::other(e)))?;
655
656 // Wait for CONNECTED or ERROR response
657 let server_heartbeat = Self::await_connected_response(&mut framed).await?;
658
659 // Calculate heartbeat intervals
660 let (cx, cy) = parse_heartbeat_header(&client_hb);
661 let (sx, sy) = parse_heartbeat_header(&server_heartbeat);
662 let (send_interval, recv_interval) = negotiate_heartbeats(cx, cy, sx, sy);
663
664 // Now spawn background task for ongoing I/O and reconnection
665 let shutdown_tx_clone = shutdown_tx.clone();
666 let subscriptions_clone = subscriptions.clone();
667
668 tokio::spawn(async move {
669 let mut backoff_secs: u64 = 1;
670
671 // Use the already-established connection for the first iteration
672 let mut current_framed = Some(framed);
673 let mut current_send_interval = send_interval;
674 let mut current_recv_interval = recv_interval;
675
676 loop {
677 let mut shutdown_sub = shutdown_tx_clone.subscribe();
678
679 // Check for shutdown before attempting connection
680 tokio::select! {
681 biased;
682 _ = shutdown_sub.recv() => break,
683 _ = future::ready(()) => {},
684 }
685
686 // Either use existing connection or establish new one (reconnect)
687 let framed = if let Some(f) = current_framed.take() {
688 f
689 } else {
690 // Reconnection attempt
691 match TcpStream::connect(&addr).await {
692 Ok(stream) => {
693 let mut framed = Framed::new(stream, StompCodec::new());
694
695 let connect = Self::build_connect_frame(
696 &accept_version,
697 &host,
698 &login,
699 &passcode,
700 &client_hb,
701 &client_id,
702 &custom_headers,
703 );
704
705 if framed.send(StompItem::Frame(connect)).await.is_err() {
706 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
707 backoff_secs = (backoff_secs * 2).min(30);
708 continue;
709 }
710
711 // Wait for CONNECTED (on reconnect, silently retry on ERROR)
712 match Self::await_connected_response(&mut framed).await {
713 Ok(server_hb) => {
714 let (cx, cy) = parse_heartbeat_header(&client_hb);
715 let (sx, sy) = parse_heartbeat_header(&server_hb);
716 let (si, ri) = negotiate_heartbeats(cx, cy, sx, sy);
717 current_send_interval = si;
718 current_recv_interval = ri;
719 framed
720 }
721 Err(_) => {
722 // Reconnect failed (auth error or other), retry with backoff
723 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
724 backoff_secs = (backoff_secs * 2).min(30);
725 continue;
726 }
727 }
728 }
729 Err(_) => {
730 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
731 backoff_secs = (backoff_secs * 2).min(30);
732 continue;
733 }
734 }
735 };
736
737 let (send_interval, recv_interval) = (current_send_interval, current_recv_interval);
738
739 let last_received = Arc::new(AtomicU64::new(current_millis()));
740 let writer_last_sent = Arc::new(AtomicU64::new(current_millis()));
741
742 let (mut sink, mut stream) = framed.split();
743 let in_tx = in_tx.clone();
744 let subscriptions = subscriptions_clone.clone();
745
746 // Clear pending message map on reconnect — messages that were
747 // outstanding before the disconnect are considered lost and
748 // will be redelivered by the server as appropriate.
749 {
750 let mut p = pending_clone.lock().await;
751 p.clear();
752 }
753
754 // Resubscribe any existing subscriptions after reconnect.
755 // We snapshot the subscription entries while holding the lock
756 // and then issue SUBSCRIBE frames using the sink.
757 let subs_snapshot: Vec<ResubEntry> = {
758 let map = subscriptions.lock().await;
759 let mut v: Vec<ResubEntry> = Vec::new();
760 for (dest, vec) in map.iter() {
761 for entry in vec.iter() {
762 v.push((
763 dest.clone(),
764 entry.id.clone(),
765 entry.ack.clone(),
766 entry.headers.clone(),
767 ));
768 }
769 }
770 v
771 };
772
773 for (dest, id, ack, headers) in subs_snapshot {
774 let mut sf = Frame::new("SUBSCRIBE");
775 sf = sf
776 .header("id", &id)
777 .header("destination", &dest)
778 .header("ack", &ack);
779 for (k, v) in headers {
780 sf = sf.header(&k, &v);
781 }
782 let _ = sink.send(StompItem::Frame(sf)).await;
783 }
784
785 let mut hb_tick = match send_interval {
786 Some(d) => tokio::time::interval(d),
787 None => tokio::time::interval(Duration::from_secs(86400)),
788 };
789 let watchdog_half = recv_interval.map(|d| d / 2);
790
791 backoff_secs = 1;
792
793 'conn: loop {
794 tokio::select! {
795 _ = shutdown_sub.recv() => { let _ = sink.close().await; break 'conn; }
796 maybe = out_rx.recv() => {
797 match maybe {
798 Some(item) => if sink.send(item).await.is_err() { break 'conn } else { writer_last_sent.store(current_millis(), Ordering::SeqCst); }
799 None => break 'conn,
800 }
801 }
802 item = stream.next() => {
803 match item {
804 Some(Ok(StompItem::Heartbeat)) => { last_received.store(current_millis(), Ordering::SeqCst); }
805 Some(Ok(StompItem::Frame(f))) => {
806 last_received.store(current_millis(), Ordering::SeqCst);
807 // Dispatch MESSAGE frames to any matching subscribers.
808 if f.command == "MESSAGE" {
809 // try to find destination, subscription and message-id headers
810 let mut dest_opt: Option<String> = None;
811 let mut sub_opt: Option<String> = None;
812 let mut msg_id_opt: Option<String> = None;
813 for (k, v) in &f.headers {
814 let kl = k.to_lowercase();
815 if kl == "destination" {
816 dest_opt = Some(v.clone());
817 } else if kl == "subscription" {
818 sub_opt = Some(v.clone());
819 } else if kl == "message-id" {
820 msg_id_opt = Some(v.clone());
821 }
822 }
823
824 // Determine whether we need to track this message as pending
825 let mut need_pending = false;
826 if let Some(sub_id) = &sub_opt {
827 let map = subscriptions.lock().await;
828 for (_dest, vec) in map.iter() {
829 for entry in vec.iter() {
830 if &entry.id == sub_id && entry.ack != "auto" {
831 need_pending = true;
832 }
833 }
834 }
835 } else if let Some(dest) = &dest_opt {
836 let map = subscriptions.lock().await;
837 if let Some(vec) = map.get(dest) {
838 for entry in vec.iter() {
839 if entry.ack != "auto" {
840 need_pending = true;
841 break;
842 }
843 }
844 }
845 }
846
847 // If required, add to pending map (per-subscription) before
848 // delivery so ACK/NACK requests from the application can
849 // reference the message. We require a `message-id` header
850 // to track messages; if missing, we cannot support ACK/NACK.
851 if let Some(msg_id) = msg_id_opt.clone().filter(|_| need_pending) {
852 // If the server provided a subscription id in the
853 // MESSAGE, store pending under that subscription.
854 if let Some(sub_id) = &sub_opt {
855 let mut p = pending_clone.lock().await;
856 let q = p
857 .entry(sub_id.clone())
858 .or_insert_with(VecDeque::new);
859 q.push_back((msg_id.clone(), f.clone()));
860 } else if let Some(dest) = &dest_opt {
861 // Destination-based delivery: add the message to
862 // the pending queue for each matching
863 // subscription on that destination.
864 let map = subscriptions.lock().await;
865 if let Some(vec) = map.get(dest) {
866 let mut p = pending_clone.lock().await;
867 for entry in vec.iter() {
868 let q = p
869 .entry(entry.id.clone())
870 .or_insert_with(VecDeque::new);
871 q.push_back((msg_id.clone(), f.clone()));
872 }
873 }
874 }
875 }
876
877 // Deliver to subscribers.
878 if let Some(sub_id) = sub_opt {
879 let mut map = subscriptions.lock().await;
880 for (_dest, vec) in map.iter_mut() {
881 vec.retain(|entry| {
882 if entry.id == sub_id {
883 let _ = entry.sender.try_send(f.clone());
884 true
885 } else {
886 true
887 }
888 });
889 }
890 } else if let Some(dest) = dest_opt {
891 let mut map = subscriptions.lock().await;
892 if let Some(vec) = map.get_mut(&dest) {
893 vec.retain(|entry| entry.sender.try_send(f.clone()).is_ok());
894 }
895 }
896 } else if f.command == "RECEIPT" {
897 // Handle RECEIPT frame: notify any waiting callers
898 if let Some(receipt_id) = f.get_header("receipt-id") {
899 let mut receipts = pending_receipts_clone.lock().await;
900 if let Some(sender) = receipts.remove(receipt_id) {
901 let _ = sender.send(());
902 }
903 }
904 // Don't forward RECEIPT frames to inbound channel
905 continue;
906 }
907
908 let _ = in_tx.send(f).await;
909 }
910 Some(Err(_)) | None => break 'conn,
911 }
912 }
913 _ = hb_tick.tick() => {
914 if let Some(dur) = send_interval {
915 let last = writer_last_sent.load(Ordering::SeqCst);
916 if current_millis().saturating_sub(last) >= dur.as_millis() as u64 {
917 if sink.send(StompItem::Heartbeat).await.is_err() { break 'conn; }
918 writer_last_sent.store(current_millis(), Ordering::SeqCst);
919 }
920 }
921 }
922 _ = async { if let Some(interval) = watchdog_half { tokio::time::sleep(interval).await } else { future::pending::<()>().await } } => {
923 if let Some(recv_dur) = recv_interval {
924 let last = last_received.load(Ordering::SeqCst);
925 if current_millis().saturating_sub(last) > (recv_dur.as_millis() as u64 * 2) {
926 let _ = sink.close().await; break 'conn;
927 }
928 }
929 }
930 }
931 }
932
933 if shutdown_sub.try_recv().is_ok() {
934 break;
935 }
936 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
937 backoff_secs = (backoff_secs * 2).min(30);
938 }
939 });
940
941 Ok(Connection {
942 outbound_tx: out_tx,
943 inbound_rx: Arc::new(Mutex::new(in_rx)),
944 shutdown_tx,
945 subscriptions,
946 sub_id_counter,
947 pending,
948 pending_receipts,
949 })
950 }
951
952 /// Build a CONNECT frame with all specified headers.
953 fn build_connect_frame(
954 accept_version: &str,
955 host: &str,
956 login: &str,
957 passcode: &str,
958 heartbeat: &str,
959 client_id: &Option<String>,
960 custom_headers: &[(String, String)],
961 ) -> Frame {
962 let mut connect = Frame::new("CONNECT")
963 .header("accept-version", accept_version)
964 .header("host", host)
965 .header("login", login)
966 .header("passcode", passcode)
967 .header("heart-beat", heartbeat);
968
969 if let Some(id) = client_id {
970 connect = connect.header("client-id", id);
971 }
972
973 // Reserved headers that custom_headers cannot override
974 let reserved = [
975 "accept-version",
976 "host",
977 "login",
978 "passcode",
979 "heart-beat",
980 "client-id",
981 ];
982
983 for (k, v) in custom_headers {
984 if !reserved.contains(&k.to_lowercase().as_str()) {
985 connect = connect.header(k, v);
986 }
987 }
988
989 connect
990 }
991
992 /// Wait for CONNECTED or ERROR response from the server.
993 ///
994 /// Returns the server's heartbeat header value on success, or an error
995 /// if the server sends an ERROR frame or closes the connection.
996 async fn await_connected_response(
997 framed: &mut Framed<TcpStream, StompCodec>,
998 ) -> Result<String, ConnError> {
999 loop {
1000 match framed.next().await {
1001 Some(Ok(StompItem::Frame(f))) => {
1002 if f.command == "CONNECTED" {
1003 // Extract heartbeat from server
1004 let server_hb = f.get_header("heart-beat").unwrap_or("0,0").to_string();
1005 return Ok(server_hb);
1006 } else if f.command == "ERROR" {
1007 // Server rejected connection (e.g., invalid credentials)
1008 return Err(ConnError::ServerRejected(ServerError::from_frame(f)));
1009 }
1010 // Ignore other frames during CONNECT phase
1011 }
1012 Some(Ok(StompItem::Heartbeat)) => {
1013 // Ignore heartbeats during handshake
1014 continue;
1015 }
1016 Some(Err(e)) => {
1017 return Err(ConnError::Io(e));
1018 }
1019 None => {
1020 return Err(ConnError::Protocol(
1021 "connection closed before CONNECTED received".to_string(),
1022 ));
1023 }
1024 }
1025 }
1026 }
1027
1028 pub async fn send_frame(&self, frame: Frame) -> Result<(), ConnError> {
1029 // Send a frame to the background writer task.
1030 //
1031 // Parameters
1032 // - `frame`: ownership of the `Frame` to send. The frame is converted
1033 // into a `StompItem::Frame` and sent over the internal mpsc channel.
1034 self.outbound_tx
1035 .send(StompItem::Frame(frame))
1036 .await
1037 .map_err(|_| ConnError::Protocol("send channel closed".into()))
1038 }
1039
1040 /// Send a frame with a receipt request and return the receipt ID.
1041 ///
1042 /// This method adds a unique `receipt` header to the frame and registers
1043 /// the receipt ID for tracking. Use `wait_for_receipt()` to wait for the
1044 /// server's RECEIPT response.
1045 ///
1046 /// # Parameters
1047 /// - `frame`: the frame to send. A `receipt` header will be added.
1048 ///
1049 /// # Returns
1050 /// The generated receipt ID that can be used with `wait_for_receipt()`.
1051 ///
1052 /// # Example
1053 /// ```ignore
1054 /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
1055 /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
1056 /// ```
1057 pub async fn send_frame_with_receipt(&self, frame: Frame) -> Result<String, ConnError> {
1058 use std::sync::atomic::AtomicU64;
1059 use std::sync::atomic::Ordering::SeqCst;
1060
1061 // Generate a unique receipt ID using a static counter
1062 static RECEIPT_COUNTER: AtomicU64 = AtomicU64::new(1);
1063 let receipt_id = format!("rcpt-{}", RECEIPT_COUNTER.fetch_add(1, SeqCst));
1064
1065 // Create the oneshot channel for notification
1066 let (tx, _rx) = oneshot::channel();
1067
1068 // Register the pending receipt
1069 {
1070 let mut receipts = self.pending_receipts.lock().await;
1071 receipts.insert(receipt_id.clone(), tx);
1072 }
1073
1074 // Add receipt header and send the frame
1075 let frame_with_receipt = frame.receipt(&receipt_id);
1076 self.send_frame(frame_with_receipt).await?;
1077
1078 Ok(receipt_id)
1079 }
1080
1081 /// Wait for a receipt confirmation from the server.
1082 ///
1083 /// This method blocks until the server sends a RECEIPT frame with the
1084 /// matching receipt-id, or until the timeout expires.
1085 ///
1086 /// # Parameters
1087 /// - `receipt_id`: the receipt ID returned by `send_frame_with_receipt()`.
1088 /// - `timeout`: maximum time to wait for the receipt.
1089 ///
1090 /// # Returns
1091 /// `Ok(())` if the receipt was received, or `Err(ConnError::ReceiptTimeout)`
1092 /// if the timeout expired.
1093 ///
1094 /// # Example
1095 /// ```ignore
1096 /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
1097 /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
1098 /// println!("Message confirmed!");
1099 /// ```
1100 pub async fn wait_for_receipt(
1101 &self,
1102 receipt_id: &str,
1103 timeout: Duration,
1104 ) -> Result<(), ConnError> {
1105 // Get the receiver for this receipt
1106 let rx = {
1107 let mut receipts = self.pending_receipts.lock().await;
1108 // Re-create the oneshot channel and swap out the sender
1109 let (tx, rx) = oneshot::channel();
1110 if let Some(old_tx) = receipts.insert(receipt_id.to_string(), tx) {
1111 // Drop the old sender - this is expected if called after send_frame_with_receipt
1112 drop(old_tx);
1113 }
1114 rx
1115 };
1116
1117 // Wait for the receipt with timeout
1118 match tokio::time::timeout(timeout, rx).await {
1119 Ok(Ok(())) => Ok(()),
1120 Ok(Err(_)) => {
1121 // Channel was closed without receiving - connection likely dropped
1122 Err(ConnError::Protocol(
1123 "receipt channel closed unexpectedly".into(),
1124 ))
1125 }
1126 Err(_) => {
1127 // Timeout expired - clean up the pending receipt
1128 let mut receipts = self.pending_receipts.lock().await;
1129 receipts.remove(receipt_id);
1130 Err(ConnError::ReceiptTimeout(receipt_id.to_string()))
1131 }
1132 }
1133 }
1134
1135 /// Send a frame and wait for server confirmation via RECEIPT.
1136 ///
1137 /// This is a convenience method that combines `send_frame_with_receipt()`
1138 /// and `wait_for_receipt()`. Use this when you want to ensure a frame
1139 /// was processed by the server before continuing.
1140 ///
1141 /// # Parameters
1142 /// - `frame`: the frame to send.
1143 /// - `timeout`: maximum time to wait for the receipt.
1144 ///
1145 /// # Returns
1146 /// `Ok(())` if the frame was sent and receipt confirmed, or an error if
1147 /// sending failed or the receipt timed out.
1148 ///
1149 /// # Example
1150 /// ```ignore
1151 /// let frame = Frame::new("SEND")
1152 /// .header("destination", "/queue/orders")
1153 /// .set_body(b"order data".to_vec());
1154 ///
1155 /// conn.send_frame_confirmed(frame, Duration::from_secs(5)).await?;
1156 /// println!("Order sent and confirmed!");
1157 /// ```
1158 pub async fn send_frame_confirmed(
1159 &self,
1160 frame: Frame,
1161 timeout: Duration,
1162 ) -> Result<(), ConnError> {
1163 // Generate receipt ID and register before sending
1164 use std::sync::atomic::AtomicU64;
1165 use std::sync::atomic::Ordering::SeqCst;
1166
1167 static RECEIPT_COUNTER: AtomicU64 = AtomicU64::new(1);
1168 let receipt_id = format!("rcpt-{}", RECEIPT_COUNTER.fetch_add(1, SeqCst));
1169
1170 // Create the oneshot channel for notification
1171 let (tx, rx) = oneshot::channel();
1172
1173 // Register the pending receipt before sending
1174 {
1175 let mut receipts = self.pending_receipts.lock().await;
1176 receipts.insert(receipt_id.clone(), tx);
1177 }
1178
1179 // Add receipt header and send the frame
1180 let frame_with_receipt = frame.receipt(&receipt_id);
1181 self.send_frame(frame_with_receipt).await?;
1182
1183 // Wait for the receipt with timeout
1184 match tokio::time::timeout(timeout, rx).await {
1185 Ok(Ok(())) => Ok(()),
1186 Ok(Err(_)) => Err(ConnError::Protocol(
1187 "receipt channel closed unexpectedly".into(),
1188 )),
1189 Err(_) => {
1190 // Timeout expired - clean up
1191 let mut receipts = self.pending_receipts.lock().await;
1192 receipts.remove(&receipt_id);
1193 Err(ConnError::ReceiptTimeout(receipt_id))
1194 }
1195 }
1196 }
1197
1198 /// Subscribe to a destination.
1199 ///
1200 /// Parameters
1201 /// - `destination`: the STOMP destination to subscribe to (e.g. "/queue/foo").
1202 /// - `ack`: acknowledgement mode to request from the server.
1203 ///
1204 /// Returns a tuple `(subscription_id, receiver)` where `subscription_id` is
1205 /// the opaque id assigned locally for this subscription and `receiver` is a
1206 /// `mpsc::Receiver<Frame>` which will yield incoming MESSAGE frames for the
1207 /// destination. The caller should read from the receiver to handle messages.
1208 /// Subscribe to a destination using optional extra headers.
1209 ///
1210 /// This variant accepts additional headers which are stored locally and
1211 /// re-sent on reconnect. Use `subscribe` as a convenience wrapper when no
1212 /// extra headers are needed.
1213 pub async fn subscribe_with_headers(
1214 &self,
1215 destination: &str,
1216 ack: AckMode,
1217 extra_headers: Vec<(String, String)>,
1218 ) -> Result<crate::subscription::Subscription, ConnError> {
1219 let id = self
1220 .sub_id_counter
1221 .fetch_add(1, Ordering::SeqCst)
1222 .to_string();
1223 let (tx, rx) = mpsc::channel::<Frame>(16);
1224 {
1225 let mut map = self.subscriptions.lock().await;
1226 map.entry(destination.to_string())
1227 .or_insert_with(Vec::new)
1228 .push(SubscriptionEntry {
1229 id: id.clone(),
1230 sender: tx.clone(),
1231 ack: ack.as_str().to_string(),
1232 headers: extra_headers.clone(),
1233 });
1234 }
1235
1236 let mut f = Frame::new("SUBSCRIBE");
1237 f = f
1238 .header("id", &id)
1239 .header("destination", destination)
1240 .header("ack", ack.as_str());
1241 for (k, v) in &extra_headers {
1242 f = f.header(k, v);
1243 }
1244 self.outbound_tx
1245 .send(StompItem::Frame(f))
1246 .await
1247 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1248
1249 Ok(crate::subscription::Subscription::new(
1250 id,
1251 destination.to_string(),
1252 rx,
1253 self.clone(),
1254 ))
1255 }
1256
1257 /// Convenience wrapper without extra headers.
1258 pub async fn subscribe(
1259 &self,
1260 destination: &str,
1261 ack: AckMode,
1262 ) -> Result<crate::subscription::Subscription, ConnError> {
1263 self.subscribe_with_headers(destination, ack, Vec::new())
1264 .await
1265 }
1266
1267 /// Subscribe with a typed `SubscriptionOptions` structure.
1268 ///
1269 /// `SubscriptionOptions.headers` are forwarded to the broker and persisted
1270 /// for automatic resubscribe after reconnect. If `durable_queue` is set,
1271 /// it will be used as the actual destination instead of `destination`.
1272 pub async fn subscribe_with_options(
1273 &self,
1274 destination: &str,
1275 ack: AckMode,
1276 options: crate::subscription::SubscriptionOptions,
1277 ) -> Result<crate::subscription::Subscription, ConnError> {
1278 let dest = options
1279 .durable_queue
1280 .as_deref()
1281 .unwrap_or(destination)
1282 .to_string();
1283 self.subscribe_with_headers(&dest, ack, options.headers)
1284 .await
1285 }
1286
1287 /// Unsubscribe a previously created subscription by its local subscription id.
1288 pub async fn unsubscribe(&self, subscription_id: &str) -> Result<(), ConnError> {
1289 let mut found = false;
1290 {
1291 let mut map = self.subscriptions.lock().await;
1292 let mut remove_keys: Vec<String> = Vec::new();
1293 for (dest, vec) in map.iter_mut() {
1294 if let Some(pos) = vec.iter().position(|entry| entry.id == subscription_id) {
1295 vec.remove(pos);
1296 found = true;
1297 }
1298 if vec.is_empty() {
1299 remove_keys.push(dest.clone());
1300 }
1301 }
1302 for k in remove_keys {
1303 map.remove(&k);
1304 }
1305 }
1306
1307 if !found {
1308 return Err(ConnError::Protocol("subscription id not found".into()));
1309 }
1310
1311 let mut f = Frame::new("UNSUBSCRIBE");
1312 f = f.header("id", subscription_id);
1313 self.outbound_tx
1314 .send(StompItem::Frame(f))
1315 .await
1316 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1317
1318 Ok(())
1319 }
1320
1321 /// Acknowledge a message previously received in `client` or
1322 /// `client-individual` ack modes.
1323 ///
1324 /// STOMP ack semantics:
1325 /// - `auto`: server considers message delivered immediately; the client
1326 /// should not ack.
1327 /// - `client`: cumulative acknowledgements. ACKing message `M` for
1328 /// subscription `S` acknowledges all messages delivered to `S` up to
1329 /// and including `M`.
1330 /// - `client-individual`: only the named message is acknowledged.
1331 ///
1332 /// Parameters
1333 /// - `subscription_id`: the local subscription id returned by
1334 /// `Connection::subscribe`. This disambiguates which subscription's
1335 /// pending queue to advance for cumulative ACKs.
1336 /// - `message_id`: the `message-id` header value from the received
1337 /// MESSAGE frame to acknowledge.
1338 ///
1339 /// Behavior
1340 /// - The pending queue for `subscription_id` is searched for `message_id`.
1341 /// If the subscription used `client` ack mode, all pending messages up to
1342 /// and including the matched message are removed. If the subscription
1343 /// used `client-individual`, only the matched message is removed.
1344 /// - An `ACK` frame is sent to the server with `id=<message_id>` and
1345 /// `subscription=<subscription_id>` headers.
1346 #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1347 pub async fn ack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1348 // Remove from the local pending queue according to subscription ack mode.
1349 let mut removed_any = false;
1350 {
1351 let mut p = self.pending.lock().await;
1352 if let Some(queue) = p.get_mut(subscription_id) {
1353 if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1354 // Determine ack mode for this subscription (default to client).
1355 let mut ack_mode = "client".to_string();
1356 {
1357 let map = self.subscriptions.lock().await;
1358 'outer: for (_dest, vec) in map.iter() {
1359 for entry in vec.iter() {
1360 if entry.id == subscription_id {
1361 ack_mode = entry.ack.clone();
1362 break 'outer;
1363 }
1364 }
1365 }
1366 }
1367
1368 if ack_mode == "client" {
1369 // cumulative: remove up to and including pos
1370 for _ in 0..=pos {
1371 queue.pop_front();
1372 removed_any = true;
1373 }
1374 } else if queue.remove(pos).is_some() {
1375 // client-individual: remove only the specific message
1376 removed_any = true;
1377 }
1378
1379 if queue.is_empty() {
1380 p.remove(subscription_id);
1381 }
1382 }
1383 }
1384 }
1385
1386 // Send ACK to server (include subscription header for clarity)
1387 let mut f = Frame::new("ACK");
1388 f = f
1389 .header("id", message_id)
1390 .header("subscription", subscription_id);
1391 self.outbound_tx
1392 .send(StompItem::Frame(f))
1393 .await
1394 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1395
1396 // If message wasn't found locally, still send ACK to server; server
1397 // may ignore or treat it as no-op.
1398 let _ = removed_any;
1399 Ok(())
1400 }
1401
1402 /// Negative-acknowledge a message (NACK).
1403 ///
1404 /// Parameters
1405 /// - `subscription_id`: the local subscription id the message was delivered under.
1406 /// - `message_id`: the `message-id` header value from the received MESSAGE.
1407 ///
1408 /// Behavior
1409 /// - Removes the message from the local pending queue (cumulatively if the
1410 /// subscription used `client` ack mode, otherwise only the single
1411 /// message). Sends a `NACK` frame to the server with `id` and
1412 /// `subscription` headers.
1413 #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1414 pub async fn nack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1415 // Mirror ack removal semantics for pending map.
1416 let mut removed_any = false;
1417 {
1418 let mut p = self.pending.lock().await;
1419 if let Some(queue) = p.get_mut(subscription_id) {
1420 if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1421 let mut ack_mode = "client".to_string();
1422 {
1423 let map = self.subscriptions.lock().await;
1424 'outer2: for (_dest, vec) in map.iter() {
1425 for entry in vec.iter() {
1426 if entry.id == subscription_id {
1427 ack_mode = entry.ack.clone();
1428 break 'outer2;
1429 }
1430 }
1431 }
1432 }
1433
1434 if ack_mode == "client" {
1435 for _ in 0..=pos {
1436 queue.pop_front();
1437 removed_any = true;
1438 }
1439 } else if queue.remove(pos).is_some() {
1440 removed_any = true;
1441 }
1442
1443 if queue.is_empty() {
1444 p.remove(subscription_id);
1445 }
1446 }
1447 }
1448 }
1449
1450 let mut f = Frame::new("NACK");
1451 f = f
1452 .header("id", message_id)
1453 .header("subscription", subscription_id);
1454 self.outbound_tx
1455 .send(StompItem::Frame(f))
1456 .await
1457 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1458
1459 let _ = removed_any;
1460 Ok(())
1461 }
1462
1463 /// Helper to send a transaction frame (BEGIN, COMMIT, or ABORT).
1464 async fn send_transaction_frame(
1465 &self,
1466 command: &str,
1467 transaction_id: &str,
1468 ) -> Result<(), ConnError> {
1469 let f = Frame::new(command).header("transaction", transaction_id);
1470 self.outbound_tx
1471 .send(StompItem::Frame(f))
1472 .await
1473 .map_err(|_| ConnError::Protocol("send channel closed".into()))
1474 }
1475
1476 /// Begin a transaction.
1477 ///
1478 /// Parameters
1479 /// - `transaction_id`: unique identifier for the transaction. The caller is
1480 /// responsible for ensuring uniqueness within the connection.
1481 ///
1482 /// Behavior
1483 /// - Sends a `BEGIN` frame to the server with `transaction:<transaction_id>`
1484 /// header. Subsequent `SEND`, `ACK`, and `NACK` frames may include this
1485 /// transaction id to group them into the transaction. The transaction must
1486 /// be finalized with either `commit` or `abort`.
1487 pub async fn begin(&self, transaction_id: &str) -> Result<(), ConnError> {
1488 self.send_transaction_frame("BEGIN", transaction_id).await
1489 }
1490
1491 /// Commit a transaction.
1492 ///
1493 /// Parameters
1494 /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1495 ///
1496 /// Behavior
1497 /// - Sends a `COMMIT` frame to the server with `transaction:<transaction_id>`
1498 /// header. All operations within the transaction are applied atomically.
1499 pub async fn commit(&self, transaction_id: &str) -> Result<(), ConnError> {
1500 self.send_transaction_frame("COMMIT", transaction_id).await
1501 }
1502
1503 /// Abort a transaction.
1504 ///
1505 /// Parameters
1506 /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1507 ///
1508 /// Behavior
1509 /// - Sends an `ABORT` frame to the server with `transaction:<transaction_id>`
1510 /// header. All operations within the transaction are discarded.
1511 pub async fn abort(&self, transaction_id: &str) -> Result<(), ConnError> {
1512 self.send_transaction_frame("ABORT", transaction_id).await
1513 }
1514
1515 /// Receive the next frame from the server.
1516 ///
1517 /// Returns `Some(ReceivedFrame::Frame(..))` for normal frames (MESSAGE, etc.),
1518 /// `Some(ReceivedFrame::Error(..))` for ERROR frames, or `None` if the
1519 /// connection has been closed.
1520 ///
1521 /// # Example
1522 ///
1523 /// ```ignore
1524 /// use iridium_stomp::ReceivedFrame;
1525 ///
1526 /// while let Some(received) = conn.next_frame().await {
1527 /// match received {
1528 /// ReceivedFrame::Frame(frame) => {
1529 /// println!("Got {}: {:?}", frame.command, frame.body);
1530 /// }
1531 /// ReceivedFrame::Error(err) => {
1532 /// eprintln!("Server error: {}", err);
1533 /// break;
1534 /// }
1535 /// }
1536 /// }
1537 /// ```
1538 pub async fn next_frame(&self) -> Option<ReceivedFrame> {
1539 let mut rx = self.inbound_rx.lock().await;
1540 let frame = rx.recv().await?;
1541
1542 // Convert ERROR frames to ServerError for better ergonomics
1543 if frame.command == "ERROR" {
1544 Some(ReceivedFrame::Error(ServerError::from_frame(frame)))
1545 } else {
1546 Some(ReceivedFrame::Frame(frame))
1547 }
1548 }
1549
1550 pub async fn close(self) {
1551 // Signal the background task to shutdown by broadcasting on the
1552 // shutdown channel. Consumers may await task termination separately
1553 // if needed.
1554 let _ = self.shutdown_tx.send(());
1555 }
1556}
1557
1558fn current_millis() -> u64 {
1559 use std::time::{SystemTime, UNIX_EPOCH};
1560 SystemTime::now()
1561 .duration_since(UNIX_EPOCH)
1562 .map(|d| d.as_millis() as u64)
1563 .unwrap_or(0)
1564}
1565
1566#[cfg(test)]
1567mod tests {
1568 use super::*;
1569 use tokio::sync::mpsc;
1570
1571 // Helper to build a MESSAGE frame with given message-id and subscription/destination headers
1572 fn make_message(
1573 message_id: &str,
1574 subscription: Option<&str>,
1575 destination: Option<&str>,
1576 ) -> Frame {
1577 let mut f = Frame::new("MESSAGE");
1578 f = f.header("message-id", message_id);
1579 if let Some(s) = subscription {
1580 f = f.header("subscription", s);
1581 }
1582 if let Some(d) = destination {
1583 f = f.header("destination", d);
1584 }
1585 f
1586 }
1587
1588 #[tokio::test]
1589 async fn test_cumulative_ack_removes_prefix() {
1590 // setup channels
1591 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1592 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1593 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1594
1595 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1596 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1597
1598 let sub_id_counter = Arc::new(AtomicU64::new(1));
1599
1600 // create a subscription entry s1 with client (cumulative) ack
1601 let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1602 {
1603 let mut map = subscriptions.lock().await;
1604 map.insert(
1605 "/queue/x".to_string(),
1606 vec![SubscriptionEntry {
1607 id: "s1".to_string(),
1608 sender: sub_sender,
1609 ack: "client".to_string(),
1610 headers: Vec::new(),
1611 }],
1612 );
1613 }
1614
1615 // fill pending queue for s1: m1,m2,m3
1616 {
1617 let mut p = pending.lock().await;
1618 let mut q = VecDeque::new();
1619 q.push_back((
1620 "m1".to_string(),
1621 make_message("m1", Some("s1"), Some("/queue/x")),
1622 ));
1623 q.push_back((
1624 "m2".to_string(),
1625 make_message("m2", Some("s1"), Some("/queue/x")),
1626 ));
1627 q.push_back((
1628 "m3".to_string(),
1629 make_message("m3", Some("s1"), Some("/queue/x")),
1630 ));
1631 p.insert("s1".to_string(), q);
1632 }
1633
1634 let conn = Connection {
1635 outbound_tx: out_tx,
1636 inbound_rx: Arc::new(Mutex::new(in_rx)),
1637 shutdown_tx,
1638 subscriptions: subscriptions.clone(),
1639 sub_id_counter,
1640 pending: pending.clone(),
1641 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1642 };
1643
1644 // ack m2 cumulatively: should remove m1 and m2, leaving m3
1645 conn.ack("s1", "m2").await.expect("ack failed");
1646
1647 // verify pending for s1 contains only m3
1648 {
1649 let p = pending.lock().await;
1650 let q = p.get("s1").expect("missing s1");
1651 assert_eq!(q.len(), 1);
1652 assert_eq!(q.front().unwrap().0, "m3");
1653 }
1654
1655 // verify an ACK frame was emitted
1656 if let Some(item) = out_rx.recv().await {
1657 match item {
1658 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1659 _ => panic!("expected frame"),
1660 }
1661 } else {
1662 panic!("no outbound frame sent")
1663 }
1664 }
1665
1666 #[tokio::test]
1667 async fn test_client_individual_ack_removes_only_one() {
1668 // setup channels
1669 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1670 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1671 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1672
1673 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1674 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1675
1676 let sub_id_counter = Arc::new(AtomicU64::new(1));
1677
1678 // create a subscription entry s2 with client-individual ack
1679 let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1680 {
1681 let mut map = subscriptions.lock().await;
1682 map.insert(
1683 "/queue/y".to_string(),
1684 vec![SubscriptionEntry {
1685 id: "s2".to_string(),
1686 sender: sub_sender,
1687 ack: "client-individual".to_string(),
1688 headers: Vec::new(),
1689 }],
1690 );
1691 }
1692
1693 // fill pending queue for s2: a,b,c
1694 {
1695 let mut p = pending.lock().await;
1696 let mut q = VecDeque::new();
1697 q.push_back((
1698 "a".to_string(),
1699 make_message("a", Some("s2"), Some("/queue/y")),
1700 ));
1701 q.push_back((
1702 "b".to_string(),
1703 make_message("b", Some("s2"), Some("/queue/y")),
1704 ));
1705 q.push_back((
1706 "c".to_string(),
1707 make_message("c", Some("s2"), Some("/queue/y")),
1708 ));
1709 p.insert("s2".to_string(), q);
1710 }
1711
1712 let conn = Connection {
1713 outbound_tx: out_tx,
1714 inbound_rx: Arc::new(Mutex::new(in_rx)),
1715 shutdown_tx,
1716 subscriptions: subscriptions.clone(),
1717 sub_id_counter,
1718 pending: pending.clone(),
1719 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1720 };
1721
1722 // ack only 'b' individually
1723 conn.ack("s2", "b").await.expect("ack failed");
1724
1725 // verify pending for s2 contains a and c
1726 {
1727 let p = pending.lock().await;
1728 let q = p.get("s2").expect("missing s2");
1729 assert_eq!(q.len(), 2);
1730 assert_eq!(q[0].0, "a");
1731 assert_eq!(q[1].0, "c");
1732 }
1733
1734 // verify an ACK frame was emitted
1735 if let Some(item) = out_rx.recv().await {
1736 match item {
1737 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1738 _ => panic!("expected frame"),
1739 }
1740 } else {
1741 panic!("no outbound frame sent")
1742 }
1743 }
1744
1745 #[tokio::test]
1746 async fn test_subscription_receive_delivers_message() {
1747 // setup channels
1748 let (out_tx, _out_rx) = mpsc::channel::<StompItem>(8);
1749 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1750 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1751
1752 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1753 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1754
1755 let sub_id_counter = Arc::new(AtomicU64::new(1));
1756
1757 let conn = Connection {
1758 outbound_tx: out_tx,
1759 inbound_rx: Arc::new(Mutex::new(in_rx)),
1760 shutdown_tx,
1761 subscriptions: subscriptions.clone(),
1762 sub_id_counter,
1763 pending: pending.clone(),
1764 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1765 };
1766
1767 // subscribe
1768 let subscription = conn
1769 .subscribe("/queue/test", AckMode::Auto)
1770 .await
1771 .expect("subscribe failed");
1772
1773 // find the sender stored in the subscriptions map and push a message
1774 {
1775 let map = conn.subscriptions.lock().await;
1776 let vec = map.get("/queue/test").expect("missing subscription vec");
1777 let sender = &vec[0].sender;
1778 let f = make_message("m1", Some(&vec[0].id), Some("/queue/test"));
1779 sender.try_send(f).expect("send to subscription failed");
1780 }
1781
1782 // consume from the subscription receiver
1783 let mut rx = subscription.into_receiver();
1784 if let Some(received) = rx.recv().await {
1785 assert_eq!(received.command, "MESSAGE");
1786 // message-id header should be present
1787 let mut found = false;
1788 for (k, _v) in &received.headers {
1789 if k.to_lowercase() == "message-id" {
1790 found = true;
1791 break;
1792 }
1793 }
1794 assert!(found, "message-id header missing");
1795 } else {
1796 panic!("no message received on subscription")
1797 }
1798 }
1799
1800 #[tokio::test]
1801 async fn test_subscription_ack_removes_pending_and_sends_ack() {
1802 // setup channels
1803 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1804 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1805 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1806
1807 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1808 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1809
1810 let sub_id_counter = Arc::new(AtomicU64::new(1));
1811
1812 let conn = Connection {
1813 outbound_tx: out_tx,
1814 inbound_rx: Arc::new(Mutex::new(in_rx)),
1815 shutdown_tx,
1816 subscriptions: subscriptions.clone(),
1817 sub_id_counter,
1818 pending: pending.clone(),
1819 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1820 };
1821
1822 // subscribe with client ack
1823 let subscription = conn
1824 .subscribe("/queue/ack", AckMode::Client)
1825 .await
1826 .expect("subscribe failed");
1827
1828 let sub_id = subscription.id().to_string();
1829
1830 // drain any initial outbound frames (SUBSCRIBE) emitted by subscribe()
1831 while out_rx.try_recv().is_ok() {}
1832
1833 // populate pending queue for this subscription
1834 {
1835 let mut p = conn.pending.lock().await;
1836 let mut q = VecDeque::new();
1837 q.push_back((
1838 "mid-1".to_string(),
1839 make_message("mid-1", Some(&sub_id), Some("/queue/ack")),
1840 ));
1841 p.insert(sub_id.clone(), q);
1842 }
1843
1844 // ack the message via the subscription helper
1845 subscription.ack("mid-1").await.expect("ack failed");
1846
1847 // ensure pending queue no longer contains the message
1848 {
1849 let p = conn.pending.lock().await;
1850 assert!(p.get(&sub_id).is_none() || p.get(&sub_id).unwrap().is_empty());
1851 }
1852
1853 // verify an ACK frame was emitted
1854 if let Some(item) = out_rx.recv().await {
1855 match item {
1856 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1857 _ => panic!("expected frame"),
1858 }
1859 } else {
1860 panic!("no outbound frame sent")
1861 }
1862 }
1863
1864 // Helper function to create a test connection and output receiver
1865 fn setup_test_connection() -> (Connection, mpsc::Receiver<StompItem>) {
1866 let (out_tx, out_rx) = mpsc::channel::<StompItem>(8);
1867 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1868 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1869
1870 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1871 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1872 let sub_id_counter = Arc::new(AtomicU64::new(1));
1873
1874 let conn = Connection {
1875 outbound_tx: out_tx,
1876 inbound_rx: Arc::new(Mutex::new(in_rx)),
1877 shutdown_tx,
1878 subscriptions,
1879 sub_id_counter,
1880 pending,
1881 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1882 };
1883
1884 (conn, out_rx)
1885 }
1886
1887 // Helper function to verify a frame with a transaction header
1888 fn verify_transaction_frame(frame: Frame, expected_command: &str, expected_tx_id: &str) {
1889 assert_eq!(frame.command, expected_command);
1890 assert!(
1891 frame
1892 .headers
1893 .iter()
1894 .any(|(k, v)| k == "transaction" && v == expected_tx_id),
1895 "transaction header with id '{}' not found",
1896 expected_tx_id
1897 );
1898 }
1899
1900 #[tokio::test]
1901 async fn test_begin_transaction_sends_frame() {
1902 let (conn, mut out_rx) = setup_test_connection();
1903
1904 conn.begin("tx1").await.expect("begin failed");
1905
1906 // verify BEGIN frame was emitted
1907 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1908 verify_transaction_frame(f, "BEGIN", "tx1");
1909 } else {
1910 panic!("no outbound frame sent")
1911 }
1912 }
1913
1914 #[tokio::test]
1915 async fn test_commit_transaction_sends_frame() {
1916 let (conn, mut out_rx) = setup_test_connection();
1917
1918 conn.commit("tx1").await.expect("commit failed");
1919
1920 // verify COMMIT frame was emitted
1921 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1922 verify_transaction_frame(f, "COMMIT", "tx1");
1923 } else {
1924 panic!("no outbound frame sent")
1925 }
1926 }
1927
1928 #[tokio::test]
1929 async fn test_abort_transaction_sends_frame() {
1930 let (conn, mut out_rx) = setup_test_connection();
1931
1932 conn.abort("tx1").await.expect("abort failed");
1933
1934 // verify ABORT frame was emitted
1935 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1936 verify_transaction_frame(f, "ABORT", "tx1");
1937 } else {
1938 panic!("no outbound frame sent")
1939 }
1940 }
1941}