iridium_stomp/
connection.rs

1use futures::{SinkExt, StreamExt, future};
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6use thiserror::Error;
7use tokio::net::TcpStream;
8use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
9use tokio_util::codec::Framed;
10
11use crate::codec::{StompCodec, StompItem};
12use crate::frame::Frame;
13
14/// Configuration for STOMP heartbeat intervals.
15///
16/// Provides a type-safe way to configure heartbeat values instead of using
17/// raw strings. The `Display` implementation formats the value as required
18/// by the STOMP protocol ("send_ms,receive_ms").
19///
20/// # Example
21///
22/// ```
23/// use iridium_stomp::Heartbeat;
24///
25/// // Create a custom heartbeat configuration
26/// let hb = Heartbeat::new(5000, 10000);
27/// assert_eq!(hb.to_string(), "5000,10000");
28///
29/// // Use predefined configurations
30/// assert_eq!(Heartbeat::disabled().to_string(), "0,0");
31/// assert_eq!(Heartbeat::default().to_string(), "10000,10000");
32/// ```
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct Heartbeat {
35    /// Minimum interval (in milliseconds) between heartbeats the client can send.
36    /// A value of 0 means the client cannot send heartbeats.
37    pub send_ms: u32,
38
39    /// Minimum interval (in milliseconds) between heartbeats the client wants to receive.
40    /// A value of 0 means the client does not want to receive heartbeats.
41    pub receive_ms: u32,
42}
43
44impl Heartbeat {
45    /// Create a new heartbeat configuration with the specified intervals.
46    ///
47    /// # Arguments
48    ///
49    /// * `send_ms` - Minimum interval in milliseconds between heartbeats the client can send.
50    /// * `receive_ms` - Minimum interval in milliseconds between heartbeats the client wants to receive.
51    ///
52    /// # Example
53    ///
54    /// ```
55    /// use iridium_stomp::Heartbeat;
56    ///
57    /// let hb = Heartbeat::new(5000, 10000);
58    /// assert_eq!(hb.send_ms, 5000);
59    /// assert_eq!(hb.receive_ms, 10000);
60    /// ```
61    pub fn new(send_ms: u32, receive_ms: u32) -> Self {
62        Self {
63            send_ms,
64            receive_ms,
65        }
66    }
67
68    /// Create a heartbeat configuration that disables heartbeats entirely.
69    ///
70    /// This is equivalent to `Heartbeat::new(0, 0)`.
71    ///
72    /// # Example
73    ///
74    /// ```
75    /// use iridium_stomp::Heartbeat;
76    ///
77    /// let hb = Heartbeat::disabled();
78    /// assert_eq!(hb.send_ms, 0);
79    /// assert_eq!(hb.receive_ms, 0);
80    /// assert_eq!(hb.to_string(), "0,0");
81    /// ```
82    pub fn disabled() -> Self {
83        Self::new(0, 0)
84    }
85
86    /// Create a heartbeat configuration from a Duration for symmetric heartbeats.
87    ///
88    /// Both send and receive intervals will be set to the same value.
89    ///
90    /// The maximum supported Duration is approximately 49.7 days (u32::MAX milliseconds,
91    /// or 4,294,967,295 ms). If a larger Duration is provided, it will be clamped to
92    /// u32::MAX milliseconds to prevent overflow.
93    ///
94    /// # Example
95    ///
96    /// ```
97    /// use iridium_stomp::Heartbeat;
98    /// use std::time::Duration;
99    ///
100    /// let hb = Heartbeat::from_duration(Duration::from_secs(15));
101    /// assert_eq!(hb.send_ms, 15000);
102    /// assert_eq!(hb.receive_ms, 15000);
103    /// ```
104    pub fn from_duration(interval: Duration) -> Self {
105        let ms = interval.as_millis().min(u32::MAX as u128) as u32;
106        Self::new(ms, ms)
107    }
108}
109
110impl Default for Heartbeat {
111    /// Returns the default heartbeat configuration: 10 seconds for both send and receive.
112    fn default() -> Self {
113        Self::new(10000, 10000)
114    }
115}
116
117impl std::fmt::Display for Heartbeat {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        write!(f, "{},{}", self.send_ms, self.receive_ms)
120    }
121}
122
123/// Internal subscription entry stored for each destination.
124#[derive(Clone)]
125pub(crate) struct SubscriptionEntry {
126    pub(crate) id: String,
127    pub(crate) sender: mpsc::Sender<Frame>,
128    pub(crate) ack: String,
129    pub(crate) headers: Vec<(String, String)>,
130}
131
132/// Alias for the subscription dispatch map: destination -> list of
133/// `SubscriptionEntry`.
134pub(crate) type Subscriptions = HashMap<String, Vec<SubscriptionEntry>>;
135
136/// Alias for the pending map: subscription_id -> queue of (message-id, Frame).
137pub(crate) type PendingMap = HashMap<String, VecDeque<(String, Frame)>>;
138
139/// Internal type for resubscribe snapshot entries: (destination, id, ack, headers)
140pub(crate) type ResubEntry = (String, String, String, Vec<(String, String)>);
141
142/// Alias for pending receipt map: receipt-id -> oneshot sender to notify when received.
143pub(crate) type PendingReceipts = HashMap<String, oneshot::Sender<()>>;
144
145/// Errors returned by `Connection` operations.
146#[derive(Error, Debug)]
147pub enum ConnError {
148    /// I/O-level error
149    #[error("io error: {0}")]
150    Io(#[from] std::io::Error),
151    /// Protocol-level error
152    #[error("protocol error: {0}")]
153    Protocol(String),
154    /// Receipt timeout error
155    #[error("receipt timeout: no RECEIPT received for '{0}' within timeout")]
156    ReceiptTimeout(String),
157}
158
159/// Represents an ERROR frame received from the STOMP server.
160///
161/// STOMP servers send ERROR frames to indicate protocol violations, authentication
162/// failures, or other server-side errors. After sending an ERROR frame, the server
163/// typically closes the connection.
164///
165/// # Example
166///
167/// ```ignore
168/// use iridium_stomp::ReceivedFrame;
169///
170/// while let Some(received) = conn.next_frame().await {
171///     match received {
172///         ReceivedFrame::Frame(frame) => {
173///             // Normal message processing
174///         }
175///         ReceivedFrame::Error(err) => {
176///             eprintln!("Server error: {}", err.message);
177///             if let Some(body) = &err.body {
178///                 eprintln!("Details: {}", body);
179///             }
180///             break;
181///         }
182///     }
183/// }
184/// ```
185#[derive(Debug, Clone, PartialEq, Eq)]
186pub struct ServerError {
187    /// The error message from the `message` header.
188    pub message: String,
189
190    /// The error body, if present. Contains additional error details.
191    pub body: Option<String>,
192
193    /// The receipt-id if this error is in response to a specific frame.
194    pub receipt_id: Option<String>,
195
196    /// The original ERROR frame for access to additional headers.
197    pub frame: Frame,
198}
199
200impl ServerError {
201    /// Create a `ServerError` from an ERROR frame.
202    ///
203    /// This is primarily used internally but is public for testing and
204    /// advanced use cases where you need to construct a `ServerError` manually.
205    pub fn from_frame(frame: Frame) -> Self {
206        let message = frame
207            .get_header("message")
208            .unwrap_or("unknown error")
209            .to_string();
210
211        let body = if frame.body.is_empty() {
212            None
213        } else {
214            String::from_utf8(frame.body.clone()).ok()
215        };
216
217        let receipt_id = frame.get_header("receipt-id").map(|s| s.to_string());
218
219        Self {
220            message,
221            body,
222            receipt_id,
223            frame,
224        }
225    }
226}
227
228impl std::fmt::Display for ServerError {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        write!(f, "STOMP server error: {}", self.message)?;
231        if let Some(body) = &self.body {
232            write!(f, " - {}", body)?;
233        }
234        Ok(())
235    }
236}
237
238impl std::error::Error for ServerError {}
239
240/// The result of receiving a frame from the server.
241///
242/// STOMP servers can send either normal frames (MESSAGE, RECEIPT, etc.) or
243/// ERROR frames indicating a problem. This enum allows callers to handle
244/// both cases with pattern matching.
245///
246/// # Example
247///
248/// ```ignore
249/// use iridium_stomp::ReceivedFrame;
250///
251/// match conn.next_frame().await {
252///     Some(ReceivedFrame::Frame(frame)) => {
253///         println!("Got frame: {}", frame.command);
254///     }
255///     Some(ReceivedFrame::Error(err)) => {
256///         eprintln!("Server error: {}", err);
257///     }
258///     None => {
259///         println!("Connection closed");
260///     }
261/// }
262/// ```
263#[derive(Debug, Clone, PartialEq, Eq)]
264pub enum ReceivedFrame {
265    /// A normal STOMP frame (MESSAGE, RECEIPT, etc.)
266    Frame(Frame),
267    /// An ERROR frame from the server
268    Error(ServerError),
269}
270
271impl ReceivedFrame {
272    /// Returns `true` if this is an error frame.
273    pub fn is_error(&self) -> bool {
274        matches!(self, ReceivedFrame::Error(_))
275    }
276
277    /// Returns `true` if this is a normal frame.
278    pub fn is_frame(&self) -> bool {
279        matches!(self, ReceivedFrame::Frame(_))
280    }
281
282    /// Returns the frame if this is a normal frame, or `None` if it's an error.
283    pub fn into_frame(self) -> Option<Frame> {
284        match self {
285            ReceivedFrame::Frame(f) => Some(f),
286            ReceivedFrame::Error(_) => None,
287        }
288    }
289
290    /// Returns the error if this is an error frame, or `None` if it's a normal frame.
291    pub fn into_error(self) -> Option<ServerError> {
292        match self {
293            ReceivedFrame::Frame(_) => None,
294            ReceivedFrame::Error(e) => Some(e),
295        }
296    }
297}
298
299/// Subscription acknowledgement modes as defined by STOMP 1.2.
300#[derive(Debug, Clone, Copy, PartialEq, Eq)]
301pub enum AckMode {
302    Auto,
303    Client,
304    ClientIndividual,
305}
306
307impl AckMode {
308    fn as_str(&self) -> &'static str {
309        match self {
310            AckMode::Auto => "auto",
311            AckMode::Client => "client",
312            AckMode::ClientIndividual => "client-individual",
313        }
314    }
315}
316
317/// Options for customizing the STOMP CONNECT frame.
318///
319/// Use this struct with `Connection::connect_with_options()` to set custom
320/// headers, specify supported STOMP versions, or configure broker-specific
321/// options like `client-id` for durable subscriptions.
322///
323/// # Validation
324///
325/// This struct performs minimal validation. Values are passed to the broker
326/// as-is, and invalid configurations will be rejected by the broker at
327/// connection time. Empty strings are technically accepted but may cause
328/// broker-specific errors.
329///
330/// # Custom Headers
331///
332/// Custom headers added via `header()` cannot override critical STOMP headers
333/// (`accept-version`, `host`, `login`, `passcode`, `heart-beat`, `client-id`).
334/// Such headers are silently ignored. Use the dedicated builder methods to
335/// set these values.
336///
337/// # Example
338///
339/// ```ignore
340/// use iridium_stomp::{Connection, ConnectOptions};
341///
342/// let options = ConnectOptions::default()
343///     .client_id("my-durable-client")
344///     .host("my-vhost")
345///     .header("custom-header", "value");
346///
347/// let conn = Connection::connect_with_options(
348///     "localhost:61613",
349///     "guest",
350///     "guest",
351///     "10000,10000",
352///     options,
353/// ).await?;
354/// ```
355#[derive(Debug, Clone, Default)]
356pub struct ConnectOptions {
357    /// STOMP version(s) to accept (e.g., "1.2" or "1.0,1.1,1.2").
358    /// Defaults to "1.2" if not set.
359    pub accept_version: Option<String>,
360
361    /// Client ID for durable subscriptions (required by ActiveMQ, etc.).
362    pub client_id: Option<String>,
363
364    /// Virtual host header value. Defaults to "/" if not set.
365    pub host: Option<String>,
366
367    /// Additional custom headers to include in the CONNECT frame.
368    /// Note: Headers that would override critical STOMP headers are ignored.
369    pub headers: Vec<(String, String)>,
370}
371
372impl ConnectOptions {
373    /// Create a new `ConnectOptions` with default values.
374    pub fn new() -> Self {
375        Self::default()
376    }
377
378    /// Set the STOMP version(s) to accept (builder style).
379    ///
380    /// Examples: "1.2", "1.1,1.2", "1.0,1.1,1.2"
381    pub fn accept_version(mut self, version: impl Into<String>) -> Self {
382        self.accept_version = Some(version.into());
383        self
384    }
385
386    /// Set the client ID for durable subscriptions (builder style).
387    ///
388    /// Required by some brokers (e.g., ActiveMQ) for durable topic subscriptions.
389    pub fn client_id(mut self, id: impl Into<String>) -> Self {
390        self.client_id = Some(id.into());
391        self
392    }
393
394    /// Set the virtual host (builder style).
395    ///
396    /// Defaults to "/" if not set.
397    pub fn host(mut self, host: impl Into<String>) -> Self {
398        self.host = Some(host.into());
399        self
400    }
401
402    /// Add a custom header to the CONNECT frame (builder style).
403    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
404        self.headers.push((key.into(), value.into()));
405        self
406    }
407}
408
409/// Parse the STOMP `heart-beat` header value (format: "cx,cy").
410///
411/// Parameters
412/// - `header`: header string from the server or client (for example
413///   "10000,10000"). The values represent milliseconds.
414///
415/// Returns a tuple `(cx, cy)` where each value is the heartbeat interval in
416/// milliseconds. Missing or invalid fields default to `0`.
417pub fn parse_heartbeat_header(header: &str) -> (u64, u64) {
418    let mut parts = header.split(',');
419    let cx = parts
420        .next()
421        .and_then(|s| s.trim().parse::<u64>().ok())
422        .unwrap_or(0);
423    let cy = parts
424        .next()
425        .and_then(|s| s.trim().parse::<u64>().ok())
426        .unwrap_or(0);
427    (cx, cy)
428}
429
430/// Negotiate heartbeat intervals between client and server.
431///
432/// Parameters
433/// - `client_out`: client's desired outgoing heartbeat interval in
434///   milliseconds (how often the client will send heartbeats).
435/// - `client_in`: client's desired incoming heartbeat interval in
436///   milliseconds (how often the client expects to receive heartbeats).
437/// - `server_out`: server's advertised outgoing interval in milliseconds.
438/// - `server_in`: server's advertised incoming interval in milliseconds.
439///
440/// Returns `(outgoing, incoming)` where each element is `Some(Duration)` if
441/// heartbeats are enabled in that direction, or `None` if disabled. The
442/// negotiated interval uses the STOMP rule of taking the maximum of the
443/// corresponding client and server values.
444pub fn negotiate_heartbeats(
445    client_out: u64,
446    client_in: u64,
447    server_out: u64,
448    server_in: u64,
449) -> (Option<Duration>, Option<Duration>) {
450    let negotiated_out_ms = std::cmp::max(client_out, server_in);
451    let negotiated_in_ms = std::cmp::max(client_in, server_out);
452
453    let outgoing = if negotiated_out_ms == 0 {
454        None
455    } else {
456        Some(Duration::from_millis(negotiated_out_ms))
457    };
458    let incoming = if negotiated_in_ms == 0 {
459        None
460    } else {
461        Some(Duration::from_millis(negotiated_in_ms))
462    };
463    (outgoing, incoming)
464}
465
466/// High-level connection object that manages a single TCP/STOMP connection.
467///
468/// The `Connection` spawns a background task that maintains the TCP transport,
469/// sends/receives STOMP frames using `StompCodec`, negotiates heartbeats, and
470/// performs simple reconnect logic with exponential backoff.
471#[derive(Clone)]
472pub struct Connection {
473    outbound_tx: mpsc::Sender<StompItem>,
474    /// The inbound receiver is shared behind a mutex so the `Connection`
475    /// handle may be cloned and callers can call `next_frame` concurrently.
476    inbound_rx: Arc<Mutex<mpsc::Receiver<Frame>>>,
477    shutdown_tx: broadcast::Sender<()>,
478    /// Map of destination -> list of (subscription id, sender) for dispatching
479    /// inbound MESSAGE frames to subscribers.
480    subscriptions: Arc<Mutex<Subscriptions>>,
481    /// Monotonic counter used to allocate subscription ids.
482    sub_id_counter: Arc<AtomicU64>,
483    /// Pending messages awaiting ACK/NACK from the application.
484    ///
485    /// Organized by subscription id. For `client` ack mode the ACK is
486    /// cumulative: acknowledging message `M` for subscription `S` acknowledges
487    /// all messages previously delivered for `S` up to and including `M`.
488    /// For `client-individual` the ACK/NACK applies only to the single
489    /// message.
490    pending: Arc<Mutex<PendingMap>>,
491    /// Pending receipt confirmations.
492    ///
493    /// When a frame is sent with a `receipt` header, the receipt-id is stored
494    /// here with a oneshot sender. When the server responds with a RECEIPT
495    /// frame, the sender is notified.
496    pending_receipts: Arc<Mutex<PendingReceipts>>,
497}
498
499impl Connection {
500    /// Heartbeat value that disables heartbeats entirely.
501    ///
502    /// Use this when you don't want the client or server to send heartbeats.
503    /// Note that some brokers may still require heartbeats for long-lived connections.
504    ///
505    /// # Example
506    ///
507    /// ```ignore
508    /// let conn = Connection::connect(
509    ///     "localhost:61613",
510    ///     "guest",
511    ///     "guest",
512    ///     Connection::NO_HEARTBEAT,
513    /// ).await?;
514    /// ```
515    pub const NO_HEARTBEAT: &'static str = "0,0";
516
517    /// Default heartbeat value: 10 seconds for both send and receive.
518    ///
519    /// This is a reasonable default for most applications. The actual heartbeat
520    /// interval will be negotiated with the server (taking the maximum of client
521    /// and server preferences).
522    ///
523    /// # Example
524    ///
525    /// ```ignore
526    /// let conn = Connection::connect(
527    ///     "localhost:61613",
528    ///     "guest",
529    ///     "guest",
530    ///     Connection::DEFAULT_HEARTBEAT,
531    /// ).await?;
532    /// ```
533    pub const DEFAULT_HEARTBEAT: &'static str = "10000,10000";
534
535    /// Establish a connection to the STOMP server at `addr` with the given
536    /// credentials and heartbeat header string (e.g. "10000,10000").
537    ///
538    /// This is a convenience wrapper around `connect_with_options()` that uses
539    /// default options (STOMP 1.2, host="/", no client-id).
540    ///
541    /// Parameters
542    /// - `addr`: TCP address (host:port) of the STOMP server.
543    /// - `login`: login username for STOMP `CONNECT`.
544    /// - `passcode`: passcode for STOMP `CONNECT`.
545    /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
546    ///   milliseconds) that will be sent in the `CONNECT` frame.
547    ///
548    /// Returns a `Connection` which provides `send_frame`, `next_frame`, and
549    /// `close` helpers. The detailed connection handling (I/O, heartbeats,
550    /// reconnects) runs on a background task spawned by this method.
551    pub async fn connect(
552        addr: &str,
553        login: &str,
554        passcode: &str,
555        client_hb: &str,
556    ) -> Result<Self, ConnError> {
557        Self::connect_with_options(addr, login, passcode, client_hb, ConnectOptions::default())
558            .await
559    }
560
561    /// Establish a connection to the STOMP server with custom options.
562    ///
563    /// Use this method when you need to set a custom `client-id` (for durable
564    /// subscriptions), specify a virtual host, negotiate different STOMP
565    /// versions, or add custom CONNECT headers.
566    ///
567    /// Parameters
568    /// - `addr`: TCP address (host:port) of the STOMP server.
569    /// - `login`: login username for STOMP `CONNECT`.
570    /// - `passcode`: passcode for STOMP `CONNECT`.
571    /// - `client_hb`: client's `heart-beat` header value ("cx,cy" in
572    ///   milliseconds) that will be sent in the `CONNECT` frame.
573    /// - `options`: custom connection options (version, host, client-id, etc.).
574    ///
575    /// # Example
576    ///
577    /// ```ignore
578    /// use iridium_stomp::{Connection, ConnectOptions};
579    ///
580    /// // Connect with a client-id for durable subscriptions
581    /// let options = ConnectOptions::default()
582    ///     .client_id("my-app-instance-1");
583    ///
584    /// let conn = Connection::connect_with_options(
585    ///     "localhost:61613",
586    ///     "guest",
587    ///     "guest",
588    ///     "10000,10000",
589    ///     options,
590    /// ).await?;
591    /// ```
592    pub async fn connect_with_options(
593        addr: &str,
594        login: &str,
595        passcode: &str,
596        client_hb: &str,
597        options: ConnectOptions,
598    ) -> Result<Self, ConnError> {
599        let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(32);
600        let (in_tx, in_rx) = mpsc::channel::<Frame>(32);
601        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
602        let sub_id_counter = Arc::new(AtomicU64::new(1));
603        let (shutdown_tx, _) = broadcast::channel::<()>(1);
604        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
605        let pending_clone = pending.clone();
606        let pending_receipts: Arc<Mutex<PendingReceipts>> = Arc::new(Mutex::new(HashMap::new()));
607        let pending_receipts_clone = pending_receipts.clone();
608
609        let addr = addr.to_string();
610        let login = login.to_string();
611        let passcode = passcode.to_string();
612        let client_hb = client_hb.to_string();
613
614        // Extract options into owned values for the spawned task
615        let accept_version = options.accept_version.unwrap_or_else(|| "1.2".to_string());
616        let host = options.host.unwrap_or_else(|| "/".to_string());
617        let client_id = options.client_id;
618        let custom_headers = options.headers;
619
620        let shutdown_tx_clone = shutdown_tx.clone();
621        let subscriptions_clone = subscriptions.clone();
622
623        tokio::spawn(async move {
624            let mut backoff_secs: u64 = 1;
625            loop {
626                let mut shutdown_sub = shutdown_tx_clone.subscribe();
627
628                tokio::select! {
629                    _ = shutdown_sub.recv() => break,
630                    _ = future::ready(()) => {},
631                }
632
633                match TcpStream::connect(&addr).await {
634                    Ok(stream) => {
635                        let mut framed = Framed::new(stream, StompCodec::new());
636
637                        let mut connect = Frame::new("CONNECT");
638                        connect = connect.header("accept-version", &accept_version);
639                        connect = connect.header("host", &host);
640                        connect = connect.header("login", &login);
641                        connect = connect.header("passcode", &passcode);
642                        connect = connect.header("heart-beat", &client_hb);
643
644                        // Add client-id if specified
645                        if let Some(ref cid) = client_id {
646                            connect = connect.header("client-id", cid);
647                        }
648
649                        // Add any custom headers, skipping any that would override
650                        // critical STOMP CONNECT headers already set above
651                        for (key, value) in &custom_headers {
652                            let dominated = matches!(
653                                key.to_ascii_lowercase().as_str(),
654                                "accept-version"
655                                    | "host"
656                                    | "login"
657                                    | "passcode"
658                                    | "heart-beat"
659                                    | "client-id"
660                            );
661                            if !dominated {
662                                connect = connect.header(key, value);
663                            }
664                        }
665
666                        if framed.send(StompItem::Frame(connect)).await.is_err() {
667                            tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
668                            backoff_secs = (backoff_secs * 2).min(30);
669                            continue;
670                        }
671
672                        let mut server_heartbeat = "0,0".to_string();
673                        loop {
674                            tokio::select! {
675                                _ = shutdown_sub.recv() => break,
676                                item = framed.next() => {
677                                    match item {
678                                        Some(Ok(StompItem::Heartbeat)) => {}
679                                        Some(Ok(StompItem::Frame(f))) => {
680                                            if f.command == "CONNECTED" {
681                                                for (k, v) in &f.headers {
682                                                    if k.to_lowercase() == "heart-beat" { server_heartbeat = v.clone(); }
683                                                }
684                                                break;
685                                            }
686                                        }
687                                        _ => break,
688                                    }
689                                }
690                            }
691                        }
692
693                        let (cx, cy) = parse_heartbeat_header(&client_hb);
694                        let (sx, sy) = parse_heartbeat_header(&server_heartbeat);
695                        let (send_interval, recv_interval) = negotiate_heartbeats(cx, cy, sx, sy);
696
697                        let last_received = Arc::new(AtomicU64::new(current_millis()));
698                        let writer_last_sent = Arc::new(AtomicU64::new(current_millis()));
699
700                        let (mut sink, mut stream) = framed.split();
701                        let in_tx = in_tx.clone();
702                        let subscriptions = subscriptions_clone.clone();
703
704                        // Clear pending message map on reconnect — messages that were
705                        // outstanding before the disconnect are considered lost and
706                        // will be redelivered by the server as appropriate.
707                        {
708                            let mut p = pending_clone.lock().await;
709                            p.clear();
710                        }
711
712                        // Resubscribe any existing subscriptions after reconnect.
713                        // We snapshot the subscription entries while holding the lock
714                        // and then issue SUBSCRIBE frames using the sink.
715                        let subs_snapshot: Vec<ResubEntry> = {
716                            let map = subscriptions.lock().await;
717                            let mut v: Vec<ResubEntry> = Vec::new();
718                            for (dest, vec) in map.iter() {
719                                for entry in vec.iter() {
720                                    v.push((
721                                        dest.clone(),
722                                        entry.id.clone(),
723                                        entry.ack.clone(),
724                                        entry.headers.clone(),
725                                    ));
726                                }
727                            }
728                            v
729                        };
730
731                        for (dest, id, ack, headers) in subs_snapshot {
732                            let mut sf = Frame::new("SUBSCRIBE");
733                            sf = sf
734                                .header("id", &id)
735                                .header("destination", &dest)
736                                .header("ack", &ack);
737                            for (k, v) in headers {
738                                sf = sf.header(&k, &v);
739                            }
740                            let _ = sink.send(StompItem::Frame(sf)).await;
741                        }
742
743                        let mut hb_tick = match send_interval {
744                            Some(d) => tokio::time::interval(d),
745                            None => tokio::time::interval(Duration::from_secs(86400)),
746                        };
747                        let watchdog_half = recv_interval.map(|d| d / 2);
748
749                        backoff_secs = 1;
750
751                        'conn: loop {
752                            tokio::select! {
753                                _ = shutdown_sub.recv() => { let _ = sink.close().await; break 'conn; }
754                                maybe = out_rx.recv() => {
755                                    match maybe {
756                                        Some(item) => if sink.send(item).await.is_err() { break 'conn } else { writer_last_sent.store(current_millis(), Ordering::SeqCst); }
757                                        None => break 'conn,
758                                    }
759                                }
760                                item = stream.next() => {
761                                    match item {
762                                        Some(Ok(StompItem::Heartbeat)) => { last_received.store(current_millis(), Ordering::SeqCst); }
763                                        Some(Ok(StompItem::Frame(f))) => {
764                                            last_received.store(current_millis(), Ordering::SeqCst);
765                                            // Dispatch MESSAGE frames to any matching subscribers.
766                                            if f.command == "MESSAGE" {
767                                                // try to find destination, subscription and message-id headers
768                                                let mut dest_opt: Option<String> = None;
769                                                let mut sub_opt: Option<String> = None;
770                                                let mut msg_id_opt: Option<String> = None;
771                                                for (k, v) in &f.headers {
772                                                    let kl = k.to_lowercase();
773                                                    if kl == "destination" {
774                                                        dest_opt = Some(v.clone());
775                                                    } else if kl == "subscription" {
776                                                        sub_opt = Some(v.clone());
777                                                    } else if kl == "message-id" {
778                                                        msg_id_opt = Some(v.clone());
779                                                    }
780                                                }
781
782                                                // Determine whether we need to track this message as pending
783                                                let mut need_pending = false;
784                                                if let Some(sub_id) = &sub_opt {
785                                                    let map = subscriptions.lock().await;
786                                                    for (_dest, vec) in map.iter() {
787                                                        for entry in vec.iter() {
788                                                            if &entry.id == sub_id && entry.ack != "auto" {
789                                                                need_pending = true;
790                                                            }
791                                                        }
792                                                    }
793                                                } else if let Some(dest) = &dest_opt {
794                                                    let map = subscriptions.lock().await;
795                                                    if let Some(vec) = map.get(dest) {
796                                                        for entry in vec.iter() {
797                                                            if entry.ack != "auto" {
798                                                                need_pending = true;
799                                                                break;
800                                                            }
801                                                        }
802                                                    }
803                                                }
804
805                                                // If required, add to pending map (per-subscription) before
806                                                // delivery so ACK/NACK requests from the application can
807                                                // reference the message. We require a `message-id` header
808                                                // to track messages; if missing, we cannot support ACK/NACK.
809                                                if let Some(msg_id) = msg_id_opt.clone().filter(|_| need_pending) {
810                                                    // If the server provided a subscription id in the
811                                                    // MESSAGE, store pending under that subscription.
812                                                    if let Some(sub_id) = &sub_opt {
813                                                        let mut p = pending_clone.lock().await;
814                                                        let q = p
815                                                            .entry(sub_id.clone())
816                                                            .or_insert_with(VecDeque::new);
817                                                        q.push_back((msg_id.clone(), f.clone()));
818                                                    } else if let Some(dest) = &dest_opt {
819                                                        // Destination-based delivery: add the message to
820                                                        // the pending queue for each matching
821                                                        // subscription on that destination.
822                                                        let map = subscriptions.lock().await;
823                                                        if let Some(vec) = map.get(dest) {
824                                                            let mut p = pending_clone.lock().await;
825                                                            for entry in vec.iter() {
826                                                                let q = p
827                                                                    .entry(entry.id.clone())
828                                                                    .or_insert_with(VecDeque::new);
829                                                                q.push_back((msg_id.clone(), f.clone()));
830                                                            }
831                                                        }
832                                                    }
833                                                }
834
835                                                // Deliver to subscribers.
836                                                if let Some(sub_id) = sub_opt {
837                                                    let mut map = subscriptions.lock().await;
838                                                    for (_dest, vec) in map.iter_mut() {
839                                                        vec.retain(|entry| {
840                                                            if entry.id == sub_id {
841                                                                let _ = entry.sender.try_send(f.clone());
842                                                                true
843                                                            } else {
844                                                                true
845                                                            }
846                                                        });
847                                                    }
848                                                } else if let Some(dest) = dest_opt {
849                                                    let mut map = subscriptions.lock().await;
850                                                    if let Some(vec) = map.get_mut(&dest) {
851                                                        vec.retain(|entry| entry.sender.try_send(f.clone()).is_ok());
852                                                    }
853                                                }
854                                            } else if f.command == "RECEIPT" {
855                                                // Handle RECEIPT frame: notify any waiting callers
856                                                if let Some(receipt_id) = f.get_header("receipt-id") {
857                                                    let mut receipts = pending_receipts_clone.lock().await;
858                                                    if let Some(sender) = receipts.remove(receipt_id) {
859                                                        let _ = sender.send(());
860                                                    }
861                                                }
862                                                // Don't forward RECEIPT frames to inbound channel
863                                                continue;
864                                            }
865
866                                            let _ = in_tx.send(f).await;
867                                        }
868                                        Some(Err(_)) | None => break 'conn,
869                                    }
870                                }
871                                _ = hb_tick.tick() => {
872                                    if let Some(dur) = send_interval {
873                                        let last = writer_last_sent.load(Ordering::SeqCst);
874                                        if current_millis().saturating_sub(last) >= dur.as_millis() as u64 {
875                                            if sink.send(StompItem::Heartbeat).await.is_err() { break 'conn; }
876                                            writer_last_sent.store(current_millis(), Ordering::SeqCst);
877                                        }
878                                    }
879                                }
880                                _ = async { if let Some(interval) = watchdog_half { tokio::time::sleep(interval).await } else { future::pending::<()>().await } } => {
881                                    if let Some(recv_dur) = recv_interval {
882                                        let last = last_received.load(Ordering::SeqCst);
883                                        if current_millis().saturating_sub(last) > (recv_dur.as_millis() as u64 * 2) {
884                                            let _ = sink.close().await; break 'conn;
885                                        }
886                                    }
887                                }
888                            }
889                        }
890
891                        if shutdown_sub.try_recv().is_ok() {
892                            break;
893                        }
894                        tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
895                        backoff_secs = (backoff_secs * 2).min(30);
896                    }
897                    Err(_) => {
898                        tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
899                        backoff_secs = (backoff_secs * 2).min(30);
900                    }
901                }
902            }
903        });
904
905        Ok(Connection {
906            outbound_tx: out_tx,
907            inbound_rx: Arc::new(Mutex::new(in_rx)),
908            shutdown_tx,
909            subscriptions,
910            sub_id_counter,
911            pending,
912            pending_receipts,
913        })
914    }
915
916    pub async fn send_frame(&self, frame: Frame) -> Result<(), ConnError> {
917        // Send a frame to the background writer task.
918        //
919        // Parameters
920        // - `frame`: ownership of the `Frame` to send. The frame is converted
921        //   into a `StompItem::Frame` and sent over the internal mpsc channel.
922        self.outbound_tx
923            .send(StompItem::Frame(frame))
924            .await
925            .map_err(|_| ConnError::Protocol("send channel closed".into()))
926    }
927
928    /// Send a frame with a receipt request and return the receipt ID.
929    ///
930    /// This method adds a unique `receipt` header to the frame and registers
931    /// the receipt ID for tracking. Use `wait_for_receipt()` to wait for the
932    /// server's RECEIPT response.
933    ///
934    /// # Parameters
935    /// - `frame`: the frame to send. A `receipt` header will be added.
936    ///
937    /// # Returns
938    /// The generated receipt ID that can be used with `wait_for_receipt()`.
939    ///
940    /// # Example
941    /// ```ignore
942    /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
943    /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
944    /// ```
945    pub async fn send_frame_with_receipt(&self, frame: Frame) -> Result<String, ConnError> {
946        use std::sync::atomic::AtomicU64;
947        use std::sync::atomic::Ordering::SeqCst;
948
949        // Generate a unique receipt ID using a static counter
950        static RECEIPT_COUNTER: AtomicU64 = AtomicU64::new(1);
951        let receipt_id = format!("rcpt-{}", RECEIPT_COUNTER.fetch_add(1, SeqCst));
952
953        // Create the oneshot channel for notification
954        let (tx, _rx) = oneshot::channel();
955
956        // Register the pending receipt
957        {
958            let mut receipts = self.pending_receipts.lock().await;
959            receipts.insert(receipt_id.clone(), tx);
960        }
961
962        // Add receipt header and send the frame
963        let frame_with_receipt = frame.receipt(&receipt_id);
964        self.send_frame(frame_with_receipt).await?;
965
966        Ok(receipt_id)
967    }
968
969    /// Wait for a receipt confirmation from the server.
970    ///
971    /// This method blocks until the server sends a RECEIPT frame with the
972    /// matching receipt-id, or until the timeout expires.
973    ///
974    /// # Parameters
975    /// - `receipt_id`: the receipt ID returned by `send_frame_with_receipt()`.
976    /// - `timeout`: maximum time to wait for the receipt.
977    ///
978    /// # Returns
979    /// `Ok(())` if the receipt was received, or `Err(ConnError::ReceiptTimeout)`
980    /// if the timeout expired.
981    ///
982    /// # Example
983    /// ```ignore
984    /// let receipt_id = conn.send_frame_with_receipt(frame).await?;
985    /// conn.wait_for_receipt(&receipt_id, Duration::from_secs(5)).await?;
986    /// println!("Message confirmed!");
987    /// ```
988    pub async fn wait_for_receipt(
989        &self,
990        receipt_id: &str,
991        timeout: Duration,
992    ) -> Result<(), ConnError> {
993        // Get the receiver for this receipt
994        let rx = {
995            let mut receipts = self.pending_receipts.lock().await;
996            // Re-create the oneshot channel and swap out the sender
997            let (tx, rx) = oneshot::channel();
998            if let Some(old_tx) = receipts.insert(receipt_id.to_string(), tx) {
999                // Drop the old sender - this is expected if called after send_frame_with_receipt
1000                drop(old_tx);
1001            }
1002            rx
1003        };
1004
1005        // Wait for the receipt with timeout
1006        match tokio::time::timeout(timeout, rx).await {
1007            Ok(Ok(())) => Ok(()),
1008            Ok(Err(_)) => {
1009                // Channel was closed without receiving - connection likely dropped
1010                Err(ConnError::Protocol(
1011                    "receipt channel closed unexpectedly".into(),
1012                ))
1013            }
1014            Err(_) => {
1015                // Timeout expired - clean up the pending receipt
1016                let mut receipts = self.pending_receipts.lock().await;
1017                receipts.remove(receipt_id);
1018                Err(ConnError::ReceiptTimeout(receipt_id.to_string()))
1019            }
1020        }
1021    }
1022
1023    /// Send a frame and wait for server confirmation via RECEIPT.
1024    ///
1025    /// This is a convenience method that combines `send_frame_with_receipt()`
1026    /// and `wait_for_receipt()`. Use this when you want to ensure a frame
1027    /// was processed by the server before continuing.
1028    ///
1029    /// # Parameters
1030    /// - `frame`: the frame to send.
1031    /// - `timeout`: maximum time to wait for the receipt.
1032    ///
1033    /// # Returns
1034    /// `Ok(())` if the frame was sent and receipt confirmed, or an error if
1035    /// sending failed or the receipt timed out.
1036    ///
1037    /// # Example
1038    /// ```ignore
1039    /// let frame = Frame::new("SEND")
1040    ///     .header("destination", "/queue/orders")
1041    ///     .set_body(b"order data".to_vec());
1042    ///
1043    /// conn.send_frame_confirmed(frame, Duration::from_secs(5)).await?;
1044    /// println!("Order sent and confirmed!");
1045    /// ```
1046    pub async fn send_frame_confirmed(
1047        &self,
1048        frame: Frame,
1049        timeout: Duration,
1050    ) -> Result<(), ConnError> {
1051        // Generate receipt ID and register before sending
1052        use std::sync::atomic::AtomicU64;
1053        use std::sync::atomic::Ordering::SeqCst;
1054
1055        static RECEIPT_COUNTER: AtomicU64 = AtomicU64::new(1);
1056        let receipt_id = format!("rcpt-{}", RECEIPT_COUNTER.fetch_add(1, SeqCst));
1057
1058        // Create the oneshot channel for notification
1059        let (tx, rx) = oneshot::channel();
1060
1061        // Register the pending receipt before sending
1062        {
1063            let mut receipts = self.pending_receipts.lock().await;
1064            receipts.insert(receipt_id.clone(), tx);
1065        }
1066
1067        // Add receipt header and send the frame
1068        let frame_with_receipt = frame.receipt(&receipt_id);
1069        self.send_frame(frame_with_receipt).await?;
1070
1071        // Wait for the receipt with timeout
1072        match tokio::time::timeout(timeout, rx).await {
1073            Ok(Ok(())) => Ok(()),
1074            Ok(Err(_)) => Err(ConnError::Protocol(
1075                "receipt channel closed unexpectedly".into(),
1076            )),
1077            Err(_) => {
1078                // Timeout expired - clean up
1079                let mut receipts = self.pending_receipts.lock().await;
1080                receipts.remove(&receipt_id);
1081                Err(ConnError::ReceiptTimeout(receipt_id))
1082            }
1083        }
1084    }
1085
1086    /// Subscribe to a destination.
1087    ///
1088    /// Parameters
1089    /// - `destination`: the STOMP destination to subscribe to (e.g. "/queue/foo").
1090    /// - `ack`: acknowledgement mode to request from the server.
1091    ///
1092    /// Returns a tuple `(subscription_id, receiver)` where `subscription_id` is
1093    /// the opaque id assigned locally for this subscription and `receiver` is a
1094    /// `mpsc::Receiver<Frame>` which will yield incoming MESSAGE frames for the
1095    /// destination. The caller should read from the receiver to handle messages.
1096    /// Subscribe to a destination using optional extra headers.
1097    ///
1098    /// This variant accepts additional headers which are stored locally and
1099    /// re-sent on reconnect. Use `subscribe` as a convenience wrapper when no
1100    /// extra headers are needed.
1101    pub async fn subscribe_with_headers(
1102        &self,
1103        destination: &str,
1104        ack: AckMode,
1105        extra_headers: Vec<(String, String)>,
1106    ) -> Result<crate::subscription::Subscription, ConnError> {
1107        let id = self
1108            .sub_id_counter
1109            .fetch_add(1, Ordering::SeqCst)
1110            .to_string();
1111        let (tx, rx) = mpsc::channel::<Frame>(16);
1112        {
1113            let mut map = self.subscriptions.lock().await;
1114            map.entry(destination.to_string())
1115                .or_insert_with(Vec::new)
1116                .push(SubscriptionEntry {
1117                    id: id.clone(),
1118                    sender: tx.clone(),
1119                    ack: ack.as_str().to_string(),
1120                    headers: extra_headers.clone(),
1121                });
1122        }
1123
1124        let mut f = Frame::new("SUBSCRIBE");
1125        f = f
1126            .header("id", &id)
1127            .header("destination", destination)
1128            .header("ack", ack.as_str());
1129        for (k, v) in &extra_headers {
1130            f = f.header(k, v);
1131        }
1132        self.outbound_tx
1133            .send(StompItem::Frame(f))
1134            .await
1135            .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1136
1137        Ok(crate::subscription::Subscription::new(
1138            id,
1139            destination.to_string(),
1140            rx,
1141            self.clone(),
1142        ))
1143    }
1144
1145    /// Convenience wrapper without extra headers.
1146    pub async fn subscribe(
1147        &self,
1148        destination: &str,
1149        ack: AckMode,
1150    ) -> Result<crate::subscription::Subscription, ConnError> {
1151        self.subscribe_with_headers(destination, ack, Vec::new())
1152            .await
1153    }
1154
1155    /// Subscribe with a typed `SubscriptionOptions` structure.
1156    ///
1157    /// `SubscriptionOptions.headers` are forwarded to the broker and persisted
1158    /// for automatic resubscribe after reconnect. If `durable_queue` is set,
1159    /// it will be used as the actual destination instead of `destination`.
1160    pub async fn subscribe_with_options(
1161        &self,
1162        destination: &str,
1163        ack: AckMode,
1164        options: crate::subscription::SubscriptionOptions,
1165    ) -> Result<crate::subscription::Subscription, ConnError> {
1166        let dest = options
1167            .durable_queue
1168            .as_deref()
1169            .unwrap_or(destination)
1170            .to_string();
1171        self.subscribe_with_headers(&dest, ack, options.headers)
1172            .await
1173    }
1174
1175    /// Unsubscribe a previously created subscription by its local subscription id.
1176    pub async fn unsubscribe(&self, subscription_id: &str) -> Result<(), ConnError> {
1177        let mut found = false;
1178        {
1179            let mut map = self.subscriptions.lock().await;
1180            let mut remove_keys: Vec<String> = Vec::new();
1181            for (dest, vec) in map.iter_mut() {
1182                if let Some(pos) = vec.iter().position(|entry| entry.id == subscription_id) {
1183                    vec.remove(pos);
1184                    found = true;
1185                }
1186                if vec.is_empty() {
1187                    remove_keys.push(dest.clone());
1188                }
1189            }
1190            for k in remove_keys {
1191                map.remove(&k);
1192            }
1193        }
1194
1195        if !found {
1196            return Err(ConnError::Protocol("subscription id not found".into()));
1197        }
1198
1199        let mut f = Frame::new("UNSUBSCRIBE");
1200        f = f.header("id", subscription_id);
1201        self.outbound_tx
1202            .send(StompItem::Frame(f))
1203            .await
1204            .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1205
1206        Ok(())
1207    }
1208
1209    /// Acknowledge a message previously received in `client` or
1210    /// `client-individual` ack modes.
1211    ///
1212    /// STOMP ack semantics:
1213    /// - `auto`: server considers message delivered immediately; the client
1214    ///   should not ack.
1215    /// - `client`: cumulative acknowledgements. ACKing message `M` for
1216    ///   subscription `S` acknowledges all messages delivered to `S` up to
1217    ///   and including `M`.
1218    /// - `client-individual`: only the named message is acknowledged.
1219    ///
1220    /// Parameters
1221    /// - `subscription_id`: the local subscription id returned by
1222    ///   `Connection::subscribe`. This disambiguates which subscription's
1223    ///   pending queue to advance for cumulative ACKs.
1224    /// - `message_id`: the `message-id` header value from the received
1225    ///   MESSAGE frame to acknowledge.
1226    ///
1227    /// Behavior
1228    /// - The pending queue for `subscription_id` is searched for `message_id`.
1229    ///   If the subscription used `client` ack mode, all pending messages up to
1230    ///   and including the matched message are removed. If the subscription
1231    ///   used `client-individual`, only the matched message is removed.
1232    /// - An `ACK` frame is sent to the server with `id=<message_id>` and
1233    ///   `subscription=<subscription_id>` headers.
1234    #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1235    pub async fn ack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1236        // Remove from the local pending queue according to subscription ack mode.
1237        let mut removed_any = false;
1238        {
1239            let mut p = self.pending.lock().await;
1240            if let Some(queue) = p.get_mut(subscription_id) {
1241                if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1242                    // Determine ack mode for this subscription (default to client).
1243                    let mut ack_mode = "client".to_string();
1244                    {
1245                        let map = self.subscriptions.lock().await;
1246                        'outer: for (_dest, vec) in map.iter() {
1247                            for entry in vec.iter() {
1248                                if entry.id == subscription_id {
1249                                    ack_mode = entry.ack.clone();
1250                                    break 'outer;
1251                                }
1252                            }
1253                        }
1254                    }
1255
1256                    if ack_mode == "client" {
1257                        // cumulative: remove up to and including pos
1258                        for _ in 0..=pos {
1259                            queue.pop_front();
1260                            removed_any = true;
1261                        }
1262                    } else if queue.remove(pos).is_some() {
1263                        // client-individual: remove only the specific message
1264                        removed_any = true;
1265                    }
1266
1267                    if queue.is_empty() {
1268                        p.remove(subscription_id);
1269                    }
1270                }
1271            }
1272        }
1273
1274        // Send ACK to server (include subscription header for clarity)
1275        let mut f = Frame::new("ACK");
1276        f = f
1277            .header("id", message_id)
1278            .header("subscription", subscription_id);
1279        self.outbound_tx
1280            .send(StompItem::Frame(f))
1281            .await
1282            .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1283
1284        // If message wasn't found locally, still send ACK to server; server
1285        // may ignore or treat it as no-op.
1286        let _ = removed_any;
1287        Ok(())
1288    }
1289
1290    /// Negative-acknowledge a message (NACK).
1291    ///
1292    /// Parameters
1293    /// - `subscription_id`: the local subscription id the message was delivered under.
1294    /// - `message_id`: the `message-id` header value from the received MESSAGE.
1295    ///
1296    /// Behavior
1297    /// - Removes the message from the local pending queue (cumulatively if the
1298    ///   subscription used `client` ack mode, otherwise only the single
1299    ///   message). Sends a `NACK` frame to the server with `id` and
1300    ///   `subscription` headers.
1301    #[allow(clippy::collapsible_if, clippy::collapsible_else_if)]
1302    pub async fn nack(&self, subscription_id: &str, message_id: &str) -> Result<(), ConnError> {
1303        // Mirror ack removal semantics for pending map.
1304        let mut removed_any = false;
1305        {
1306            let mut p = self.pending.lock().await;
1307            if let Some(queue) = p.get_mut(subscription_id) {
1308                if let Some(pos) = queue.iter().position(|(mid, _)| mid == message_id) {
1309                    let mut ack_mode = "client".to_string();
1310                    {
1311                        let map = self.subscriptions.lock().await;
1312                        'outer2: for (_dest, vec) in map.iter() {
1313                            for entry in vec.iter() {
1314                                if entry.id == subscription_id {
1315                                    ack_mode = entry.ack.clone();
1316                                    break 'outer2;
1317                                }
1318                            }
1319                        }
1320                    }
1321
1322                    if ack_mode == "client" {
1323                        for _ in 0..=pos {
1324                            queue.pop_front();
1325                            removed_any = true;
1326                        }
1327                    } else if queue.remove(pos).is_some() {
1328                        removed_any = true;
1329                    }
1330
1331                    if queue.is_empty() {
1332                        p.remove(subscription_id);
1333                    }
1334                }
1335            }
1336        }
1337
1338        let mut f = Frame::new("NACK");
1339        f = f
1340            .header("id", message_id)
1341            .header("subscription", subscription_id);
1342        self.outbound_tx
1343            .send(StompItem::Frame(f))
1344            .await
1345            .map_err(|_| ConnError::Protocol("send channel closed".into()))?;
1346
1347        let _ = removed_any;
1348        Ok(())
1349    }
1350
1351    /// Helper to send a transaction frame (BEGIN, COMMIT, or ABORT).
1352    async fn send_transaction_frame(
1353        &self,
1354        command: &str,
1355        transaction_id: &str,
1356    ) -> Result<(), ConnError> {
1357        let f = Frame::new(command).header("transaction", transaction_id);
1358        self.outbound_tx
1359            .send(StompItem::Frame(f))
1360            .await
1361            .map_err(|_| ConnError::Protocol("send channel closed".into()))
1362    }
1363
1364    /// Begin a transaction.
1365    ///
1366    /// Parameters
1367    /// - `transaction_id`: unique identifier for the transaction. The caller is
1368    ///   responsible for ensuring uniqueness within the connection.
1369    ///
1370    /// Behavior
1371    /// - Sends a `BEGIN` frame to the server with `transaction:<transaction_id>`
1372    ///   header. Subsequent `SEND`, `ACK`, and `NACK` frames may include this
1373    ///   transaction id to group them into the transaction. The transaction must
1374    ///   be finalized with either `commit` or `abort`.
1375    pub async fn begin(&self, transaction_id: &str) -> Result<(), ConnError> {
1376        self.send_transaction_frame("BEGIN", transaction_id).await
1377    }
1378
1379    /// Commit a transaction.
1380    ///
1381    /// Parameters
1382    /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1383    ///
1384    /// Behavior
1385    /// - Sends a `COMMIT` frame to the server with `transaction:<transaction_id>`
1386    ///   header. All operations within the transaction are applied atomically.
1387    pub async fn commit(&self, transaction_id: &str) -> Result<(), ConnError> {
1388        self.send_transaction_frame("COMMIT", transaction_id).await
1389    }
1390
1391    /// Abort a transaction.
1392    ///
1393    /// Parameters
1394    /// - `transaction_id`: the transaction identifier previously passed to `begin`.
1395    ///
1396    /// Behavior
1397    /// - Sends an `ABORT` frame to the server with `transaction:<transaction_id>`
1398    ///   header. All operations within the transaction are discarded.
1399    pub async fn abort(&self, transaction_id: &str) -> Result<(), ConnError> {
1400        self.send_transaction_frame("ABORT", transaction_id).await
1401    }
1402
1403    /// Receive the next frame from the server.
1404    ///
1405    /// Returns `Some(ReceivedFrame::Frame(..))` for normal frames (MESSAGE, etc.),
1406    /// `Some(ReceivedFrame::Error(..))` for ERROR frames, or `None` if the
1407    /// connection has been closed.
1408    ///
1409    /// # Example
1410    ///
1411    /// ```ignore
1412    /// use iridium_stomp::ReceivedFrame;
1413    ///
1414    /// while let Some(received) = conn.next_frame().await {
1415    ///     match received {
1416    ///         ReceivedFrame::Frame(frame) => {
1417    ///             println!("Got {}: {:?}", frame.command, frame.body);
1418    ///         }
1419    ///         ReceivedFrame::Error(err) => {
1420    ///             eprintln!("Server error: {}", err);
1421    ///             break;
1422    ///         }
1423    ///     }
1424    /// }
1425    /// ```
1426    pub async fn next_frame(&self) -> Option<ReceivedFrame> {
1427        let mut rx = self.inbound_rx.lock().await;
1428        let frame = rx.recv().await?;
1429
1430        // Convert ERROR frames to ServerError for better ergonomics
1431        if frame.command == "ERROR" {
1432            Some(ReceivedFrame::Error(ServerError::from_frame(frame)))
1433        } else {
1434            Some(ReceivedFrame::Frame(frame))
1435        }
1436    }
1437
1438    pub async fn close(self) {
1439        // Signal the background task to shutdown by broadcasting on the
1440        // shutdown channel. Consumers may await task termination separately
1441        // if needed.
1442        let _ = self.shutdown_tx.send(());
1443    }
1444}
1445
1446fn current_millis() -> u64 {
1447    use std::time::{SystemTime, UNIX_EPOCH};
1448    SystemTime::now()
1449        .duration_since(UNIX_EPOCH)
1450        .map(|d| d.as_millis() as u64)
1451        .unwrap_or(0)
1452}
1453
1454#[cfg(test)]
1455mod tests {
1456    use super::*;
1457    use tokio::sync::mpsc;
1458
1459    // Helper to build a MESSAGE frame with given message-id and subscription/destination headers
1460    fn make_message(
1461        message_id: &str,
1462        subscription: Option<&str>,
1463        destination: Option<&str>,
1464    ) -> Frame {
1465        let mut f = Frame::new("MESSAGE");
1466        f = f.header("message-id", message_id);
1467        if let Some(s) = subscription {
1468            f = f.header("subscription", s);
1469        }
1470        if let Some(d) = destination {
1471            f = f.header("destination", d);
1472        }
1473        f
1474    }
1475
1476    #[tokio::test]
1477    async fn test_cumulative_ack_removes_prefix() {
1478        // setup channels
1479        let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1480        let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1481        let (shutdown_tx, _) = broadcast::channel::<()>(1);
1482
1483        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1484        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1485
1486        let sub_id_counter = Arc::new(AtomicU64::new(1));
1487
1488        // create a subscription entry s1 with client (cumulative) ack
1489        let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1490        {
1491            let mut map = subscriptions.lock().await;
1492            map.insert(
1493                "/queue/x".to_string(),
1494                vec![SubscriptionEntry {
1495                    id: "s1".to_string(),
1496                    sender: sub_sender,
1497                    ack: "client".to_string(),
1498                    headers: Vec::new(),
1499                }],
1500            );
1501        }
1502
1503        // fill pending queue for s1: m1,m2,m3
1504        {
1505            let mut p = pending.lock().await;
1506            let mut q = VecDeque::new();
1507            q.push_back((
1508                "m1".to_string(),
1509                make_message("m1", Some("s1"), Some("/queue/x")),
1510            ));
1511            q.push_back((
1512                "m2".to_string(),
1513                make_message("m2", Some("s1"), Some("/queue/x")),
1514            ));
1515            q.push_back((
1516                "m3".to_string(),
1517                make_message("m3", Some("s1"), Some("/queue/x")),
1518            ));
1519            p.insert("s1".to_string(), q);
1520        }
1521
1522        let conn = Connection {
1523            outbound_tx: out_tx,
1524            inbound_rx: Arc::new(Mutex::new(in_rx)),
1525            shutdown_tx,
1526            subscriptions: subscriptions.clone(),
1527            sub_id_counter,
1528            pending: pending.clone(),
1529            pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1530        };
1531
1532        // ack m2 cumulatively: should remove m1 and m2, leaving m3
1533        conn.ack("s1", "m2").await.expect("ack failed");
1534
1535        // verify pending for s1 contains only m3
1536        {
1537            let p = pending.lock().await;
1538            let q = p.get("s1").expect("missing s1");
1539            assert_eq!(q.len(), 1);
1540            assert_eq!(q.front().unwrap().0, "m3");
1541        }
1542
1543        // verify an ACK frame was emitted
1544        if let Some(item) = out_rx.recv().await {
1545            match item {
1546                StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1547                _ => panic!("expected frame"),
1548            }
1549        } else {
1550            panic!("no outbound frame sent")
1551        }
1552    }
1553
1554    #[tokio::test]
1555    async fn test_client_individual_ack_removes_only_one() {
1556        // setup channels
1557        let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1558        let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1559        let (shutdown_tx, _) = broadcast::channel::<()>(1);
1560
1561        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1562        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1563
1564        let sub_id_counter = Arc::new(AtomicU64::new(1));
1565
1566        // create a subscription entry s2 with client-individual ack
1567        let (sub_sender, _sub_rx) = mpsc::channel::<Frame>(4);
1568        {
1569            let mut map = subscriptions.lock().await;
1570            map.insert(
1571                "/queue/y".to_string(),
1572                vec![SubscriptionEntry {
1573                    id: "s2".to_string(),
1574                    sender: sub_sender,
1575                    ack: "client-individual".to_string(),
1576                    headers: Vec::new(),
1577                }],
1578            );
1579        }
1580
1581        // fill pending queue for s2: a,b,c
1582        {
1583            let mut p = pending.lock().await;
1584            let mut q = VecDeque::new();
1585            q.push_back((
1586                "a".to_string(),
1587                make_message("a", Some("s2"), Some("/queue/y")),
1588            ));
1589            q.push_back((
1590                "b".to_string(),
1591                make_message("b", Some("s2"), Some("/queue/y")),
1592            ));
1593            q.push_back((
1594                "c".to_string(),
1595                make_message("c", Some("s2"), Some("/queue/y")),
1596            ));
1597            p.insert("s2".to_string(), q);
1598        }
1599
1600        let conn = Connection {
1601            outbound_tx: out_tx,
1602            inbound_rx: Arc::new(Mutex::new(in_rx)),
1603            shutdown_tx,
1604            subscriptions: subscriptions.clone(),
1605            sub_id_counter,
1606            pending: pending.clone(),
1607            pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1608        };
1609
1610        // ack only 'b' individually
1611        conn.ack("s2", "b").await.expect("ack failed");
1612
1613        // verify pending for s2 contains a and c
1614        {
1615            let p = pending.lock().await;
1616            let q = p.get("s2").expect("missing s2");
1617            assert_eq!(q.len(), 2);
1618            assert_eq!(q[0].0, "a");
1619            assert_eq!(q[1].0, "c");
1620        }
1621
1622        // verify an ACK frame was emitted
1623        if let Some(item) = out_rx.recv().await {
1624            match item {
1625                StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1626                _ => panic!("expected frame"),
1627            }
1628        } else {
1629            panic!("no outbound frame sent")
1630        }
1631    }
1632
1633    #[tokio::test]
1634    async fn test_subscription_receive_delivers_message() {
1635        // setup channels
1636        let (out_tx, _out_rx) = mpsc::channel::<StompItem>(8);
1637        let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1638        let (shutdown_tx, _) = broadcast::channel::<()>(1);
1639
1640        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1641        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1642
1643        let sub_id_counter = Arc::new(AtomicU64::new(1));
1644
1645        let conn = Connection {
1646            outbound_tx: out_tx,
1647            inbound_rx: Arc::new(Mutex::new(in_rx)),
1648            shutdown_tx,
1649            subscriptions: subscriptions.clone(),
1650            sub_id_counter,
1651            pending: pending.clone(),
1652            pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1653        };
1654
1655        // subscribe
1656        let subscription = conn
1657            .subscribe("/queue/test", AckMode::Auto)
1658            .await
1659            .expect("subscribe failed");
1660
1661        // find the sender stored in the subscriptions map and push a message
1662        {
1663            let map = conn.subscriptions.lock().await;
1664            let vec = map.get("/queue/test").expect("missing subscription vec");
1665            let sender = &vec[0].sender;
1666            let f = make_message("m1", Some(&vec[0].id), Some("/queue/test"));
1667            sender.try_send(f).expect("send to subscription failed");
1668        }
1669
1670        // consume from the subscription receiver
1671        let mut rx = subscription.into_receiver();
1672        if let Some(received) = rx.recv().await {
1673            assert_eq!(received.command, "MESSAGE");
1674            // message-id header should be present
1675            let mut found = false;
1676            for (k, _v) in &received.headers {
1677                if k.to_lowercase() == "message-id" {
1678                    found = true;
1679                    break;
1680                }
1681            }
1682            assert!(found, "message-id header missing");
1683        } else {
1684            panic!("no message received on subscription")
1685        }
1686    }
1687
1688    #[tokio::test]
1689    async fn test_subscription_ack_removes_pending_and_sends_ack() {
1690        // setup channels
1691        let (out_tx, mut out_rx) = mpsc::channel::<StompItem>(8);
1692        let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1693        let (shutdown_tx, _) = broadcast::channel::<()>(1);
1694
1695        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1696        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1697
1698        let sub_id_counter = Arc::new(AtomicU64::new(1));
1699
1700        let conn = Connection {
1701            outbound_tx: out_tx,
1702            inbound_rx: Arc::new(Mutex::new(in_rx)),
1703            shutdown_tx,
1704            subscriptions: subscriptions.clone(),
1705            sub_id_counter,
1706            pending: pending.clone(),
1707            pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1708        };
1709
1710        // subscribe with client ack
1711        let subscription = conn
1712            .subscribe("/queue/ack", AckMode::Client)
1713            .await
1714            .expect("subscribe failed");
1715
1716        let sub_id = subscription.id().to_string();
1717
1718        // drain any initial outbound frames (SUBSCRIBE) emitted by subscribe()
1719        while out_rx.try_recv().is_ok() {}
1720
1721        // populate pending queue for this subscription
1722        {
1723            let mut p = conn.pending.lock().await;
1724            let mut q = VecDeque::new();
1725            q.push_back((
1726                "mid-1".to_string(),
1727                make_message("mid-1", Some(&sub_id), Some("/queue/ack")),
1728            ));
1729            p.insert(sub_id.clone(), q);
1730        }
1731
1732        // ack the message via the subscription helper
1733        subscription.ack("mid-1").await.expect("ack failed");
1734
1735        // ensure pending queue no longer contains the message
1736        {
1737            let p = conn.pending.lock().await;
1738            assert!(p.get(&sub_id).is_none() || p.get(&sub_id).unwrap().is_empty());
1739        }
1740
1741        // verify an ACK frame was emitted
1742        if let Some(item) = out_rx.recv().await {
1743            match item {
1744                StompItem::Frame(f) => assert_eq!(f.command, "ACK"),
1745                _ => panic!("expected frame"),
1746            }
1747        } else {
1748            panic!("no outbound frame sent")
1749        }
1750    }
1751
1752    // Helper function to create a test connection and output receiver
1753    fn setup_test_connection() -> (Connection, mpsc::Receiver<StompItem>) {
1754        let (out_tx, out_rx) = mpsc::channel::<StompItem>(8);
1755        let (_in_tx, in_rx) = mpsc::channel::<Frame>(8);
1756        let (shutdown_tx, _) = broadcast::channel::<()>(1);
1757
1758        let subscriptions: Arc<Mutex<Subscriptions>> = Arc::new(Mutex::new(HashMap::new()));
1759        let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
1760        let sub_id_counter = Arc::new(AtomicU64::new(1));
1761
1762        let conn = Connection {
1763            outbound_tx: out_tx,
1764            inbound_rx: Arc::new(Mutex::new(in_rx)),
1765            shutdown_tx,
1766            subscriptions,
1767            sub_id_counter,
1768            pending,
1769            pending_receipts: Arc::new(Mutex::new(HashMap::new())),
1770        };
1771
1772        (conn, out_rx)
1773    }
1774
1775    // Helper function to verify a frame with a transaction header
1776    fn verify_transaction_frame(frame: Frame, expected_command: &str, expected_tx_id: &str) {
1777        assert_eq!(frame.command, expected_command);
1778        assert!(
1779            frame
1780                .headers
1781                .iter()
1782                .any(|(k, v)| k == "transaction" && v == expected_tx_id),
1783            "transaction header with id '{}' not found",
1784            expected_tx_id
1785        );
1786    }
1787
1788    #[tokio::test]
1789    async fn test_begin_transaction_sends_frame() {
1790        let (conn, mut out_rx) = setup_test_connection();
1791
1792        conn.begin("tx1").await.expect("begin failed");
1793
1794        // verify BEGIN frame was emitted
1795        if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1796            verify_transaction_frame(f, "BEGIN", "tx1");
1797        } else {
1798            panic!("no outbound frame sent")
1799        }
1800    }
1801
1802    #[tokio::test]
1803    async fn test_commit_transaction_sends_frame() {
1804        let (conn, mut out_rx) = setup_test_connection();
1805
1806        conn.commit("tx1").await.expect("commit failed");
1807
1808        // verify COMMIT frame was emitted
1809        if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1810            verify_transaction_frame(f, "COMMIT", "tx1");
1811        } else {
1812            panic!("no outbound frame sent")
1813        }
1814    }
1815
1816    #[tokio::test]
1817    async fn test_abort_transaction_sends_frame() {
1818        let (conn, mut out_rx) = setup_test_connection();
1819
1820        conn.abort("tx1").await.expect("abort failed");
1821
1822        // verify ABORT frame was emitted
1823        if let Some(StompItem::Frame(f)) = out_rx.recv().await {
1824            verify_transaction_frame(f, "ABORT", "tx1");
1825        } else {
1826            panic!("no outbound frame sent")
1827        }
1828    }
1829}