iridium_stomp/
connection.rs

1use futures::{SinkExt, StreamExt, future};
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6use thiserror::Error;
7use tokio::net::TcpStream;
8use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
9use tokio_util::codec::Framed;
10
11use crate::codec::{StompCodec, StompItem};
12use crate::frame::Frame;
13
14/// Configuration for STOMP heartbeat intervals.
15///
16/// Provides a type-safe way to configure heartbeat values instead of using
17/// raw strings. The `Display` implementation formats the value as required
18/// by the STOMP protocol ("send_ms,receive_ms").
19///
20/// # Example
21///
22/// ```
23/// use iridium_stomp::Heartbeat;
24///
25/// // Create a custom heartbeat configuration
26/// let hb = Heartbeat::new(5000, 10000);
27/// assert_eq!(hb.to_string(), "5000,10000");
28///
29/// // Use predefined configurations
30/// assert_eq!(Heartbeat::disabled().to_string(), "0,0");
31/// assert_eq!(Heartbeat::default().to_string(), "10000,10000");
32/// ```
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct Heartbeat {
35    /// Minimum interval (in milliseconds) between heartbeats the client can send.
36    /// A value of 0 means the client cannot send heartbeats.
37    pub send_ms: u32,
38
39    /// Minimum interval (in milliseconds) between heartbeats the client wants to receive.
40    /// A value of 0 means the client does not want to receive heartbeats.
41    pub receive_ms: u32,
42}
43
44impl Heartbeat {
45    /// Create a new heartbeat configuration with the specified intervals.
46    ///
47    /// # Arguments
48    ///
49    /// * `send_ms` - Minimum interval in milliseconds between heartbeats the client can send.
50    /// * `receive_ms` - Minimum interval in milliseconds between heartbeats the client wants to receive.
51    ///
52    /// # Example
53    ///
54    /// ```
55    /// use iridium_stomp::Heartbeat;
56    ///
57    /// let hb = Heartbeat::new(5000, 10000);
58    /// assert_eq!(hb.send_ms, 5000);
59    /// assert_eq!(hb.receive_ms, 10000);
60    /// ```
61    pub fn new(send_ms: u32, receive_ms: u32) -> Self {
62        Self {
63            send_ms,
64            receive_ms,
65        }
66    }
67
68    /// Create a heartbeat configuration that disables heartbeats entirely.
69    ///
70    /// This is equivalent to `Heartbeat::new(0, 0)`.
71    ///
72    /// # Example
73    ///
74    /// ```
75    /// use iridium_stomp::Heartbeat;
76    ///
77    /// let hb = Heartbeat::disabled();
78    /// assert_eq!(hb.send_ms, 0);
79    /// assert_eq!(hb.receive_ms, 0);
80    /// assert_eq!(hb.to_string(), "0,0");
81    /// ```
82    pub fn disabled() -> Self {
83        Self::new(0, 0)
84    }
85
86    /// Create a heartbeat configuration from a Duration for symmetric heartbeats.
87    ///
88    /// Both send and receive intervals will be set to the same value.
89    ///
90    /// The maximum supported Duration is approximately 49.7 days (u32::MAX milliseconds,
91    /// or 4,294,967,295 ms). If a larger Duration is provided, it will be clamped to
92    /// u32::MAX milliseconds to prevent overflow.
93    ///
94    /// # Example
95    ///
96    /// ```
97    /// use iridium_stomp::Heartbeat;
98    /// use std::time::Duration;
99    ///
100    /// let hb = Heartbeat::from_duration(Duration::from_secs(15));
101    /// assert_eq!(hb.send_ms, 15000);
102    /// assert_eq!(hb.receive_ms, 15000);
103    /// ```
104    pub fn from_duration(interval: Duration) -> Self {
105        let ms = interval.as_millis().min(u32::MAX as u128) as u32;
106        Self::new(ms, ms)
107    }
108}
109
110impl Default for Heartbeat {
111    /// Returns the default heartbeat configuration: 10 seconds for both send and receive.
112    fn default() -> Self {
113        Self::new(10000, 10000)
114    }
115}
116
117impl std::fmt::Display for Heartbeat {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        write!(f, "{},{}", self.send_ms, self.receive_ms)
120    }
121}
122
123/// Internal subscription entry stored for each destination.
124#[derive(Clone)]
125pub(crate) struct SubscriptionEntry {
126    pub(crate) id: String,
127    pub(crate) sender: mpsc::Sender<Frame>,
128    pub(crate) ack: String,
129    pub(crate) headers: Vec<(String, String)>,
130}
131
132/// Alias for the subscription dispatch map: destination -> list of
133/// `SubscriptionEntry`.
134pub(crate) type Subscriptions = HashMap<String, Vec<SubscriptionEntry>>;
135
136/// Alias for the pending map: subscription_id -> queue of (message-id, Frame).
137pub(crate) type PendingMap = HashMap<String, VecDeque<(String, Frame)>>;
138
139/// Internal type for resubscribe snapshot entries: (destination, id, ack, headers)
140pub(crate) type ResubEntry = (String, String, String, Vec<(String, String)>);
141
142/// Alias for pending receipt map: receipt-id -> oneshot sender to notify when received.
143pub(crate) type PendingReceipts = HashMap<String, oneshot::Sender<()>>;
144
145/// Errors returned by `Connection` operations.
146#[derive(Error, Debug)]
147pub enum ConnError {
148    /// I/O-level error
149    #[error("io error: {0}")]
150    Io(#[from] std::io::Error),
151    /// Protocol-level error
152    #[error("protocol error: {0}")]
153    Protocol(String),
154    /// Receipt timeout error
155    #[error("receipt timeout: no RECEIPT received for '{0}' within timeout")]
156    ReceiptTimeout(String),
157    /// Server rejected the connection (e.g., authentication failure)
158    ///
159    /// This error is returned when the server sends an ERROR frame in response
160    /// to the CONNECT frame. Common causes include invalid credentials,
161    /// unauthorized access, or broker configuration issues.
162    #[error("server rejected connection: {0}")]
163    ServerRejected(ServerError),
164}
165
166/// Represents an ERROR frame received from the STOMP server.
167///
168/// STOMP servers send ERROR frames to indicate protocol violations, authentication
169/// failures, or other server-side errors. After sending an ERROR frame, the server
170/// typically closes the connection.
171///
172/// # Example
173///
174/// ```ignore
175/// use iridium_stomp::ReceivedFrame;
176///
177/// while let Some(received) = conn.next_frame().await {
178///     match received {
179///         ReceivedFrame::Frame(frame) => {
180///             // Normal message processing
181///         }
182///         ReceivedFrame::Error(err) => {
183///             eprintln!("Server error: {}", err.message);
184///             if let Some(body) = &err.body {
185///                 eprintln!("Details: {}", body);
186///             }
187///             break;
188///         }
189///     }
190/// }
191/// ```
192#[derive(Debug, Clone, PartialEq, Eq)]
193pub struct ServerError {
194    /// The error message from the `message` header.
195    pub message: String,
196
197    /// The error body, if present. Contains additional error details.
198    pub body: Option<String>,
199
200    /// The receipt-id if this error is in response to a specific frame.
201    pub receipt_id: Option<String>,
202
203    /// The original ERROR frame for access to additional headers.
204    pub frame: Frame,
205}
206
207impl ServerError {
208    /// Create a `ServerError` from an ERROR frame.
209    ///
210    /// This is primarily used internally but is public for testing and
211    /// advanced use cases where you need to construct a `ServerError` manually.
212    pub fn from_frame(frame: Frame) -> Self {
213        let message = frame
214            .get_header("message")
215            .unwrap_or("unknown error")
216            .to_string();
217
218        let body = if frame.body.is_empty() {
219            None
220        } else {
221            String::from_utf8(frame.body.clone()).ok()
222        };
223
224        let receipt_id = frame.get_header("receipt-id").map(|s| s.to_string());
225
226        Self {
227            message,
228            body,
229            receipt_id,
230            frame,
231        }
232    }
233}
234
235impl std::fmt::Display for ServerError {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        write!(f, "STOMP server error: {}", self.message)?;
238        if let Some(body) = &self.body {
239            write!(f, " - {}", body)?;
240        }
241        Ok(())
242    }
243}
244
245impl std::error::Error for ServerError {}
246
247/// The result of receiving a frame from the server.
248///
249/// STOMP servers can send either normal frames (MESSAGE, RECEIPT, etc.) or
250/// ERROR frames indicating a problem. This enum allows callers to handle
251/// both cases with pattern matching.
252///
253/// # Example
254///
255/// ```ignore
256/// use iridium_stomp::ReceivedFrame;
257///
258/// match conn.next_frame().await {
259///     Some(ReceivedFrame::Frame(frame)) => {
260///         println!("Got frame: {}", frame.command);
261///     }
262///     Some(ReceivedFrame::Error(err)) => {
263///         eprintln!("Server error: {}", err);
264///     }
265///     None => {
266///         println!("Connection closed");
267///     }
268/// }
269/// ```
270#[derive(Debug, Clone, PartialEq, Eq)]
271pub enum ReceivedFrame {
272    /// A normal STOMP frame (MESSAGE, RECEIPT, etc.)
273    Frame(Frame),
274    /// An ERROR frame from the server
275    Error(ServerError),
276}
277
278impl ReceivedFrame {
279    /// Returns `true` if this is an error frame.
280    pub fn is_error(&self) -> bool {
281        matches!(self, ReceivedFrame::Error(_))
282    }
283
284    /// Returns `true` if this is a normal frame.
285    pub fn is_frame(&self) -> bool {
286        matches!(self, ReceivedFrame::Frame(_))
287    }
288
289    /// Returns the frame if this is a normal frame, or `None` if it's an error.
290    pub fn into_frame(self) -> Option<Frame> {
291        match self {
292            ReceivedFrame::Frame(f) => Some(f),
293            ReceivedFrame::Error(_) => None,
294        }
295    }
296
297    /// Returns the error if this is an error frame, or `None` if it's a normal frame.
298    pub fn into_error(self) -> Option<ServerError> {
299        match self {
300            ReceivedFrame::Frame(_) => None,
301            ReceivedFrame::Error(e) => Some(e),
302        }
303    }
304}
305
306/// Subscription acknowledgement modes as defined by STOMP 1.2.
307#[derive(Debug, Clone, Copy, PartialEq, Eq)]
308pub enum AckMode {
309    Auto,
310    Client,
311    ClientIndividual,
312}
313
314impl AckMode {
315    fn as_str(&self) -> &'static str {
316        match self {
317            AckMode::Auto => "auto",
318            AckMode::Client => "client",
319            AckMode::ClientIndividual => "client-individual",
320        }
321    }
322}
323
324/// Options for customizing the STOMP CONNECT frame.
325///
326/// Use this struct with `Connection::connect_with_options()` to set custom
327/// headers, specify supported STOMP versions, or configure broker-specific
328/// options like `client-id` for durable subscriptions.
329///
330/// # Validation
331///
332/// This struct performs minimal validation. Values are passed to the broker
333/// as-is, and invalid configurations will be rejected by the broker at
334/// connection time. Empty strings are technically accepted but may cause
335/// broker-specific errors.
336///
337/// # Custom Headers
338///
339/// Custom headers added via `header()` cannot override critical STOMP headers
340/// (`accept-version`, `host`, `login`, `passcode`, `heart-beat`, `client-id`).
341/// Such headers are silently ignored. Use the dedicated builder methods to
342/// set these values.
343///
344/// # Example
345///
346/// ```ignore
347/// use iridium_stomp::{Connection, ConnectOptions};
348///
349/// let options = ConnectOptions::default()
350///     .client_id("my-durable-client")
351///     .host("my-vhost")
352///     .header("custom-header", "value");
353///
354/// let conn = Connection::connect_with_options(
355///     "localhost:61613",
356///     "guest",
357///     "guest",
358///     "10000,10000",
359///     options,
360/// ).await?;
361/// ```
362#[derive(Debug, Clone, Default)]
363pub struct ConnectOptions {
364    /// STOMP version(s) to accept (e.g., "1.2" or "1.0,1.1,1.2").
365    /// Defaults to "1.2" if not set.
366    pub accept_version: Option<String>,
367
368    /// Client ID for durable subscriptions (required by ActiveMQ, etc.).
369    pub client_id: Option<String>,
370
371    /// Virtual host header value. Defaults to "/" if not set.
372    pub host: Option<String>,
373
374    /// Additional custom headers to include in the CONNECT frame.
375    /// Note: Headers that would override critical STOMP headers are ignored.
376    pub headers: Vec<(String, String)>,
377}
378
379impl ConnectOptions {
380    /// Create a new `ConnectOptions` with default values.
381    pub fn new() -> Self {
382        Self::default()
383    }
384
385    /// Set the STOMP version(s) to accept (builder style).
386    ///
387    /// Examples: "1.2", "1.1,1.2", "1.0,1.1,1.2"
388    pub fn accept_version(mut self, version: impl Into<String>) -> Self {
389        self.accept_version = Some(version.into());
390        self
391    }
392
393    /// Set the client ID for durable subscriptions (builder style).
394    ///
395    /// Required by some brokers (e.g., ActiveMQ) for durable topic subscriptions.
396    pub fn client_id(mut self, id: impl Into<String>) -> Self {
397        self.client_id = Some(id.into());
398        self
399    }
400
401    /// Set the virtual host (builder style).
402    ///
403    /// Defaults to "/" if not set.
404    pub fn host(mut self, host: impl Into<String>) -> Self {
405        self.host = Some(host.into());
406        self
407    }
408
409    /// Add a custom header to the CONNECT frame (builder style).
410    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
411        self.headers.push((key.into(), value.into()));
412        self
413    }
414}
415
416/// Parse the STOMP `heart-beat` header value (format: "cx,cy").
417///
418/// Parameters
419/// - `header`: header string from the server or client (for example
420///   "10000,10000"). The values represent milliseconds.
421///
422/// Returns a tuple `(cx, cy)` where each value is the heartbeat interval in
423/// milliseconds. Missing or invalid fields default to `0`.
424pub fn parse_heartbeat_header(header: &str) -> (u64, u64) {
425    let mut parts = header.split(',');
426    let cx = parts
427        .next()
428        .and_then(|s| s.trim().parse::<u64>().ok())
429        .unwrap_or(0);
430    let cy = parts
431        .next()
432        .and_then(|s| s.trim().parse::<u64>().ok())
433        .unwrap_or(0);
434    (cx, cy)
435}
436
437/// Negotiate heartbeat intervals between client and server.
438///
439/// Parameters
440/// - `client_out`: client's desired outgoing heartbeat interval in
441///   milliseconds (how often the client will send heartbeats).
442/// - `client_in`: client's desired incoming heartbeat interval in
443///   milliseconds (how often the client expects to receive heartbeats).
444/// - `server_out`: server's advertised outgoing interval in milliseconds.
445/// - `server_in`: server's advertised incoming interval in milliseconds.
446///
447/// Returns `(outgoing, incoming)` where each element is `Some(Duration)` if
448/// heartbeats are enabled in that direction, or `None` if disabled. The
449/// negotiated interval uses the STOMP rule of taking the maximum of the
450/// corresponding client and server values.
451pub fn negotiate_heartbeats(
452    client_out: u64,
453    client_in: u64,
454    server_out: u64,
455    server_in: u64,
456) -> (Option<Duration>, Option<Duration>) {
457    let negotiated_out_ms = std::cmp::max(client_out, server_in);
458    let negotiated_in_ms = std::cmp::max(client_in, server_out);
459
460    let outgoing = if negotiated_out_ms == 0 {
461        None
462    } else {
463        Some(Duration::from_millis(negotiated_out_ms))
464    };
465    let incoming = if negotiated_in_ms == 0 {
466        None
467    } else {
468        Some(Duration::from_millis(negotiated_in_ms))
469    };
470    (outgoing, incoming)
471}
472
473/// High-level connection object that manages a single TCP/STOMP connection.
474///
475/// The `Connection` spawns a background task that maintains the TCP transport,
476/// sends/receives STOMP frames using `StompCodec`, negotiates heartbeats, and
477/// performs simple reconnect logic with exponential backoff.
478#[derive(Clone)]
479pub struct Connection {
480    outbound_tx: mpsc::Sender<StompItem>,
481    /// The inbound receiver is shared behind a mutex so the `Connection`
482    /// handle may be cloned and callers can call `next_frame` concurrently.
483    inbound_rx: Arc<Mutex<mpsc::Receiver<Frame>>>,
484    shutdown_tx: broadcast::Sender<()>,
485    /// Map of destination -> list of (subscription id, sender) for dispatching
486    /// inbound MESSAGE frames to subscribers.
487    subscriptions: Arc<Mutex<Subscriptions>>,
488    /// Monotonic counter used to allocate subscription ids.
489    sub_id_counter: Arc<AtomicU64>,
490    /// Pending messages awaiting ACK/NACK from the application.
491    ///
492    /// Organized by subscription id. For `client` ack mode the ACK is
493    /// cumulative: acknowledging message `M` for subscription `S` acknowledges
494    /// all messages previously delivered for `S` up to and including `M`.
495    /// For `client-individual` the ACK/NACK applies only to the single
496    /// message.
497    pending: Arc<Mutex<PendingMap>>,
498    /// Pending receipt confirmations.
499    ///
500    /// When a frame is sent with a `receipt` header, the receipt-id is stored
501    /// here with a oneshot sender. When the server responds with a RECEIPT
502    /// frame, the sender is notified.
503    pending_receipts: Arc<Mutex<PendingReceipts>>,
504}
505
506impl Connection {
507    /// Heartbeat value that disables heartbeats entirely.
508    ///
509    /// Use this when you don't want the client or server to send heartbeats.
510    /// Note that some brokers may still require heartbeats for long-lived connections.
511    ///
512    /// # Example
513    ///
514    /// ```ignore
515    /// let conn = Connection::connect(
516    ///     "localhost:61613",
517    ///     "guest",
518    ///     "guest",
519    ///     Connection::NO_HEARTBEAT,
520    /// ).await?;
521    /// ```
522    pub const NO_HEARTBEAT: &'static str = "0,0";
523
524    /// Default heartbeat value: 10 seconds for both send and receive.
525    ///
526    /// This is a reasonable default for most applications. The actual heartbeat
527    /// interval will be negotiated with the server (taking the maximum of client
528    /// and server preferences).
529    ///
530    /// # Example
531    ///
532    /// ```ignore
533    /// let conn = Connection::connect(
534    ///     "localhost:61613",
535    ///     "guest",
536    ///     "guest",
537    ///     Connection::DEFAULT_HEARTBEAT,
538    /// ).await?;
539    /// ```
540    pub const DEFAULT_HEARTBEAT: &'static str = "10000,10000";
541
542    /// Establish a connection to the STOMP server at `addr` with the given
543    /// credentials and heartbeat header string (e.g. "10000,10000").
544    ///
545    /// This is a convenience wrapper around `connect_with_options()` that uses
546    /// default options (STOMP 1.2, host="/", no client-id).
547    ///
548    /// Parameters
549    /// - `addr`: TCP address (host:port) of the STOMP server.
550    /// - `login`: login username for STOMP `CONNECT`.
551    /// - `passcode`: passcode for STOMP `CONNECT`.
552    /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
553    ///   milliseconds) that will be sent in the `CONNECT` frame.
554    ///
555    /// Returns a `Connection` which provides `send_frame`, `next_frame`, and
556    /// `close` helpers. The detailed connection handling (I/O, heartbeats,
557    /// reconnects) runs on a background task spawned by this method.
558    pub async fn connect(
559        addr: &str,
560        login: &str,
561        passcode: &str,
562        client_hb: &str,
563    ) -> Result<Self, ConnError> {
564        Self::connect_with_options(addr, login, passcode, client_hb, ConnectOptions::default())
565            .await
566    }
567
568    /// Establish a connection to the STOMP server with custom options.
569    ///
570    /// Use this method when you need to set a custom `client-id` (for durable
571    /// subscriptions), specify a virtual host, negotiate different STOMP
572    /// versions, or add custom CONNECT headers.
573    ///
574    /// Parameters
575    /// - `addr`: TCP address (host:port) of the STOMP server.
576    /// - `login`: login username for STOMP `CONNECT`.
577    /// - `passcode`: passcode for STOMP `CONNECT`.
578    /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
579    ///   milliseconds) that will be sent in the `CONNECT` frame.
580    /// - `options`: custom connection options (version, host, client-id, etc.).
581    ///
582    /// # Errors
583    ///
584    /// Returns an error if:
585    /// - The TCP connection cannot be established (`ConnError::Io`)
586    /// - The server rejects the connection, e.g., due to invalid credentials
587    ///   (`ConnError::ServerRejected`)
588    /// - The server closes the connection without responding (`ConnError::Protocol`)
589    ///
590    /// # Example
591    ///
592    /// ```ignore
593    /// use iridium_stomp::{Connection, ConnectOptions};
594    ///
595    /// // Connect with a client-id for durable subscriptions
596    /// let options = ConnectOptions::default()
597    ///     .client_id("my-app-instance-1");
598    ///
599    /// let conn = Connection::connect_with_options(
600    ///     "localhost:61613",
601    ///     "guest",
602    ///     "guest",
603    ///     "10000,10000",
604    ///     options,
605    /// ).await?;
606    /// ```
607    pub async fn connect_with_options(
608        addr: &str,
609        login: &str,
610        passcode: &str,
611        client_hb: &str,
612        options: ConnectOptions,
613    ) -> Result<Self, ConnError> {
614        let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(32);
615        let (in_tx, in_rx) = mpsc::channel::<Frame>(32);
616        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
617        let sub_id_counter = Arc::new(AtomicU64::new(1));
618        let (shutdown_tx, _) = broadcast::channel::<()>(1);
619        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
620        let pending_clone = pending.clone();
621        let pending_receipts: Arc<Mutex<PendingReceipts>> = Arc::new(Mutex::new(HashMap::new()));
622        let pending_receipts_clone = pending_receipts.clone();
623
624        let addr = addr.to_string();
625        let login = login.to_string();
626        let passcode = passcode.to_string();
627        let client_hb = client_hb.to_string();
628
629        // Extract options into owned values for the spawned task
630        let accept_version = options.accept_version.unwrap_or_else(|| "1.2".to_string());
631        let host = options.host.unwrap_or_else(|| "/".to_string());
632        let client_id = options.client_id;
633        let custom_headers = options.headers;
634
635        // Perform initial connection and STOMP handshake before spawning background task.
636        // This ensures authentication errors are returned to the caller immediately.
637        let stream = TcpStream::connect(&addr).await?;
638        let mut framed = Framed::new(stream, StompCodec::new());
639
640        // Build and send CONNECT frame
641        let connect = Self::build_connect_frame(
642            &accept_version,
643            &host,
644            &login,
645            &passcode,
646            &client_hb,
647            &client_id,
648            &custom_headers,
649        );
650
651        framed
652            .send(StompItem::Frame(connect))
653            .await
654            .map_err(|e| ConnError::Io(std::io::Error::other(e)))?;
655
656        // Wait for CONNECTED or ERROR response
657        let server_heartbeat = Self::await_connected_response(&mut framed).await?;
658
659        // Calculate heartbeat intervals
660        let (cx, cy) = parse_heartbeat_header(&client_hb);
661        let (sx, sy) = parse_heartbeat_header(&server_heartbeat);
662        let (send_interval, recv_interval) = negotiate_heartbeats(cx, cy, sx, sy);
663
664        // Now spawn background task for ongoing I/O and reconnection
665        let shutdown_tx_clone = shutdown_tx.clone();
666        let subscriptions_clone = subscriptions.clone();
667
668        tokio::spawn(async move {
669            let mut backoff_secs: u64 = 1;
670
671            // Use the already-established connection for the first iteration
672            let mut current_framed = Some(framed);
673            let mut current_send_interval = send_interval;
674            let mut current_recv_interval = recv_interval;
675
676            loop {
677                let mut shutdown_sub = shutdown_tx_clone.subscribe();
678
679                // Check for shutdown before attempting connection
680                tokio::select! {
681                    biased;
682                    _ = shutdown_sub.recv() => break,
683                    _ = future::ready(()) => {},
684                }
685
686                // Either use existing connection or establish new one (reconnect)
687                let framed = if let Some(f) = current_framed.take() {
688                    f
689                } else {
690                    // Reconnection attempt
691                    match TcpStream::connect(&addr).await {
692                        Ok(stream) => {
693                            let mut framed = Framed::new(stream, StompCodec::new());
694
695                            let connect = Self::build_connect_frame(
696                                &accept_version,
697                                &host,
698                                &login,
699                                &passcode,
700                                &client_hb,
701                                &client_id,
702                                &custom_headers,
703                            );
704
705                            if framed.send(StompItem::Frame(connect)).await.is_err() {
706                                tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
707                                backoff_secs = (backoff_secs * 2).min(30);
708                                continue;
709                            }
710
711                            // Wait for CONNECTED (on reconnect, silently retry on ERROR)
712                            match Self::await_connected_response(&mut framed).await {
713                                Ok(server_hb) => {
714                                    let (cx, cy) = parse_heartbeat_header(&client_hb);
715                                    let (sx, sy) = parse_heartbeat_header(&server_hb);
716                                    let (si, ri) = negotiate_heartbeats(cx, cy, sx, sy);
717                                    current_send_interval = si;
718                                    current_recv_interval = ri;
719                                    framed
720                                }
721                                Err(_) => {
722                                    // Reconnect failed (auth error or other), retry with backoff
723                                    tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
724                                    backoff_secs = (backoff_secs * 2).min(30);
725                                    continue;
726                                }
727                            }
728                        }
729                        Err(_) => {
730                            tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
731                            backoff_secs = (backoff_secs * 2).min(30);
732                            continue;
733                        }
734                    }
735                };
736
737                let (send_interval, recv_interval) = (current_send_interval, current_recv_interval);
738
739                let last_received = Arc::new(AtomicU64::new(current_millis()));
740                let writer_last_sent = Arc::new(AtomicU64::new(current_millis()));
741
742                let (mut sink, mut stream) = framed.split();
743                let in_tx = in_tx.clone();
744                let subscriptions = subscriptions_clone.clone();
745
746                // Clear pending message map on reconnect — messages that were
747                // outstanding before the disconnect are considered lost and
748                // will be redelivered by the server as appropriate.
749                {
750                    let mut p = pending_clone.lock().await;
751                    p.clear();
752                }
753
754                // Resubscribe any existing subscriptions after reconnect.
755                // We snapshot the subscription entries while holding the lock
756                // and then issue SUBSCRIBE frames using the sink.
757                let subs_snapshot: Vec<ResubEntry> = {
758                    let map = subscriptions.lock().await;
759                    let mut v: Vec<ResubEntry> = Vec::new();
760                    for (dest, vec) in map.iter() {
761                        for entry in vec.iter() {
762                            v.push((
763                                dest.clone(),
764                                entry.id.clone(),
765                                entry.ack.clone(),
766                                entry.headers.clone(),
767                            ));
768                        }
769                    }
770                    v
771                };
772
773                for (dest, id, ack, headers) in subs_snapshot {
774                    let mut sf = Frame::new("SUBSCRIBE");
775                    sf = sf
776                        .header("id", &id)
777                        .header("destination", &dest)
778                        .header("ack", &ack);
779                    for (k, v) in headers {
780                        sf = sf.header(&k, &v);
781                    }
782                    let _ = sink.send(StompItem::Frame(sf)).await;
783                }
784
785                let mut hb_tick = match send_interval {
786                    Some(d) => tokio::time::interval(d),
787                    None => tokio::time::interval(Duration::from_secs(86400)),
788                };
789                let watchdog_half = recv_interval.map(|d| d / 2);
790
791                backoff_secs = 1;
792
793                'conn: loop {
794                    tokio::select! {
795                        _ = shutdown_sub.recv() => { let _ = sink.close().await; break 'conn; }
796                        maybe = out_rx.recv() => {
797                            match maybe {
798                                Some(item) => if sink.send(item).await.is_err() { break 'conn } else { writer_last_sent.store(current_millis(), Ordering::SeqCst); }
799                                None => break 'conn,
800                            }
801                        }
802                        item = stream.next() => {
803                            match item {
804                                Some(Ok(StompItem::Heartbeat)) => { last_received.store(current_millis(), Ordering::SeqCst); }
805                                Some(Ok(StompItem::Frame(f))) => {
806                                    last_received.store(current_millis(), Ordering::SeqCst);
807                                    // Dispatch MESSAGE frames to any matching subscribers.
808                                    if f.command == "MESSAGE" {
809                                        // try to find destination, subscription and message-id headers
810                                        let mut dest_opt: Option<String> = None;
811                                        let mut sub_opt: Option<String> = None;
812                                        let mut msg_id_opt: Option<String> = None;
813                                        for (k, v) in &f.headers {
814                                            let kl = k.to_lowercase();
815                                            if kl == "destination" {
816                                                dest_opt = Some(v.clone());
817                                            } else if kl == "subscription" {
818                                                sub_opt = Some(v.clone());
819                                            } else if kl == "message-id" {
820                                                msg_id_opt = Some(v.clone());
821                                            }
822                                        }
823
824                                        // Determine whether we need to track this message as pending
825                                        let mut need_pending = false;
826                                        if let Some(sub_id) = &sub_opt {
827                                            let map = subscriptions.lock().await;
828                                            for (_dest, vec) in map.iter() {
829                                                for entry in vec.iter() {
830                                                    if &entry.id == sub_id && entry.ack != "auto" {
831                                                        need_pending = true;
832                                                    }
833                                                }
834                                            }
835                                        } else if let Some(dest) = &dest_opt {
836                                            let map = subscriptions.lock().await;
837                                            if let Some(vec) = map.get(dest) {
838                                                for entry in vec.iter() {
839                                                    if entry.ack != "auto" {
840                                                        need_pending = true;
841                                                        break;
842                                                    }
843                                                }
844                                            }
845                                        }
846
847                                        // If required, add to pending map (per-subscription) before
848                                        // delivery so ACK/NACK requests from the application can
849                                        // reference the message. We require a `message-id` header
850                                        // to track messages; if missing, we cannot support ACK/NACK.
851                                        if let Some(msg_id) = msg_id_opt.clone().filter(|_| need_pending) {
852                                            // If the server provided a subscription id in the
853                                            // MESSAGE, store pending under that subscription.
854                                            if let Some(sub_id) = &sub_opt {
855                                                let mut p = pending_clone.lock().await;
856                                                let q = p
857                                                    .entry(sub_id.clone())
858                                                    .or_insert_with(VecDeque::new);
859                                                q.push_back((msg_id.clone(), f.clone()));
860                                            } else if let Some(dest) = &dest_opt {
861                                                // Destination-based delivery: add the message to
862                                                // the pending queue for each matching
863                                                // subscription on that destination.
864                                                let map = subscriptions.lock().await;
865                                                if let Some(vec) = map.get(dest) {
866                                                    let mut p = pending_clone.lock().await;
867                                                    for entry in vec.iter() {
868                                                        let q = p
869                                                            .entry(entry.id.clone())
870                                                            .or_insert_with(VecDeque::new);
871                                                        q.push_back((msg_id.clone(), f.clone()));
872                                                    }
873                                                }
874                                            }
875                                        }
876
877                                        // Deliver to subscribers.
878                                        if let Some(sub_id) = sub_opt {
879                                            let mut map = subscriptions.lock().await;
880                                            for (_dest, vec) in map.iter_mut() {
881                                                vec.retain(|entry| {
882                                                    if entry.id == sub_id {
883                                                        let _ = entry.sender.try_send(f.clone());
884                                                        true
885                                                    } else {
886                                                        true
887                                                    }
888                                                });
889                                            }
890                                        } else if let Some(dest) = dest_opt {
891                                            let mut map = subscriptions.lock().await;
892                                            if let Some(vec) = map.get_mut(&dest) {
893                                                vec.retain(|entry| entry.sender.try_send(f.clone()).is_ok());
894                                            }
895                                        }
896                                    } else if f.command == "RECEIPT" {
897                                        // Handle RECEIPT frame: notify any waiting callers
898                                        if let Some(receipt_id) = f.get_header("receipt-id") {
899                                            let mut receipts = pending_receipts_clone.lock().await;
900                                            if let Some(sender) = receipts.remove(receipt_id) {
901                                                let _ = sender.send(());
902                                            }
903                                        }
904                                        // Don't forward RECEIPT frames to inbound channel
905                                        continue;
906                                    }
907
908                                    let _ = in_tx.send(f).await;
909                                }
910                                Some(Err(_)) | None => break 'conn,
911                            }
912                        }
913                        _ = hb_tick.tick() => {
914                            if let Some(dur) = send_interval {
915                                let last = writer_last_sent.load(Ordering::SeqCst);
916                                if current_millis().saturating_sub(last) >= dur.as_millis() as u64 {
917                                    if sink.send(StompItem::Heartbeat).await.is_err() { break 'conn; }
918                                    writer_last_sent.store(current_millis(), Ordering::SeqCst);
919                                }
920                            }
921                        }
922                        _ = async { if let Some(interval) = watchdog_half { tokio::time::sleep(interval).await } else { future::pending::<()>().await } } => {
923                            if let Some(recv_dur) = recv_interval {
924                                let last = last_received.load(Ordering::SeqCst);
925                                if current_millis().saturating_sub(last) > (recv_dur.as_millis() as u64 * 2) {
926                                    let _ = sink.close().await; break 'conn;
927                                }
928                            }
929                        }
930                    }
931                }
932
933                if shutdown_sub.try_recv().is_ok() {
934                    break;
935                }
936                tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
937                backoff_secs = (backoff_secs * 2).min(30);
938            }
939        });
940
941        Ok(Connection {
942            outbound_tx: out_tx,
943            inbound_rx: Arc::new(Mutex::new(in_rx)),
944            shutdown_tx,
945            subscriptions,
946            sub_id_counter,
947            pending,
948            pending_receipts,
949        })
950    }
951
952    /// Build a CONNECT frame with all specified headers.
953    fn build_connect_frame(
954        accept_version: &str,
955        host: &str,
956        login: &str,
957        passcode: &str,
958        heartbeat: &str,
959        client_id: &Option<String>,
960        custom_headers: &[(String, String)],
961    ) -> Frame {
962        let mut connect = Frame::new("CONNECT")
963            .header("accept-version", accept_version)
964            .header("host", host)
965            .header("login", login)
966            .header("passcode", passcode)
967            .header("heart-beat", heartbeat);
968
969        if let Some(id) = client_id {
970            connect = connect.header("client-id", id);
971        }
972
973        // Reserved headers that custom_headers cannot override
974        let reserved = [
975            "accept-version",
976            "host",
977            "login",
978            "passcode",
979            "heart-beat",
980            "client-id",
981        ];
982
983        for (k, v) in custom_headers {
984            if !reserved.contains(&k.to_lowercase().as_str()) {
985                connect = connect.header(k, v);
986            }
987        }
988
989        connect
990    }
991
992    /// Wait for CONNECTED or ERROR response from the server.
993    ///
994    /// Returns the server's heartbeat header value on success, or an error
995    /// if the server sends an ERROR frame or closes the connection.
996    async fn await_connected_response(
997        framed: &mut Framed<TcpStream, StompCodec>,
998    ) -> Result<String, ConnError> {
999        loop {
1000            match framed.next().await {
1001                Some(Ok(StompItem::Frame(f))) => {
1002                    if f.command == "CONNECTED" {
1003                        // Extract heartbeat from server
1004                        let server_hb = f.get_header("heart-beat").unwrap_or("0,0").to_string();
1005                        return Ok(server_hb);
1006                    } else if f.command == "ERROR" {
1007                        // Server rejected connection (e.g., invalid credentials)
1008                        return Err(ConnError::ServerRejected(ServerError::from_frame(f)));
1009                    }
1010                    // Ignore other frames during CONNECT phase
1011                }
1012                Some(Ok(StompItem::Heartbeat)) => {
1013                    // Ignore heartbeats during handshake
1014                    continue;
1015                }
1016                Some(Err(e)) => {
1017                    return Err(ConnError::Io(e));
1018                }
1019                None => {
1020                    return Err(ConnError::Protocol(
1021                        "connection closed before CONNECTED received".to_string(),
1022                    ));
1023                }
1024            }
1025        }
1026    }
1027
1028    pub async fn send_frame(&self, frame: Frame) -> Result<(), ConnError> {
1029        // Send a frame to the background writer task.
1030        //
1031        // Parameters
1032        // - `frame`: ownership of the `Frame` to send. The frame is converted
1033        //   into a `StompItem::Frame` and sent over the internal mpsc channel.
1034        self.outbound_tx
1035            .send(StompItem::Frame(frame))
1036            .await
1037            .map_err(|_| ConnError::Protocol("send channel closed".into()))
1038    }
1039
1040    /// Send a frame with a receipt request and return the receipt ID.
1041    ///
1042    /// This method adds a unique `receipt` header to the frame and registers
1043    /// the receipt ID for tracking. Use `wait_for_receipt()` to wait for the
1044    /// server's RECEIPT response.
1045    ///
1046    /// # Parameters
1047    /// - `frame`: the frame to send. A `receipt` header will be added.
1048    ///
1049    /// # Returns
1050    /// The generated receipt ID that can be used with `wait_for_receipt()`.
1051    ///
1052    /// # Example
1053    /// ```ignore
1054    /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
1055    /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
1056    /// ```
1057    pub async fn send_frame_with_receipt(&self, frame: Frame) -> Result<String, ConnError> {
1058        use std::sync::atomic::AtomicU64;
1059        use std::sync::atomic::Ordering::SeqCst;
1060
1061        // Generate a unique receipt ID using a static counter
1062        static RECEIPT_COUNTER: AtomicU64 = AtomicU64::new(1);
1063        let receipt_id = format!("rcpt-{}", RECEIPT_COUNTER.fetch_add(1, SeqCst));
1064
1065        // Create the oneshot channel for notification
1066        let (tx, _rx) = oneshot::channel();
1067
1068        // Register the pending receipt
1069        {
1070            let mut receipts = self.pending_receipts.lock().await;
1071            receipts.insert(receipt_id.clone(), tx);
1072        }
1073
1074        // Add receipt header and send the frame
1075        let frame_with_receipt = frame.receipt(&receipt_id);
1076        self.send_frame(frame_with_receipt).await?;
1077
1078        Ok(receipt_id)
1079    }
1080
1081    /// Wait for a receipt confirmation from the server.
1082    ///
1083    /// This method blocks until the server sends a RECEIPT frame with the
1084    /// matching receipt-id, or until the timeout expires.
1085    ///
1086    /// # Parameters
1087    /// - `receipt_id`: the receipt ID returned by `send_frame_with_receipt()`.
1088    /// - `timeout`: maximum time to wait for the receipt.
1089    ///
1090    /// # Returns
1091    /// `Ok(())` if the receipt was received, or `Err(ConnError::ReceiptTimeout)`
1092    /// if the timeout expired.
1093    ///
1094    /// # Example
1095    /// ```ignore
1096    /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
1097    /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
1098    /// println!("Message confirmed!");
1099    /// ```
1100    pub async fn wait_for_receipt(
1101        &self,
1102        receipt_id: &str,
1103        timeout: Duration,
1104    ) -> Result<(), ConnError> {
1105        // Get the receiver for this receipt
1106        let rx = {
1107            let mut receipts = self.pending_receipts.lock().await;
1108            // Re-create the oneshot channel and swap out the sender
1109            let (tx, rx) = oneshot::channel();
1110            if let Some(old_tx) = receipts.insert(receipt_id.to_string(), tx) {
1111                // Drop the old sender - this is expected if called after send_frame_with_receipt
1112                drop(old_tx);
1113            }
1114            rx
1115        };
1116
1117        // Wait for the receipt with timeout
1118        match tokio::time::timeout(timeout, rx).await {
1119            Ok(Ok(())) => Ok(()),
1120            Ok(Err(_)) => {
1121                // Channel was closed without receiving - connection likely dropped
1122                Err(ConnError::Protocol(
1123                    "receipt channel closed unexpectedly".into(),
1124                ))
1125            }
1126            Err(_) => {
1127                // Timeout expired - clean up the pending receipt
1128                let mut receipts = self.pending_receipts.lock().await;
1129                receipts.remove(receipt_id);
1130                Err(ConnError::ReceiptTimeout(receipt_id.to_string()))
1131            }
1132        }
1133    }
1134
1135    /// Send a frame and wait for server confirmation via RECEIPT.
1136    ///
1137    /// This is a convenience method that combines `send_frame_with_receipt()`
1138    /// and `wait_for_receipt()`. Use this when you want to ensure a frame
1139    /// was processed by the server before continuing.
1140    ///
1141    /// # Parameters
1142    /// - `frame`: the frame to send.
1143    /// - `timeout`: maximum time to wait for the receipt.
1144    ///
1145    /// # Returns
1146    /// `Ok(())` if the frame was sent and receipt confirmed, or an error if
1147    /// sending failed or the receipt timed out.
1148    ///
1149    /// # Example
1150    /// ```ignore
1151    /// let frame = Frame::new("SEND")
1152    ///     .header("destination", "/queue/orders")
1153    ///     .set_body(b"order data".to_vec());
1154    ///
1155    /// conn.send_frame_confirmed(frame, Duration::from_secs(5)).await?;
1156    /// println!("Order sent and confirmed!");
1157    /// ```
1158    pub async fn send_frame_confirmed(
1159        &self,
1160        frame: Frame,
1161        timeout: Duration,
1162    ) -> Result<(), ConnError> {
1163        // Generate receipt ID and register before sending
1164        use std::sync::atomic::AtomicU64;
1165        use std::sync::atomic::Ordering::SeqCst;
1166
1167        static RECEIPT_COUNTER: AtomicU64 = AtomicU64::new(1);
1168        let receipt_id = format!("rcpt-{}", RECEIPT_COUNTER.fetch_add(1, SeqCst));
1169
1170        // Create the oneshot channel for notification
1171        let (tx, rx) = oneshot::channel();
1172
1173        // Register the pending receipt before sending
1174        {
1175            let mut receipts = self.pending_receipts.lock().await;
1176            receipts.insert(receipt_id.clone(), tx);
1177        }
1178
1179        // Add receipt header and send the frame
1180        let frame_with_receipt = frame.receipt(&receipt_id);
1181        self.send_frame(frame_with_receipt).await?;
1182
1183        // Wait for the receipt with timeout
1184        match tokio::time::timeout(timeout, rx).await {
1185            Ok(Ok(())) => Ok(()),
1186            Ok(Err(_)) => Err(ConnError::Protocol(
1187                "receipt channel closed unexpectedly".into(),
1188            )),
1189            Err(_) => {
1190                // Timeout expired - clean up
1191                let mut receipts = self.pending_receipts.lock().await;
1192                receipts.remove(&receipt_id);
1193                Err(ConnError::ReceiptTimeout(receipt_id))
1194            }
1195        }
1196    }
1197
1198    /// Subscribe to a destination.
1199    ///
1200    /// Parameters
1201    /// - `destination`: the STOMP destination to subscribe to (e.g. "/queue/foo").
1202    /// - `ack`: acknowledgement mode to request from the server.
1203    ///
1204    /// Returns a tuple `(subscription_id, receiver)` where `subscription_id` is
1205    /// the opaque id assigned locally for this subscription and `receiver` is a
1206    /// `mpsc::Receiver<Frame>` which will yield incoming MESSAGE frames for the
1207    /// destination. The caller should read from the receiver to handle messages.
1208    /// Subscribe to a destination using optional extra headers.
1209    ///
1210    /// This variant accepts additional headers which are stored locally and
1211    /// re-sent on reconnect. Use `subscribe` as a convenience wrapper when no
1212    /// extra headers are needed.
1213    pub async fn subscribe_with_headers(
1214        &self,
1215        destination: &str,
1216        ack: AckMode,
1217        extra_headers: Vec<(String, String)>,
1218    ) -> Result<crate::subscription::Subscription, ConnError> {
1219        let id = self
1220            .sub_id_counter
1221            .fetch_add(1, Ordering::SeqCst)
1222            .to_string();
1223        let (tx, rx) = mpsc::channel::<Frame>(16);
1224        {
1225            let mut map = self.subscriptions.lock().await;
1226            map.entry(destination.to_string())
1227                .or_insert_with(Vec::new)
1228                .push(SubscriptionEntry {
1229                    id: id.clone(),
1230                    sender: tx.clone(),
1231                    ack: ack.as_str().to_string(),
1232                    headers: extra_headers.clone(),
1233                });
1234        }
1235
1236        let mut f = Frame::new("SUBSCRIBE");
1237        f = f
1238            .header("id", &id)
1239            .header("destination", destination)
1240            .header("ack", ack.as_str());
1241        for (k, v) in &extra_headers {
1242            f = f.header(k, v);
1243        }
1244        self.outbound_tx
1245            .send(StompItem::Frame(f))
1246            .await
1247            .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1248
1249        Ok(crate::subscription::Subscription::new(
1250            id,
1251            destination.to_string(),
1252            rx,
1253            self.clone(),
1254        ))
1255    }
1256
1257    /// Convenience wrapper without extra headers.
1258    pub async fn subscribe(
1259        &self,
1260        destination: &str,
1261        ack: AckMode,
1262    ) -> Result<crate::subscription::Subscription, ConnError> {
1263        self.subscribe_with_headers(destination, ack, Vec::new())
1264            .await
1265    }
1266
1267    /// Subscribe with a typed `SubscriptionOptions` structure.
1268    ///
1269    /// `SubscriptionOptions.headers` are forwarded to the broker and persisted
1270    /// for automatic resubscribe after reconnect. If `durable_queue` is set,
1271    /// it will be used as the actual destination instead of `destination`.
1272    pub async fn subscribe_with_options(
1273        &self,
1274        destination: &str,
1275        ack: AckMode,
1276        options: crate::subscription::SubscriptionOptions,
1277    ) -> Result<crate::subscription::Subscription, ConnError> {
1278        let dest = options
1279            .durable_queue
1280            .as_deref()
1281            .unwrap_or(destination)
1282            .to_string();
1283        self.subscribe_with_headers(&dest, ack, options.headers)
1284            .await
1285    }
1286
1287    /// Unsubscribe a previously created subscription by its local subscription id.
1288    pub async fn unsubscribe(&self, subscription_id: &str) -> Result<(), ConnError> {
1289        let mut found = false;
1290        {
1291            let mut map = self.subscriptions.lock().await;
1292            let mut remove_keys: Vec<String> = Vec::new();
1293            for (dest, vec) in map.iter_mut() {
1294                if let Some(pos) = vec.iter().position(|entry| entry.id == subscription_id) {
1295                    vec.remove(pos);
1296                    found = true;
1297                }
1298                if vec.is_empty() {
1299                    remove_keys.push(dest.clone());
1300                }
1301            }
1302            for k in remove_keys {
1303                map.remove(&k);
1304            }
1305        }
1306
1307        if !found {
1308            return Err(ConnError::Protocol("subscription id not found".into()));
1309        }
1310
1311        let mut f = Frame::new("UNSUBSCRIBE");
1312        f = f.header("id", subscription_id);
1313        self.outbound_tx
1314            .send(StompItem::Frame(f))
1315            .await
1316            .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1317
1318        Ok(())
1319    }
1320
1321    /// Acknowledge a message previously received in `client` or
1322    /// `client-individual` ack modes.
1323    ///
1324    /// STOMP ack semantics:
1325    /// - `auto`: server considers message delivered immediately; the client
1326    ///   should not ack.
1327    /// - `client`: cumulative acknowledgements. ACKing message `M` for
1328    ///   subscription `S` acknowledges all messages delivered to `S` up to
1329    ///   and including `M`.
1330    /// - `client-individual`: only the named message is acknowledged.
1331    ///
1332    /// Parameters
1333    /// - `subscription_id`: the local subscription id returned by
1334    ///   `Connection::subscribe`. This disambiguates which subscription's
1335    ///   pending queue to advance for cumulative ACKs.
1336    /// - `message_id`: the `message-id` header value from the received
1337    ///   MESSAGE frame to acknowledge.
1338    ///
1339    /// Behavior
1340    /// - The pending queue for `subscription_id` is searched for `message_id`.
1341    ///   If the subscription used `client` ack mode, all pending messages up to
1342    ///   and including the matched message are removed. If the subscription
1343    ///   used `client-individual`, only the matched message is removed.
1344    /// - An `ACK` frame is sent to the server with `id=<message_id>` and
1345    ///   `subscription=<subscription_id>` headers.
1346    #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1347    pub async fn ack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1348        // Remove from the local pending queue according to subscription ack mode.
1349        let mut removed_any = false;
1350        {
1351            let mut p = self.pending.lock().await;
1352            if let Some(queue) = p.get_mut(subscription_id) {
1353                if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1354                    // Determine ack mode for this subscription (default to client).
1355                    let mut ack_mode = "client".to_string();
1356                    {
1357                        let map = self.subscriptions.lock().await;
1358                        'outer: for (_dest, vec) in map.iter() {
1359                            for entry in vec.iter() {
1360                                if entry.id == subscription_id {
1361                                    ack_mode = entry.ack.clone();
1362                                    break 'outer;
1363                                }
1364                            }
1365                        }
1366                    }
1367
1368                    if ack_mode == "client" {
1369                        // cumulative: remove up to and including pos
1370                        for _ in 0..=pos {
1371                            queue.pop_front();
1372                            removed_any = true;
1373                        }
1374                    } else if queue.remove(pos).is_some() {
1375                        // client-individual: remove only the specific message
1376                        removed_any = true;
1377                    }
1378
1379                    if queue.is_empty() {
1380                        p.remove(subscription_id);
1381                    }
1382                }
1383            }
1384        }
1385
1386        // Send ACK to server (include subscription header for clarity)
1387        let mut f = Frame::new("ACK");
1388        f = f
1389            .header("id", message_id)
1390            .header("subscription", subscription_id);
1391        self.outbound_tx
1392            .send(StompItem::Frame(f))
1393            .await
1394            .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1395
1396        // If message wasn't found locally, still send ACK to server; server
1397        // may ignore or treat it as no-op.
1398        let _ = removed_any;
1399        Ok(())
1400    }
1401
1402    /// Negative-acknowledge a message (NACK).
1403    ///
1404    /// Parameters
1405    /// - `subscription_id`: the local subscription id the message was delivered under.
1406    /// - `message_id`: the `message-id` header value from the received MESSAGE.
1407    ///
1408    /// Behavior
1409    /// - Removes the message from the local pending queue (cumulatively if the
1410    ///   subscription used `client` ack mode, otherwise only the single
1411    ///   message). Sends a `NACK` frame to the server with `id` and
1412    ///   `subscription` headers.
1413    #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1414    pub async fn nack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1415        // Mirror ack removal semantics for pending map.
1416        let mut removed_any = false;
1417        {
1418            let mut p = self.pending.lock().await;
1419            if let Some(queue) = p.get_mut(subscription_id) {
1420                if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1421                    let mut ack_mode = "client".to_string();
1422                    {
1423                        let map = self.subscriptions.lock().await;
1424                        'outer2: for (_dest, vec) in map.iter() {
1425                            for entry in vec.iter() {
1426                                if entry.id == subscription_id {
1427                                    ack_mode = entry.ack.clone();
1428                                    break 'outer2;
1429                                }
1430                            }
1431                        }
1432                    }
1433
1434                    if ack_mode == "client" {
1435                        for _ in 0..=pos {
1436                            queue.pop_front();
1437                            removed_any = true;
1438                        }
1439                    } else if queue.remove(pos).is_some() {
1440                        removed_any = true;
1441                    }
1442
1443                    if queue.is_empty() {
1444                        p.remove(subscription_id);
1445                    }
1446                }
1447            }
1448        }
1449
1450        let mut f = Frame::new("NACK");
1451        f = f
1452            .header("id", message_id)
1453            .header("subscription", subscription_id);
1454        self.outbound_tx
1455            .send(StompItem::Frame(f))
1456            .await
1457            .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1458
1459        let _ = removed_any;
1460        Ok(())
1461    }
1462
1463    /// Helper to send a transaction frame (BEGIN, COMMIT, or ABORT).
1464    async fn send_transaction_frame(
1465        &self,
1466        command: &str,
1467        transaction_id: &str,
1468    ) -> Result<(), ConnError> {
1469        let f = Frame::new(command).header("transaction", transaction_id);
1470        self.outbound_tx
1471            .send(StompItem::Frame(f))
1472            .await
1473            .map_err(|_| ConnError::Protocol("send channel closed".into()))
1474    }
1475
1476    /// Begin a transaction.
1477    ///
1478    /// Parameters
1479    /// - `transaction_id`: unique identifier for the transaction. The caller is
1480    ///   responsible for ensuring uniqueness within the connection.
1481    ///
1482    /// Behavior
1483    /// - Sends a `BEGIN` frame to the server with `transaction:<transaction_id>`
1484    ///   header. Subsequent `SEND`, `ACK`, and `NACK` frames may include this
1485    ///   transaction id to group them into the transaction. The transaction must
1486    ///   be finalized with either `commit` or `abort`.
1487    pub async fn begin(&self, transaction_id: &str) -> Result<(), ConnError> {
1488        self.send_transaction_frame("BEGIN", transaction_id).await
1489    }
1490
1491    /// Commit a transaction.
1492    ///
1493    /// Parameters
1494    /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1495    ///
1496    /// Behavior
1497    /// - Sends a `COMMIT` frame to the server with `transaction:<transaction_id>`
1498    ///   header. All operations within the transaction are applied atomically.
1499    pub async fn commit(&self, transaction_id: &str) -> Result<(), ConnError> {
1500        self.send_transaction_frame("COMMIT", transaction_id).await
1501    }
1502
1503    /// Abort a transaction.
1504    ///
1505    /// Parameters
1506    /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1507    ///
1508    /// Behavior
1509    /// - Sends an `ABORT` frame to the server with `transaction:<transaction_id>`
1510    ///   header. All operations within the transaction are discarded.
1511    pub async fn abort(&self, transaction_id: &str) -> Result<(), ConnError> {
1512        self.send_transaction_frame("ABORT", transaction_id).await
1513    }
1514
1515    /// Receive the next frame from the server.
1516    ///
1517    /// Returns `Some(ReceivedFrame::Frame(..))` for normal frames (MESSAGE, etc.),
1518    /// `Some(ReceivedFrame::Error(..))` for ERROR frames, or `None` if the
1519    /// connection has been closed.
1520    ///
1521    /// # Example
1522    ///
1523    /// ```ignore
1524    /// use iridium_stomp::ReceivedFrame;
1525    ///
1526    /// while let Some(received) = conn.next_frame().await {
1527    ///     match received {
1528    ///         ReceivedFrame::Frame(frame) => {
1529    ///             println!("Got {}: {:?}", frame.command, frame.body);
1530    ///         }
1531    ///         ReceivedFrame::Error(err) => {
1532    ///             eprintln!("Server error: {}", err);
1533    ///             break;
1534    ///         }
1535    ///     }
1536    /// }
1537    /// ```
1538    pub async fn next_frame(&self) -> Option<ReceivedFrame> {
1539        let mut rx = self.inbound_rx.lock().await;
1540        let frame = rx.recv().await?;
1541
1542        // Convert ERROR frames to ServerError for better ergonomics
1543        if frame.command == "ERROR" {
1544            Some(ReceivedFrame::Error(ServerError::from_frame(frame)))
1545        } else {
1546            Some(ReceivedFrame::Frame(frame))
1547        }
1548    }
1549
1550    pub async fn close(self) {
1551        // Signal the background task to shutdown by broadcasting on the
1552        // shutdown channel. Consumers may await task termination separately
1553        // if needed.
1554        let _ = self.shutdown_tx.send(());
1555    }
1556}
1557
1558fn current_millis() -> u64 {
1559    use std::time::{SystemTime, UNIX_EPOCH};
1560    SystemTime::now()
1561        .duration_since(UNIX_EPOCH)
1562        .map(|d| d.as_millis() as u64)
1563        .unwrap_or(0)
1564}
1565
1566#[cfg(test)]
1567mod tests {
1568    use super::*;
1569    use tokio::sync::mpsc;
1570
1571    // Helper to build a MESSAGE frame with given message-id and subscription/destination headers
1572    fn make_message(
1573        message_id: &str,
1574        subscription: Option<&str>,
1575        destination: Option<&str>,
1576    ) -> Frame {
1577        let mut f = Frame::new("MESSAGE");
1578        f = f.header("message-id", message_id);
1579        if let Some(s) = subscription {
1580            f = f.header("subscription", s);
1581        }
1582        if let Some(d) = destination {
1583            f = f.header("destination", d);
1584        }
1585        f
1586    }
1587
1588    #[tokio::test]
1589    async fn test_cumulative_ack_removes_prefix() {
1590        // setup channels
1591        let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1592        let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1593        let (shutdown_tx, _) = broadcast::channel::<()>(1);
1594
1595        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1596        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1597
1598        let sub_id_counter = Arc::new(AtomicU64::new(1));
1599
1600        // create a subscription entry s1 with client (cumulative) ack
1601        let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1602        {
1603            let mut map = subscriptions.lock().await;
1604            map.insert(
1605                "/queue/x".to_string(),
1606                vec![SubscriptionEntry {
1607                    id: "s1".to_string(),
1608                    sender: sub_sender,
1609                    ack: "client".to_string(),
1610                    headers: Vec::new(),
1611                }],
1612            );
1613        }
1614
1615        // fill pending queue for s1: m1,m2,m3
1616        {
1617            let mut p = pending.lock().await;
1618            let mut q = VecDeque::new();
1619            q.push_back((
1620                "m1".to_string(),
1621                make_message("m1", Some("s1"), Some("/queue/x")),
1622            ));
1623            q.push_back((
1624                "m2".to_string(),
1625                make_message("m2", Some("s1"), Some("/queue/x")),
1626            ));
1627            q.push_back((
1628                "m3".to_string(),
1629                make_message("m3", Some("s1"), Some("/queue/x")),
1630            ));
1631            p.insert("s1".to_string(), q);
1632        }
1633
1634        let conn = Connection {
1635            outbound_tx: out_tx,
1636            inbound_rx: Arc::new(Mutex::new(in_rx)),
1637            shutdown_tx,
1638            subscriptions: subscriptions.clone(),
1639            sub_id_counter,
1640            pending: pending.clone(),
1641            pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1642        };
1643
1644        // ack m2 cumulatively: should remove m1 and m2, leaving m3
1645        conn.ack("s1", "m2").await.expect("ack failed");
1646
1647        // verify pending for s1 contains only m3
1648        {
1649            let p = pending.lock().await;
1650            let q = p.get("s1").expect("missing s1");
1651            assert_eq!(q.len(), 1);
1652            assert_eq!(q.front().unwrap().0, "m3");
1653        }
1654
1655        // verify an ACK frame was emitted
1656        if let Some(item) = out_rx.recv().await {
1657            match item {
1658                StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1659                _ => panic!("expected frame"),
1660            }
1661        } else {
1662            panic!("no outbound frame sent")
1663        }
1664    }
1665
1666    #[tokio::test]
1667    async fn test_client_individual_ack_removes_only_one() {
1668        // setup channels
1669        let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1670        let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1671        let (shutdown_tx, _) = broadcast::channel::<()>(1);
1672
1673        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1674        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1675
1676        let sub_id_counter = Arc::new(AtomicU64::new(1));
1677
1678        // create a subscription entry s2 with client-individual ack
1679        let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1680        {
1681            let mut map = subscriptions.lock().await;
1682            map.insert(
1683                "/queue/y".to_string(),
1684                vec![SubscriptionEntry {
1685                    id: "s2".to_string(),
1686                    sender: sub_sender,
1687                    ack: "client-individual".to_string(),
1688                    headers: Vec::new(),
1689                }],
1690            );
1691        }
1692
1693        // fill pending queue for s2: a,b,c
1694        {
1695            let mut p = pending.lock().await;
1696            let mut q = VecDeque::new();
1697            q.push_back((
1698                "a".to_string(),
1699                make_message("a", Some("s2"), Some("/queue/y")),
1700            ));
1701            q.push_back((
1702                "b".to_string(),
1703                make_message("b", Some("s2"), Some("/queue/y")),
1704            ));
1705            q.push_back((
1706                "c".to_string(),
1707                make_message("c", Some("s2"), Some("/queue/y")),
1708            ));
1709            p.insert("s2".to_string(), q);
1710        }
1711
1712        let conn = Connection {
1713            outbound_tx: out_tx,
1714            inbound_rx: Arc::new(Mutex::new(in_rx)),
1715            shutdown_tx,
1716            subscriptions: subscriptions.clone(),
1717            sub_id_counter,
1718            pending: pending.clone(),
1719            pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1720        };
1721
1722        // ack only 'b' individually
1723        conn.ack("s2", "b").await.expect("ack failed");
1724
1725        // verify pending for s2 contains a and c
1726        {
1727            let p = pending.lock().await;
1728            let q = p.get("s2").expect("missing s2");
1729            assert_eq!(q.len(), 2);
1730            assert_eq!(q[0].0, "a");
1731            assert_eq!(q[1].0, "c");
1732        }
1733
1734        // verify an ACK frame was emitted
1735        if let Some(item) = out_rx.recv().await {
1736            match item {
1737                StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1738                _ => panic!("expected frame"),
1739            }
1740        } else {
1741            panic!("no outbound frame sent")
1742        }
1743    }
1744
1745    #[tokio::test]
1746    async fn test_subscription_receive_delivers_message() {
1747        // setup channels
1748        let (out_tx, _out_rx) = mpsc::channel::<StompItem>(8);
1749        let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1750        let (shutdown_tx, _) = broadcast::channel::<()>(1);
1751
1752        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1753        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1754
1755        let sub_id_counter = Arc::new(AtomicU64::new(1));
1756
1757        let conn = Connection {
1758            outbound_tx: out_tx,
1759            inbound_rx: Arc::new(Mutex::new(in_rx)),
1760            shutdown_tx,
1761            subscriptions: subscriptions.clone(),
1762            sub_id_counter,
1763            pending: pending.clone(),
1764            pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1765        };
1766
1767        // subscribe
1768        let subscription = conn
1769            .subscribe("/queue/test", AckMode::Auto)
1770            .await
1771            .expect("subscribe failed");
1772
1773        // find the sender stored in the subscriptions map and push a message
1774        {
1775            let map = conn.subscriptions.lock().await;
1776            let vec = map.get("/queue/test").expect("missing subscription vec");
1777            let sender = &vec[0].sender;
1778            let f = make_message("m1", Some(&vec[0].id), Some("/queue/test"));
1779            sender.try_send(f).expect("send to subscription failed");
1780        }
1781
1782        // consume from the subscription receiver
1783        let mut rx = subscription.into_receiver();
1784        if let Some(received) = rx.recv().await {
1785            assert_eq!(received.command, "MESSAGE");
1786            // message-id header should be present
1787            let mut found = false;
1788            for (k, _v) in &received.headers {
1789                if k.to_lowercase() == "message-id" {
1790                    found = true;
1791                    break;
1792                }
1793            }
1794            assert!(found, "message-id header missing");
1795        } else {
1796            panic!("no message received on subscription")
1797        }
1798    }
1799
1800    #[tokio::test]
1801    async fn test_subscription_ack_removes_pending_and_sends_ack() {
1802        // setup channels
1803        let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1804        let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1805        let (shutdown_tx, _) = broadcast::channel::<()>(1);
1806
1807        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1808        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1809
1810        let sub_id_counter = Arc::new(AtomicU64::new(1));
1811
1812        let conn = Connection {
1813            outbound_tx: out_tx,
1814            inbound_rx: Arc::new(Mutex::new(in_rx)),
1815            shutdown_tx,
1816            subscriptions: subscriptions.clone(),
1817            sub_id_counter,
1818            pending: pending.clone(),
1819            pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1820        };
1821
1822        // subscribe with client ack
1823        let subscription = conn
1824            .subscribe("/queue/ack", AckMode::Client)
1825            .await
1826            .expect("subscribe failed");
1827
1828        let sub_id = subscription.id().to_string();
1829
1830        // drain any initial outbound frames (SUBSCRIBE) emitted by subscribe()
1831        while out_rx.try_recv().is_ok() {}
1832
1833        // populate pending queue for this subscription
1834        {
1835            let mut p = conn.pending.lock().await;
1836            let mut q = VecDeque::new();
1837            q.push_back((
1838                "mid-1".to_string(),
1839                make_message("mid-1", Some(&sub_id), Some("/queue/ack")),
1840            ));
1841            p.insert(sub_id.clone(), q);
1842        }
1843
1844        // ack the message via the subscription helper
1845        subscription.ack("mid-1").await.expect("ack failed");
1846
1847        // ensure pending queue no longer contains the message
1848        {
1849            let p = conn.pending.lock().await;
1850            assert!(p.get(&sub_id).is_none() || p.get(&sub_id).unwrap().is_empty());
1851        }
1852
1853        // verify an ACK frame was emitted
1854        if let Some(item) = out_rx.recv().await {
1855            match item {
1856                StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1857                _ => panic!("expected frame"),
1858            }
1859        } else {
1860            panic!("no outbound frame sent")
1861        }
1862    }
1863
1864    // Helper function to create a test connection and output receiver
1865    fn setup_test_connection() -> (Connection, mpsc::Receiver<StompItem>) {
1866        let (out_tx, out_rx) = mpsc::channel::<StompItem>(8);
1867        let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1868        let (shutdown_tx, _) = broadcast::channel::<()>(1);
1869
1870        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1871        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1872        let sub_id_counter = Arc::new(AtomicU64::new(1));
1873
1874        let conn = Connection {
1875            outbound_tx: out_tx,
1876            inbound_rx: Arc::new(Mutex::new(in_rx)),
1877            shutdown_tx,
1878            subscriptions,
1879            sub_id_counter,
1880            pending,
1881            pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1882        };
1883
1884        (conn, out_rx)
1885    }
1886
1887    // Helper function to verify a frame with a transaction header
1888    fn verify_transaction_frame(frame: Frame, expected_command: &str, expected_tx_id: &str) {
1889        assert_eq!(frame.command, expected_command);
1890        assert!(
1891            frame
1892                .headers
1893                .iter()
1894                .any(|(k, v)| k == "transaction" && v == expected_tx_id),
1895            "transaction header with id '{}' not found",
1896            expected_tx_id
1897        );
1898    }
1899
1900    #[tokio::test]
1901    async fn test_begin_transaction_sends_frame() {
1902        let (conn, mut out_rx) = setup_test_connection();
1903
1904        conn.begin("tx1").await.expect("begin failed");
1905
1906        // verify BEGIN frame was emitted
1907        if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1908            verify_transaction_frame(f, "BEGIN", "tx1");
1909        } else {
1910            panic!("no outbound frame sent")
1911        }
1912    }
1913
1914    #[tokio::test]
1915    async fn test_commit_transaction_sends_frame() {
1916        let (conn, mut out_rx) = setup_test_connection();
1917
1918        conn.commit("tx1").await.expect("commit failed");
1919
1920        // verify COMMIT frame was emitted
1921        if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1922            verify_transaction_frame(f, "COMMIT", "tx1");
1923        } else {
1924            panic!("no outbound frame sent")
1925        }
1926    }
1927
1928    #[tokio::test]
1929    async fn test_abort_transaction_sends_frame() {
1930        let (conn, mut out_rx) = setup_test_connection();
1931
1932        conn.abort("tx1").await.expect("abort failed");
1933
1934        // verify ABORT frame was emitted
1935        if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1936            verify_transaction_frame(f, "ABORT", "tx1");
1937        } else {
1938            panic!("no outbound frame sent")
1939        }
1940    }
1941}