iridium_stomp/connection.rs
1use futures::{SinkExt, StreamExt, future};
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6use thiserror::Error;
7use tokio::net::TcpStream;
8use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
9use tokio_util::codec::Framed;
10
11use crate::codec::{StompCodec, StompItem};
12use crate::frame::Frame;
13
14/// Configuration for STOMP heartbeat intervals.
15///
16/// Provides a type-safe way to configure heartbeat values instead of using
17/// raw strings. The `Display` implementation formats the value as required
18/// by the STOMP protocol ("send_ms,receive_ms").
19///
20/// # Example
21///
22/// ```
23/// use iridium_stomp::Heartbeat;
24///
25/// // Create a custom heartbeat configuration
26/// let hb = Heartbeat::new(5000, 10000);
27/// assert_eq!(hb.to_string(), "5000,10000");
28///
29/// // Use predefined configurations
30/// assert_eq!(Heartbeat::disabled().to_string(), "0,0");
31/// assert_eq!(Heartbeat::default().to_string(), "10000,10000");
32/// ```
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct Heartbeat {
35 /// Minimum interval (in milliseconds) between heartbeats the client can send.
36 /// A value of 0 means the client cannot send heartbeats.
37 pub send_ms: u32,
38
39 /// Minimum interval (in milliseconds) between heartbeats the client wants to receive.
40 /// A value of 0 means the client does not want to receive heartbeats.
41 pub receive_ms: u32,
42}
43
44impl Heartbeat {
45 /// Create a new heartbeat configuration with the specified intervals.
46 ///
47 /// # Arguments
48 ///
49 /// * `send_ms` - Minimum interval in milliseconds between heartbeats the client can send.
50 /// * `receive_ms` - Minimum interval in milliseconds between heartbeats the client wants to receive.
51 ///
52 /// # Example
53 ///
54 /// ```
55 /// use iridium_stomp::Heartbeat;
56 ///
57 /// let hb = Heartbeat::new(5000, 10000);
58 /// assert_eq!(hb.send_ms, 5000);
59 /// assert_eq!(hb.receive_ms, 10000);
60 /// ```
61 pub fn new(send_ms: u32, receive_ms: u32) -> Self {
62 Self {
63 send_ms,
64 receive_ms,
65 }
66 }
67
68 /// Create a heartbeat configuration that disables heartbeats entirely.
69 ///
70 /// This is equivalent to `Heartbeat::new(0, 0)`.
71 ///
72 /// # Example
73 ///
74 /// ```
75 /// use iridium_stomp::Heartbeat;
76 ///
77 /// let hb = Heartbeat::disabled();
78 /// assert_eq!(hb.send_ms, 0);
79 /// assert_eq!(hb.receive_ms, 0);
80 /// assert_eq!(hb.to_string(), "0,0");
81 /// ```
82 pub fn disabled() -> Self {
83 Self::new(0, 0)
84 }
85
86 /// Create a heartbeat configuration from a Duration for symmetric heartbeats.
87 ///
88 /// Both send and receive intervals will be set to the same value.
89 ///
90 /// The maximum supported Duration is approximately 49.7 days (u32::MAX milliseconds,
91 /// or 4,294,967,295 ms). If a larger Duration is provided, it will be clamped to
92 /// u32::MAX milliseconds to prevent overflow.
93 ///
94 /// # Example
95 ///
96 /// ```
97 /// use iridium_stomp::Heartbeat;
98 /// use std::time::Duration;
99 ///
100 /// let hb = Heartbeat::from_duration(Duration::from_secs(15));
101 /// assert_eq!(hb.send_ms, 15000);
102 /// assert_eq!(hb.receive_ms, 15000);
103 /// ```
104 pub fn from_duration(interval: Duration) -> Self {
105 let ms = interval.as_millis().min(u32::MAX as u128) as u32;
106 Self::new(ms, ms)
107 }
108}
109
110impl Default for Heartbeat {
111 /// Returns the default heartbeat configuration: 10 seconds for both send and receive.
112 fn default() -> Self {
113 Self::new(10000, 10000)
114 }
115}
116
117impl std::fmt::Display for Heartbeat {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 write!(f, "{},{}", self.send_ms, self.receive_ms)
120 }
121}
122
123/// Internal subscription entry stored for each destination.
124#[derive(Clone)]
125pub(crate) struct SubscriptionEntry {
126 pub(crate) id: String,
127 pub(crate) sender: mpsc::Sender<Frame>,
128 pub(crate) ack: String,
129 pub(crate) headers: Vec<(String, String)>,
130}
131
132/// Alias for the subscription dispatch map: destination -> list of
133/// `SubscriptionEntry`.
134pub(crate) type Subscriptions = HashMap<String, Vec<SubscriptionEntry>>;
135
136/// Alias for the pending map: subscription_id -> queue of (message-id, Frame).
137pub(crate) type PendingMap = HashMap<String, VecDeque<(String, Frame)>>;
138
139/// Internal type for resubscribe snapshot entries: (destination, id, ack, headers)
140pub(crate) type ResubEntry = (String, String, String, Vec<(String, String)>);
141
142/// Alias for pending receipt map: receipt-id -> oneshot sender to notify when received.
143pub(crate) type PendingReceipts = HashMap<String, oneshot::Sender<()>>;
144
145/// Errors returned by `Connection` operations.
146#[derive(Error, Debug)]
147pub enum ConnError {
148 /// I/O-level error
149 #[error("io error: {0}")]
150 Io(#[from] std::io::Error),
151 /// Protocol-level error
152 #[error("protocol error: {0}")]
153 Protocol(String),
154 /// Receipt timeout error
155 #[error("receipt timeout: no RECEIPT received for '{0}' within timeout")]
156 ReceiptTimeout(String),
157}
158
159/// Represents an ERROR frame received from the STOMP server.
160///
161/// STOMP servers send ERROR frames to indicate protocol violations, authentication
162/// failures, or other server-side errors. After sending an ERROR frame, the server
163/// typically closes the connection.
164///
165/// # Example
166///
167/// ```ignore
168/// use iridium_stomp::ReceivedFrame;
169///
170/// while let Some(received) = conn.next_frame().await {
171/// match received {
172/// ReceivedFrame::Frame(frame) => {
173/// // Normal message processing
174/// }
175/// ReceivedFrame::Error(err) => {
176/// eprintln!("Server error: {}", err.message);
177/// if let Some(body) = &err.body {
178/// eprintln!("Details: {}", body);
179/// }
180/// break;
181/// }
182/// }
183/// }
184/// ```
185#[derive(Debug, Clone, PartialEq, Eq)]
186pub struct ServerError {
187 /// The error message from the `message` header.
188 pub message: String,
189
190 /// The error body, if present. Contains additional error details.
191 pub body: Option<String>,
192
193 /// The receipt-id if this error is in response to a specific frame.
194 pub receipt_id: Option<String>,
195
196 /// The original ERROR frame for access to additional headers.
197 pub frame: Frame,
198}
199
200impl ServerError {
201 /// Create a `ServerError` from an ERROR frame.
202 ///
203 /// This is primarily used internally but is public for testing and
204 /// advanced use cases where you need to construct a `ServerError` manually.
205 pub fn from_frame(frame: Frame) -> Self {
206 let message = frame
207 .get_header("message")
208 .unwrap_or("unknown error")
209 .to_string();
210
211 let body = if frame.body.is_empty() {
212 None
213 } else {
214 String::from_utf8(frame.body.clone()).ok()
215 };
216
217 let receipt_id = frame.get_header("receipt-id").map(|s| s.to_string());
218
219 Self {
220 message,
221 body,
222 receipt_id,
223 frame,
224 }
225 }
226}
227
228impl std::fmt::Display for ServerError {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 write!(f, "STOMP server error: {}", self.message)?;
231 if let Some(body) = &self.body {
232 write!(f, " - {}", body)?;
233 }
234 Ok(())
235 }
236}
237
238impl std::error::Error for ServerError {}
239
240/// The result of receiving a frame from the server.
241///
242/// STOMP servers can send either normal frames (MESSAGE, RECEIPT, etc.) or
243/// ERROR frames indicating a problem. This enum allows callers to handle
244/// both cases with pattern matching.
245///
246/// # Example
247///
248/// ```ignore
249/// use iridium_stomp::ReceivedFrame;
250///
251/// match conn.next_frame().await {
252/// Some(ReceivedFrame::Frame(frame)) => {
253/// println!("Got frame: {}", frame.command);
254/// }
255/// Some(ReceivedFrame::Error(err)) => {
256/// eprintln!("Server error: {}", err);
257/// }
258/// None => {
259/// println!("Connection closed");
260/// }
261/// }
262/// ```
263#[derive(Debug, Clone, PartialEq, Eq)]
264pub enum ReceivedFrame {
265 /// A normal STOMP frame (MESSAGE, RECEIPT, etc.)
266 Frame(Frame),
267 /// An ERROR frame from the server
268 Error(ServerError),
269}
270
271impl ReceivedFrame {
272 /// Returns `true` if this is an error frame.
273 pub fn is_error(&self) -> bool {
274 matches!(self, ReceivedFrame::Error(_))
275 }
276
277 /// Returns `true` if this is a normal frame.
278 pub fn is_frame(&self) -> bool {
279 matches!(self, ReceivedFrame::Frame(_))
280 }
281
282 /// Returns the frame if this is a normal frame, or `None` if it's an error.
283 pub fn into_frame(self) -> Option<Frame> {
284 match self {
285 ReceivedFrame::Frame(f) => Some(f),
286 ReceivedFrame::Error(_) => None,
287 }
288 }
289
290 /// Returns the error if this is an error frame, or `None` if it's a normal frame.
291 pub fn into_error(self) -> Option<ServerError> {
292 match self {
293 ReceivedFrame::Frame(_) => None,
294 ReceivedFrame::Error(e) => Some(e),
295 }
296 }
297}
298
299/// Subscription acknowledgement modes as defined by STOMP 1.2.
300#[derive(Debug, Clone, Copy, PartialEq, Eq)]
301pub enum AckMode {
302 Auto,
303 Client,
304 ClientIndividual,
305}
306
307impl AckMode {
308 fn as_str(&self) -> &'static str {
309 match self {
310 AckMode::Auto => "auto",
311 AckMode::Client => "client",
312 AckMode::ClientIndividual => "client-individual",
313 }
314 }
315}
316
317/// Options for customizing the STOMP CONNECT frame.
318///
319/// Use this struct with `Connection::connect_with_options()` to set custom
320/// headers, specify supported STOMP versions, or configure broker-specific
321/// options like `client-id` for durable subscriptions.
322///
323/// # Validation
324///
325/// This struct performs minimal validation. Values are passed to the broker
326/// as-is, and invalid configurations will be rejected by the broker at
327/// connection time. Empty strings are technically accepted but may cause
328/// broker-specific errors.
329///
330/// # Custom Headers
331///
332/// Custom headers added via `header()` cannot override critical STOMP headers
333/// (`accept-version`, `host`, `login`, `passcode`, `heart-beat`, `client-id`).
334/// Such headers are silently ignored. Use the dedicated builder methods to
335/// set these values.
336///
337/// # Example
338///
339/// ```ignore
340/// use iridium_stomp::{Connection, ConnectOptions};
341///
342/// let options = ConnectOptions::default()
343/// .client_id("my-durable-client")
344/// .host("my-vhost")
345/// .header("custom-header", "value");
346///
347/// let conn = Connection::connect_with_options(
348/// "localhost:61613",
349/// "guest",
350/// "guest",
351/// "10000,10000",
352/// options,
353/// ).await?;
354/// ```
355#[derive(Debug, Clone, Default)]
356pub struct ConnectOptions {
357 /// STOMP version(s) to accept (e.g., "1.2" or "1.0,1.1,1.2").
358 /// Defaults to "1.2" if not set.
359 pub accept_version: Option<String>,
360
361 /// Client ID for durable subscriptions (required by ActiveMQ, etc.).
362 pub client_id: Option<String>,
363
364 /// Virtual host header value. Defaults to "/" if not set.
365 pub host: Option<String>,
366
367 /// Additional custom headers to include in the CONNECT frame.
368 /// Note: Headers that would override critical STOMP headers are ignored.
369 pub headers: Vec<(String, String)>,
370}
371
372impl ConnectOptions {
373 /// Create a new `ConnectOptions` with default values.
374 pub fn new() -> Self {
375 Self::default()
376 }
377
378 /// Set the STOMP version(s) to accept (builder style).
379 ///
380 /// Examples: "1.2", "1.1,1.2", "1.0,1.1,1.2"
381 pub fn accept_version(mut self, version: impl Into<String>) -> Self {
382 self.accept_version = Some(version.into());
383 self
384 }
385
386 /// Set the client ID for durable subscriptions (builder style).
387 ///
388 /// Required by some brokers (e.g., ActiveMQ) for durable topic subscriptions.
389 pub fn client_id(mut self, id: impl Into<String>) -> Self {
390 self.client_id = Some(id.into());
391 self
392 }
393
394 /// Set the virtual host (builder style).
395 ///
396 /// Defaults to "/" if not set.
397 pub fn host(mut self, host: impl Into<String>) -> Self {
398 self.host = Some(host.into());
399 self
400 }
401
402 /// Add a custom header to the CONNECT frame (builder style).
403 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
404 self.headers.push((key.into(), value.into()));
405 self
406 }
407}
408
409/// Parse the STOMP `heart-beat` header value (format: "cx,cy").
410///
411/// Parameters
412/// - `header`: header string from the server or client (for example
413/// "10000,10000"). The values represent milliseconds.
414///
415/// Returns a tuple `(cx, cy)` where each value is the heartbeat interval in
416/// milliseconds. Missing or invalid fields default to `0`.
417pub fn parse_heartbeat_header(header: &str) -> (u64, u64) {
418 let mut parts = header.split(',');
419 let cx = parts
420 .next()
421 .and_then(|s| s.trim().parse::<u64>().ok())
422 .unwrap_or(0);
423 let cy = parts
424 .next()
425 .and_then(|s| s.trim().parse::<u64>().ok())
426 .unwrap_or(0);
427 (cx, cy)
428}
429
430/// Negotiate heartbeat intervals between client and server.
431///
432/// Parameters
433/// - `client_out`: client's desired outgoing heartbeat interval in
434/// milliseconds (how often the client will send heartbeats).
435/// - `client_in`: client's desired incoming heartbeat interval in
436/// milliseconds (how often the client expects to receive heartbeats).
437/// - `server_out`: server's advertised outgoing interval in milliseconds.
438/// - `server_in`: server's advertised incoming interval in milliseconds.
439///
440/// Returns `(outgoing, incoming)` where each element is `Some(Duration)` if
441/// heartbeats are enabled in that direction, or `None` if disabled. The
442/// negotiated interval uses the STOMP rule of taking the maximum of the
443/// corresponding client and server values.
444pub fn negotiate_heartbeats(
445 client_out: u64,
446 client_in: u64,
447 server_out: u64,
448 server_in: u64,
449) -> (Option<Duration>, Option<Duration>) {
450 let negotiated_out_ms = std::cmp::max(client_out, server_in);
451 let negotiated_in_ms = std::cmp::max(client_in, server_out);
452
453 let outgoing = if negotiated_out_ms == 0 {
454 None
455 } else {
456 Some(Duration::from_millis(negotiated_out_ms))
457 };
458 let incoming = if negotiated_in_ms == 0 {
459 None
460 } else {
461 Some(Duration::from_millis(negotiated_in_ms))
462 };
463 (outgoing, incoming)
464}
465
466/// High-level connection object that manages a single TCP/STOMP connection.
467///
468/// The `Connection` spawns a background task that maintains the TCP transport,
469/// sends/receives STOMP frames using `StompCodec`, negotiates heartbeats, and
470/// performs simple reconnect logic with exponential backoff.
471#[derive(Clone)]
472pub struct Connection {
473 outbound_tx: mpsc::Sender<StompItem>,
474 /// The inbound receiver is shared behind a mutex so the `Connection`
475 /// handle may be cloned and callers can call `next_frame` concurrently.
476 inbound_rx: Arc<Mutex<mpsc::Receiver<Frame>>>,
477 shutdown_tx: broadcast::Sender<()>,
478 /// Map of destination -> list of (subscription id, sender) for dispatching
479 /// inbound MESSAGE frames to subscribers.
480 subscriptions: Arc<Mutex<Subscriptions>>,
481 /// Monotonic counter used to allocate subscription ids.
482 sub_id_counter: Arc<AtomicU64>,
483 /// Pending messages awaiting ACK/NACK from the application.
484 ///
485 /// Organized by subscription id. For `client` ack mode the ACK is
486 /// cumulative: acknowledging message `M` for subscription `S` acknowledges
487 /// all messages previously delivered for `S` up to and including `M`.
488 /// For `client-individual` the ACK/NACK applies only to the single
489 /// message.
490 pending: Arc<Mutex<PendingMap>>,
491 /// Pending receipt confirmations.
492 ///
493 /// When a frame is sent with a `receipt` header, the receipt-id is stored
494 /// here with a oneshot sender. When the server responds with a RECEIPT
495 /// frame, the sender is notified.
496 pending_receipts: Arc<Mutex<PendingReceipts>>,
497}
498
499impl Connection {
500 /// Heartbeat value that disables heartbeats entirely.
501 ///
502 /// Use this when you don't want the client or server to send heartbeats.
503 /// Note that some brokers may still require heartbeats for long-lived connections.
504 ///
505 /// # Example
506 ///
507 /// ```ignore
508 /// let conn = Connection::connect(
509 /// "localhost:61613",
510 /// "guest",
511 /// "guest",
512 /// Connection::NO_HEARTBEAT,
513 /// ).await?;
514 /// ```
515 pub const NO_HEARTBEAT: &'static str = "0,0";
516
517 /// Default heartbeat value: 10 seconds for both send and receive.
518 ///
519 /// This is a reasonable default for most applications. The actual heartbeat
520 /// interval will be negotiated with the server (taking the maximum of client
521 /// and server preferences).
522 ///
523 /// # Example
524 ///
525 /// ```ignore
526 /// let conn = Connection::connect(
527 /// "localhost:61613",
528 /// "guest",
529 /// "guest",
530 /// Connection::DEFAULT_HEARTBEAT,
531 /// ).await?;
532 /// ```
533 pub const DEFAULT_HEARTBEAT: &'static str = "10000,10000";
534
535 /// Establish a connection to the STOMP server at `addr` with the given
536 /// credentials and heartbeat header string (e.g. "10000,10000").
537 ///
538 /// This is a convenience wrapper around `connect_with_options()` that uses
539 /// default options (STOMP 1.2, host="/", no client-id).
540 ///
541 /// Parameters
542 /// - `addr`: TCP address (host:port) of the STOMP server.
543 /// - `login`: login username for STOMP `CONNECT`.
544 /// - `passcode`: passcode for STOMP `CONNECT`.
545 /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
546 /// milliseconds) that will be sent in the `CONNECT` frame.
547 ///
548 /// Returns a `Connection` which provides `send_frame`, `next_frame`, and
549 /// `close` helpers. The detailed connection handling (I/O, heartbeats,
550 /// reconnects) runs on a background task spawned by this method.
551 pub async fn connect(
552 addr: &str,
553 login: &str,
554 passcode: &str,
555 client_hb: &str,
556 ) -> Result<Self, ConnError> {
557 Self::connect_with_options(addr, login, passcode, client_hb, ConnectOptions::default())
558 .await
559 }
560
561 /// Establish a connection to the STOMP server with custom options.
562 ///
563 /// Use this method when you need to set a custom `client-id` (for durable
564 /// subscriptions), specify a virtual host, negotiate different STOMP
565 /// versions, or add custom CONNECT headers.
566 ///
567 /// Parameters
568 /// - `addr`: TCP address (host:port) of the STOMP server.
569 /// - `login`: login username for STOMP `CONNECT`.
570 /// - `passcode`: passcode for STOMP `CONNECT`.
571 /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
572 /// milliseconds) that will be sent in the `CONNECT` frame.
573 /// - `options`: custom connection options (version, host, client-id, etc.).
574 ///
575 /// # Example
576 ///
577 /// ```ignore
578 /// use iridium_stomp::{Connection, ConnectOptions};
579 ///
580 /// // Connect with a client-id for durable subscriptions
581 /// let options = ConnectOptions::default()
582 /// .client_id("my-app-instance-1");
583 ///
584 /// let conn = Connection::connect_with_options(
585 /// "localhost:61613",
586 /// "guest",
587 /// "guest",
588 /// "10000,10000",
589 /// options,
590 /// ).await?;
591 /// ```
592 pub async fn connect_with_options(
593 addr: &str,
594 login: &str,
595 passcode: &str,
596 client_hb: &str,
597 options: ConnectOptions,
598 ) -> Result<Self, ConnError> {
599 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(32);
600 let (in_tx, in_rx) = mpsc::channel::<Frame>(32);
601 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
602 let sub_id_counter = Arc::new(AtomicU64::new(1));
603 let (shutdown_tx, _) = broadcast::channel::<()>(1);
604 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
605 let pending_clone = pending.clone();
606 let pending_receipts: Arc<Mutex<PendingReceipts>> = Arc::new(Mutex::new(HashMap::new()));
607 let pending_receipts_clone = pending_receipts.clone();
608
609 let addr = addr.to_string();
610 let login = login.to_string();
611 let passcode = passcode.to_string();
612 let client_hb = client_hb.to_string();
613
614 // Extract options into owned values for the spawned task
615 let accept_version = options.accept_version.unwrap_or_else(|| "1.2".to_string());
616 let host = options.host.unwrap_or_else(|| "/".to_string());
617 let client_id = options.client_id;
618 let custom_headers = options.headers;
619
620 let shutdown_tx_clone = shutdown_tx.clone();
621 let subscriptions_clone = subscriptions.clone();
622
623 tokio::spawn(async move {
624 let mut backoff_secs: u64 = 1;
625 loop {
626 let mut shutdown_sub = shutdown_tx_clone.subscribe();
627
628 tokio::select! {
629 _ = shutdown_sub.recv() => break,
630 _ = future::ready(()) => {},
631 }
632
633 match TcpStream::connect(&addr).await {
634 Ok(stream) => {
635 let mut framed = Framed::new(stream, StompCodec::new());
636
637 let mut connect = Frame::new("CONNECT");
638 connect = connect.header("accept-version", &accept_version);
639 connect = connect.header("host", &host);
640 connect = connect.header("login", &login);
641 connect = connect.header("passcode", &passcode);
642 connect = connect.header("heart-beat", &client_hb);
643
644 // Add client-id if specified
645 if let Some(ref cid) = client_id {
646 connect = connect.header("client-id", cid);
647 }
648
649 // Add any custom headers, skipping any that would override
650 // critical STOMP CONNECT headers already set above
651 for (key, value) in &custom_headers {
652 let dominated = matches!(
653 key.to_ascii_lowercase().as_str(),
654 "accept-version"
655 | "host"
656 | "login"
657 | "passcode"
658 | "heart-beat"
659 | "client-id"
660 );
661 if !dominated {
662 connect = connect.header(key, value);
663 }
664 }
665
666 if framed.send(StompItem::Frame(connect)).await.is_err() {
667 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
668 backoff_secs = (backoff_secs * 2).min(30);
669 continue;
670 }
671
672 let mut server_heartbeat = "0,0".to_string();
673 loop {
674 tokio::select! {
675 _ = shutdown_sub.recv() => break,
676 item = framed.next() => {
677 match item {
678 Some(Ok(StompItem::Heartbeat)) => {}
679 Some(Ok(StompItem::Frame(f))) => {
680 if f.command == "CONNECTED" {
681 for (k, v) in &f.headers {
682 if k.to_lowercase() == "heart-beat" { server_heartbeat = v.clone(); }
683 }
684 break;
685 }
686 }
687 _ => break,
688 }
689 }
690 }
691 }
692
693 let (cx, cy) = parse_heartbeat_header(&client_hb);
694 let (sx, sy) = parse_heartbeat_header(&server_heartbeat);
695 let (send_interval, recv_interval) = negotiate_heartbeats(cx, cy, sx, sy);
696
697 let last_received = Arc::new(AtomicU64::new(current_millis()));
698 let writer_last_sent = Arc::new(AtomicU64::new(current_millis()));
699
700 let (mut sink, mut stream) = framed.split();
701 let in_tx = in_tx.clone();
702 let subscriptions = subscriptions_clone.clone();
703
704 // Clear pending message map on reconnect — messages that were
705 // outstanding before the disconnect are considered lost and
706 // will be redelivered by the server as appropriate.
707 {
708 let mut p = pending_clone.lock().await;
709 p.clear();
710 }
711
712 // Resubscribe any existing subscriptions after reconnect.
713 // We snapshot the subscription entries while holding the lock
714 // and then issue SUBSCRIBE frames using the sink.
715 let subs_snapshot: Vec<ResubEntry> = {
716 let map = subscriptions.lock().await;
717 let mut v: Vec<ResubEntry> = Vec::new();
718 for (dest, vec) in map.iter() {
719 for entry in vec.iter() {
720 v.push((
721 dest.clone(),
722 entry.id.clone(),
723 entry.ack.clone(),
724 entry.headers.clone(),
725 ));
726 }
727 }
728 v
729 };
730
731 for (dest, id, ack, headers) in subs_snapshot {
732 let mut sf = Frame::new("SUBSCRIBE");
733 sf = sf
734 .header("id", &id)
735 .header("destination", &dest)
736 .header("ack", &ack);
737 for (k, v) in headers {
738 sf = sf.header(&k, &v);
739 }
740 let _ = sink.send(StompItem::Frame(sf)).await;
741 }
742
743 let mut hb_tick = match send_interval {
744 Some(d) => tokio::time::interval(d),
745 None => tokio::time::interval(Duration::from_secs(86400)),
746 };
747 let watchdog_half = recv_interval.map(|d| d / 2);
748
749 backoff_secs = 1;
750
751 'conn: loop {
752 tokio::select! {
753 _ = shutdown_sub.recv() => { let _ = sink.close().await; break 'conn; }
754 maybe = out_rx.recv() => {
755 match maybe {
756 Some(item) => if sink.send(item).await.is_err() { break 'conn } else { writer_last_sent.store(current_millis(), Ordering::SeqCst); }
757 None => break 'conn,
758 }
759 }
760 item = stream.next() => {
761 match item {
762 Some(Ok(StompItem::Heartbeat)) => { last_received.store(current_millis(), Ordering::SeqCst); }
763 Some(Ok(StompItem::Frame(f))) => {
764 last_received.store(current_millis(), Ordering::SeqCst);
765 // Dispatch MESSAGE frames to any matching subscribers.
766 if f.command == "MESSAGE" {
767 // try to find destination, subscription and message-id headers
768 let mut dest_opt: Option<String> = None;
769 let mut sub_opt: Option<String> = None;
770 let mut msg_id_opt: Option<String> = None;
771 for (k, v) in &f.headers {
772 let kl = k.to_lowercase();
773 if kl == "destination" {
774 dest_opt = Some(v.clone());
775 } else if kl == "subscription" {
776 sub_opt = Some(v.clone());
777 } else if kl == "message-id" {
778 msg_id_opt = Some(v.clone());
779 }
780 }
781
782 // Determine whether we need to track this message as pending
783 let mut need_pending = false;
784 if let Some(sub_id) = &sub_opt {
785 let map = subscriptions.lock().await;
786 for (_dest, vec) in map.iter() {
787 for entry in vec.iter() {
788 if &entry.id == sub_id && entry.ack != "auto" {
789 need_pending = true;
790 }
791 }
792 }
793 } else if let Some(dest) = &dest_opt {
794 let map = subscriptions.lock().await;
795 if let Some(vec) = map.get(dest) {
796 for entry in vec.iter() {
797 if entry.ack != "auto" {
798 need_pending = true;
799 break;
800 }
801 }
802 }
803 }
804
805 // If required, add to pending map (per-subscription) before
806 // delivery so ACK/NACK requests from the application can
807 // reference the message. We require a `message-id` header
808 // to track messages; if missing, we cannot support ACK/NACK.
809 if let Some(msg_id) = msg_id_opt.clone().filter(|_| need_pending) {
810 // If the server provided a subscription id in the
811 // MESSAGE, store pending under that subscription.
812 if let Some(sub_id) = &sub_opt {
813 let mut p = pending_clone.lock().await;
814 let q = p
815 .entry(sub_id.clone())
816 .or_insert_with(VecDeque::new);
817 q.push_back((msg_id.clone(), f.clone()));
818 } else if let Some(dest) = &dest_opt {
819 // Destination-based delivery: add the message to
820 // the pending queue for each matching
821 // subscription on that destination.
822 let map = subscriptions.lock().await;
823 if let Some(vec) = map.get(dest) {
824 let mut p = pending_clone.lock().await;
825 for entry in vec.iter() {
826 let q = p
827 .entry(entry.id.clone())
828 .or_insert_with(VecDeque::new);
829 q.push_back((msg_id.clone(), f.clone()));
830 }
831 }
832 }
833 }
834
835 // Deliver to subscribers.
836 if let Some(sub_id) = sub_opt {
837 let mut map = subscriptions.lock().await;
838 for (_dest, vec) in map.iter_mut() {
839 vec.retain(|entry| {
840 if entry.id == sub_id {
841 let _ = entry.sender.try_send(f.clone());
842 true
843 } else {
844 true
845 }
846 });
847 }
848 } else if let Some(dest) = dest_opt {
849 let mut map = subscriptions.lock().await;
850 if let Some(vec) = map.get_mut(&dest) {
851 vec.retain(|entry| entry.sender.try_send(f.clone()).is_ok());
852 }
853 }
854 } else if f.command == "RECEIPT" {
855 // Handle RECEIPT frame: notify any waiting callers
856 if let Some(receipt_id) = f.get_header("receipt-id") {
857 let mut receipts = pending_receipts_clone.lock().await;
858 if let Some(sender) = receipts.remove(receipt_id) {
859 let _ = sender.send(());
860 }
861 }
862 // Don't forward RECEIPT frames to inbound channel
863 continue;
864 }
865
866 let _ = in_tx.send(f).await;
867 }
868 Some(Err(_)) | None => break 'conn,
869 }
870 }
871 _ = hb_tick.tick() => {
872 if let Some(dur) = send_interval {
873 let last = writer_last_sent.load(Ordering::SeqCst);
874 if current_millis().saturating_sub(last) >= dur.as_millis() as u64 {
875 if sink.send(StompItem::Heartbeat).await.is_err() { break 'conn; }
876 writer_last_sent.store(current_millis(), Ordering::SeqCst);
877 }
878 }
879 }
880 _ = async { if let Some(interval) = watchdog_half { tokio::time::sleep(interval).await } else { future::pending::<()>().await } } => {
881 if let Some(recv_dur) = recv_interval {
882 let last = last_received.load(Ordering::SeqCst);
883 if current_millis().saturating_sub(last) > (recv_dur.as_millis() as u64 * 2) {
884 let _ = sink.close().await; break 'conn;
885 }
886 }
887 }
888 }
889 }
890
891 if shutdown_sub.try_recv().is_ok() {
892 break;
893 }
894 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
895 backoff_secs = (backoff_secs * 2).min(30);
896 }
897 Err(_) => {
898 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
899 backoff_secs = (backoff_secs * 2).min(30);
900 }
901 }
902 }
903 });
904
905 Ok(Connection {
906 outbound_tx: out_tx,
907 inbound_rx: Arc::new(Mutex::new(in_rx)),
908 shutdown_tx,
909 subscriptions,
910 sub_id_counter,
911 pending,
912 pending_receipts,
913 })
914 }
915
916 pub async fn send_frame(&self, frame: Frame) -> Result<(), ConnError> {
917 // Send a frame to the background writer task.
918 //
919 // Parameters
920 // - `frame`: ownership of the `Frame` to send. The frame is converted
921 // into a `StompItem::Frame` and sent over the internal mpsc channel.
922 self.outbound_tx
923 .send(StompItem::Frame(frame))
924 .await
925 .map_err(|_| ConnError::Protocol("send channel closed".into()))
926 }
927
928 /// Send a frame with a receipt request and return the receipt ID.
929 ///
930 /// This method adds a unique `receipt` header to the frame and registers
931 /// the receipt ID for tracking. Use `wait_for_receipt()` to wait for the
932 /// server's RECEIPT response.
933 ///
934 /// # Parameters
935 /// - `frame`: the frame to send. A `receipt` header will be added.
936 ///
937 /// # Returns
938 /// The generated receipt ID that can be used with `wait_for_receipt()`.
939 ///
940 /// # Example
941 /// ```ignore
942 /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
943 /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
944 /// ```
945 pub async fn send_frame_with_receipt(&self, frame: Frame) -> Result<String, ConnError> {
946 use std::sync::atomic::AtomicU64;
947 use std::sync::atomic::Ordering::SeqCst;
948
949 // Generate a unique receipt ID using a static counter
950 static RECEIPT_COUNTER: AtomicU64 = AtomicU64::new(1);
951 let receipt_id = format!("rcpt-{}", RECEIPT_COUNTER.fetch_add(1, SeqCst));
952
953 // Create the oneshot channel for notification
954 let (tx, _rx) = oneshot::channel();
955
956 // Register the pending receipt
957 {
958 let mut receipts = self.pending_receipts.lock().await;
959 receipts.insert(receipt_id.clone(), tx);
960 }
961
962 // Add receipt header and send the frame
963 let frame_with_receipt = frame.receipt(&receipt_id);
964 self.send_frame(frame_with_receipt).await?;
965
966 Ok(receipt_id)
967 }
968
969 /// Wait for a receipt confirmation from the server.
970 ///
971 /// This method blocks until the server sends a RECEIPT frame with the
972 /// matching receipt-id, or until the timeout expires.
973 ///
974 /// # Parameters
975 /// - `receipt_id`: the receipt ID returned by `send_frame_with_receipt()`.
976 /// - `timeout`: maximum time to wait for the receipt.
977 ///
978 /// # Returns
979 /// `Ok(())` if the receipt was received, or `Err(ConnError::ReceiptTimeout)`
980 /// if the timeout expired.
981 ///
982 /// # Example
983 /// ```ignore
984 /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
985 /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
986 /// println!("Message confirmed!");
987 /// ```
988 pub async fn wait_for_receipt(
989 &self,
990 receipt_id: &str,
991 timeout: Duration,
992 ) -> Result<(), ConnError> {
993 // Get the receiver for this receipt
994 let rx = {
995 let mut receipts = self.pending_receipts.lock().await;
996 // Re-create the oneshot channel and swap out the sender
997 let (tx, rx) = oneshot::channel();
998 if let Some(old_tx) = receipts.insert(receipt_id.to_string(), tx) {
999 // Drop the old sender - this is expected if called after send_frame_with_receipt
1000 drop(old_tx);
1001 }
1002 rx
1003 };
1004
1005 // Wait for the receipt with timeout
1006 match tokio::time::timeout(timeout, rx).await {
1007 Ok(Ok(())) => Ok(()),
1008 Ok(Err(_)) => {
1009 // Channel was closed without receiving - connection likely dropped
1010 Err(ConnError::Protocol(
1011 "receipt channel closed unexpectedly".into(),
1012 ))
1013 }
1014 Err(_) => {
1015 // Timeout expired - clean up the pending receipt
1016 let mut receipts = self.pending_receipts.lock().await;
1017 receipts.remove(receipt_id);
1018 Err(ConnError::ReceiptTimeout(receipt_id.to_string()))
1019 }
1020 }
1021 }
1022
1023 /// Send a frame and wait for server confirmation via RECEIPT.
1024 ///
1025 /// This is a convenience method that combines `send_frame_with_receipt()`
1026 /// and `wait_for_receipt()`. Use this when you want to ensure a frame
1027 /// was processed by the server before continuing.
1028 ///
1029 /// # Parameters
1030 /// - `frame`: the frame to send.
1031 /// - `timeout`: maximum time to wait for the receipt.
1032 ///
1033 /// # Returns
1034 /// `Ok(())` if the frame was sent and receipt confirmed, or an error if
1035 /// sending failed or the receipt timed out.
1036 ///
1037 /// # Example
1038 /// ```ignore
1039 /// let frame = Frame::new("SEND")
1040 /// .header("destination", "/queue/orders")
1041 /// .set_body(b"order data".to_vec());
1042 ///
1043 /// conn.send_frame_confirmed(frame, Duration::from_secs(5)).await?;
1044 /// println!("Order sent and confirmed!");
1045 /// ```
1046 pub async fn send_frame_confirmed(
1047 &self,
1048 frame: Frame,
1049 timeout: Duration,
1050 ) -> Result<(), ConnError> {
1051 // Generate receipt ID and register before sending
1052 use std::sync::atomic::AtomicU64;
1053 use std::sync::atomic::Ordering::SeqCst;
1054
1055 static RECEIPT_COUNTER: AtomicU64 = AtomicU64::new(1);
1056 let receipt_id = format!("rcpt-{}", RECEIPT_COUNTER.fetch_add(1, SeqCst));
1057
1058 // Create the oneshot channel for notification
1059 let (tx, rx) = oneshot::channel();
1060
1061 // Register the pending receipt before sending
1062 {
1063 let mut receipts = self.pending_receipts.lock().await;
1064 receipts.insert(receipt_id.clone(), tx);
1065 }
1066
1067 // Add receipt header and send the frame
1068 let frame_with_receipt = frame.receipt(&receipt_id);
1069 self.send_frame(frame_with_receipt).await?;
1070
1071 // Wait for the receipt with timeout
1072 match tokio::time::timeout(timeout, rx).await {
1073 Ok(Ok(())) => Ok(()),
1074 Ok(Err(_)) => Err(ConnError::Protocol(
1075 "receipt channel closed unexpectedly".into(),
1076 )),
1077 Err(_) => {
1078 // Timeout expired - clean up
1079 let mut receipts = self.pending_receipts.lock().await;
1080 receipts.remove(&receipt_id);
1081 Err(ConnError::ReceiptTimeout(receipt_id))
1082 }
1083 }
1084 }
1085
1086 /// Subscribe to a destination.
1087 ///
1088 /// Parameters
1089 /// - `destination`: the STOMP destination to subscribe to (e.g. "/queue/foo").
1090 /// - `ack`: acknowledgement mode to request from the server.
1091 ///
1092 /// Returns a tuple `(subscription_id, receiver)` where `subscription_id` is
1093 /// the opaque id assigned locally for this subscription and `receiver` is a
1094 /// `mpsc::Receiver<Frame>` which will yield incoming MESSAGE frames for the
1095 /// destination. The caller should read from the receiver to handle messages.
1096 /// Subscribe to a destination using optional extra headers.
1097 ///
1098 /// This variant accepts additional headers which are stored locally and
1099 /// re-sent on reconnect. Use `subscribe` as a convenience wrapper when no
1100 /// extra headers are needed.
1101 pub async fn subscribe_with_headers(
1102 &self,
1103 destination: &str,
1104 ack: AckMode,
1105 extra_headers: Vec<(String, String)>,
1106 ) -> Result<crate::subscription::Subscription, ConnError> {
1107 let id = self
1108 .sub_id_counter
1109 .fetch_add(1, Ordering::SeqCst)
1110 .to_string();
1111 let (tx, rx) = mpsc::channel::<Frame>(16);
1112 {
1113 let mut map = self.subscriptions.lock().await;
1114 map.entry(destination.to_string())
1115 .or_insert_with(Vec::new)
1116 .push(SubscriptionEntry {
1117 id: id.clone(),
1118 sender: tx.clone(),
1119 ack: ack.as_str().to_string(),
1120 headers: extra_headers.clone(),
1121 });
1122 }
1123
1124 let mut f = Frame::new("SUBSCRIBE");
1125 f = f
1126 .header("id", &id)
1127 .header("destination", destination)
1128 .header("ack", ack.as_str());
1129 for (k, v) in &extra_headers {
1130 f = f.header(k, v);
1131 }
1132 self.outbound_tx
1133 .send(StompItem::Frame(f))
1134 .await
1135 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1136
1137 Ok(crate::subscription::Subscription::new(
1138 id,
1139 destination.to_string(),
1140 rx,
1141 self.clone(),
1142 ))
1143 }
1144
1145 /// Convenience wrapper without extra headers.
1146 pub async fn subscribe(
1147 &self,
1148 destination: &str,
1149 ack: AckMode,
1150 ) -> Result<crate::subscription::Subscription, ConnError> {
1151 self.subscribe_with_headers(destination, ack, Vec::new())
1152 .await
1153 }
1154
1155 /// Subscribe with a typed `SubscriptionOptions` structure.
1156 ///
1157 /// `SubscriptionOptions.headers` are forwarded to the broker and persisted
1158 /// for automatic resubscribe after reconnect. If `durable_queue` is set,
1159 /// it will be used as the actual destination instead of `destination`.
1160 pub async fn subscribe_with_options(
1161 &self,
1162 destination: &str,
1163 ack: AckMode,
1164 options: crate::subscription::SubscriptionOptions,
1165 ) -> Result<crate::subscription::Subscription, ConnError> {
1166 let dest = options
1167 .durable_queue
1168 .as_deref()
1169 .unwrap_or(destination)
1170 .to_string();
1171 self.subscribe_with_headers(&dest, ack, options.headers)
1172 .await
1173 }
1174
1175 /// Unsubscribe a previously created subscription by its local subscription id.
1176 pub async fn unsubscribe(&self, subscription_id: &str) -> Result<(), ConnError> {
1177 let mut found = false;
1178 {
1179 let mut map = self.subscriptions.lock().await;
1180 let mut remove_keys: Vec<String> = Vec::new();
1181 for (dest, vec) in map.iter_mut() {
1182 if let Some(pos) = vec.iter().position(|entry| entry.id == subscription_id) {
1183 vec.remove(pos);
1184 found = true;
1185 }
1186 if vec.is_empty() {
1187 remove_keys.push(dest.clone());
1188 }
1189 }
1190 for k in remove_keys {
1191 map.remove(&k);
1192 }
1193 }
1194
1195 if !found {
1196 return Err(ConnError::Protocol("subscription id not found".into()));
1197 }
1198
1199 let mut f = Frame::new("UNSUBSCRIBE");
1200 f = f.header("id", subscription_id);
1201 self.outbound_tx
1202 .send(StompItem::Frame(f))
1203 .await
1204 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1205
1206 Ok(())
1207 }
1208
1209 /// Acknowledge a message previously received in `client` or
1210 /// `client-individual` ack modes.
1211 ///
1212 /// STOMP ack semantics:
1213 /// - `auto`: server considers message delivered immediately; the client
1214 /// should not ack.
1215 /// - `client`: cumulative acknowledgements. ACKing message `M` for
1216 /// subscription `S` acknowledges all messages delivered to `S` up to
1217 /// and including `M`.
1218 /// - `client-individual`: only the named message is acknowledged.
1219 ///
1220 /// Parameters
1221 /// - `subscription_id`: the local subscription id returned by
1222 /// `Connection::subscribe`. This disambiguates which subscription's
1223 /// pending queue to advance for cumulative ACKs.
1224 /// - `message_id`: the `message-id` header value from the received
1225 /// MESSAGE frame to acknowledge.
1226 ///
1227 /// Behavior
1228 /// - The pending queue for `subscription_id` is searched for `message_id`.
1229 /// If the subscription used `client` ack mode, all pending messages up to
1230 /// and including the matched message are removed. If the subscription
1231 /// used `client-individual`, only the matched message is removed.
1232 /// - An `ACK` frame is sent to the server with `id=<message_id>` and
1233 /// `subscription=<subscription_id>` headers.
1234 #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1235 pub async fn ack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1236 // Remove from the local pending queue according to subscription ack mode.
1237 let mut removed_any = false;
1238 {
1239 let mut p = self.pending.lock().await;
1240 if let Some(queue) = p.get_mut(subscription_id) {
1241 if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1242 // Determine ack mode for this subscription (default to client).
1243 let mut ack_mode = "client".to_string();
1244 {
1245 let map = self.subscriptions.lock().await;
1246 'outer: for (_dest, vec) in map.iter() {
1247 for entry in vec.iter() {
1248 if entry.id == subscription_id {
1249 ack_mode = entry.ack.clone();
1250 break 'outer;
1251 }
1252 }
1253 }
1254 }
1255
1256 if ack_mode == "client" {
1257 // cumulative: remove up to and including pos
1258 for _ in 0..=pos {
1259 queue.pop_front();
1260 removed_any = true;
1261 }
1262 } else if queue.remove(pos).is_some() {
1263 // client-individual: remove only the specific message
1264 removed_any = true;
1265 }
1266
1267 if queue.is_empty() {
1268 p.remove(subscription_id);
1269 }
1270 }
1271 }
1272 }
1273
1274 // Send ACK to server (include subscription header for clarity)
1275 let mut f = Frame::new("ACK");
1276 f = f
1277 .header("id", message_id)
1278 .header("subscription", subscription_id);
1279 self.outbound_tx
1280 .send(StompItem::Frame(f))
1281 .await
1282 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1283
1284 // If message wasn't found locally, still send ACK to server; server
1285 // may ignore or treat it as no-op.
1286 let _ = removed_any;
1287 Ok(())
1288 }
1289
1290 /// Negative-acknowledge a message (NACK).
1291 ///
1292 /// Parameters
1293 /// - `subscription_id`: the local subscription id the message was delivered under.
1294 /// - `message_id`: the `message-id` header value from the received MESSAGE.
1295 ///
1296 /// Behavior
1297 /// - Removes the message from the local pending queue (cumulatively if the
1298 /// subscription used `client` ack mode, otherwise only the single
1299 /// message). Sends a `NACK` frame to the server with `id` and
1300 /// `subscription` headers.
1301 #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1302 pub async fn nack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1303 // Mirror ack removal semantics for pending map.
1304 let mut removed_any = false;
1305 {
1306 let mut p = self.pending.lock().await;
1307 if let Some(queue) = p.get_mut(subscription_id) {
1308 if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1309 let mut ack_mode = "client".to_string();
1310 {
1311 let map = self.subscriptions.lock().await;
1312 'outer2: for (_dest, vec) in map.iter() {
1313 for entry in vec.iter() {
1314 if entry.id == subscription_id {
1315 ack_mode = entry.ack.clone();
1316 break 'outer2;
1317 }
1318 }
1319 }
1320 }
1321
1322 if ack_mode == "client" {
1323 for _ in 0..=pos {
1324 queue.pop_front();
1325 removed_any = true;
1326 }
1327 } else if queue.remove(pos).is_some() {
1328 removed_any = true;
1329 }
1330
1331 if queue.is_empty() {
1332 p.remove(subscription_id);
1333 }
1334 }
1335 }
1336 }
1337
1338 let mut f = Frame::new("NACK");
1339 f = f
1340 .header("id", message_id)
1341 .header("subscription", subscription_id);
1342 self.outbound_tx
1343 .send(StompItem::Frame(f))
1344 .await
1345 .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1346
1347 let _ = removed_any;
1348 Ok(())
1349 }
1350
1351 /// Helper to send a transaction frame (BEGIN, COMMIT, or ABORT).
1352 async fn send_transaction_frame(
1353 &self,
1354 command: &str,
1355 transaction_id: &str,
1356 ) -> Result<(), ConnError> {
1357 let f = Frame::new(command).header("transaction", transaction_id);
1358 self.outbound_tx
1359 .send(StompItem::Frame(f))
1360 .await
1361 .map_err(|_| ConnError::Protocol("send channel closed".into()))
1362 }
1363
1364 /// Begin a transaction.
1365 ///
1366 /// Parameters
1367 /// - `transaction_id`: unique identifier for the transaction. The caller is
1368 /// responsible for ensuring uniqueness within the connection.
1369 ///
1370 /// Behavior
1371 /// - Sends a `BEGIN` frame to the server with `transaction:<transaction_id>`
1372 /// header. Subsequent `SEND`, `ACK`, and `NACK` frames may include this
1373 /// transaction id to group them into the transaction. The transaction must
1374 /// be finalized with either `commit` or `abort`.
1375 pub async fn begin(&self, transaction_id: &str) -> Result<(), ConnError> {
1376 self.send_transaction_frame("BEGIN", transaction_id).await
1377 }
1378
1379 /// Commit a transaction.
1380 ///
1381 /// Parameters
1382 /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1383 ///
1384 /// Behavior
1385 /// - Sends a `COMMIT` frame to the server with `transaction:<transaction_id>`
1386 /// header. All operations within the transaction are applied atomically.
1387 pub async fn commit(&self, transaction_id: &str) -> Result<(), ConnError> {
1388 self.send_transaction_frame("COMMIT", transaction_id).await
1389 }
1390
1391 /// Abort a transaction.
1392 ///
1393 /// Parameters
1394 /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1395 ///
1396 /// Behavior
1397 /// - Sends an `ABORT` frame to the server with `transaction:<transaction_id>`
1398 /// header. All operations within the transaction are discarded.
1399 pub async fn abort(&self, transaction_id: &str) -> Result<(), ConnError> {
1400 self.send_transaction_frame("ABORT", transaction_id).await
1401 }
1402
1403 /// Receive the next frame from the server.
1404 ///
1405 /// Returns `Some(ReceivedFrame::Frame(..))` for normal frames (MESSAGE, etc.),
1406 /// `Some(ReceivedFrame::Error(..))` for ERROR frames, or `None` if the
1407 /// connection has been closed.
1408 ///
1409 /// # Example
1410 ///
1411 /// ```ignore
1412 /// use iridium_stomp::ReceivedFrame;
1413 ///
1414 /// while let Some(received) = conn.next_frame().await {
1415 /// match received {
1416 /// ReceivedFrame::Frame(frame) => {
1417 /// println!("Got {}: {:?}", frame.command, frame.body);
1418 /// }
1419 /// ReceivedFrame::Error(err) => {
1420 /// eprintln!("Server error: {}", err);
1421 /// break;
1422 /// }
1423 /// }
1424 /// }
1425 /// ```
1426 pub async fn next_frame(&self) -> Option<ReceivedFrame> {
1427 let mut rx = self.inbound_rx.lock().await;
1428 let frame = rx.recv().await?;
1429
1430 // Convert ERROR frames to ServerError for better ergonomics
1431 if frame.command == "ERROR" {
1432 Some(ReceivedFrame::Error(ServerError::from_frame(frame)))
1433 } else {
1434 Some(ReceivedFrame::Frame(frame))
1435 }
1436 }
1437
1438 pub async fn close(self) {
1439 // Signal the background task to shutdown by broadcasting on the
1440 // shutdown channel. Consumers may await task termination separately
1441 // if needed.
1442 let _ = self.shutdown_tx.send(());
1443 }
1444}
1445
1446fn current_millis() -> u64 {
1447 use std::time::{SystemTime, UNIX_EPOCH};
1448 SystemTime::now()
1449 .duration_since(UNIX_EPOCH)
1450 .map(|d| d.as_millis() as u64)
1451 .unwrap_or(0)
1452}
1453
1454#[cfg(test)]
1455mod tests {
1456 use super::*;
1457 use tokio::sync::mpsc;
1458
1459 // Helper to build a MESSAGE frame with given message-id and subscription/destination headers
1460 fn make_message(
1461 message_id: &str,
1462 subscription: Option<&str>,
1463 destination: Option<&str>,
1464 ) -> Frame {
1465 let mut f = Frame::new("MESSAGE");
1466 f = f.header("message-id", message_id);
1467 if let Some(s) = subscription {
1468 f = f.header("subscription", s);
1469 }
1470 if let Some(d) = destination {
1471 f = f.header("destination", d);
1472 }
1473 f
1474 }
1475
1476 #[tokio::test]
1477 async fn test_cumulative_ack_removes_prefix() {
1478 // setup channels
1479 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1480 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1481 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1482
1483 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1484 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1485
1486 let sub_id_counter = Arc::new(AtomicU64::new(1));
1487
1488 // create a subscription entry s1 with client (cumulative) ack
1489 let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1490 {
1491 let mut map = subscriptions.lock().await;
1492 map.insert(
1493 "/queue/x".to_string(),
1494 vec![SubscriptionEntry {
1495 id: "s1".to_string(),
1496 sender: sub_sender,
1497 ack: "client".to_string(),
1498 headers: Vec::new(),
1499 }],
1500 );
1501 }
1502
1503 // fill pending queue for s1: m1,m2,m3
1504 {
1505 let mut p = pending.lock().await;
1506 let mut q = VecDeque::new();
1507 q.push_back((
1508 "m1".to_string(),
1509 make_message("m1", Some("s1"), Some("/queue/x")),
1510 ));
1511 q.push_back((
1512 "m2".to_string(),
1513 make_message("m2", Some("s1"), Some("/queue/x")),
1514 ));
1515 q.push_back((
1516 "m3".to_string(),
1517 make_message("m3", Some("s1"), Some("/queue/x")),
1518 ));
1519 p.insert("s1".to_string(), q);
1520 }
1521
1522 let conn = Connection {
1523 outbound_tx: out_tx,
1524 inbound_rx: Arc::new(Mutex::new(in_rx)),
1525 shutdown_tx,
1526 subscriptions: subscriptions.clone(),
1527 sub_id_counter,
1528 pending: pending.clone(),
1529 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1530 };
1531
1532 // ack m2 cumulatively: should remove m1 and m2, leaving m3
1533 conn.ack("s1", "m2").await.expect("ack failed");
1534
1535 // verify pending for s1 contains only m3
1536 {
1537 let p = pending.lock().await;
1538 let q = p.get("s1").expect("missing s1");
1539 assert_eq!(q.len(), 1);
1540 assert_eq!(q.front().unwrap().0, "m3");
1541 }
1542
1543 // verify an ACK frame was emitted
1544 if let Some(item) = out_rx.recv().await {
1545 match item {
1546 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1547 _ => panic!("expected frame"),
1548 }
1549 } else {
1550 panic!("no outbound frame sent")
1551 }
1552 }
1553
1554 #[tokio::test]
1555 async fn test_client_individual_ack_removes_only_one() {
1556 // setup channels
1557 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1558 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1559 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1560
1561 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1562 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1563
1564 let sub_id_counter = Arc::new(AtomicU64::new(1));
1565
1566 // create a subscription entry s2 with client-individual ack
1567 let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1568 {
1569 let mut map = subscriptions.lock().await;
1570 map.insert(
1571 "/queue/y".to_string(),
1572 vec![SubscriptionEntry {
1573 id: "s2".to_string(),
1574 sender: sub_sender,
1575 ack: "client-individual".to_string(),
1576 headers: Vec::new(),
1577 }],
1578 );
1579 }
1580
1581 // fill pending queue for s2: a,b,c
1582 {
1583 let mut p = pending.lock().await;
1584 let mut q = VecDeque::new();
1585 q.push_back((
1586 "a".to_string(),
1587 make_message("a", Some("s2"), Some("/queue/y")),
1588 ));
1589 q.push_back((
1590 "b".to_string(),
1591 make_message("b", Some("s2"), Some("/queue/y")),
1592 ));
1593 q.push_back((
1594 "c".to_string(),
1595 make_message("c", Some("s2"), Some("/queue/y")),
1596 ));
1597 p.insert("s2".to_string(), q);
1598 }
1599
1600 let conn = Connection {
1601 outbound_tx: out_tx,
1602 inbound_rx: Arc::new(Mutex::new(in_rx)),
1603 shutdown_tx,
1604 subscriptions: subscriptions.clone(),
1605 sub_id_counter,
1606 pending: pending.clone(),
1607 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1608 };
1609
1610 // ack only 'b' individually
1611 conn.ack("s2", "b").await.expect("ack failed");
1612
1613 // verify pending for s2 contains a and c
1614 {
1615 let p = pending.lock().await;
1616 let q = p.get("s2").expect("missing s2");
1617 assert_eq!(q.len(), 2);
1618 assert_eq!(q[0].0, "a");
1619 assert_eq!(q[1].0, "c");
1620 }
1621
1622 // verify an ACK frame was emitted
1623 if let Some(item) = out_rx.recv().await {
1624 match item {
1625 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1626 _ => panic!("expected frame"),
1627 }
1628 } else {
1629 panic!("no outbound frame sent")
1630 }
1631 }
1632
1633 #[tokio::test]
1634 async fn test_subscription_receive_delivers_message() {
1635 // setup channels
1636 let (out_tx, _out_rx) = mpsc::channel::<StompItem>(8);
1637 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1638 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1639
1640 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1641 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1642
1643 let sub_id_counter = Arc::new(AtomicU64::new(1));
1644
1645 let conn = Connection {
1646 outbound_tx: out_tx,
1647 inbound_rx: Arc::new(Mutex::new(in_rx)),
1648 shutdown_tx,
1649 subscriptions: subscriptions.clone(),
1650 sub_id_counter,
1651 pending: pending.clone(),
1652 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1653 };
1654
1655 // subscribe
1656 let subscription = conn
1657 .subscribe("/queue/test", AckMode::Auto)
1658 .await
1659 .expect("subscribe failed");
1660
1661 // find the sender stored in the subscriptions map and push a message
1662 {
1663 let map = conn.subscriptions.lock().await;
1664 let vec = map.get("/queue/test").expect("missing subscription vec");
1665 let sender = &vec[0].sender;
1666 let f = make_message("m1", Some(&vec[0].id), Some("/queue/test"));
1667 sender.try_send(f).expect("send to subscription failed");
1668 }
1669
1670 // consume from the subscription receiver
1671 let mut rx = subscription.into_receiver();
1672 if let Some(received) = rx.recv().await {
1673 assert_eq!(received.command, "MESSAGE");
1674 // message-id header should be present
1675 let mut found = false;
1676 for (k, _v) in &received.headers {
1677 if k.to_lowercase() == "message-id" {
1678 found = true;
1679 break;
1680 }
1681 }
1682 assert!(found, "message-id header missing");
1683 } else {
1684 panic!("no message received on subscription")
1685 }
1686 }
1687
1688 #[tokio::test]
1689 async fn test_subscription_ack_removes_pending_and_sends_ack() {
1690 // setup channels
1691 let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1692 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1693 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1694
1695 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1696 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1697
1698 let sub_id_counter = Arc::new(AtomicU64::new(1));
1699
1700 let conn = Connection {
1701 outbound_tx: out_tx,
1702 inbound_rx: Arc::new(Mutex::new(in_rx)),
1703 shutdown_tx,
1704 subscriptions: subscriptions.clone(),
1705 sub_id_counter,
1706 pending: pending.clone(),
1707 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1708 };
1709
1710 // subscribe with client ack
1711 let subscription = conn
1712 .subscribe("/queue/ack", AckMode::Client)
1713 .await
1714 .expect("subscribe failed");
1715
1716 let sub_id = subscription.id().to_string();
1717
1718 // drain any initial outbound frames (SUBSCRIBE) emitted by subscribe()
1719 while out_rx.try_recv().is_ok() {}
1720
1721 // populate pending queue for this subscription
1722 {
1723 let mut p = conn.pending.lock().await;
1724 let mut q = VecDeque::new();
1725 q.push_back((
1726 "mid-1".to_string(),
1727 make_message("mid-1", Some(&sub_id), Some("/queue/ack")),
1728 ));
1729 p.insert(sub_id.clone(), q);
1730 }
1731
1732 // ack the message via the subscription helper
1733 subscription.ack("mid-1").await.expect("ack failed");
1734
1735 // ensure pending queue no longer contains the message
1736 {
1737 let p = conn.pending.lock().await;
1738 assert!(p.get(&sub_id).is_none() || p.get(&sub_id).unwrap().is_empty());
1739 }
1740
1741 // verify an ACK frame was emitted
1742 if let Some(item) = out_rx.recv().await {
1743 match item {
1744 StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1745 _ => panic!("expected frame"),
1746 }
1747 } else {
1748 panic!("no outbound frame sent")
1749 }
1750 }
1751
1752 // Helper function to create a test connection and output receiver
1753 fn setup_test_connection() -> (Connection, mpsc::Receiver<StompItem>) {
1754 let (out_tx, out_rx) = mpsc::channel::<StompItem>(8);
1755 let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1756 let (shutdown_tx, _) = broadcast::channel::<()>(1);
1757
1758 let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1759 let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1760 let sub_id_counter = Arc::new(AtomicU64::new(1));
1761
1762 let conn = Connection {
1763 outbound_tx: out_tx,
1764 inbound_rx: Arc::new(Mutex::new(in_rx)),
1765 shutdown_tx,
1766 subscriptions,
1767 sub_id_counter,
1768 pending,
1769 pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1770 };
1771
1772 (conn, out_rx)
1773 }
1774
1775 // Helper function to verify a frame with a transaction header
1776 fn verify_transaction_frame(frame: Frame, expected_command: &str, expected_tx_id: &str) {
1777 assert_eq!(frame.command, expected_command);
1778 assert!(
1779 frame
1780 .headers
1781 .iter()
1782 .any(|(k, v)| k == "transaction" && v == expected_tx_id),
1783 "transaction header with id '{}' not found",
1784 expected_tx_id
1785 );
1786 }
1787
1788 #[tokio::test]
1789 async fn test_begin_transaction_sends_frame() {
1790 let (conn, mut out_rx) = setup_test_connection();
1791
1792 conn.begin("tx1").await.expect("begin failed");
1793
1794 // verify BEGIN frame was emitted
1795 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1796 verify_transaction_frame(f, "BEGIN", "tx1");
1797 } else {
1798 panic!("no outbound frame sent")
1799 }
1800 }
1801
1802 #[tokio::test]
1803 async fn test_commit_transaction_sends_frame() {
1804 let (conn, mut out_rx) = setup_test_connection();
1805
1806 conn.commit("tx1").await.expect("commit failed");
1807
1808 // verify COMMIT frame was emitted
1809 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1810 verify_transaction_frame(f, "COMMIT", "tx1");
1811 } else {
1812 panic!("no outbound frame sent")
1813 }
1814 }
1815
1816 #[tokio::test]
1817 async fn test_abort_transaction_sends_frame() {
1818 let (conn, mut out_rx) = setup_test_connection();
1819
1820 conn.abort("tx1").await.expect("abort failed");
1821
1822 // verify ABORT frame was emitted
1823 if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1824 verify_transaction_frame(f, "ABORT", "tx1");
1825 } else {
1826 panic!("no outbound frame sent")
1827 }
1828 }
1829}