iridium_stomp/connection.rs
1use futures::{SinkExt, StreamExt, future};
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6use thiserror::Error;
7use tokio::net::TcpStream;
8use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
9use tokio_util::codec::Framed;
10
11use crate::codec::{StompCodec, StompItem};
12use crate::frame::Frame;
13
14/// Configuration for STOMP heartbeat intervals.
15///
16/// Provides a type-safe way to configure heartbeat values instead of using
17/// raw strings. The `Display` implementation formats the value as required
18/// by the STOMP protocol ("send_ms,receive_ms").
19///
20/// # Example
21///
22/// ```
23/// use iridium_stomp::Heartbeat;
24///
25/// // Create a custom heartbeat configuration
26/// let hb = Heartbeat::new(5000, 10000);
27/// assert_eq!(hb.to_string(), "5000,10000");
28///
29/// // Use predefined configurations
30/// assert_eq!(Heartbeat::disabled().to_string(), "0,0");
31/// assert_eq!(Heartbeat::default().to_string(), "10000,10000");
32/// ```
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct Heartbeat {
35 /// Minimum interval (in milliseconds) between heartbeats the client can send.
36 /// A value of 0 means the client cannot send heartbeats.
37 pub send_ms: u32,
38
39 /// Minimum interval (in milliseconds) between heartbeats the client wants to receive.
40 /// A value of 0 means the client does not want to receive heartbeats.
41 pub receive_ms: u32,
42}
43
44impl Heartbeat {
45 /// Create a new heartbeat configuration with the specified intervals.
46 ///
47 /// # Arguments
48 ///
49 /// * `send_ms` - Minimum interval in milliseconds between heartbeats the client can send.
50 /// * `receive_ms` - Minimum interval in milliseconds between heartbeats the client wants to receive.
51 ///
52 /// # Example
53 ///
54 /// ```
55 /// use iridium_stomp::Heartbeat;
56 ///
57 /// let hb = Heartbeat::new(5000, 10000);
58 /// assert_eq!(hb.send_ms, 5000);
59 /// assert_eq!(hb.receive_ms, 10000);
60 /// ```
61 pub fn new(send_ms: u32, receive_ms: u32) -> Self {
62 Self {
63 send_ms,
64 receive_ms,
65 }
66 }
67
68 /// Create a heartbeat configuration that disables heartbeats entirely.
69 ///
70 /// This is equivalent to `Heartbeat::new(0, 0)`.
71 ///
72 /// # Example
73 ///
74 /// ```
75 /// use iridium_stomp::Heartbeat;
76 ///
77 /// let hb = Heartbeat::disabled();
78 /// assert_eq!(hb.send_ms, 0);
79 /// assert_eq!(hb.receive_ms, 0);
80 /// assert_eq!(hb.to_string(), "0,0");
81 /// ```
82 pub fn disabled() -> Self {
83 Self::new(0, 0)
84 }
85
86 /// Create a heartbeat configuration from a Duration for symmetric heartbeats.
87 ///
88 /// Both send and receive intervals will be set to the same value.
89 ///
90 /// The maximum supported Duration is approximately 49.7 days (u32::MAX milliseconds,
91 /// or 4,294,967,295 ms). If a larger Duration is provided, it will be clamped to
92 /// u32::MAX milliseconds to prevent overflow.
93 ///
94 /// # Example
95 ///
96 /// ```
97 /// use iridium_stomp::Heartbeat;
98 /// use std::time::Duration;
99 ///
100 /// let hb = Heartbeat::from_duration(Duration::from_secs(15));
101 /// assert_eq!(hb.send_ms, 15000);
102 /// assert_eq!(hb.receive_ms, 15000);
103 /// ```
104 pub fn from_duration(interval: Duration) -> Self {
105 let ms = interval.as_millis().min(u32::MAX as u128) as u32;
106 Self::new(ms, ms)
107 }
108}
109
110impl Default for Heartbeat {
111 /// Returns the default heartbeat configuration: 10 seconds for both send and receive.
112 fn default() -> Self {
113 Self::new(10000, 10000)
114 }
115}
116
117impl std::fmt::Display for Heartbeat {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 write!(f, "{},{}", self.send_ms, self.receive_ms)
120 }
121}
122
123/// Internal subscription entry stored for each destination.
124#[derive(Clone)]
125pub(crate) struct SubscriptionEntry {
126 pub(crate) id: String,
127 pub(crate) sender: mpsc::Sender<Frame>,
128 pub(crate) ack: String,
129 pub(crate) headers: Vec<(String, String)>,
130}
131
132/// Alias for the subscription dispatch map: destination -> list of
133/// `SubscriptionEntry`.
134pub(crate) type Subscriptions = HashMap<String, Vec<SubscriptionEntry>>;
135
136/// Alias for the pending map: subscription_id -> queue of (message-id, Frame).
137pub(crate) type PendingMap = HashMap<String, VecDeque<(String, Frame)>>;
138
139/// Internal type for resubscribe snapshot entries: (destination, id, ack, headers)
140pub(crate) type ResubEntry = (String, String, String, Vec<(String, String)>);
141
142/// Alias for pending receipt map: receipt-id -> oneshot sender to notify when received.
143pub(crate) type PendingReceipts = HashMap<String, oneshot::Sender<()>>;
144
145/// Errors returned by `Connection` operations.
146#[derive(Error, Debug)]
147pub enum ConnError {
148 /// I/O-level error
149 #[error("io error: {0}")]
150 Io(#[from] std::io::Error),
151 /// Protocol-level error
152 #[error("protocol error: {0}")]
153 Protocol(String),
154 /// Receipt timeout error
155 #[error("receipt timeout: no RECEIPT received for '{0}' within timeout")]
156 ReceiptTimeout(String),
157 /// Server rejected the connection (e.g., authentication failure)
158 ///
159 /// This error is returned when the server sends an ERROR frame in response
160 /// to the CONNECT frame. Common causes include invalid credentials,
161 /// unauthorized access, or broker configuration issues.
162 #[error("server rejected connection: {0}")]
163 ServerRejected(ServerError),
164}
165
166/// Represents an ERROR frame received from the STOMP server.
167///
168/// STOMP servers send ERROR frames to indicate protocol violations, authentication
169/// failures, or other server-side errors. After sending an ERROR frame, the server
170/// typically closes the connection.
171///
172/// # Example
173///
174/// ```ignore
175/// use iridium_stomp::ReceivedFrame;
176///
177/// while let Some(received) = conn.next_frame().await {
178/// match received {
179/// ReceivedFrame::Frame(frame) => {
180/// // Normal message processing
181/// }
182/// ReceivedFrame::Error(err) => {
183/// eprintln!("Server error: {}", err.message);
184/// if let Some(body) = &err.body {
185/// eprintln!("Details: {}", body);
186/// }
187/// break;
188/// }
189/// }
190/// }
191/// ```
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub struct ServerError {
194 /// The error message from the `message` header.
195 pub message: String,
196
197 /// The error body, if present. Contains additional error details.
198 pub body: Option<String>,
199
200 /// The receipt-id if this error is in response to a specific frame.
201 pub receipt_id: Option<String>,
202
203 /// The original ERROR frame for access to additional headers.
204 pub frame: Frame,
205}
206
207impl ServerError {
208 /// Create a `ServerError` from an ERROR frame.
209 ///
210 /// This is primarily used internally but is public for testing and
211 /// advanced use cases where you need to construct a `ServerError` manually.
212 pub fn from_frame(frame: Frame) -> Self {
213 let message = frame
214 .get_header("message")
215 .unwrap_or("unknown error")
216 .to_string();
217
218 let body = if frame.body.is_empty() {
219 None
220 } else {
221 String::from_utf8(frame.body.clone()).ok()
222 };
223
224 let receipt_id = frame.get_header("receipt-id").map(|s| s.to_string());
225
226 Self {
227 message,
228 body,
229 receipt_id,
230 frame,
231 }
232 }
233}
234
235impl std::fmt::Display for ServerError {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 write!(f, "STOMP server error: {}", self.message)?;
238 if let Some(body) = &self.body {
239 write!(f, " - {}", body)?;
240 }
241 Ok(())
242 }
243}
244
245impl std::error::Error for ServerError {}
246
247/// The result of receiving a frame from the server.
248///
249/// STOMP servers can send either normal frames (MESSAGE, RECEIPT, etc.) or
250/// ERROR frames indicating a problem. This enum allows callers to handle
251/// both cases with pattern matching.
252///
253/// # Example
254///
255/// ```ignore
256/// use iridium_stomp::ReceivedFrame;
257///
258/// match conn.next_frame().await {
259/// Some(ReceivedFrame::Frame(frame)) => {
260/// println!("Got frame: {}", frame.command);
261/// }
262/// Some(ReceivedFrame::Error(err)) => {
263/// eprintln!("Server error: {}", err);
264/// }
265/// None => {
266/// println!("Connection closed");
267/// }
268/// }
269/// ```
270#[derive(Debug, Clone, PartialEq, Eq)]
271pub enum ReceivedFrame {
272 /// A normal STOMP frame (MESSAGE, RECEIPT, etc.)
273 Frame(Frame),
274 /// An ERROR frame from the server
275 Error(ServerError),
276}
277
278impl ReceivedFrame {
279 /// Returns `true` if this is an error frame.
280 pub fn is_error(&self) -> bool {
281 matches!(self, ReceivedFrame::Error(_))
282 }
283
284 /// Returns `true` if this is a normal frame.
285 pub fn is_frame(&self) -> bool {
286 matches!(self, ReceivedFrame::Frame(_))
287 }
288
289 /// Returns the frame if this is a normal frame, or `None` if it's an error.
290 pub fn into_frame(self) -> Option<Frame> {
291 match self {
292 ReceivedFrame::Frame(f) => Some(f),
293 ReceivedFrame::Error(_) => None,
294 }
295 }
296
297 /// Returns the error if this is an error frame, or `None` if it's a normal frame.
298 pub fn into_error(self) -> Option<ServerError> {
299 match self {
300 ReceivedFrame::Frame(_) => None,
301 ReceivedFrame::Error(e) => Some(e),
302 }
303 }
304}
305
306/// Subscription acknowledgement modes as defined by STOMP 1.2.
307#[derive(Debug, Clone, Copy, PartialEq, Eq)]
308pub enum AckMode {
309 Auto,
310 Client,
311 ClientIndividual,
312}
313
314impl AckMode {
315 fn as_str(&self) -> &'static str {
316 match self {
317 AckMode::Auto => "auto",
318 AckMode::Client => "client",
319 AckMode::ClientIndividual => "client-individual",
320 }
321 }
322}
323
324/// Options for customizing the STOMP CONNECT frame.
325///
326/// Use this struct with `Connection::connect_with_options()` to set custom
327/// headers, specify supported STOMP versions, or configure broker-specific
328/// options like `client-id` for durable subscriptions.
329///
330/// # Validation
331///
332/// This struct performs minimal validation. Values are passed to the broker
333/// as-is, and invalid configurations will be rejected by the broker at
334/// connection time. Empty strings are technically accepted but may cause
335/// broker-specific errors.
336///
337/// # Custom Headers
338///
339/// Custom headers added via `header()` cannot override critical STOMP headers
340/// (`accept-version`, `host`, `login`, `passcode`, `heart-beat`, `client-id`).
341/// Such headers are silently ignored. Use the dedicated builder methods to
342/// set these values.
343///
344/// # Example
345///
346/// ```ignore
347/// use iridium_stomp::{Connection, ConnectOptions};
348///
349/// let options = ConnectOptions::default()
350/// .client_id("my-durable-client")
351/// .host("my-vhost")
352/// .header("custom-header", "value");
353///
354/// let conn = Connection::connect_with_options(
355/// "localhost:61613",
356/// "guest",
357/// "guest",
358/// "10000,10000",
359/// options,
360/// ).await?;
361/// ```
362#[derive(Clone, Default)]
363pub struct ConnectOptions {
364 /// STOMP version(s) to accept (e.g., "1.2" or "1.0,1.1,1.2").
365 /// Defaults to "1.2" if not set.
366 pub accept_version: Option<String>,
367
368 /// Client ID for durable subscriptions (required by ActiveMQ, etc.).
369 pub client_id: Option<String>,
370
371 /// Virtual host header value. Defaults to "/" if not set.
372 pub host: Option<String>,
373
374 /// Additional custom headers to include in the CONNECT frame.
375 /// Note: Headers that would override critical STOMP headers are ignored.
376 pub headers: Vec<(String, String)>,
377
378 /// Optional channel to receive heartbeat notifications.
379 /// When set, the connection will send a `()` on this channel each time
380 /// a heartbeat is received from the server.
381 pub heartbeat_tx: Option<mpsc::Sender<()>>,
382}
383
384impl std::fmt::Debug for ConnectOptions {
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 f.debug_struct("ConnectOptions")
387 .field("accept_version", &self.accept_version)
388 .field("client_id", &self.client_id)
389 .field("host", &self.host)
390 .field("headers", &self.headers)
391 .field(
392 "heartbeat_tx",
393 &self.heartbeat_tx.as_ref().map(|_| "Some(...)"),
394 )
395 .finish()
396 }
397}
398
399impl ConnectOptions {
400 /// Create a new `ConnectOptions` with default values.
401 pub fn new() -> Self {
402 Self::default()
403 }
404
405 /// Set the STOMP version(s) to accept (builder style).
406 ///
407 /// Examples: "1.2", "1.1,1.2", "1.0,1.1,1.2"
408 pub fn accept_version(mut self, version: impl Into<String>) -> Self {
409 self.accept_version = Some(version.into());
410 self
411 }
412
413 /// Set the client ID for durable subscriptions (builder style).
414 ///
415 /// Required by some brokers (e.g., ActiveMQ) for durable topic subscriptions.
416 pub fn client_id(mut self, id: impl Into<String>) -> Self {
417 self.client_id = Some(id.into());
418 self
419 }
420
421 /// Set the virtual host (builder style).
422 ///
423 /// Defaults to "/" if not set.
424 pub fn host(mut self, host: impl Into<String>) -> Self {
425 self.host = Some(host.into());
426 self
427 }
428
429 /// Add a custom header to the CONNECT frame (builder style).
430 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
431 self.headers.push((key.into(), value.into()));
432 self
433 }
434
435 /// Set a channel to receive heartbeat notifications (builder style).
436 ///
437 /// When set, the connection will send a `()` on this channel each time
438 /// a heartbeat is received from the server. This is useful for CLI tools
439 /// or monitoring applications that want to display heartbeat status.
440 ///
441 /// # Note
442 ///
443 /// Notifications are sent using `try_send()` to avoid blocking the
444 /// connection's background task. If the channel buffer is full,
445 /// notifications will be silently dropped. Use a sufficiently sized
446 /// channel buffer (e.g., 16) to avoid missing notifications.
447 ///
448 /// # Example
449 ///
450 /// ```ignore
451 /// use tokio::sync::mpsc;
452 /// use iridium_stomp::ConnectOptions;
453 ///
454 /// let (tx, mut rx) = mpsc::channel(16);
455 /// let options = ConnectOptions::default()
456 /// .with_heartbeat_notify(tx);
457 ///
458 /// // In another task:
459 /// while rx.recv().await.is_some() {
460 /// println!("Heartbeat received!");
461 /// }
462 /// ```
463 pub fn with_heartbeat_notify(mut self, tx: mpsc::Sender<()>) -> Self {
464 self.heartbeat_tx = Some(tx);
465 self
466 }
467}
468
469/// Parse the STOMP `heart-beat` header value (format: "cx,cy").
470///
471/// Parameters
472/// - `header`: header string from the server or client (for example
473/// "10000,10000"). The values represent milliseconds.
474///
475/// Returns a tuple `(cx, cy)` where each value is the heartbeat interval in
476/// milliseconds. Missing or invalid fields default to `0`.
477pub fn parse_heartbeat_header(header: &str) -> (u64, u64) {
478 let mut parts = header.split(',');
479 let cx = parts
480 .next()
481 .and_then(|s| s.trim().parse::<u64>().ok())
482 .unwrap_or(0);
483 let cy = parts
484 .next()
485 .and_then(|s| s.trim().parse::<u64>().ok())
486 .unwrap_or(0);
487 (cx, cy)
488}
489
490/// Negotiate heartbeat intervals between client and server.
491///
492/// Parameters
493/// - `client_out`: client's desired outgoing heartbeat interval in
494/// milliseconds (how often the client will send heartbeats).
495/// - `client_in`: client's desired incoming heartbeat interval in
496/// milliseconds (how often the client expects to receive heartbeats).
497/// - `server_out`: server's advertised outgoing interval in milliseconds.
498/// - `server_in`: server's advertised incoming interval in milliseconds.
499///
500/// Returns `(outgoing, incoming)` where each element is `Some(Duration)` if
501/// heartbeats are enabled in that direction, or `None` if disabled. The
502/// negotiated interval uses the STOMP rule of taking the maximum of the
503/// corresponding client and server values.
504pub fn negotiate_heartbeats(
505 client_out: u64,
506 client_in: u64,
507 server_out: u64,
508 server_in: u64,
509) -> (Option<Duration>, Option<Duration>) {
510 let negotiated_out_ms = std::cmp::max(client_out, server_in);
511 let negotiated_in_ms = std::cmp::max(client_in, server_out);
512
513 let outgoing = if negotiated_out_ms == 0 {
514 None
515 } else {
516 Some(Duration::from_millis(negotiated_out_ms))
517 };
518 let incoming = if negotiated_in_ms == 0 {
519 None
520 } else {
521 Some(Duration::from_millis(negotiated_in_ms))
522 };
523 (outgoing, incoming)
524}
525
526/// High-level connection object that manages a single TCP/STOMP connection.
527///
528/// The `Connection` spawns a background task that maintains the TCP transport,
529/// sends/receives STOMP frames using `StompCodec`, negotiates heartbeats, and
530/// performs simple reconnect logic with exponential backoff.
531#[derive(Clone)]
532pub struct Connection {
533 outbound_tx: mpsc::Sender<StompItem>,
534 /// The inbound receiver is shared behind a mutex so the `Connection`
535 /// handle may be cloned and callers can call `next_frame` concurrently.
536 inbound_rx: Arc<Mutex<mpsc::Receiver<Frame>>>,
537 shutdown_tx: broadcast::Sender<()>,
538 /// Map of destination -> list of (subscription id, sender) for dispatching
539 /// inbound MESSAGE frames to subscribers.
540 subscriptions: Arc<Mutex<Subscriptions>>,
541 /// Monotonic counter used to allocate subscription ids.
542 sub_id_counter: Arc<AtomicU64>,
543 /// Pending messages awaiting ACK/NACK from the application.
544 ///
545 /// Organized by subscription id. For `client` ack mode the ACK is
546 /// cumulative: acknowledging message `M` for subscription `S` acknowledges
547 /// all messages previously delivered for `S` up to and including `M`.
548 /// For `client-individual` the ACK/NACK applies only to the single
549 /// message.
550 pending: Arc<Mutex<PendingMap>>,
551 /// Pending receipt confirmations.
552 ///
553 /// When a frame is sent with a `receipt` header, the receipt-id is stored
554 /// here with a oneshot sender. When the server responds with a RECEIPT
555 /// frame, the sender is notified.
556 pending_receipts: Arc<Mutex<PendingReceipts>>,
557}
558
559impl Connection {
560 /// Heartbeat value that disables heartbeats entirely.
561 ///
562 /// Use this when you don't want the client or server to send heartbeats.
563 /// Note that some brokers may still require heartbeats for long-lived connections.
564 ///
565 /// # Example
566 ///
567 /// ```ignore
568 /// let conn = Connection::connect(
569 /// "localhost:61613",
570 /// "guest",
571 /// "guest",
572 /// Connection::NO_HEARTBEAT,
573 /// ).await?;
574 /// ```
575 pub const NO_HEARTBEAT: &'static str = "0,0";
576
577 /// Default heartbeat value: 10 seconds for both send and receive.
578 ///
579 /// This is a reasonable default for most applications. The actual heartbeat
580 /// interval will be negotiated with the server (taking the maximum of client
581 /// and server preferences).
582 ///
583 /// # Example
584 ///
585 /// ```ignore
586 /// let conn = Connection::connect(
587 /// "localhost:61613",
588 /// "guest",
589 /// "guest",
590 /// Connection::DEFAULT_HEARTBEAT,
591 /// ).await?;
592 /// ```
593 pub const DEFAULT_HEARTBEAT: &'static str = "10000,10000";
594
595 /// Establish a connection to the STOMP server at `addr` with the given
596 /// credentials and heartbeat header string (e.g. "10000,10000").
597 ///
598 /// This is a convenience wrapper around `connect_with_options()` that uses
599 /// default options (STOMP 1.2, host="/", no client-id).
600 ///
601 /// Parameters
602 /// - `addr`: TCP address (host:port) of the STOMP server.
603 /// - `login`: login username for STOMP `CONNECT`.
604 /// - `passcode`: passcode for STOMP `CONNECT`.
605 /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
606 /// milliseconds) that will be sent in the `CONNECT` frame.
607 ///
608 /// Returns a `Connection` which provides `send_frame`, `next_frame`, and
609 /// `close` helpers. The detailed connection handling (I/O, heartbeats,
610 /// reconnects) runs on a background task spawned by this method.
611 pub async fn connect(
612 addr: &str,
613 login: &str,
614 passcode: &str,
615 client_hb: &str,
616 ) -> Result<Self, ConnError> {
617 Self::connect_with_options(addr, login, passcode, client_hb, ConnectOptions::default())
618 .await
619 }
620
621 /// Establish a connection to the STOMP server with custom options.
622 ///
623 /// Use this method when you need to set a custom `client-id` (for durable
624 /// subscriptions), specify a virtual host, negotiate different STOMP
625 /// versions, or add custom CONNECT headers.
626 ///
627 /// Parameters
628 /// - `addr`: TCP address (host:port) of the STOMP server.
629 /// - `login`: login username for STOMP `CONNECT`.
630 /// - `passcode`: passcode for STOMP `CONNECT`.
631 /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
632 /// milliseconds) that will be sent in the `CONNECT` frame.
633 /// - `options`: custom connection options (version, host, client-id, etc.).
634 ///
635 /// # Errors
636 ///
637 /// Returns an error if:
638 /// - The TCP connection cannot be established (`ConnError::Io`)
639 /// - The server rejects the connection, e.g., due to invalid credentials
640 /// (`ConnError::ServerRejected`)
641 /// - The server closes the connection without responding (`ConnError::Protocol`)
642 ///
643 /// # Example
644 ///
645 /// ```ignore
646 /// use iridium_stomp::{Connection, ConnectOptions};
647 ///
648 /// // Connect with a client-id for durable subscriptions
649 /// let options = ConnectOptions::default()
650 /// .client_id("my-app-instance-1");
651 ///
652 /// let conn = Connection::connect_with_options(
653 /// "localhost:61613",
654 /// "guest",
655 /// "guest",
656 /// "10000,10000",
657 /// options,
658 /// ).await?;
659 /// ```
660 pub async fn connect_with_options(
661 addr: &str,
662 login: &str,
663 passcode: &str,
664 client_hb: &str,
665 options: ConnectOptions,
666 ) -> Result<Self, ConnError> {
667 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(32);
668 let (in_tx, in_rx) = mpsc::channel::<Frame>(32);
669 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
670 let sub_id_counter = Arc::new(AtomicU64::new(1));
671 let (shutdown_tx, _) = broadcast::channel::<()>(1);
672 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
673 let pending_clone = pending.clone();
674 let pending_receipts: Arc<Mutex<PendingReceipts>> = Arc::new(Mutex::new(HashMap::new()));
675 let pending_receipts_clone = pending_receipts.clone();
676
677 let addr = addr.to_string();
678 let login = login.to_string();
679 let passcode = passcode.to_string();
680 let client_hb = client_hb.to_string();
681
682 // Extract options into owned values for the spawned task
683 let accept_version = options.accept_version.unwrap_or_else(|| "1.2".to_string());
684 let host = options.host.unwrap_or_else(|| "/".to_string());
685 let client_id = options.client_id;
686 let custom_headers = options.headers;
687 let heartbeat_notify_tx = options.heartbeat_tx;
688
689 // Perform initial connection and STOMP handshake before spawning background task.
690 // This ensures authentication errors are returned to the caller immediately.
691 let stream = TcpStream::connect(&addr).await?;
692 let mut framed = Framed::new(stream, StompCodec::new());
693
694 // Build and send CONNECT frame
695 let connect = Self::build_connect_frame(
696 &accept_version,
697 &host,
698 &login,
699 &passcode,
700 &client_hb,
701 &client_id,
702 &custom_headers,
703 );
704
705 framed
706 .send(StompItem::Frame(connect))
707 .await
708 .map_err(|e| ConnError::Io(std::io::Error::other(e)))?;
709
710 // Wait for CONNECTED or ERROR response
711 let server_heartbeat = Self::await_connected_response(&mut framed).await?;
712
713 // Calculate heartbeat intervals
714 let (cx, cy) = parse_heartbeat_header(&client_hb);
715 let (sx, sy) = parse_heartbeat_header(&server_heartbeat);
716 let (send_interval, recv_interval) = negotiate_heartbeats(cx, cy, sx, sy);
717
718 // Now spawn background task for ongoing I/O and reconnection
719 let shutdown_tx_clone = shutdown_tx.clone();
720 let subscriptions_clone = subscriptions.clone();
721
722 tokio::spawn(async move {
723 let mut backoff_secs: u64 = 1;
724
725 // Use the already-established connection for the first iteration
726 let mut current_framed = Some(framed);
727 let mut current_send_interval = send_interval;
728 let mut current_recv_interval = recv_interval;
729
730 loop {
731 let mut shutdown_sub = shutdown_tx_clone.subscribe();
732
733 // Check for shutdown before attempting connection
734 tokio::select! {
735 biased;
736 _ = shutdown_sub.recv() => break,
737 _ = future::ready(()) => {},
738 }
739
740 // Either use existing connection or establish new one (reconnect)
741 let framed = if let Some(f) = current_framed.take() {
742 f
743 } else {
744 // Reconnection attempt
745 match TcpStream::connect(&addr).await {
746 Ok(stream) => {
747 let mut framed = Framed::new(stream, StompCodec::new());
748
749 let connect = Self::build_connect_frame(
750 &accept_version,
751 &host,
752 &login,
753 &passcode,
754 &client_hb,
755 &client_id,
756 &custom_headers,
757 );
758
759 if framed.send(StompItem::Frame(connect)).await.is_err() {
760 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
761 backoff_secs = (backoff_secs * 2).min(30);
762 continue;
763 }
764
765 // Wait for CONNECTED (on reconnect, silently retry on ERROR)
766 match Self::await_connected_response(&mut framed).await {
767 Ok(server_hb) => {
768 let (cx, cy) = parse_heartbeat_header(&client_hb);
769 let (sx, sy) = parse_heartbeat_header(&server_hb);
770 let (si, ri) = negotiate_heartbeats(cx, cy, sx, sy);
771 current_send_interval = si;
772 current_recv_interval = ri;
773 framed
774 }
775 Err(_) => {
776 // Reconnect failed (auth error or other), retry with backoff
777 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
778 backoff_secs = (backoff_secs * 2).min(30);
779 continue;
780 }
781 }
782 }
783 Err(_) => {
784 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
785 backoff_secs = (backoff_secs * 2).min(30);
786 continue;
787 }
788 }
789 };
790
791 let (send_interval, recv_interval) = (current_send_interval, current_recv_interval);
792
793 let last_received = Arc::new(AtomicU64::new(current_millis()));
794 let writer_last_sent = Arc::new(AtomicU64::new(current_millis()));
795
796 let (mut sink, mut stream) = framed.split();
797 let in_tx = in_tx.clone();
798 let subscriptions = subscriptions_clone.clone();
799
800 // Clear pending message map on reconnect — messages that were
801 // outstanding before the disconnect are considered lost and
802 // will be redelivered by the server as appropriate.
803 {
804 let mut p = pending_clone.lock().await;
805 p.clear();
806 }
807
808 // Resubscribe any existing subscriptions after reconnect.
809 // We snapshot the subscription entries while holding the lock
810 // and then issue SUBSCRIBE frames using the sink.
811 let subs_snapshot: Vec<ResubEntry> = {
812 let map = subscriptions.lock().await;
813 let mut v: Vec<ResubEntry> = Vec::new();
814 for (dest, vec) in map.iter() {
815 for entry in vec.iter() {
816 v.push((
817 dest.clone(),
818 entry.id.clone(),
819 entry.ack.clone(),
820 entry.headers.clone(),
821 ));
822 }
823 }
824 v
825 };
826
827 for (dest, id, ack, headers) in subs_snapshot {
828 let mut sf = Frame::new("SUBSCRIBE");
829 sf = sf
830 .header("id", &id)
831 .header("destination", &dest)
832 .header("ack", &ack);
833 for (k, v) in headers {
834 sf = sf.header(&k, &v);
835 }
836 let _ = sink.send(StompItem::Frame(sf)).await;
837 }
838
839 let mut hb_tick = match send_interval {
840 Some(d) => tokio::time::interval(d),
841 None => tokio::time::interval(Duration::from_secs(86400)),
842 };
843 let watchdog_half = recv_interval.map(|d| d / 2);
844
845 backoff_secs = 1;
846
847 'conn: loop {
848 tokio::select! {
849 _ = shutdown_sub.recv() => { let _ = sink.close().await; break 'conn; }
850 maybe = out_rx.recv() => {
851 match maybe {
852 Some(item) => if sink.send(item).await.is_err() { break 'conn } else { writer_last_sent.store(current_millis(), Ordering::SeqCst); }
853 None => break 'conn,
854 }
855 }
856 item = stream.next() => {
857 match item {
858 Some(Ok(StompItem::Heartbeat)) => {
859 last_received.store(current_millis(), Ordering::SeqCst);
860 if let Some(ref tx) = heartbeat_notify_tx {
861 let _ = tx.try_send(());
862 }
863 }
864 Some(Ok(StompItem::Frame(f))) => {
865 last_received.store(current_millis(), Ordering::SeqCst);
866 // Dispatch MESSAGE frames to any matching subscribers.
867 if f.command == "MESSAGE" {
868 // try to find destination, subscription and message-id headers
869 let mut dest_opt: Option<String> = None;
870 let mut sub_opt: Option<String> = None;
871 let mut msg_id_opt: Option<String> = None;
872 for (k, v) in &f.headers {
873 let kl = k.to_lowercase();
874 if kl == "destination" {
875 dest_opt = Some(v.clone());
876 } else if kl == "subscription" {
877 sub_opt = Some(v.clone());
878 } else if kl == "message-id" {
879 msg_id_opt = Some(v.clone());
880 }
881 }
882
883 // Determine whether we need to track this message as pending
884 let mut need_pending = false;
885 if let Some(sub_id) = &sub_opt {
886 let map = subscriptions.lock().await;
887 for (_dest, vec) in map.iter() {
888 for entry in vec.iter() {
889 if &entry.id == sub_id && entry.ack != "auto" {
890 need_pending = true;
891 }
892 }
893 }
894 } else if let Some(dest) = &dest_opt {
895 let map = subscriptions.lock().await;
896 if let Some(vec) = map.get(dest) {
897 for entry in vec.iter() {
898 if entry.ack != "auto" {
899 need_pending = true;
900 break;
901 }
902 }
903 }
904 }
905
906 // If required, add to pending map (per-subscription) before
907 // delivery so ACK/NACK requests from the application can
908 // reference the message. We require a `message-id` header
909 // to track messages; if missing, we cannot support ACK/NACK.
910 if let Some(msg_id) = msg_id_opt.clone().filter(|_| need_pending) {
911 // If the server provided a subscription id in the
912 // MESSAGE, store pending under that subscription.
913 if let Some(sub_id) = &sub_opt {
914 let mut p = pending_clone.lock().await;
915 let q = p
916 .entry(sub_id.clone())
917 .or_insert_with(VecDeque::new);
918 q.push_back((msg_id.clone(), f.clone()));
919 } else if let Some(dest) = &dest_opt {
920 // Destination-based delivery: add the message to
921 // the pending queue for each matching
922 // subscription on that destination.
923 let map = subscriptions.lock().await;
924 if let Some(vec) = map.get(dest) {
925 let mut p = pending_clone.lock().await;
926 for entry in vec.iter() {
927 let q = p
928 .entry(entry.id.clone())
929 .or_insert_with(VecDeque::new);
930 q.push_back((msg_id.clone(), f.clone()));
931 }
932 }
933 }
934 }
935
936 // Deliver to subscribers.
937 if let Some(sub_id) = sub_opt {
938 let mut map = subscriptions.lock().await;
939 for (_dest, vec) in map.iter_mut() {
940 vec.retain(|entry| {
941 if entry.id == sub_id {
942 let _ = entry.sender.try_send(f.clone());
943 true
944 } else {
945 true
946 }
947 });
948 }
949 } else if let Some(dest) = dest_opt {
950 let mut map = subscriptions.lock().await;
951 if let Some(vec) = map.get_mut(&dest) {
952 vec.retain(|entry| entry.sender.try_send(f.clone()).is_ok());
953 }
954 }
955 } else if f.command == "RECEIPT" {
956 // Handle RECEIPT frame: notify any waiting callers
957 if let Some(receipt_id) = f.get_header("receipt-id") {
958 let mut receipts = pending_receipts_clone.lock().await;
959 if let Some(sender) = receipts.remove(receipt_id) {
960 let _ = sender.send(());
961 }
962 }
963 // Don't forward RECEIPT frames to inbound channel
964 continue;
965 }
966
967 let _ = in_tx.send(f).await;
968 }
969 Some(Err(_)) | None => break 'conn,
970 }
971 }
972 _ = hb_tick.tick() => {
973 if let Some(dur) = send_interval {
974 let last = writer_last_sent.load(Ordering::SeqCst);
975 if current_millis().saturating_sub(last) >= dur.as_millis() as u64 {
976 if sink.send(StompItem::Heartbeat).await.is_err() { break 'conn; }
977 writer_last_sent.store(current_millis(), Ordering::SeqCst);
978 }
979 }
980 }
981 _ = async { if let Some(interval) = watchdog_half { tokio::time::sleep(interval).await } else { future::pending::<()>().await } } => {
982 if let Some(recv_dur) = recv_interval {
983 let last = last_received.load(Ordering::SeqCst);
984 if current_millis().saturating_sub(last) > (recv_dur.as_millis() as u64 * 2) {
985 let _ = sink.close().await; break 'conn;
986 }
987 }
988 }
989 }
990 }
991
992 if shutdown_sub.try_recv().is_ok() {
993 break;
994 }
995 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
996 backoff_secs = (backoff_secs * 2).min(30);
997 }
998 });
999
1000 Ok(Connection {
1001 outbound_tx: out_tx,
1002 inbound_rx: Arc::new(Mutex::new(in_rx)),
1003 shutdown_tx,
1004 subscriptions,
1005 sub_id_counter,
1006 pending,
1007 pending_receipts,
1008 })
1009 }
1010
1011 /// Build a CONNECT frame with all specified headers.
1012 fn build_connect_frame(
1013 accept_version: &str,
1014 host: &str,
1015 login: &str,
1016 passcode: &str,
1017 heartbeat: &str,
1018 client_id: &Option<String>,
1019 custom_headers: &[(String, String)],
1020 ) -> Frame {
1021 let mut connect = Frame::new("CONNECT")
1022 .header("accept-version", accept_version)
1023 .header("host", host)
1024 .header("login", login)
1025 .header("passcode", passcode)
1026 .header("heart-beat", heartbeat);
1027
1028 if let Some(id) = client_id {
1029 connect = connect.header("client-id", id);
1030 }
1031
1032 // Reserved headers that custom_headers cannot override
1033 let reserved = [
1034 "accept-version",
1035 "host",
1036 "login",
1037 "passcode",
1038 "heart-beat",
1039 "client-id",
1040 ];
1041
1042 for (k, v) in custom_headers {
1043 if !reserved.contains(&k.to_lowercase().as_str()) {
1044 connect = connect.header(k, v);
1045 }
1046 }
1047
1048 connect
1049 }
1050
1051 /// Wait for CONNECTED or ERROR response from the server.
1052 ///
1053 /// Returns the server's heartbeat header value on success, or an error
1054 /// if the server sends an ERROR frame or closes the connection.
1055 async fn await_connected_response(
1056 framed: &mut Framed<TcpStream, StompCodec>,
1057 ) -> Result<String, ConnError> {
1058 loop {
1059 match framed.next().await {
1060 Some(Ok(StompItem::Frame(f))) => {
1061 if f.command == "CONNECTED" {
1062 // Extract heartbeat from server
1063 let server_hb = f.get_header("heart-beat").unwrap_or("0,0").to_string();
1064 return Ok(server_hb);
1065 } else if f.command == "ERROR" {
1066 // Server rejected connection (e.g., invalid credentials)
1067 return Err(ConnError::ServerRejected(ServerError::from_frame(f)));
1068 }
1069 // Ignore other frames during CONNECT phase
1070 }
1071 Some(Ok(StompItem::Heartbeat)) => {
1072 // Ignore heartbeats during handshake
1073 continue;
1074 }
1075 Some(Err(e)) => {
1076 return Err(ConnError::Io(e));
1077 }
1078 None => {
1079 return Err(ConnError::Protocol(
1080 "connection closed before CONNECTED received".to_string(),
1081 ));
1082 }
1083 }
1084 }
1085 }
1086
1087 pub async fn send_frame(&self, frame: Frame) -> Result<(), ConnError> {
1088 // Send a frame to the background writer task.
1089 //
1090 // Parameters
1091 // - `frame`: ownership of the `Frame` to send. The frame is converted
1092 // into a `StompItem::Frame` and sent over the internal mpsc channel.
1093 self.outbound_tx
1094 .send(StompItem::Frame(frame))
1095 .await
1096 .map_err(|_| ConnError::Protocol("send channel closed".into()))
1097 }
1098
1099 /// Generate a unique receipt ID.
1100 fn generate_receipt_id() -> String {
1101 static RECEIPT_COUNTER: AtomicU64 = AtomicU64::new(1);
1102 format!("rcpt-{}", RECEIPT_COUNTER.fetch_add(1, Ordering::SeqCst))
1103 }
1104
1105 /// Send a frame with a receipt request and return the receipt ID.
1106 ///
1107 /// This method adds a unique `receipt` header to the frame and registers
1108 /// the receipt ID for tracking. Use `wait_for_receipt()` to wait for the
1109 /// server's RECEIPT response.
1110 ///
1111 /// # Parameters
1112 /// - `frame`: the frame to send. A `receipt` header will be added.
1113 ///
1114 /// # Returns
1115 /// The generated receipt ID that can be used with `wait_for_receipt()`.
1116 ///
1117 /// # Example
1118 /// ```ignore
1119 /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
1120 /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
1121 /// ```
1122 pub async fn send_frame_with_receipt(&self, frame: Frame) -> Result<String, ConnError> {
1123 let receipt_id = Self::generate_receipt_id();
1124
1125 // Create the oneshot channel for notification
1126 let (tx, _rx) = oneshot::channel();
1127
1128 // Register the pending receipt
1129 {
1130 let mut receipts = self.pending_receipts.lock().await;
1131 receipts.insert(receipt_id.clone(), tx);
1132 }
1133
1134 // Add receipt header and send the frame
1135 let frame_with_receipt = frame.receipt(&receipt_id);
1136 self.send_frame(frame_with_receipt).await?;
1137
1138 Ok(receipt_id)
1139 }
1140
1141 /// Wait for a receipt confirmation from the server.
1142 ///
1143 /// This method blocks until the server sends a RECEIPT frame with the
1144 /// matching receipt-id, or until the timeout expires.
1145 ///
1146 /// # Parameters
1147 /// - `receipt_id`: the receipt ID returned by `send_frame_with_receipt()`.
1148 /// - `timeout`: maximum time to wait for the receipt.
1149 ///
1150 /// # Returns
1151 /// `Ok(())` if the receipt was received, or `Err(ConnError::ReceiptTimeout)`
1152 /// if the timeout expired.
1153 ///
1154 /// # Example
1155 /// ```ignore
1156 /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
1157 /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
1158 /// println!("Message confirmed!");
1159 /// ```
1160 pub async fn wait_for_receipt(
1161 &self,
1162 receipt_id: &str,
1163 timeout: Duration,
1164 ) -> Result<(), ConnError> {
1165 // Get the receiver for this receipt
1166 let rx = {
1167 let mut receipts = self.pending_receipts.lock().await;
1168 // Re-create the oneshot channel and swap out the sender
1169 let (tx, rx) = oneshot::channel();
1170 if let Some(old_tx) = receipts.insert(receipt_id.to_string(), tx) {
1171 // Drop the old sender - this is expected if called after send_frame_with_receipt
1172 drop(old_tx);
1173 }
1174 rx
1175 };
1176
1177 // Wait for the receipt with timeout
1178 match tokio::time::timeout(timeout, rx).await {
1179 Ok(Ok(())) => Ok(()),
1180 Ok(Err(_)) => {
1181 // Channel was closed without receiving - connection likely dropped
1182 Err(ConnError::Protocol(
1183 "receipt channel closed unexpectedly".into(),
1184 ))
1185 }
1186 Err(_) => {
1187 // Timeout expired - clean up the pending receipt
1188 let mut receipts = self.pending_receipts.lock().await;
1189 receipts.remove(receipt_id);
1190 Err(ConnError::ReceiptTimeout(receipt_id.to_string()))
1191 }
1192 }
1193 }
1194
1195 /// Send a frame and wait for server confirmation via RECEIPT.
1196 ///
1197 /// This is a convenience method that combines `send_frame_with_receipt()`
1198 /// and `wait_for_receipt()`. Use this when you want to ensure a frame
1199 /// was processed by the server before continuing.
1200 ///
1201 /// # Parameters
1202 /// - `frame`: the frame to send.
1203 /// - `timeout`: maximum time to wait for the receipt.
1204 ///
1205 /// # Returns
1206 /// `Ok(())` if the frame was sent and receipt confirmed, or an error if
1207 /// sending failed or the receipt timed out.
1208 ///
1209 /// # Example
1210 /// ```ignore
1211 /// let frame = Frame::new("SEND")
1212 /// .header("destination", "/queue/orders")
1213 /// .set_body(b"order data".to_vec());
1214 ///
1215 /// conn.send_frame_confirmed(frame, Duration::from_secs(5)).await?;
1216 /// println!("Order sent and confirmed!");
1217 /// ```
1218 pub async fn send_frame_confirmed(
1219 &self,
1220 frame: Frame,
1221 timeout: Duration,
1222 ) -> Result<(), ConnError> {
1223 let receipt_id = Self::generate_receipt_id();
1224
1225 // Create the oneshot channel for notification
1226 let (tx, rx) = oneshot::channel();
1227
1228 // Register the pending receipt before sending
1229 {
1230 let mut receipts = self.pending_receipts.lock().await;
1231 receipts.insert(receipt_id.clone(), tx);
1232 }
1233
1234 // Add receipt header and send the frame
1235 let frame_with_receipt = frame.receipt(&receipt_id);
1236 self.send_frame(frame_with_receipt).await?;
1237
1238 // Wait for the receipt with timeout
1239 match tokio::time::timeout(timeout, rx).await {
1240 Ok(Ok(())) => Ok(()),
1241 Ok(Err(_)) => Err(ConnError::Protocol(
1242 "receipt channel closed unexpectedly".into(),
1243 )),
1244 Err(_) => {
1245 // Timeout expired - clean up
1246 let mut receipts = self.pending_receipts.lock().await;
1247 receipts.remove(&receipt_id);
1248 Err(ConnError::ReceiptTimeout(receipt_id))
1249 }
1250 }
1251 }
1252
1253 /// Subscribe to a destination.
1254 ///
1255 /// Parameters
1256 /// - `destination`: the STOMP destination to subscribe to (e.g. "/queue/foo").
1257 /// - `ack`: acknowledgement mode to request from the server.
1258 ///
1259 /// Returns a tuple `(subscription_id, receiver)` where `subscription_id` is
1260 /// the opaque id assigned locally for this subscription and `receiver` is a
1261 /// `mpsc::Receiver<Frame>` which will yield incoming MESSAGE frames for the
1262 /// destination. The caller should read from the receiver to handle messages.
1263 /// Subscribe to a destination using optional extra headers.
1264 ///
1265 /// This variant accepts additional headers which are stored locally and
1266 /// re-sent on reconnect. Use `subscribe` as a convenience wrapper when no
1267 /// extra headers are needed.
1268 pub async fn subscribe_with_headers(
1269 &self,
1270 destination: &str,
1271 ack: AckMode,
1272 extra_headers: Vec<(String, String)>,
1273 ) -> Result<crate::subscription::Subscription, ConnError> {
1274 let id = self
1275 .sub_id_counter
1276 .fetch_add(1, Ordering::SeqCst)
1277 .to_string();
1278 let (tx, rx) = mpsc::channel::<Frame>(16);
1279 {
1280 let mut map = self.subscriptions.lock().await;
1281 map.entry(destination.to_string())
1282 .or_insert_with(Vec::new)
1283 .push(SubscriptionEntry {
1284 id: id.clone(),
1285 sender: tx.clone(),
1286 ack: ack.as_str().to_string(),
1287 headers: extra_headers.clone(),
1288 });
1289 }
1290
1291 let mut f = Frame::new("SUBSCRIBE");
1292 f = f
1293 .header("id", &id)
1294 .header("destination", destination)
1295 .header("ack", ack.as_str());
1296 for (k, v) in &extra_headers {
1297 f = f.header(k, v);
1298 }
1299 self.outbound_tx
1300 .send(StompItem::Frame(f))
1301 .await
1302 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1303
1304 Ok(crate::subscription::Subscription::new(
1305 id,
1306 destination.to_string(),
1307 rx,
1308 self.clone(),
1309 ))
1310 }
1311
1312 /// Convenience wrapper without extra headers.
1313 pub async fn subscribe(
1314 &self,
1315 destination: &str,
1316 ack: AckMode,
1317 ) -> Result<crate::subscription::Subscription, ConnError> {
1318 self.subscribe_with_headers(destination, ack, Vec::new())
1319 .await
1320 }
1321
1322 /// Subscribe with a typed `SubscriptionOptions` structure.
1323 ///
1324 /// `SubscriptionOptions.headers` are forwarded to the broker and persisted
1325 /// for automatic resubscribe after reconnect. If `durable_queue` is set,
1326 /// it will be used as the actual destination instead of `destination`.
1327 pub async fn subscribe_with_options(
1328 &self,
1329 destination: &str,
1330 ack: AckMode,
1331 options: crate::subscription::SubscriptionOptions,
1332 ) -> Result<crate::subscription::Subscription, ConnError> {
1333 let dest = options
1334 .durable_queue
1335 .as_deref()
1336 .unwrap_or(destination)
1337 .to_string();
1338 self.subscribe_with_headers(&dest, ack, options.headers)
1339 .await
1340 }
1341
1342 /// Unsubscribe a previously created subscription by its local subscription id.
1343 pub async fn unsubscribe(&self, subscription_id: &str) -> Result<(), ConnError> {
1344 let mut found = false;
1345 {
1346 let mut map = self.subscriptions.lock().await;
1347 let mut remove_keys: Vec<String> = Vec::new();
1348 for (dest, vec) in map.iter_mut() {
1349 if let Some(pos) = vec.iter().position(|entry| entry.id == subscription_id) {
1350 vec.remove(pos);
1351 found = true;
1352 }
1353 if vec.is_empty() {
1354 remove_keys.push(dest.clone());
1355 }
1356 }
1357 for k in remove_keys {
1358 map.remove(&k);
1359 }
1360 }
1361
1362 if !found {
1363 return Err(ConnError::Protocol("subscription id not found".into()));
1364 }
1365
1366 let mut f = Frame::new("UNSUBSCRIBE");
1367 f = f.header("id", subscription_id);
1368 self.outbound_tx
1369 .send(StompItem::Frame(f))
1370 .await
1371 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1372
1373 Ok(())
1374 }
1375
1376 /// Acknowledge a message previously received in `client` or
1377 /// `client-individual` ack modes.
1378 ///
1379 /// STOMP ack semantics:
1380 /// - `auto`: server considers message delivered immediately; the client
1381 /// should not ack.
1382 /// - `client`: cumulative acknowledgements. ACKing message `M` for
1383 /// subscription `S` acknowledges all messages delivered to `S` up to
1384 /// and including `M`.
1385 /// - `client-individual`: only the named message is acknowledged.
1386 ///
1387 /// Parameters
1388 /// - `subscription_id`: the local subscription id returned by
1389 /// `Connection::subscribe`. This disambiguates which subscription's
1390 /// pending queue to advance for cumulative ACKs.
1391 /// - `message_id`: the `message-id` header value from the received
1392 /// MESSAGE frame to acknowledge.
1393 ///
1394 /// Behavior
1395 /// - The pending queue for `subscription_id` is searched for `message_id`.
1396 /// If the subscription used `client` ack mode, all pending messages up to
1397 /// and including the matched message are removed. If the subscription
1398 /// used `client-individual`, only the matched message is removed.
1399 /// - An `ACK` frame is sent to the server with `id=<message_id>` and
1400 /// `subscription=<subscription_id>` headers.
1401 #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1402 pub async fn ack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1403 // Remove from the local pending queue according to subscription ack mode.
1404 let mut removed_any = false;
1405 {
1406 let mut p = self.pending.lock().await;
1407 if let Some(queue) = p.get_mut(subscription_id) {
1408 if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1409 // Determine ack mode for this subscription (default to client).
1410 let mut ack_mode = "client".to_string();
1411 {
1412 let map = self.subscriptions.lock().await;
1413 'outer: for (_dest, vec) in map.iter() {
1414 for entry in vec.iter() {
1415 if entry.id == subscription_id {
1416 ack_mode = entry.ack.clone();
1417 break 'outer;
1418 }
1419 }
1420 }
1421 }
1422
1423 if ack_mode == "client" {
1424 // cumulative: remove up to and including pos
1425 for _ in 0..=pos {
1426 queue.pop_front();
1427 removed_any = true;
1428 }
1429 } else if queue.remove(pos).is_some() {
1430 // client-individual: remove only the specific message
1431 removed_any = true;
1432 }
1433
1434 if queue.is_empty() {
1435 p.remove(subscription_id);
1436 }
1437 }
1438 }
1439 }
1440
1441 // Send ACK to server (include subscription header for clarity)
1442 let mut f = Frame::new("ACK");
1443 f = f
1444 .header("id", message_id)
1445 .header("subscription", subscription_id);
1446 self.outbound_tx
1447 .send(StompItem::Frame(f))
1448 .await
1449 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1450
1451 // If message wasn't found locally, still send ACK to server; server
1452 // may ignore or treat it as no-op.
1453 let _ = removed_any;
1454 Ok(())
1455 }
1456
1457 /// Negative-acknowledge a message (NACK).
1458 ///
1459 /// Parameters
1460 /// - `subscription_id`: the local subscription id the message was delivered under.
1461 /// - `message_id`: the `message-id` header value from the received MESSAGE.
1462 ///
1463 /// Behavior
1464 /// - Removes the message from the local pending queue (cumulatively if the
1465 /// subscription used `client` ack mode, otherwise only the single
1466 /// message). Sends a `NACK` frame to the server with `id` and
1467 /// `subscription` headers.
1468 #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1469 pub async fn nack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1470 // Mirror ack removal semantics for pending map.
1471 let mut removed_any = false;
1472 {
1473 let mut p = self.pending.lock().await;
1474 if let Some(queue) = p.get_mut(subscription_id) {
1475 if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1476 let mut ack_mode = "client".to_string();
1477 {
1478 let map = self.subscriptions.lock().await;
1479 'outer2: for (_dest, vec) in map.iter() {
1480 for entry in vec.iter() {
1481 if entry.id == subscription_id {
1482 ack_mode = entry.ack.clone();
1483 break 'outer2;
1484 }
1485 }
1486 }
1487 }
1488
1489 if ack_mode == "client" {
1490 for _ in 0..=pos {
1491 queue.pop_front();
1492 removed_any = true;
1493 }
1494 } else if queue.remove(pos).is_some() {
1495 removed_any = true;
1496 }
1497
1498 if queue.is_empty() {
1499 p.remove(subscription_id);
1500 }
1501 }
1502 }
1503 }
1504
1505 let mut f = Frame::new("NACK");
1506 f = f
1507 .header("id", message_id)
1508 .header("subscription", subscription_id);
1509 self.outbound_tx
1510 .send(StompItem::Frame(f))
1511 .await
1512 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1513
1514 let _ = removed_any;
1515 Ok(())
1516 }
1517
1518 /// Helper to send a transaction frame (BEGIN, COMMIT, or ABORT).
1519 async fn send_transaction_frame(
1520 &self,
1521 command: &str,
1522 transaction_id: &str,
1523 ) -> Result<(), ConnError> {
1524 let f = Frame::new(command).header("transaction", transaction_id);
1525 self.outbound_tx
1526 .send(StompItem::Frame(f))
1527 .await
1528 .map_err(|_| ConnError::Protocol("send channel closed".into()))
1529 }
1530
1531 /// Begin a transaction.
1532 ///
1533 /// Parameters
1534 /// - `transaction_id`: unique identifier for the transaction. The caller is
1535 /// responsible for ensuring uniqueness within the connection.
1536 ///
1537 /// Behavior
1538 /// - Sends a `BEGIN` frame to the server with `transaction:<transaction_id>`
1539 /// header. Subsequent `SEND`, `ACK`, and `NACK` frames may include this
1540 /// transaction id to group them into the transaction. The transaction must
1541 /// be finalized with either `commit` or `abort`.
1542 pub async fn begin(&self, transaction_id: &str) -> Result<(), ConnError> {
1543 self.send_transaction_frame("BEGIN", transaction_id).await
1544 }
1545
1546 /// Commit a transaction.
1547 ///
1548 /// Parameters
1549 /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1550 ///
1551 /// Behavior
1552 /// - Sends a `COMMIT` frame to the server with `transaction:<transaction_id>`
1553 /// header. All operations within the transaction are applied atomically.
1554 pub async fn commit(&self, transaction_id: &str) -> Result<(), ConnError> {
1555 self.send_transaction_frame("COMMIT", transaction_id).await
1556 }
1557
1558 /// Abort a transaction.
1559 ///
1560 /// Parameters
1561 /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1562 ///
1563 /// Behavior
1564 /// - Sends an `ABORT` frame to the server with `transaction:<transaction_id>`
1565 /// header. All operations within the transaction are discarded.
1566 pub async fn abort(&self, transaction_id: &str) -> Result<(), ConnError> {
1567 self.send_transaction_frame("ABORT", transaction_id).await
1568 }
1569
1570 /// Receive the next frame from the server.
1571 ///
1572 /// Returns `Some(ReceivedFrame::Frame(..))` for normal frames (MESSAGE, etc.),
1573 /// `Some(ReceivedFrame::Error(..))` for ERROR frames, or `None` if the
1574 /// connection has been closed.
1575 ///
1576 /// # Example
1577 ///
1578 /// ```ignore
1579 /// use iridium_stomp::ReceivedFrame;
1580 ///
1581 /// while let Some(received) = conn.next_frame().await {
1582 /// match received {
1583 /// ReceivedFrame::Frame(frame) => {
1584 /// println!("Got {}: {:?}", frame.command, frame.body);
1585 /// }
1586 /// ReceivedFrame::Error(err) => {
1587 /// eprintln!("Server error: {}", err);
1588 /// break;
1589 /// }
1590 /// }
1591 /// }
1592 /// ```
1593 pub async fn next_frame(&self) -> Option<ReceivedFrame> {
1594 let mut rx = self.inbound_rx.lock().await;
1595 let frame = rx.recv().await?;
1596
1597 // Convert ERROR frames to ServerError for better ergonomics
1598 if frame.command == "ERROR" {
1599 Some(ReceivedFrame::Error(ServerError::from_frame(frame)))
1600 } else {
1601 Some(ReceivedFrame::Frame(frame))
1602 }
1603 }
1604
1605 pub async fn close(self) {
1606 // Signal the background task to shutdown by broadcasting on the
1607 // shutdown channel. Consumers may await task termination separately
1608 // if needed.
1609 let _ = self.shutdown_tx.send(());
1610 }
1611}
1612
1613fn current_millis() -> u64 {
1614 use std::time::{SystemTime, UNIX_EPOCH};
1615 SystemTime::now()
1616 .duration_since(UNIX_EPOCH)
1617 .map(|d| d.as_millis() as u64)
1618 .unwrap_or(0)
1619}
1620
1621#[cfg(test)]
1622mod tests {
1623 use super::*;
1624 use tokio::sync::mpsc;
1625
1626 // Helper to build a MESSAGE frame with given message-id and subscription/destination headers
1627 fn make_message(
1628 message_id: &str,
1629 subscription: Option<&str>,
1630 destination: Option<&str>,
1631 ) -> Frame {
1632 let mut f = Frame::new("MESSAGE");
1633 f = f.header("message-id", message_id);
1634 if let Some(s) = subscription {
1635 f = f.header("subscription", s);
1636 }
1637 if let Some(d) = destination {
1638 f = f.header("destination", d);
1639 }
1640 f
1641 }
1642
1643 #[tokio::test]
1644 async fn test_cumulative_ack_removes_prefix() {
1645 // setup channels
1646 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1647 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1648 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1649
1650 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1651 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1652
1653 let sub_id_counter = Arc::new(AtomicU64::new(1));
1654
1655 // create a subscription entry s1 with client (cumulative) ack
1656 let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1657 {
1658 let mut map = subscriptions.lock().await;
1659 map.insert(
1660 "/queue/x".to_string(),
1661 vec![SubscriptionEntry {
1662 id: "s1".to_string(),
1663 sender: sub_sender,
1664 ack: "client".to_string(),
1665 headers: Vec::new(),
1666 }],
1667 );
1668 }
1669
1670 // fill pending queue for s1: m1,m2,m3
1671 {
1672 let mut p = pending.lock().await;
1673 let mut q = VecDeque::new();
1674 q.push_back((
1675 "m1".to_string(),
1676 make_message("m1", Some("s1"), Some("/queue/x")),
1677 ));
1678 q.push_back((
1679 "m2".to_string(),
1680 make_message("m2", Some("s1"), Some("/queue/x")),
1681 ));
1682 q.push_back((
1683 "m3".to_string(),
1684 make_message("m3", Some("s1"), Some("/queue/x")),
1685 ));
1686 p.insert("s1".to_string(), q);
1687 }
1688
1689 let conn = Connection {
1690 outbound_tx: out_tx,
1691 inbound_rx: Arc::new(Mutex::new(in_rx)),
1692 shutdown_tx,
1693 subscriptions: subscriptions.clone(),
1694 sub_id_counter,
1695 pending: pending.clone(),
1696 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1697 };
1698
1699 // ack m2 cumulatively: should remove m1 and m2, leaving m3
1700 conn.ack("s1", "m2").await.expect("ack failed");
1701
1702 // verify pending for s1 contains only m3
1703 {
1704 let p = pending.lock().await;
1705 let q = p.get("s1").expect("missing s1");
1706 assert_eq!(q.len(), 1);
1707 assert_eq!(q.front().unwrap().0, "m3");
1708 }
1709
1710 // verify an ACK frame was emitted
1711 if let Some(item) = out_rx.recv().await {
1712 match item {
1713 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1714 _ => panic!("expected frame"),
1715 }
1716 } else {
1717 panic!("no outbound frame sent")
1718 }
1719 }
1720
1721 #[tokio::test]
1722 async fn test_client_individual_ack_removes_only_one() {
1723 // setup channels
1724 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1725 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1726 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1727
1728 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1729 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1730
1731 let sub_id_counter = Arc::new(AtomicU64::new(1));
1732
1733 // create a subscription entry s2 with client-individual ack
1734 let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1735 {
1736 let mut map = subscriptions.lock().await;
1737 map.insert(
1738 "/queue/y".to_string(),
1739 vec![SubscriptionEntry {
1740 id: "s2".to_string(),
1741 sender: sub_sender,
1742 ack: "client-individual".to_string(),
1743 headers: Vec::new(),
1744 }],
1745 );
1746 }
1747
1748 // fill pending queue for s2: a,b,c
1749 {
1750 let mut p = pending.lock().await;
1751 let mut q = VecDeque::new();
1752 q.push_back((
1753 "a".to_string(),
1754 make_message("a", Some("s2"), Some("/queue/y")),
1755 ));
1756 q.push_back((
1757 "b".to_string(),
1758 make_message("b", Some("s2"), Some("/queue/y")),
1759 ));
1760 q.push_back((
1761 "c".to_string(),
1762 make_message("c", Some("s2"), Some("/queue/y")),
1763 ));
1764 p.insert("s2".to_string(), q);
1765 }
1766
1767 let conn = Connection {
1768 outbound_tx: out_tx,
1769 inbound_rx: Arc::new(Mutex::new(in_rx)),
1770 shutdown_tx,
1771 subscriptions: subscriptions.clone(),
1772 sub_id_counter,
1773 pending: pending.clone(),
1774 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1775 };
1776
1777 // ack only 'b' individually
1778 conn.ack("s2", "b").await.expect("ack failed");
1779
1780 // verify pending for s2 contains a and c
1781 {
1782 let p = pending.lock().await;
1783 let q = p.get("s2").expect("missing s2");
1784 assert_eq!(q.len(), 2);
1785 assert_eq!(q[0].0, "a");
1786 assert_eq!(q[1].0, "c");
1787 }
1788
1789 // verify an ACK frame was emitted
1790 if let Some(item) = out_rx.recv().await {
1791 match item {
1792 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1793 _ => panic!("expected frame"),
1794 }
1795 } else {
1796 panic!("no outbound frame sent")
1797 }
1798 }
1799
1800 #[tokio::test]
1801 async fn test_subscription_receive_delivers_message() {
1802 // setup channels
1803 let (out_tx, _out_rx) = mpsc::channel::<StompItem>(8);
1804 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1805 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1806
1807 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1808 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1809
1810 let sub_id_counter = Arc::new(AtomicU64::new(1));
1811
1812 let conn = Connection {
1813 outbound_tx: out_tx,
1814 inbound_rx: Arc::new(Mutex::new(in_rx)),
1815 shutdown_tx,
1816 subscriptions: subscriptions.clone(),
1817 sub_id_counter,
1818 pending: pending.clone(),
1819 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1820 };
1821
1822 // subscribe
1823 let subscription = conn
1824 .subscribe("/queue/test", AckMode::Auto)
1825 .await
1826 .expect("subscribe failed");
1827
1828 // find the sender stored in the subscriptions map and push a message
1829 {
1830 let map = conn.subscriptions.lock().await;
1831 let vec = map.get("/queue/test").expect("missing subscription vec");
1832 let sender = &vec[0].sender;
1833 let f = make_message("m1", Some(&vec[0].id), Some("/queue/test"));
1834 sender.try_send(f).expect("send to subscription failed");
1835 }
1836
1837 // consume from the subscription receiver
1838 let mut rx = subscription.into_receiver();
1839 if let Some(received) = rx.recv().await {
1840 assert_eq!(received.command, "MESSAGE");
1841 // message-id header should be present
1842 let mut found = false;
1843 for (k, _v) in &received.headers {
1844 if k.to_lowercase() == "message-id" {
1845 found = true;
1846 break;
1847 }
1848 }
1849 assert!(found, "message-id header missing");
1850 } else {
1851 panic!("no message received on subscription")
1852 }
1853 }
1854
1855 #[tokio::test]
1856 async fn test_subscription_ack_removes_pending_and_sends_ack() {
1857 // setup channels
1858 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1859 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1860 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1861
1862 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1863 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1864
1865 let sub_id_counter = Arc::new(AtomicU64::new(1));
1866
1867 let conn = Connection {
1868 outbound_tx: out_tx,
1869 inbound_rx: Arc::new(Mutex::new(in_rx)),
1870 shutdown_tx,
1871 subscriptions: subscriptions.clone(),
1872 sub_id_counter,
1873 pending: pending.clone(),
1874 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1875 };
1876
1877 // subscribe with client ack
1878 let subscription = conn
1879 .subscribe("/queue/ack", AckMode::Client)
1880 .await
1881 .expect("subscribe failed");
1882
1883 let sub_id = subscription.id().to_string();
1884
1885 // drain any initial outbound frames (SUBSCRIBE) emitted by subscribe()
1886 while out_rx.try_recv().is_ok() {}
1887
1888 // populate pending queue for this subscription
1889 {
1890 let mut p = conn.pending.lock().await;
1891 let mut q = VecDeque::new();
1892 q.push_back((
1893 "mid-1".to_string(),
1894 make_message("mid-1", Some(&sub_id), Some("/queue/ack")),
1895 ));
1896 p.insert(sub_id.clone(), q);
1897 }
1898
1899 // ack the message via the subscription helper
1900 subscription.ack("mid-1").await.expect("ack failed");
1901
1902 // ensure pending queue no longer contains the message
1903 {
1904 let p = conn.pending.lock().await;
1905 assert!(p.get(&sub_id).is_none() || p.get(&sub_id).unwrap().is_empty());
1906 }
1907
1908 // verify an ACK frame was emitted
1909 if let Some(item) = out_rx.recv().await {
1910 match item {
1911 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1912 _ => panic!("expected frame"),
1913 }
1914 } else {
1915 panic!("no outbound frame sent")
1916 }
1917 }
1918
1919 // Helper function to create a test connection and output receiver
1920 fn setup_test_connection() -> (Connection, mpsc::Receiver<StompItem>) {
1921 let (out_tx, out_rx) = mpsc::channel::<StompItem>(8);
1922 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1923 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1924
1925 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1926 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1927 let sub_id_counter = Arc::new(AtomicU64::new(1));
1928
1929 let conn = Connection {
1930 outbound_tx: out_tx,
1931 inbound_rx: Arc::new(Mutex::new(in_rx)),
1932 shutdown_tx,
1933 subscriptions,
1934 sub_id_counter,
1935 pending,
1936 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1937 };
1938
1939 (conn, out_rx)
1940 }
1941
1942 // Helper function to verify a frame with a transaction header
1943 fn verify_transaction_frame(frame: Frame, expected_command: &str, expected_tx_id: &str) {
1944 assert_eq!(frame.command, expected_command);
1945 assert!(
1946 frame
1947 .headers
1948 .iter()
1949 .any(|(k, v)| k == "transaction" && v == expected_tx_id),
1950 "transaction header with id '{}' not found",
1951 expected_tx_id
1952 );
1953 }
1954
1955 #[tokio::test]
1956 async fn test_begin_transaction_sends_frame() {
1957 let (conn, mut out_rx) = setup_test_connection();
1958
1959 conn.begin("tx1").await.expect("begin failed");
1960
1961 // verify BEGIN frame was emitted
1962 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1963 verify_transaction_frame(f, "BEGIN", "tx1");
1964 } else {
1965 panic!("no outbound frame sent")
1966 }
1967 }
1968
1969 #[tokio::test]
1970 async fn test_commit_transaction_sends_frame() {
1971 let (conn, mut out_rx) = setup_test_connection();
1972
1973 conn.commit("tx1").await.expect("commit failed");
1974
1975 // verify COMMIT frame was emitted
1976 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1977 verify_transaction_frame(f, "COMMIT", "tx1");
1978 } else {
1979 panic!("no outbound frame sent")
1980 }
1981 }
1982
1983 #[tokio::test]
1984 async fn test_abort_transaction_sends_frame() {
1985 let (conn, mut out_rx) = setup_test_connection();
1986
1987 conn.abort("tx1").await.expect("abort failed");
1988
1989 // verify ABORT frame was emitted
1990 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1991 verify_transaction_frame(f, "ABORT", "tx1");
1992 } else {
1993 panic!("no outbound frame sent")
1994 }
1995 }
1996}