iridium_stomp/connection.rs
1use futures::{SinkExt, StreamExt, future};
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6use thiserror::Error;
7use tokio::net::TcpStream;
8use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
9use tokio_util::codec::Framed;
10
11use crate::codec::{StompCodec, StompItem};
12use crate::frame::Frame;
13
14/// Configuration for STOMP heartbeat intervals.
15///
16/// Provides a type-safe way to configure heartbeat values instead of using
17/// raw strings. The `Display` implementation formats the value as required
18/// by the STOMP protocol ("send_ms,receive_ms").
19///
20/// # Example
21///
22/// ```
23/// use iridium_stomp::Heartbeat;
24///
25/// // Create a custom heartbeat configuration
26/// let hb = Heartbeat::new(5000, 10000);
27/// assert_eq!(hb.to_string(), "5000,10000");
28///
29/// // Use predefined configurations
30/// assert_eq!(Heartbeat::disabled().to_string(), "0,0");
31/// assert_eq!(Heartbeat::default().to_string(), "10000,10000");
32/// ```
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct Heartbeat {
35 /// Minimum interval (in milliseconds) between heartbeats the client can send.
36 /// A value of 0 means the client cannot send heartbeats.
37 pub send_ms: u32,
38
39 /// Minimum interval (in milliseconds) between heartbeats the client wants to receive.
40 /// A value of 0 means the client does not want to receive heartbeats.
41 pub receive_ms: u32,
42}
43
44impl Heartbeat {
45 /// Create a new heartbeat configuration with the specified intervals.
46 ///
47 /// # Arguments
48 ///
49 /// * `send_ms` - Minimum interval in milliseconds between heartbeats the client can send.
50 /// * `receive_ms` - Minimum interval in milliseconds between heartbeats the client wants to receive.
51 ///
52 /// # Example
53 ///
54 /// ```
55 /// use iridium_stomp::Heartbeat;
56 ///
57 /// let hb = Heartbeat::new(5000, 10000);
58 /// assert_eq!(hb.send_ms, 5000);
59 /// assert_eq!(hb.receive_ms, 10000);
60 /// ```
61 pub fn new(send_ms: u32, receive_ms: u32) -> Self {
62 Self {
63 send_ms,
64 receive_ms,
65 }
66 }
67
68 /// Create a heartbeat configuration that disables heartbeats entirely.
69 ///
70 /// This is equivalent to `Heartbeat::new(0, 0)`.
71 ///
72 /// # Example
73 ///
74 /// ```
75 /// use iridium_stomp::Heartbeat;
76 ///
77 /// let hb = Heartbeat::disabled();
78 /// assert_eq!(hb.send_ms, 0);
79 /// assert_eq!(hb.receive_ms, 0);
80 /// assert_eq!(hb.to_string(), "0,0");
81 /// ```
82 pub fn disabled() -> Self {
83 Self::new(0, 0)
84 }
85
86 /// Create a heartbeat configuration from a Duration for symmetric heartbeats.
87 ///
88 /// Both send and receive intervals will be set to the same value.
89 ///
90 /// The maximum supported Duration is approximately 49.7 days (u32::MAX milliseconds,
91 /// or 4,294,967,295 ms). If a larger Duration is provided, it will be clamped to
92 /// u32::MAX milliseconds to prevent overflow.
93 ///
94 /// # Example
95 ///
96 /// ```
97 /// use iridium_stomp::Heartbeat;
98 /// use std::time::Duration;
99 ///
100 /// let hb = Heartbeat::from_duration(Duration::from_secs(15));
101 /// assert_eq!(hb.send_ms, 15000);
102 /// assert_eq!(hb.receive_ms, 15000);
103 /// ```
104 pub fn from_duration(interval: Duration) -> Self {
105 let ms = interval.as_millis().min(u32::MAX as u128) as u32;
106 Self::new(ms, ms)
107 }
108}
109
110impl Default for Heartbeat {
111 /// Returns the default heartbeat configuration: 10 seconds for both send and receive.
112 fn default() -> Self {
113 Self::new(10000, 10000)
114 }
115}
116
117impl std::fmt::Display for Heartbeat {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 write!(f, "{},{}", self.send_ms, self.receive_ms)
120 }
121}
122
123/// Internal subscription entry stored for each destination.
124#[derive(Clone)]
125pub(crate) struct SubscriptionEntry {
126 pub(crate) id: String,
127 pub(crate) sender: mpsc::Sender<Frame>,
128 pub(crate) ack: String,
129 pub(crate) headers: Vec<(String, String)>,
130}
131
132/// Alias for the subscription dispatch map: destination -> list of
133/// `SubscriptionEntry`.
134pub(crate) type Subscriptions = HashMap<String, Vec<SubscriptionEntry>>;
135
136/// Alias for the pending map: subscription_id -> queue of (message-id, Frame).
137pub(crate) type PendingMap = HashMap<String, VecDeque<(String, Frame)>>;
138
139/// Internal type for resubscribe snapshot entries: (destination, id, ack, headers)
140pub(crate) type ResubEntry = (String, String, String, Vec<(String, String)>);
141
142/// Alias for pending receipt map: receipt-id -> oneshot sender to notify when received.
143pub(crate) type PendingReceipts = HashMap<String, oneshot::Sender<()>>;
144
145/// Errors returned by `Connection` operations.
146#[derive(Error, Debug)]
147pub enum ConnError {
148 /// I/O-level error
149 #[error("io error: {0}")]
150 Io(#[from] std::io::Error),
151 /// Protocol-level error
152 #[error("protocol error: {0}")]
153 Protocol(String),
154 /// Receipt timeout error
155 #[error("receipt timeout: no RECEIPT received for '{0}' within timeout")]
156 ReceiptTimeout(String),
157 /// Server rejected the connection (e.g., authentication failure)
158 ///
159 /// This error is returned when the server sends an ERROR frame in response
160 /// to the CONNECT frame. Common causes include invalid credentials,
161 /// unauthorized access, or broker configuration issues.
162 #[error("server rejected connection: {0}")]
163 ServerRejected(ServerError),
164}
165
166/// Represents an ERROR frame received from the STOMP server.
167///
168/// STOMP servers send ERROR frames to indicate protocol violations, authentication
169/// failures, or other server-side errors. After sending an ERROR frame, the server
170/// typically closes the connection.
171///
172/// # Example
173///
174/// ```ignore
175/// use iridium_stomp::ReceivedFrame;
176///
177/// while let Some(received) = conn.next_frame().await {
178/// match received {
179/// ReceivedFrame::Frame(frame) => {
180/// // Normal message processing
181/// }
182/// ReceivedFrame::Error(err) => {
183/// eprintln!("Server error: {}", err.message);
184/// if let Some(body) = &err.body {
185/// eprintln!("Details: {}", body);
186/// }
187/// break;
188/// }
189/// }
190/// }
191/// ```
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub struct ServerError {
194 /// The error message from the `message` header.
195 pub message: String,
196
197 /// The error body, if present. Contains additional error details.
198 pub body: Option<String>,
199
200 /// The receipt-id if this error is in response to a specific frame.
201 pub receipt_id: Option<String>,
202
203 /// The original ERROR frame for access to additional headers.
204 pub frame: Frame,
205}
206
207impl ServerError {
208 /// Create a `ServerError` from an ERROR frame.
209 ///
210 /// This is primarily used internally but is public for testing and
211 /// advanced use cases where you need to construct a `ServerError` manually.
212 pub fn from_frame(frame: Frame) -> Self {
213 let message = frame
214 .get_header("message")
215 .unwrap_or("unknown error")
216 .to_string();
217
218 let body = if frame.body.is_empty() {
219 None
220 } else {
221 String::from_utf8(frame.body.clone()).ok()
222 };
223
224 let receipt_id = frame.get_header("receipt-id").map(|s| s.to_string());
225
226 Self {
227 message,
228 body,
229 receipt_id,
230 frame,
231 }
232 }
233}
234
235impl std::fmt::Display for ServerError {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 write!(f, "STOMP server error: {}", self.message)?;
238 if let Some(body) = &self.body {
239 write!(f, " - {}", body)?;
240 }
241 Ok(())
242 }
243}
244
245impl std::error::Error for ServerError {}
246
247/// The result of receiving a frame from the server.
248///
249/// STOMP servers can send either normal frames (MESSAGE, RECEIPT, etc.) or
250/// ERROR frames indicating a problem. This enum allows callers to handle
251/// both cases with pattern matching.
252///
253/// # Example
254///
255/// ```ignore
256/// use iridium_stomp::ReceivedFrame;
257///
258/// match conn.next_frame().await {
259/// Some(ReceivedFrame::Frame(frame)) => {
260/// println!("Got frame: {}", frame.command);
261/// }
262/// Some(ReceivedFrame::Error(err)) => {
263/// eprintln!("Server error: {}", err);
264/// }
265/// None => {
266/// println!("Connection closed");
267/// }
268/// }
269/// ```
270#[derive(Debug, Clone, PartialEq, Eq)]
271pub enum ReceivedFrame {
272 /// A normal STOMP frame (MESSAGE, RECEIPT, etc.)
273 Frame(Frame),
274 /// An ERROR frame from the server
275 Error(ServerError),
276}
277
278impl ReceivedFrame {
279 /// Returns `true` if this is an error frame.
280 pub fn is_error(&self) -> bool {
281 matches!(self, ReceivedFrame::Error(_))
282 }
283
284 /// Returns `true` if this is a normal frame.
285 pub fn is_frame(&self) -> bool {
286 matches!(self, ReceivedFrame::Frame(_))
287 }
288
289 /// Returns the frame if this is a normal frame, or `None` if it's an error.
290 pub fn into_frame(self) -> Option<Frame> {
291 match self {
292 ReceivedFrame::Frame(f) => Some(f),
293 ReceivedFrame::Error(_) => None,
294 }
295 }
296
297 /// Returns the error if this is an error frame, or `None` if it's a normal frame.
298 pub fn into_error(self) -> Option<ServerError> {
299 match self {
300 ReceivedFrame::Frame(_) => None,
301 ReceivedFrame::Error(e) => Some(e),
302 }
303 }
304}
305
306/// Subscription acknowledgement modes as defined by STOMP 1.2.
307#[derive(Debug, Clone, Copy, PartialEq, Eq)]
308pub enum AckMode {
309 Auto,
310 Client,
311 ClientIndividual,
312}
313
314impl AckMode {
315 fn as_str(&self) -> &'static str {
316 match self {
317 AckMode::Auto => "auto",
318 AckMode::Client => "client",
319 AckMode::ClientIndividual => "client-individual",
320 }
321 }
322}
323
324/// Options for customizing the STOMP CONNECT frame.
325///
326/// Use this struct with `Connection::connect_with_options()` to set custom
327/// headers, specify supported STOMP versions, or configure broker-specific
328/// options like `client-id` for durable subscriptions.
329///
330/// # Validation
331///
332/// This struct performs minimal validation. Values are passed to the broker
333/// as-is, and invalid configurations will be rejected by the broker at
334/// connection time. Empty strings are technically accepted but may cause
335/// broker-specific errors.
336///
337/// # Custom Headers
338///
339/// Custom headers added via `header()` cannot override critical STOMP headers
340/// (`accept-version`, `host`, `login`, `passcode`, `heart-beat`, `client-id`).
341/// Such headers are silently ignored. Use the dedicated builder methods to
342/// set these values.
343///
344/// # Example
345///
346/// ```ignore
347/// use iridium_stomp::{Connection, ConnectOptions};
348///
349/// let options = ConnectOptions::default()
350/// .client_id("my-durable-client")
351/// .host("my-vhost")
352/// .header("custom-header", "value");
353///
354/// let conn = Connection::connect_with_options(
355/// "localhost:61613",
356/// "guest",
357/// "guest",
358/// "10000,10000",
359/// options,
360/// ).await?;
361/// ```
362#[derive(Clone, Default)]
363pub struct ConnectOptions {
364 /// STOMP version(s) to accept (e.g., "1.2" or "1.0,1.1,1.2").
365 /// Defaults to "1.2" if not set.
366 pub accept_version: Option<String>,
367
368 /// Client ID for durable subscriptions (required by ActiveMQ, etc.).
369 pub client_id: Option<String>,
370
371 /// Virtual host header value. Defaults to "/" if not set.
372 pub host: Option<String>,
373
374 /// Additional custom headers to include in the CONNECT frame.
375 /// Note: Headers that would override critical STOMP headers are ignored.
376 pub headers: Vec<(String, String)>,
377
378 /// Optional channel to receive heartbeat notifications.
379 /// When set, the connection will send a `()` on this channel each time
380 /// a heartbeat is received from the server.
381 pub heartbeat_tx: Option<mpsc::Sender<()>>,
382}
383
384impl std::fmt::Debug for ConnectOptions {
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 f.debug_struct("ConnectOptions")
387 .field("accept_version", &self.accept_version)
388 .field("client_id", &self.client_id)
389 .field("host", &self.host)
390 .field("headers", &self.headers)
391 .field(
392 "heartbeat_tx",
393 &self.heartbeat_tx.as_ref().map(|_| "Some(...)"),
394 )
395 .finish()
396 }
397}
398
399impl ConnectOptions {
400 /// Create a new `ConnectOptions` with default values.
401 pub fn new() -> Self {
402 Self::default()
403 }
404
405 /// Set the STOMP version(s) to accept (builder style).
406 ///
407 /// Examples: "1.2", "1.1,1.2", "1.0,1.1,1.2"
408 pub fn accept_version(mut self, version: impl Into<String>) -> Self {
409 self.accept_version = Some(version.into());
410 self
411 }
412
413 /// Set the client ID for durable subscriptions (builder style).
414 ///
415 /// Required by some brokers (e.g., ActiveMQ) for durable topic subscriptions.
416 pub fn client_id(mut self, id: impl Into<String>) -> Self {
417 self.client_id = Some(id.into());
418 self
419 }
420
421 /// Set the virtual host (builder style).
422 ///
423 /// Defaults to "/" if not set.
424 pub fn host(mut self, host: impl Into<String>) -> Self {
425 self.host = Some(host.into());
426 self
427 }
428
429 /// Add a custom header to the CONNECT frame (builder style).
430 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
431 self.headers.push((key.into(), value.into()));
432 self
433 }
434
435 /// Set a channel to receive heartbeat notifications (builder style).
436 ///
437 /// When set, the connection will send a `()` on this channel each time
438 /// a heartbeat is received from the server. This is useful for CLI tools
439 /// or monitoring applications that want to display heartbeat status.
440 ///
441 /// # Note
442 ///
443 /// Notifications are sent using `try_send()` to avoid blocking the
444 /// connection's background task. If the channel buffer is full,
445 /// notifications will be silently dropped. Use a sufficiently sized
446 /// channel buffer (e.g., 16) to avoid missing notifications.
447 ///
448 /// # Example
449 ///
450 /// ```ignore
451 /// use tokio::sync::mpsc;
452 /// use iridium_stomp::ConnectOptions;
453 ///
454 /// let (tx, mut rx) = mpsc::channel(16);
455 /// let options = ConnectOptions::default()
456 /// .with_heartbeat_notify(tx);
457 ///
458 /// // In another task:
459 /// while rx.recv().await.is_some() {
460 /// println!("Heartbeat received!");
461 /// }
462 /// ```
463 pub fn with_heartbeat_notify(mut self, tx: mpsc::Sender<()>) -> Self {
464 self.heartbeat_tx = Some(tx);
465 self
466 }
467}
468
469/// Parse the STOMP `heart-beat` header value (format: "cx,cy").
470///
471/// Parameters
472/// - `header`: header string from the server or client (for example
473/// "10000,10000"). The values represent milliseconds.
474///
475/// Returns a tuple `(cx, cy)` where each value is the heartbeat interval in
476/// milliseconds. Missing or invalid fields default to `0`.
477pub fn parse_heartbeat_header(header: &str) -> (u64, u64) {
478 let mut parts = header.split(',');
479 let cx = parts
480 .next()
481 .and_then(|s| s.trim().parse::<u64>().ok())
482 .unwrap_or(0);
483 let cy = parts
484 .next()
485 .and_then(|s| s.trim().parse::<u64>().ok())
486 .unwrap_or(0);
487 (cx, cy)
488}
489
490/// Negotiate heartbeat intervals between client and server.
491///
492/// Parameters
493/// - `client_out`: client's desired outgoing heartbeat interval in
494/// milliseconds (how often the client will send heartbeats).
495/// - `client_in`: client's desired incoming heartbeat interval in
496/// milliseconds (how often the client expects to receive heartbeats).
497/// - `server_out`: server's advertised outgoing interval in milliseconds.
498/// - `server_in`: server's advertised incoming interval in milliseconds.
499///
500/// Returns `(outgoing, incoming)` where each element is `Some(Duration)` if
501/// heartbeats are enabled in that direction, or `None` if disabled. The
502/// negotiated interval uses the STOMP rule of taking the maximum of the
503/// corresponding client and server values.
504pub fn negotiate_heartbeats(
505 client_out: u64,
506 client_in: u64,
507 server_out: u64,
508 server_in: u64,
509) -> (Option<Duration>, Option<Duration>) {
510 let negotiated_out_ms = std::cmp::max(client_out, server_in);
511 let negotiated_in_ms = std::cmp::max(client_in, server_out);
512
513 let outgoing = if negotiated_out_ms == 0 {
514 None
515 } else {
516 Some(Duration::from_millis(negotiated_out_ms))
517 };
518 let incoming = if negotiated_in_ms == 0 {
519 None
520 } else {
521 Some(Duration::from_millis(negotiated_in_ms))
522 };
523 (outgoing, incoming)
524}
525
526/// Extract the destination from an ERROR frame.
527///
528/// Tries multiple strategies:
529/// 1. Check for a `destination` header (some brokers include it)
530/// 2. Parse the error message/body for `/topic/...` or `/queue/...` patterns
531///
532/// Returns `None` if no destination can be identified.
533fn extract_destination_from_error(frame: &Frame) -> Option<String> {
534 // Strategy 1: Check for destination header
535 if let Some(dest) = frame.get_header("destination") {
536 return Some(dest.to_string());
537 }
538
539 // Strategy 2: Look for destination pattern in message header or body
540 let message = frame.get_header("message").unwrap_or("");
541 let body = String::from_utf8_lossy(&frame.body);
542
543 // Combine message and body for searching
544 let text = format!("{} {}", message, body);
545
546 // Look for /topic/ or /queue/ patterns
547 for prefix in ["/topic/", "/queue/"] {
548 if let Some(start) = text.find(prefix) {
549 // Extract until whitespace, comma, quote, or end of string
550 let rest = &text[start..];
551 let end = rest
552 .find(|c: char| c.is_whitespace() || c == ',' || c == '"' || c == '\'')
553 .unwrap_or(rest.len());
554 if end > prefix.len() {
555 return Some(rest[..end].to_string());
556 }
557 }
558 }
559
560 None
561}
562
563/// Extract subscription ID from an ERROR frame message.
564///
565/// Looks for patterns like "subscription 1" or "subscription sub-1" in the
566/// error message or body. Artemis uses this format for subscription errors.
567fn extract_subscription_id_from_error(frame: &Frame) -> Option<String> {
568 let message = frame.get_header("message").unwrap_or("");
569 let body = String::from_utf8_lossy(&frame.body);
570 let text = format!("{} {}", message, body);
571
572 // Look for "subscription X" pattern (Artemis format)
573 if let Some(idx) = text.to_lowercase().find("subscription ") {
574 let rest = &text[idx + 13..]; // "subscription " is 13 chars
575 // Extract the subscription ID (could be numeric or alphanumeric like "sub-1")
576 let end = rest
577 .find(|c: char| c.is_whitespace() || c == ',' || c == '"' || c == '\'')
578 .unwrap_or(rest.len());
579 if end > 0 {
580 return Some(rest[..end].to_string());
581 }
582 }
583
584 None
585}
586
587/// Look up a destination by subscription ID in the subscriptions map.
588async fn lookup_destination_by_sub_id(
589 sub_id: &str,
590 subscriptions: &Arc<Mutex<Subscriptions>>,
591) -> Option<String> {
592 let map = subscriptions.lock().await;
593 for (dest, entries) in map.iter() {
594 for entry in entries {
595 if entry.id == sub_id {
596 return Some(dest.clone());
597 }
598 }
599 }
600 None
601}
602
603/// High-level connection object that manages a single TCP/STOMP connection.
604///
605/// The `Connection` spawns a background task that maintains the TCP transport,
606/// sends/receives STOMP frames using `StompCodec`, negotiates heartbeats, and
607/// performs simple reconnect logic with exponential backoff.
608#[derive(Clone)]
609pub struct Connection {
610 outbound_tx: mpsc::Sender<StompItem>,
611 /// The inbound receiver is shared behind a mutex so the `Connection`
612 /// handle may be cloned and callers can call `next_frame` concurrently.
613 inbound_rx: Arc<Mutex<mpsc::Receiver<Frame>>>,
614 shutdown_tx: broadcast::Sender<()>,
615 /// Map of destination -> list of (subscription id, sender) for dispatching
616 /// inbound MESSAGE frames to subscribers.
617 subscriptions: Arc<Mutex<Subscriptions>>,
618 /// Monotonic counter used to allocate subscription ids.
619 sub_id_counter: Arc<AtomicU64>,
620 /// Pending messages awaiting ACK/NACK from the application.
621 ///
622 /// Organized by subscription id. For `client` ack mode the ACK is
623 /// cumulative: acknowledging message `M` for subscription `S` acknowledges
624 /// all messages previously delivered for `S` up to and including `M`.
625 /// For `client-individual` the ACK/NACK applies only to the single
626 /// message.
627 pending: Arc<Mutex<PendingMap>>,
628 /// Pending receipt confirmations.
629 ///
630 /// When a frame is sent with a `receipt` header, the receipt-id is stored
631 /// here with a oneshot sender. When the server responds with a RECEIPT
632 /// frame, the sender is notified.
633 pending_receipts: Arc<Mutex<PendingReceipts>>,
634}
635
636impl Connection {
637 /// Heartbeat value that disables heartbeats entirely.
638 ///
639 /// Use this when you don't want the client or server to send heartbeats.
640 /// Note that some brokers may still require heartbeats for long-lived connections.
641 ///
642 /// # Example
643 ///
644 /// ```ignore
645 /// let conn = Connection::connect(
646 /// "localhost:61613",
647 /// "guest",
648 /// "guest",
649 /// Connection::NO_HEARTBEAT,
650 /// ).await?;
651 /// ```
652 pub const NO_HEARTBEAT: &'static str = "0,0";
653
654 /// Default heartbeat value: 10 seconds for both send and receive.
655 ///
656 /// This is a reasonable default for most applications. The actual heartbeat
657 /// interval will be negotiated with the server (taking the maximum of client
658 /// and server preferences).
659 ///
660 /// # Example
661 ///
662 /// ```ignore
663 /// let conn = Connection::connect(
664 /// "localhost:61613",
665 /// "guest",
666 /// "guest",
667 /// Connection::DEFAULT_HEARTBEAT,
668 /// ).await?;
669 /// ```
670 pub const DEFAULT_HEARTBEAT: &'static str = "10000,10000";
671
672 /// Establish a connection to the STOMP server at `addr` with the given
673 /// credentials and heartbeat header string (e.g. "10000,10000").
674 ///
675 /// This is a convenience wrapper around `connect_with_options()` that uses
676 /// default options (STOMP 1.2, host="/", no client-id).
677 ///
678 /// Parameters
679 /// - `addr`: TCP address (host:port) of the STOMP server.
680 /// - `login`: login username for STOMP `CONNECT`.
681 /// - `passcode`: passcode for STOMP `CONNECT`.
682 /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
683 /// milliseconds) that will be sent in the `CONNECT` frame.
684 ///
685 /// Returns a `Connection` which provides `send_frame`, `next_frame`, and
686 /// `close` helpers. The detailed connection handling (I/O, heartbeats,
687 /// reconnects) runs on a background task spawned by this method.
688 pub async fn connect(
689 addr: &str,
690 login: &str,
691 passcode: &str,
692 client_hb: &str,
693 ) -> Result<Self, ConnError> {
694 Self::connect_with_options(addr, login, passcode, client_hb, ConnectOptions::default())
695 .await
696 }
697
698 /// Establish a connection to the STOMP server with custom options.
699 ///
700 /// Use this method when you need to set a custom `client-id` (for durable
701 /// subscriptions), specify a virtual host, negotiate different STOMP
702 /// versions, or add custom CONNECT headers.
703 ///
704 /// Parameters
705 /// - `addr`: TCP address (host:port) of the STOMP server.
706 /// - `login`: login username for STOMP `CONNECT`.
707 /// - `passcode`: passcode for STOMP `CONNECT`.
708 /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
709 /// milliseconds) that will be sent in the `CONNECT` frame.
710 /// - `options`: custom connection options (version, host, client-id, etc.).
711 ///
712 /// # Errors
713 ///
714 /// Returns an error if:
715 /// - The TCP connection cannot be established (`ConnError::Io`)
716 /// - The server rejects the connection, e.g., due to invalid credentials
717 /// (`ConnError::ServerRejected`)
718 /// - The server closes the connection without responding (`ConnError::Protocol`)
719 ///
720 /// # Example
721 ///
722 /// ```ignore
723 /// use iridium_stomp::{Connection, ConnectOptions};
724 ///
725 /// // Connect with a client-id for durable subscriptions
726 /// let options = ConnectOptions::default()
727 /// .client_id("my-app-instance-1");
728 ///
729 /// let conn = Connection::connect_with_options(
730 /// "localhost:61613",
731 /// "guest",
732 /// "guest",
733 /// "10000,10000",
734 /// options,
735 /// ).await?;
736 /// ```
737 pub async fn connect_with_options(
738 addr: &str,
739 login: &str,
740 passcode: &str,
741 client_hb: &str,
742 options: ConnectOptions,
743 ) -> Result<Self, ConnError> {
744 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(32);
745 let (in_tx, in_rx) = mpsc::channel::<Frame>(32);
746 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
747 let sub_id_counter = Arc::new(AtomicU64::new(1));
748 let (shutdown_tx, _) = broadcast::channel::<()>(1);
749 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
750 let pending_clone = pending.clone();
751 let pending_receipts: Arc<Mutex<PendingReceipts>> = Arc::new(Mutex::new(HashMap::new()));
752 let pending_receipts_clone = pending_receipts.clone();
753
754 let addr = addr.to_string();
755 let login = login.to_string();
756 let passcode = passcode.to_string();
757 let client_hb = client_hb.to_string();
758
759 // Extract options into owned values for the spawned task
760 let accept_version = options.accept_version.unwrap_or_else(|| "1.2".to_string());
761 let host = options.host.unwrap_or_else(|| "/".to_string());
762 let client_id = options.client_id;
763 let custom_headers = options.headers;
764 let heartbeat_notify_tx = options.heartbeat_tx;
765
766 // Perform initial connection and STOMP handshake before spawning background task.
767 // This ensures authentication errors are returned to the caller immediately.
768 let stream = TcpStream::connect(&addr).await?;
769 let mut framed = Framed::new(stream, StompCodec::new());
770
771 // Build and send CONNECT frame
772 let connect = Self::build_connect_frame(
773 &accept_version,
774 &host,
775 &login,
776 &passcode,
777 &client_hb,
778 &client_id,
779 &custom_headers,
780 );
781
782 framed
783 .send(StompItem::Frame(connect))
784 .await
785 .map_err(|e| ConnError::Io(std::io::Error::other(e)))?;
786
787 // Wait for CONNECTED or ERROR response
788 let server_heartbeat = Self::await_connected_response(&mut framed).await?;
789
790 // Calculate heartbeat intervals
791 let (cx, cy) = parse_heartbeat_header(&client_hb);
792 let (sx, sy) = parse_heartbeat_header(&server_heartbeat);
793 let (send_interval, recv_interval) = negotiate_heartbeats(cx, cy, sx, sy);
794
795 // Now spawn background task for ongoing I/O and reconnection
796 let shutdown_tx_clone = shutdown_tx.clone();
797 let subscriptions_clone = subscriptions.clone();
798
799 tokio::spawn(async move {
800 let mut backoff_secs: u64 = 1;
801
802 // Use the already-established connection for the first iteration
803 let mut current_framed = Some(framed);
804 let mut current_send_interval = send_interval;
805 let mut current_recv_interval = recv_interval;
806
807 // Track subscription errors across reconnections. If a subscription
808 // receives too many consecutive errors, we remove it to prevent
809 // error loops (e.g., Artemis sending repeated permission errors).
810 let mut subscription_errors: HashMap<String, u32> = HashMap::new();
811 // Track subscription IDs that have been abandoned so we can ignore
812 // subsequent errors for them.
813 let mut abandoned_sub_ids: std::collections::HashSet<String> =
814 std::collections::HashSet::new();
815 const SUBSCRIPTION_ERROR_THRESHOLD: u32 = 3;
816
817 loop {
818 let mut shutdown_sub = shutdown_tx_clone.subscribe();
819
820 // Check for shutdown before attempting connection
821 tokio::select! {
822 biased;
823 _ = shutdown_sub.recv() => break,
824 _ = future::ready(()) => {},
825 }
826
827 // Either use existing connection or establish new one (reconnect)
828 let framed = if let Some(f) = current_framed.take() {
829 f
830 } else {
831 // Reconnection attempt
832 match TcpStream::connect(&addr).await {
833 Ok(stream) => {
834 let mut framed = Framed::new(stream, StompCodec::new());
835
836 let connect = Self::build_connect_frame(
837 &accept_version,
838 &host,
839 &login,
840 &passcode,
841 &client_hb,
842 &client_id,
843 &custom_headers,
844 );
845
846 if framed.send(StompItem::Frame(connect)).await.is_err() {
847 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
848 backoff_secs = (backoff_secs * 2).min(30);
849 continue;
850 }
851
852 // Wait for CONNECTED (on reconnect, silently retry on ERROR)
853 match Self::await_connected_response(&mut framed).await {
854 Ok(server_hb) => {
855 let (cx, cy) = parse_heartbeat_header(&client_hb);
856 let (sx, sy) = parse_heartbeat_header(&server_hb);
857 let (si, ri) = negotiate_heartbeats(cx, cy, sx, sy);
858 current_send_interval = si;
859 current_recv_interval = ri;
860 framed
861 }
862 Err(_) => {
863 // Reconnect failed (auth error or other), retry with backoff
864 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
865 backoff_secs = (backoff_secs * 2).min(30);
866 continue;
867 }
868 }
869 }
870 Err(_) => {
871 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
872 backoff_secs = (backoff_secs * 2).min(30);
873 continue;
874 }
875 }
876 };
877
878 let (send_interval, recv_interval) = (current_send_interval, current_recv_interval);
879
880 let last_received = Arc::new(AtomicU64::new(current_millis()));
881 let writer_last_sent = Arc::new(AtomicU64::new(current_millis()));
882
883 let (mut sink, mut stream) = framed.split();
884 let in_tx = in_tx.clone();
885 let subscriptions = subscriptions_clone.clone();
886
887 // Clear pending message map on reconnect — messages that were
888 // outstanding before the disconnect are considered lost and
889 // will be redelivered by the server as appropriate.
890 {
891 let mut p = pending_clone.lock().await;
892 p.clear();
893 }
894
895 // Resubscribe any existing subscriptions after reconnect.
896 // We snapshot the subscription entries while holding the lock
897 // and then issue SUBSCRIBE frames using the sink.
898 let subs_snapshot: Vec<ResubEntry> = {
899 let map = subscriptions.lock().await;
900 let mut v: Vec<ResubEntry> = Vec::new();
901 for (dest, vec) in map.iter() {
902 for entry in vec.iter() {
903 v.push((
904 dest.clone(),
905 entry.id.clone(),
906 entry.ack.clone(),
907 entry.headers.clone(),
908 ));
909 }
910 }
911 v
912 };
913
914 for (dest, id, ack, headers) in subs_snapshot {
915 let mut sf = Frame::new("SUBSCRIBE");
916 sf = sf
917 .header("id", &id)
918 .header("destination", &dest)
919 .header("ack", &ack);
920 for (k, v) in headers {
921 sf = sf.header(&k, &v);
922 }
923 let _ = sink.send(StompItem::Frame(sf)).await;
924 }
925
926 let mut hb_tick = match send_interval {
927 Some(d) => tokio::time::interval(d),
928 None => tokio::time::interval(Duration::from_secs(86400)),
929 };
930 let watchdog_half = recv_interval.map(|d| d / 2);
931
932 let conn_start = tokio::time::Instant::now();
933
934 'conn: loop {
935 tokio::select! {
936 _ = shutdown_sub.recv() => { let _ = sink.close().await; break 'conn; }
937 maybe = out_rx.recv() => {
938 match maybe {
939 Some(item) => if sink.send(item).await.is_err() { break 'conn } else { writer_last_sent.store(current_millis(), Ordering::SeqCst); }
940 None => break 'conn,
941 }
942 }
943 item = stream.next() => {
944 match item {
945 Some(Ok(StompItem::Heartbeat)) => {
946 last_received.store(current_millis(), Ordering::SeqCst);
947 if let Some(ref tx) = heartbeat_notify_tx {
948 let _ = tx.try_send(());
949 }
950 }
951 Some(Ok(StompItem::Frame(f))) => {
952 last_received.store(current_millis(), Ordering::SeqCst);
953 // Dispatch MESSAGE frames to any matching subscribers.
954 if f.command == "MESSAGE" {
955 // try to find destination, subscription and message-id headers
956 let mut dest_opt: Option<String> = None;
957 let mut sub_opt: Option<String> = None;
958 let mut msg_id_opt: Option<String> = None;
959 for (k, v) in &f.headers {
960 let kl = k.to_lowercase();
961 if kl == "destination" {
962 dest_opt = Some(v.clone());
963 } else if kl == "subscription" {
964 sub_opt = Some(v.clone());
965 } else if kl == "message-id" {
966 msg_id_opt = Some(v.clone());
967 }
968 }
969
970 // Determine whether we need to track this message as pending
971 let mut need_pending = false;
972 if let Some(sub_id) = &sub_opt {
973 let map = subscriptions.lock().await;
974 for (_dest, vec) in map.iter() {
975 for entry in vec.iter() {
976 if &entry.id == sub_id && entry.ack != "auto" {
977 need_pending = true;
978 }
979 }
980 }
981 } else if let Some(dest) = &dest_opt {
982 let map = subscriptions.lock().await;
983 if let Some(vec) = map.get(dest) {
984 for entry in vec.iter() {
985 if entry.ack != "auto" {
986 need_pending = true;
987 break;
988 }
989 }
990 }
991 }
992
993 // If required, add to pending map (per-subscription) before
994 // delivery so ACK/NACK requests from the application can
995 // reference the message. We require a `message-id` header
996 // to track messages; if missing, we cannot support ACK/NACK.
997 if let Some(msg_id) = msg_id_opt.clone().filter(|_| need_pending) {
998 // If the server provided a subscription id in the
999 // MESSAGE, store pending under that subscription.
1000 if let Some(sub_id) = &sub_opt {
1001 let mut p = pending_clone.lock().await;
1002 let q = p
1003 .entry(sub_id.clone())
1004 .or_insert_with(VecDeque::new);
1005 q.push_back((msg_id.clone(), f.clone()));
1006 } else if let Some(dest) = &dest_opt {
1007 // Destination-based delivery: add the message to
1008 // the pending queue for each matching
1009 // subscription on that destination.
1010 let map = subscriptions.lock().await;
1011 if let Some(vec) = map.get(dest) {
1012 let mut p = pending_clone.lock().await;
1013 for entry in vec.iter() {
1014 let q = p
1015 .entry(entry.id.clone())
1016 .or_insert_with(VecDeque::new);
1017 q.push_back((msg_id.clone(), f.clone()));
1018 }
1019 }
1020 }
1021 }
1022
1023 // Deliver to subscribers.
1024 if let Some(sub_id) = sub_opt {
1025 let mut map = subscriptions.lock().await;
1026 for (_dest, vec) in map.iter_mut() {
1027 vec.retain(|entry| {
1028 if entry.id == sub_id {
1029 let _ = entry.sender.try_send(f.clone());
1030 true
1031 } else {
1032 true
1033 }
1034 });
1035 }
1036 } else if let Some(dest) = dest_opt {
1037 let mut map = subscriptions.lock().await;
1038 if let Some(vec) = map.get_mut(&dest) {
1039 vec.retain(|entry| entry.sender.try_send(f.clone()).is_ok());
1040 }
1041 }
1042 } else if f.command == "RECEIPT" {
1043 // Handle RECEIPT frame: notify any waiting callers
1044 if let Some(receipt_id) = f.get_header("receipt-id") {
1045 let mut receipts = pending_receipts_clone.lock().await;
1046 if let Some(sender) = receipts.remove(receipt_id) {
1047 let _ = sender.send(());
1048 }
1049 }
1050 // Don't forward RECEIPT frames to inbound channel
1051 continue;
1052 } else if f.command == "ERROR" {
1053 // Track subscription-related errors. If we see repeated
1054 // errors for the same destination, remove the subscription
1055 // to prevent error loops.
1056 //
1057 // First, check if this error is for an already-abandoned
1058 // subscription (Artemis keeps sending errors after we abandon).
1059 let sub_id = extract_subscription_id_from_error(&f);
1060 if let Some(ref id) = sub_id
1061 && abandoned_sub_ids.contains(id)
1062 {
1063 // Skip this error - subscription already abandoned
1064 continue;
1065 }
1066
1067 // Try to identify the destination:
1068 // 1. Extract directly from ERROR frame
1069 // 2. Look up by subscription ID (Artemis uses "subscription N")
1070 let dest = if let Some(d) = extract_destination_from_error(&f)
1071 {
1072 Some(d)
1073 } else if let Some(ref id) = sub_id {
1074 lookup_destination_by_sub_id(id, &subscriptions).await
1075 } else {
1076 None
1077 };
1078
1079 if let Some(dest) = dest {
1080 let count = {
1081 let c = subscription_errors
1082 .entry(dest.clone())
1083 .or_insert(0);
1084 *c += 1;
1085 *c
1086 };
1087
1088 if count >= SUBSCRIPTION_ERROR_THRESHOLD {
1089 // Remove the subscription from auto-resubscribe
1090 let mut map = subscriptions.lock().await;
1091 if map.remove(&dest).is_some() {
1092 // Track the subscription ID as abandoned
1093 if let Some(id) = sub_id {
1094 abandoned_sub_ids.insert(id);
1095 }
1096 // Send abandonment notification
1097 let msg = format!(
1098 "Subscription abandoned: {} errors for {}",
1099 count, dest
1100 );
1101 let abandon_frame = Frame::new("ERROR")
1102 .header("message", &msg)
1103 .header("destination", &dest)
1104 .header("x-abandoned", "true");
1105 let _ = in_tx.send(abandon_frame).await;
1106 }
1107 }
1108 }
1109 }
1110
1111 let _ = in_tx.send(f).await;
1112 }
1113 Some(Err(_)) | None => break 'conn,
1114 }
1115 }
1116 _ = hb_tick.tick() => {
1117 if let Some(dur) = send_interval {
1118 let last = writer_last_sent.load(Ordering::SeqCst);
1119 if current_millis().saturating_sub(last) >= dur.as_millis() as u64 {
1120 if sink.send(StompItem::Heartbeat).await.is_err() { break 'conn; }
1121 writer_last_sent.store(current_millis(), Ordering::SeqCst);
1122 }
1123 }
1124 }
1125 _ = async { if let Some(interval) = watchdog_half { tokio::time::sleep(interval).await } else { future::pending::<()>().await } } => {
1126 if let Some(recv_dur) = recv_interval {
1127 let last = last_received.load(Ordering::SeqCst);
1128 if current_millis().saturating_sub(last) > (recv_dur.as_millis() as u64 * 2) {
1129 let _ = sink.close().await; break 'conn;
1130 }
1131 }
1132 }
1133 }
1134 }
1135
1136 if shutdown_sub.try_recv().is_ok() {
1137 break;
1138 }
1139 if conn_start.elapsed() >= Duration::from_secs(backoff_secs.max(5)) {
1140 // Connection was stable — reset backoff
1141 backoff_secs = 1;
1142 } else {
1143 // Connection died quickly — increase backoff
1144 backoff_secs = (backoff_secs * 2).min(30);
1145 }
1146 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
1147 }
1148 });
1149
1150 Ok(Connection {
1151 outbound_tx: out_tx,
1152 inbound_rx: Arc::new(Mutex::new(in_rx)),
1153 shutdown_tx,
1154 subscriptions,
1155 sub_id_counter,
1156 pending,
1157 pending_receipts,
1158 })
1159 }
1160
1161 /// Build a CONNECT frame with all specified headers.
1162 fn build_connect_frame(
1163 accept_version: &str,
1164 host: &str,
1165 login: &str,
1166 passcode: &str,
1167 heartbeat: &str,
1168 client_id: &Option<String>,
1169 custom_headers: &[(String, String)],
1170 ) -> Frame {
1171 let mut connect = Frame::new("CONNECT")
1172 .header("accept-version", accept_version)
1173 .header("host", host)
1174 .header("login", login)
1175 .header("passcode", passcode)
1176 .header("heart-beat", heartbeat);
1177
1178 if let Some(id) = client_id {
1179 connect = connect.header("client-id", id);
1180 }
1181
1182 // Reserved headers that custom_headers cannot override
1183 let reserved = [
1184 "accept-version",
1185 "host",
1186 "login",
1187 "passcode",
1188 "heart-beat",
1189 "client-id",
1190 ];
1191
1192 for (k, v) in custom_headers {
1193 if !reserved.contains(&k.to_lowercase().as_str()) {
1194 connect = connect.header(k, v);
1195 }
1196 }
1197
1198 connect
1199 }
1200
1201 /// Wait for CONNECTED or ERROR response from the server.
1202 ///
1203 /// Returns the server's heartbeat header value on success, or an error
1204 /// if the server sends an ERROR frame or closes the connection.
1205 async fn await_connected_response(
1206 framed: &mut Framed<TcpStream, StompCodec>,
1207 ) -> Result<String, ConnError> {
1208 loop {
1209 match framed.next().await {
1210 Some(Ok(StompItem::Frame(f))) => {
1211 if f.command == "CONNECTED" {
1212 // Extract heartbeat from server
1213 let server_hb = f.get_header("heart-beat").unwrap_or("0,0").to_string();
1214 return Ok(server_hb);
1215 } else if f.command == "ERROR" {
1216 // Server rejected connection (e.g., invalid credentials)
1217 return Err(ConnError::ServerRejected(ServerError::from_frame(f)));
1218 }
1219 // Ignore other frames during CONNECT phase
1220 }
1221 Some(Ok(StompItem::Heartbeat)) => {
1222 // Ignore heartbeats during handshake
1223 continue;
1224 }
1225 Some(Err(e)) => {
1226 return Err(ConnError::Io(e));
1227 }
1228 None => {
1229 return Err(ConnError::Protocol(
1230 "connection closed before CONNECTED received".to_string(),
1231 ));
1232 }
1233 }
1234 }
1235 }
1236
1237 pub async fn send_frame(&self, frame: Frame) -> Result<(), ConnError> {
1238 // Send a frame to the background writer task.
1239 //
1240 // Parameters
1241 // - `frame`: ownership of the `Frame` to send. The frame is converted
1242 // into a `StompItem::Frame` and sent over the internal mpsc channel.
1243 self.outbound_tx
1244 .send(StompItem::Frame(frame))
1245 .await
1246 .map_err(|_| ConnError::Protocol("send channel closed".into()))
1247 }
1248
1249 /// Generate a unique receipt ID.
1250 fn generate_receipt_id() -> String {
1251 static RECEIPT_COUNTER: AtomicU64 = AtomicU64::new(1);
1252 format!("rcpt-{}", RECEIPT_COUNTER.fetch_add(1, Ordering::SeqCst))
1253 }
1254
1255 /// Send a frame with a receipt request and return the receipt ID.
1256 ///
1257 /// This method adds a unique `receipt` header to the frame and registers
1258 /// the receipt ID for tracking. Use `wait_for_receipt()` to wait for the
1259 /// server's RECEIPT response.
1260 ///
1261 /// # Parameters
1262 /// - `frame`: the frame to send. A `receipt` header will be added.
1263 ///
1264 /// # Returns
1265 /// The generated receipt ID that can be used with `wait_for_receipt()`.
1266 ///
1267 /// # Example
1268 /// ```ignore
1269 /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
1270 /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
1271 /// ```
1272 pub async fn send_frame_with_receipt(&self, frame: Frame) -> Result<String, ConnError> {
1273 let receipt_id = Self::generate_receipt_id();
1274
1275 // Create the oneshot channel for notification
1276 let (tx, _rx) = oneshot::channel();
1277
1278 // Register the pending receipt
1279 {
1280 let mut receipts = self.pending_receipts.lock().await;
1281 receipts.insert(receipt_id.clone(), tx);
1282 }
1283
1284 // Add receipt header and send the frame
1285 let frame_with_receipt = frame.receipt(&receipt_id);
1286 self.send_frame(frame_with_receipt).await?;
1287
1288 Ok(receipt_id)
1289 }
1290
1291 /// Wait for a receipt confirmation from the server.
1292 ///
1293 /// This method blocks until the server sends a RECEIPT frame with the
1294 /// matching receipt-id, or until the timeout expires.
1295 ///
1296 /// # Parameters
1297 /// - `receipt_id`: the receipt ID returned by `send_frame_with_receipt()`.
1298 /// - `timeout`: maximum time to wait for the receipt.
1299 ///
1300 /// # Returns
1301 /// `Ok(())` if the receipt was received, or `Err(ConnError::ReceiptTimeout)`
1302 /// if the timeout expired.
1303 ///
1304 /// # Example
1305 /// ```ignore
1306 /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
1307 /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
1308 /// println!("Message confirmed!");
1309 /// ```
1310 pub async fn wait_for_receipt(
1311 &self,
1312 receipt_id: &str,
1313 timeout: Duration,
1314 ) -> Result<(), ConnError> {
1315 // Get the receiver for this receipt
1316 let rx = {
1317 let mut receipts = self.pending_receipts.lock().await;
1318 // Re-create the oneshot channel and swap out the sender
1319 let (tx, rx) = oneshot::channel();
1320 if let Some(old_tx) = receipts.insert(receipt_id.to_string(), tx) {
1321 // Drop the old sender - this is expected if called after send_frame_with_receipt
1322 drop(old_tx);
1323 }
1324 rx
1325 };
1326
1327 // Wait for the receipt with timeout
1328 match tokio::time::timeout(timeout, rx).await {
1329 Ok(Ok(())) => Ok(()),
1330 Ok(Err(_)) => {
1331 // Channel was closed without receiving - connection likely dropped
1332 Err(ConnError::Protocol(
1333 "receipt channel closed unexpectedly".into(),
1334 ))
1335 }
1336 Err(_) => {
1337 // Timeout expired - clean up the pending receipt
1338 let mut receipts = self.pending_receipts.lock().await;
1339 receipts.remove(receipt_id);
1340 Err(ConnError::ReceiptTimeout(receipt_id.to_string()))
1341 }
1342 }
1343 }
1344
1345 /// Send a frame and wait for server confirmation via RECEIPT.
1346 ///
1347 /// This is a convenience method that combines `send_frame_with_receipt()`
1348 /// and `wait_for_receipt()`. Use this when you want to ensure a frame
1349 /// was processed by the server before continuing.
1350 ///
1351 /// # Parameters
1352 /// - `frame`: the frame to send.
1353 /// - `timeout`: maximum time to wait for the receipt.
1354 ///
1355 /// # Returns
1356 /// `Ok(())` if the frame was sent and receipt confirmed, or an error if
1357 /// sending failed or the receipt timed out.
1358 ///
1359 /// # Example
1360 /// ```ignore
1361 /// let frame = Frame::new("SEND")
1362 /// .header("destination", "/queue/orders")
1363 /// .set_body(b"order data".to_vec());
1364 ///
1365 /// conn.send_frame_confirmed(frame, Duration::from_secs(5)).await?;
1366 /// println!("Order sent and confirmed!");
1367 /// ```
1368 pub async fn send_frame_confirmed(
1369 &self,
1370 frame: Frame,
1371 timeout: Duration,
1372 ) -> Result<(), ConnError> {
1373 let receipt_id = Self::generate_receipt_id();
1374
1375 // Create the oneshot channel for notification
1376 let (tx, rx) = oneshot::channel();
1377
1378 // Register the pending receipt before sending
1379 {
1380 let mut receipts = self.pending_receipts.lock().await;
1381 receipts.insert(receipt_id.clone(), tx);
1382 }
1383
1384 // Add receipt header and send the frame
1385 let frame_with_receipt = frame.receipt(&receipt_id);
1386 self.send_frame(frame_with_receipt).await?;
1387
1388 // Wait for the receipt with timeout
1389 match tokio::time::timeout(timeout, rx).await {
1390 Ok(Ok(())) => Ok(()),
1391 Ok(Err(_)) => Err(ConnError::Protocol(
1392 "receipt channel closed unexpectedly".into(),
1393 )),
1394 Err(_) => {
1395 // Timeout expired - clean up
1396 let mut receipts = self.pending_receipts.lock().await;
1397 receipts.remove(&receipt_id);
1398 Err(ConnError::ReceiptTimeout(receipt_id))
1399 }
1400 }
1401 }
1402
1403 /// Subscribe to a destination.
1404 ///
1405 /// Parameters
1406 /// - `destination`: the STOMP destination to subscribe to (e.g. "/queue/foo").
1407 /// - `ack`: acknowledgement mode to request from the server.
1408 ///
1409 /// Returns a tuple `(subscription_id, receiver)` where `subscription_id` is
1410 /// the opaque id assigned locally for this subscription and `receiver` is a
1411 /// `mpsc::Receiver<Frame>` which will yield incoming MESSAGE frames for the
1412 /// destination. The caller should read from the receiver to handle messages.
1413 /// Subscribe to a destination using optional extra headers.
1414 ///
1415 /// This variant accepts additional headers which are stored locally and
1416 /// re-sent on reconnect. Use `subscribe` as a convenience wrapper when no
1417 /// extra headers are needed.
1418 pub async fn subscribe_with_headers(
1419 &self,
1420 destination: &str,
1421 ack: AckMode,
1422 extra_headers: Vec<(String, String)>,
1423 ) -> Result<crate::subscription::Subscription, ConnError> {
1424 let id = self
1425 .sub_id_counter
1426 .fetch_add(1, Ordering::SeqCst)
1427 .to_string();
1428 let (tx, rx) = mpsc::channel::<Frame>(16);
1429 {
1430 let mut map = self.subscriptions.lock().await;
1431 map.entry(destination.to_string())
1432 .or_insert_with(Vec::new)
1433 .push(SubscriptionEntry {
1434 id: id.clone(),
1435 sender: tx.clone(),
1436 ack: ack.as_str().to_string(),
1437 headers: extra_headers.clone(),
1438 });
1439 }
1440
1441 let mut f = Frame::new("SUBSCRIBE");
1442 f = f
1443 .header("id", &id)
1444 .header("destination", destination)
1445 .header("ack", ack.as_str());
1446 for (k, v) in &extra_headers {
1447 f = f.header(k, v);
1448 }
1449 self.outbound_tx
1450 .send(StompItem::Frame(f))
1451 .await
1452 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1453
1454 Ok(crate::subscription::Subscription::new(
1455 id,
1456 destination.to_string(),
1457 rx,
1458 self.clone(),
1459 ))
1460 }
1461
1462 /// Convenience wrapper without extra headers.
1463 pub async fn subscribe(
1464 &self,
1465 destination: &str,
1466 ack: AckMode,
1467 ) -> Result<crate::subscription::Subscription, ConnError> {
1468 self.subscribe_with_headers(destination, ack, Vec::new())
1469 .await
1470 }
1471
1472 /// Subscribe with a typed `SubscriptionOptions` structure.
1473 ///
1474 /// `SubscriptionOptions.headers` are forwarded to the broker and persisted
1475 /// for automatic resubscribe after reconnect. If `durable_queue` is set,
1476 /// it will be used as the actual destination instead of `destination`.
1477 pub async fn subscribe_with_options(
1478 &self,
1479 destination: &str,
1480 ack: AckMode,
1481 options: crate::subscription::SubscriptionOptions,
1482 ) -> Result<crate::subscription::Subscription, ConnError> {
1483 let dest = options
1484 .durable_queue
1485 .as_deref()
1486 .unwrap_or(destination)
1487 .to_string();
1488 self.subscribe_with_headers(&dest, ack, options.headers)
1489 .await
1490 }
1491
1492 /// Unsubscribe a previously created subscription by its local subscription id.
1493 pub async fn unsubscribe(&self, subscription_id: &str) -> Result<(), ConnError> {
1494 let mut found = false;
1495 {
1496 let mut map = self.subscriptions.lock().await;
1497 let mut remove_keys: Vec<String> = Vec::new();
1498 for (dest, vec) in map.iter_mut() {
1499 if let Some(pos) = vec.iter().position(|entry| entry.id == subscription_id) {
1500 vec.remove(pos);
1501 found = true;
1502 }
1503 if vec.is_empty() {
1504 remove_keys.push(dest.clone());
1505 }
1506 }
1507 for k in remove_keys {
1508 map.remove(&k);
1509 }
1510 }
1511
1512 if !found {
1513 return Err(ConnError::Protocol("subscription id not found".into()));
1514 }
1515
1516 let mut f = Frame::new("UNSUBSCRIBE");
1517 f = f.header("id", subscription_id);
1518 self.outbound_tx
1519 .send(StompItem::Frame(f))
1520 .await
1521 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1522
1523 Ok(())
1524 }
1525
1526 /// Acknowledge a message previously received in `client` or
1527 /// `client-individual` ack modes.
1528 ///
1529 /// STOMP ack semantics:
1530 /// - `auto`: server considers message delivered immediately; the client
1531 /// should not ack.
1532 /// - `client`: cumulative acknowledgements. ACKing message `M` for
1533 /// subscription `S` acknowledges all messages delivered to `S` up to
1534 /// and including `M`.
1535 /// - `client-individual`: only the named message is acknowledged.
1536 ///
1537 /// Parameters
1538 /// - `subscription_id`: the local subscription id returned by
1539 /// `Connection::subscribe`. This disambiguates which subscription's
1540 /// pending queue to advance for cumulative ACKs.
1541 /// - `message_id`: the `message-id` header value from the received
1542 /// MESSAGE frame to acknowledge.
1543 ///
1544 /// Behavior
1545 /// - The pending queue for `subscription_id` is searched for `message_id`.
1546 /// If the subscription used `client` ack mode, all pending messages up to
1547 /// and including the matched message are removed. If the subscription
1548 /// used `client-individual`, only the matched message is removed.
1549 /// - An `ACK` frame is sent to the server with `id=<message_id>` and
1550 /// `subscription=<subscription_id>` headers.
1551 #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1552 pub async fn ack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1553 // Remove from the local pending queue according to subscription ack mode.
1554 let mut removed_any = false;
1555 {
1556 let mut p = self.pending.lock().await;
1557 if let Some(queue) = p.get_mut(subscription_id) {
1558 if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1559 // Determine ack mode for this subscription (default to client).
1560 let mut ack_mode = "client".to_string();
1561 {
1562 let map = self.subscriptions.lock().await;
1563 'outer: for (_dest, vec) in map.iter() {
1564 for entry in vec.iter() {
1565 if entry.id == subscription_id {
1566 ack_mode = entry.ack.clone();
1567 break 'outer;
1568 }
1569 }
1570 }
1571 }
1572
1573 if ack_mode == "client" {
1574 // cumulative: remove up to and including pos
1575 for _ in 0..=pos {
1576 queue.pop_front();
1577 removed_any = true;
1578 }
1579 } else if queue.remove(pos).is_some() {
1580 // client-individual: remove only the specific message
1581 removed_any = true;
1582 }
1583
1584 if queue.is_empty() {
1585 p.remove(subscription_id);
1586 }
1587 }
1588 }
1589 }
1590
1591 // Send ACK to server (include subscription header for clarity)
1592 let mut f = Frame::new("ACK");
1593 f = f
1594 .header("id", message_id)
1595 .header("subscription", subscription_id);
1596 self.outbound_tx
1597 .send(StompItem::Frame(f))
1598 .await
1599 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1600
1601 // If message wasn't found locally, still send ACK to server; server
1602 // may ignore or treat it as no-op.
1603 let _ = removed_any;
1604 Ok(())
1605 }
1606
1607 /// Negative-acknowledge a message (NACK).
1608 ///
1609 /// Parameters
1610 /// - `subscription_id`: the local subscription id the message was delivered under.
1611 /// - `message_id`: the `message-id` header value from the received MESSAGE.
1612 ///
1613 /// Behavior
1614 /// - Removes the message from the local pending queue (cumulatively if the
1615 /// subscription used `client` ack mode, otherwise only the single
1616 /// message). Sends a `NACK` frame to the server with `id` and
1617 /// `subscription` headers.
1618 #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1619 pub async fn nack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1620 // Mirror ack removal semantics for pending map.
1621 let mut removed_any = false;
1622 {
1623 let mut p = self.pending.lock().await;
1624 if let Some(queue) = p.get_mut(subscription_id) {
1625 if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1626 let mut ack_mode = "client".to_string();
1627 {
1628 let map = self.subscriptions.lock().await;
1629 'outer2: for (_dest, vec) in map.iter() {
1630 for entry in vec.iter() {
1631 if entry.id == subscription_id {
1632 ack_mode = entry.ack.clone();
1633 break 'outer2;
1634 }
1635 }
1636 }
1637 }
1638
1639 if ack_mode == "client" {
1640 for _ in 0..=pos {
1641 queue.pop_front();
1642 removed_any = true;
1643 }
1644 } else if queue.remove(pos).is_some() {
1645 removed_any = true;
1646 }
1647
1648 if queue.is_empty() {
1649 p.remove(subscription_id);
1650 }
1651 }
1652 }
1653 }
1654
1655 let mut f = Frame::new("NACK");
1656 f = f
1657 .header("id", message_id)
1658 .header("subscription", subscription_id);
1659 self.outbound_tx
1660 .send(StompItem::Frame(f))
1661 .await
1662 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1663
1664 let _ = removed_any;
1665 Ok(())
1666 }
1667
1668 /// Helper to send a transaction frame (BEGIN, COMMIT, or ABORT).
1669 async fn send_transaction_frame(
1670 &self,
1671 command: &str,
1672 transaction_id: &str,
1673 ) -> Result<(), ConnError> {
1674 let f = Frame::new(command).header("transaction", transaction_id);
1675 self.outbound_tx
1676 .send(StompItem::Frame(f))
1677 .await
1678 .map_err(|_| ConnError::Protocol("send channel closed".into()))
1679 }
1680
1681 /// Begin a transaction.
1682 ///
1683 /// Parameters
1684 /// - `transaction_id`: unique identifier for the transaction. The caller is
1685 /// responsible for ensuring uniqueness within the connection.
1686 ///
1687 /// Behavior
1688 /// - Sends a `BEGIN` frame to the server with `transaction:<transaction_id>`
1689 /// header. Subsequent `SEND`, `ACK`, and `NACK` frames may include this
1690 /// transaction id to group them into the transaction. The transaction must
1691 /// be finalized with either `commit` or `abort`.
1692 pub async fn begin(&self, transaction_id: &str) -> Result<(), ConnError> {
1693 self.send_transaction_frame("BEGIN", transaction_id).await
1694 }
1695
1696 /// Commit a transaction.
1697 ///
1698 /// Parameters
1699 /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1700 ///
1701 /// Behavior
1702 /// - Sends a `COMMIT` frame to the server with `transaction:<transaction_id>`
1703 /// header. All operations within the transaction are applied atomically.
1704 pub async fn commit(&self, transaction_id: &str) -> Result<(), ConnError> {
1705 self.send_transaction_frame("COMMIT", transaction_id).await
1706 }
1707
1708 /// Abort a transaction.
1709 ///
1710 /// Parameters
1711 /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1712 ///
1713 /// Behavior
1714 /// - Sends an `ABORT` frame to the server with `transaction:<transaction_id>`
1715 /// header. All operations within the transaction are discarded.
1716 pub async fn abort(&self, transaction_id: &str) -> Result<(), ConnError> {
1717 self.send_transaction_frame("ABORT", transaction_id).await
1718 }
1719
1720 /// Receive the next frame from the server.
1721 ///
1722 /// Returns `Some(ReceivedFrame::Frame(..))` for normal frames (MESSAGE, etc.),
1723 /// `Some(ReceivedFrame::Error(..))` for ERROR frames, or `None` if the
1724 /// connection has been closed.
1725 ///
1726 /// # Example
1727 ///
1728 /// ```ignore
1729 /// use iridium_stomp::ReceivedFrame;
1730 ///
1731 /// while let Some(received) = conn.next_frame().await {
1732 /// match received {
1733 /// ReceivedFrame::Frame(frame) => {
1734 /// println!("Got {}: {:?}", frame.command, frame.body);
1735 /// }
1736 /// ReceivedFrame::Error(err) => {
1737 /// eprintln!("Server error: {}", err);
1738 /// break;
1739 /// }
1740 /// }
1741 /// }
1742 /// ```
1743 pub async fn next_frame(&self) -> Option<ReceivedFrame> {
1744 let mut rx = self.inbound_rx.lock().await;
1745 let frame = rx.recv().await?;
1746
1747 // Convert ERROR frames to ServerError for better ergonomics
1748 if frame.command == "ERROR" {
1749 Some(ReceivedFrame::Error(ServerError::from_frame(frame)))
1750 } else {
1751 Some(ReceivedFrame::Frame(frame))
1752 }
1753 }
1754
1755 pub async fn close(self) {
1756 // Signal the background task to shutdown by broadcasting on the
1757 // shutdown channel. Consumers may await task termination separately
1758 // if needed.
1759 let _ = self.shutdown_tx.send(());
1760 }
1761}
1762
1763fn current_millis() -> u64 {
1764 use std::time::{SystemTime, UNIX_EPOCH};
1765 SystemTime::now()
1766 .duration_since(UNIX_EPOCH)
1767 .map(|d| d.as_millis() as u64)
1768 .unwrap_or(0)
1769}
1770
1771#[cfg(test)]
1772mod tests {
1773 use super::*;
1774 use tokio::sync::mpsc;
1775
1776 // Helper to build a MESSAGE frame with given message-id and subscription/destination headers
1777 fn make_message(
1778 message_id: &str,
1779 subscription: Option<&str>,
1780 destination: Option<&str>,
1781 ) -> Frame {
1782 let mut f = Frame::new("MESSAGE");
1783 f = f.header("message-id", message_id);
1784 if let Some(s) = subscription {
1785 f = f.header("subscription", s);
1786 }
1787 if let Some(d) = destination {
1788 f = f.header("destination", d);
1789 }
1790 f
1791 }
1792
1793 #[tokio::test]
1794 async fn test_cumulative_ack_removes_prefix() {
1795 // setup channels
1796 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1797 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1798 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1799
1800 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1801 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1802
1803 let sub_id_counter = Arc::new(AtomicU64::new(1));
1804
1805 // create a subscription entry s1 with client (cumulative) ack
1806 let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1807 {
1808 let mut map = subscriptions.lock().await;
1809 map.insert(
1810 "/queue/x".to_string(),
1811 vec![SubscriptionEntry {
1812 id: "s1".to_string(),
1813 sender: sub_sender,
1814 ack: "client".to_string(),
1815 headers: Vec::new(),
1816 }],
1817 );
1818 }
1819
1820 // fill pending queue for s1: m1,m2,m3
1821 {
1822 let mut p = pending.lock().await;
1823 let mut q = VecDeque::new();
1824 q.push_back((
1825 "m1".to_string(),
1826 make_message("m1", Some("s1"), Some("/queue/x")),
1827 ));
1828 q.push_back((
1829 "m2".to_string(),
1830 make_message("m2", Some("s1"), Some("/queue/x")),
1831 ));
1832 q.push_back((
1833 "m3".to_string(),
1834 make_message("m3", Some("s1"), Some("/queue/x")),
1835 ));
1836 p.insert("s1".to_string(), q);
1837 }
1838
1839 let conn = Connection {
1840 outbound_tx: out_tx,
1841 inbound_rx: Arc::new(Mutex::new(in_rx)),
1842 shutdown_tx,
1843 subscriptions: subscriptions.clone(),
1844 sub_id_counter,
1845 pending: pending.clone(),
1846 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1847 };
1848
1849 // ack m2 cumulatively: should remove m1 and m2, leaving m3
1850 conn.ack("s1", "m2").await.expect("ack failed");
1851
1852 // verify pending for s1 contains only m3
1853 {
1854 let p = pending.lock().await;
1855 let q = p.get("s1").expect("missing s1");
1856 assert_eq!(q.len(), 1);
1857 assert_eq!(q.front().unwrap().0, "m3");
1858 }
1859
1860 // verify an ACK frame was emitted
1861 if let Some(item) = out_rx.recv().await {
1862 match item {
1863 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1864 _ => panic!("expected frame"),
1865 }
1866 } else {
1867 panic!("no outbound frame sent")
1868 }
1869 }
1870
1871 #[tokio::test]
1872 async fn test_client_individual_ack_removes_only_one() {
1873 // setup channels
1874 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1875 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1876 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1877
1878 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1879 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1880
1881 let sub_id_counter = Arc::new(AtomicU64::new(1));
1882
1883 // create a subscription entry s2 with client-individual ack
1884 let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1885 {
1886 let mut map = subscriptions.lock().await;
1887 map.insert(
1888 "/queue/y".to_string(),
1889 vec![SubscriptionEntry {
1890 id: "s2".to_string(),
1891 sender: sub_sender,
1892 ack: "client-individual".to_string(),
1893 headers: Vec::new(),
1894 }],
1895 );
1896 }
1897
1898 // fill pending queue for s2: a,b,c
1899 {
1900 let mut p = pending.lock().await;
1901 let mut q = VecDeque::new();
1902 q.push_back((
1903 "a".to_string(),
1904 make_message("a", Some("s2"), Some("/queue/y")),
1905 ));
1906 q.push_back((
1907 "b".to_string(),
1908 make_message("b", Some("s2"), Some("/queue/y")),
1909 ));
1910 q.push_back((
1911 "c".to_string(),
1912 make_message("c", Some("s2"), Some("/queue/y")),
1913 ));
1914 p.insert("s2".to_string(), q);
1915 }
1916
1917 let conn = Connection {
1918 outbound_tx: out_tx,
1919 inbound_rx: Arc::new(Mutex::new(in_rx)),
1920 shutdown_tx,
1921 subscriptions: subscriptions.clone(),
1922 sub_id_counter,
1923 pending: pending.clone(),
1924 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1925 };
1926
1927 // ack only 'b' individually
1928 conn.ack("s2", "b").await.expect("ack failed");
1929
1930 // verify pending for s2 contains a and c
1931 {
1932 let p = pending.lock().await;
1933 let q = p.get("s2").expect("missing s2");
1934 assert_eq!(q.len(), 2);
1935 assert_eq!(q[0].0, "a");
1936 assert_eq!(q[1].0, "c");
1937 }
1938
1939 // verify an ACK frame was emitted
1940 if let Some(item) = out_rx.recv().await {
1941 match item {
1942 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1943 _ => panic!("expected frame"),
1944 }
1945 } else {
1946 panic!("no outbound frame sent")
1947 }
1948 }
1949
1950 #[tokio::test]
1951 async fn test_subscription_receive_delivers_message() {
1952 // setup channels
1953 let (out_tx, _out_rx) = mpsc::channel::<StompItem>(8);
1954 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1955 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1956
1957 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1958 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1959
1960 let sub_id_counter = Arc::new(AtomicU64::new(1));
1961
1962 let conn = Connection {
1963 outbound_tx: out_tx,
1964 inbound_rx: Arc::new(Mutex::new(in_rx)),
1965 shutdown_tx,
1966 subscriptions: subscriptions.clone(),
1967 sub_id_counter,
1968 pending: pending.clone(),
1969 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1970 };
1971
1972 // subscribe
1973 let subscription = conn
1974 .subscribe("/queue/test", AckMode::Auto)
1975 .await
1976 .expect("subscribe failed");
1977
1978 // find the sender stored in the subscriptions map and push a message
1979 {
1980 let map = conn.subscriptions.lock().await;
1981 let vec = map.get("/queue/test").expect("missing subscription vec");
1982 let sender = &vec[0].sender;
1983 let f = make_message("m1", Some(&vec[0].id), Some("/queue/test"));
1984 sender.try_send(f).expect("send to subscription failed");
1985 }
1986
1987 // consume from the subscription receiver
1988 let mut rx = subscription.into_receiver();
1989 if let Some(received) = rx.recv().await {
1990 assert_eq!(received.command, "MESSAGE");
1991 // message-id header should be present
1992 let mut found = false;
1993 for (k, _v) in &received.headers {
1994 if k.to_lowercase() == "message-id" {
1995 found = true;
1996 break;
1997 }
1998 }
1999 assert!(found, "message-id header missing");
2000 } else {
2001 panic!("no message received on subscription")
2002 }
2003 }
2004
2005 #[tokio::test]
2006 async fn test_subscription_ack_removes_pending_and_sends_ack() {
2007 // setup channels
2008 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
2009 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
2010 let (shutdown_tx, _) = broadcast::channel::<()>(1);
2011
2012 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
2013 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
2014
2015 let sub_id_counter = Arc::new(AtomicU64::new(1));
2016
2017 let conn = Connection {
2018 outbound_tx: out_tx,
2019 inbound_rx: Arc::new(Mutex::new(in_rx)),
2020 shutdown_tx,
2021 subscriptions: subscriptions.clone(),
2022 sub_id_counter,
2023 pending: pending.clone(),
2024 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
2025 };
2026
2027 // subscribe with client ack
2028 let subscription = conn
2029 .subscribe("/queue/ack", AckMode::Client)
2030 .await
2031 .expect("subscribe failed");
2032
2033 let sub_id = subscription.id().to_string();
2034
2035 // drain any initial outbound frames (SUBSCRIBE) emitted by subscribe()
2036 while out_rx.try_recv().is_ok() {}
2037
2038 // populate pending queue for this subscription
2039 {
2040 let mut p = conn.pending.lock().await;
2041 let mut q = VecDeque::new();
2042 q.push_back((
2043 "mid-1".to_string(),
2044 make_message("mid-1", Some(&sub_id), Some("/queue/ack")),
2045 ));
2046 p.insert(sub_id.clone(), q);
2047 }
2048
2049 // ack the message via the subscription helper
2050 subscription.ack("mid-1").await.expect("ack failed");
2051
2052 // ensure pending queue no longer contains the message
2053 {
2054 let p = conn.pending.lock().await;
2055 assert!(p.get(&sub_id).is_none() || p.get(&sub_id).unwrap().is_empty());
2056 }
2057
2058 // verify an ACK frame was emitted
2059 if let Some(item) = out_rx.recv().await {
2060 match item {
2061 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
2062 _ => panic!("expected frame"),
2063 }
2064 } else {
2065 panic!("no outbound frame sent")
2066 }
2067 }
2068
2069 // Helper function to create a test connection and output receiver
2070 fn setup_test_connection() -> (Connection, mpsc::Receiver<StompItem>) {
2071 let (out_tx, out_rx) = mpsc::channel::<StompItem>(8);
2072 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
2073 let (shutdown_tx, _) = broadcast::channel::<()>(1);
2074
2075 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
2076 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
2077 let sub_id_counter = Arc::new(AtomicU64::new(1));
2078
2079 let conn = Connection {
2080 outbound_tx: out_tx,
2081 inbound_rx: Arc::new(Mutex::new(in_rx)),
2082 shutdown_tx,
2083 subscriptions,
2084 sub_id_counter,
2085 pending,
2086 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
2087 };
2088
2089 (conn, out_rx)
2090 }
2091
2092 // Helper function to verify a frame with a transaction header
2093 fn verify_transaction_frame(frame: Frame, expected_command: &str, expected_tx_id: &str) {
2094 assert_eq!(frame.command, expected_command);
2095 assert!(
2096 frame
2097 .headers
2098 .iter()
2099 .any(|(k, v)| k == "transaction" && v == expected_tx_id),
2100 "transaction header with id '{}' not found",
2101 expected_tx_id
2102 );
2103 }
2104
2105 #[tokio::test]
2106 async fn test_begin_transaction_sends_frame() {
2107 let (conn, mut out_rx) = setup_test_connection();
2108
2109 conn.begin("tx1").await.expect("begin failed");
2110
2111 // verify BEGIN frame was emitted
2112 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
2113 verify_transaction_frame(f, "BEGIN", "tx1");
2114 } else {
2115 panic!("no outbound frame sent")
2116 }
2117 }
2118
2119 #[tokio::test]
2120 async fn test_commit_transaction_sends_frame() {
2121 let (conn, mut out_rx) = setup_test_connection();
2122
2123 conn.commit("tx1").await.expect("commit failed");
2124
2125 // verify COMMIT frame was emitted
2126 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
2127 verify_transaction_frame(f, "COMMIT", "tx1");
2128 } else {
2129 panic!("no outbound frame sent")
2130 }
2131 }
2132
2133 #[tokio::test]
2134 async fn test_abort_transaction_sends_frame() {
2135 let (conn, mut out_rx) = setup_test_connection();
2136
2137 conn.abort("tx1").await.expect("abort failed");
2138
2139 // verify ABORT frame was emitted
2140 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
2141 verify_transaction_frame(f, "ABORT", "tx1");
2142 } else {
2143 panic!("no outbound frame sent")
2144 }
2145 }
2146
2147 #[test]
2148 fn test_extract_destination_from_error_header() {
2149 // When ERROR frame has destination header, extract it directly
2150 let frame = Frame::new("ERROR")
2151 .header("message", "AMQ339016: Error creating STOMP subscription")
2152 .header("destination", "/topic/test.restricted");
2153
2154 let dest = extract_destination_from_error(&frame);
2155 assert_eq!(dest, Some("/topic/test.restricted".to_string()));
2156 }
2157
2158 #[test]
2159 fn test_extract_destination_from_error_message() {
2160 // When destination is in message header text
2161 let frame = Frame::new("ERROR").header(
2162 "message",
2163 "AMQ339016: Error creating subscription for /topic/test.restricted",
2164 );
2165
2166 let dest = extract_destination_from_error(&frame);
2167 assert_eq!(dest, Some("/topic/test.restricted".to_string()));
2168 }
2169
2170 #[test]
2171 fn test_extract_destination_from_error_body() {
2172 // When destination is in body text
2173 let frame = Frame::new("ERROR")
2174 .header("message", "AMQ339016: Error creating subscription")
2175 .set_body(b"User guest is not authorized for /queue/orders".to_vec());
2176
2177 let dest = extract_destination_from_error(&frame);
2178 assert_eq!(dest, Some("/queue/orders".to_string()));
2179 }
2180
2181 #[test]
2182 fn test_extract_destination_from_error_none() {
2183 // When no destination can be identified
2184 let frame = Frame::new("ERROR").header("message", "Generic error without destination info");
2185
2186 let dest = extract_destination_from_error(&frame);
2187 assert_eq!(dest, None);
2188 }
2189
2190 #[test]
2191 fn test_extract_destination_from_error_with_trailing_punct() {
2192 // When destination has trailing punctuation
2193 let frame = Frame::new("ERROR").header(
2194 "message",
2195 "Error for /topic/events, please check permissions",
2196 );
2197
2198 let dest = extract_destination_from_error(&frame);
2199 assert_eq!(dest, Some("/topic/events".to_string()));
2200 }
2201
2202 #[test]
2203 fn test_extract_subscription_id_from_error_artemis_format() {
2204 // Artemis format: "AMQ339016 Error creating subscription 1"
2205 let frame =
2206 Frame::new("ERROR").header("message", "AMQ339016 Error creating subscription 1");
2207
2208 let sub_id = extract_subscription_id_from_error(&frame);
2209 assert_eq!(sub_id, Some("1".to_string()));
2210 }
2211
2212 #[test]
2213 fn test_extract_subscription_id_from_error_numeric() {
2214 // Multiple digit subscription ID
2215 let frame = Frame::new("ERROR").header("message", "Error for subscription 123 on server");
2216
2217 let sub_id = extract_subscription_id_from_error(&frame);
2218 assert_eq!(sub_id, Some("123".to_string()));
2219 }
2220
2221 #[test]
2222 fn test_extract_subscription_id_from_error_none() {
2223 // No subscription ID in error
2224 let frame = Frame::new("ERROR").header("message", "Generic connection error");
2225
2226 let sub_id = extract_subscription_id_from_error(&frame);
2227 assert_eq!(sub_id, None);
2228 }
2229
2230 #[tokio::test]
2231 async fn test_lookup_destination_by_sub_id() {
2232 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
2233 let (sender, _rx) = mpsc::channel::<Frame>(4);
2234
2235 // Add a subscription
2236 {
2237 let mut map = subscriptions.lock().await;
2238 map.insert(
2239 "/topic/test.restricted".to_string(),
2240 vec![SubscriptionEntry {
2241 id: "1".to_string(),
2242 sender,
2243 ack: "auto".to_string(),
2244 headers: Vec::new(),
2245 }],
2246 );
2247 }
2248
2249 // Should find the destination
2250 let dest = lookup_destination_by_sub_id("1", &subscriptions).await;
2251 assert_eq!(dest, Some("/topic/test.restricted".to_string()));
2252
2253 // Should not find non-existent subscription
2254 let dest = lookup_destination_by_sub_id("999", &subscriptions).await;
2255 assert_eq!(dest, None);
2256 }
2257}