Skip to main content

queue_runtime/providers/
azure.rs

1//! Azure Service Bus provider implementation.
2//!
3//! This module provides production-ready Azure Service Bus integration with:
4//! - Native session support for ordered message processing
5//! - Connection pooling and sender/receiver caching
6//! - Dead letter queue integration
7//! - Multiple authentication methods (connection string, managed identity, client secret)
8//! - Comprehensive error classification for retry logic
9//!
10//! ## Authentication Methods
11//!
12//! The provider supports four authentication methods:
13//! - **ConnectionString**: Direct connection string with embedded credentials
14//! - **ManagedIdentity**: Azure Managed Identity for serverless environments
15//! - **ClientSecret**: Service principal with tenant/client ID and secret
16//! - **DefaultCredential**: Default Azure credential chain for development
17//!
18//! ## Session Management
19//!
20//! Azure Service Bus provides native session support with:
21//! - Strict FIFO ordering within session boundaries
22//! - Exclusive session locks during processing
23//! - Automatic lock renewal for long-running operations
24//! - Session state storage for stateful processing
25//!
26//! ## Example
27//!
28//! ```no_run
29//! use queue_runtime::{QueueClientFactory, QueueConfig, ProviderConfig, AzureServiceBusConfig, AzureAuthMethod};
30//! use chrono::Duration;
31//!
32//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
33//! let config = QueueConfig {
34//!     provider: ProviderConfig::AzureServiceBus(AzureServiceBusConfig {
35//!         connection_string: Some("Endpoint=sb://...".to_string()),
36//!         namespace: None,
37//!         auth_method: AzureAuthMethod::ConnectionString,
38//!         use_sessions: true,
39//!         session_timeout: Duration::minutes(5),
40//!     }),
41//!     ..Default::default()
42//! };
43//!
44//! let client = QueueClientFactory::create_client(config).await?;
45//! # Ok(())
46//! # }
47//! ```
48
49use crate::client::{QueueProvider, SessionProvider};
50use crate::error::{ConfigurationError, QueueError, SerializationError};
51use crate::message::{
52    Message, MessageId, QueueName, ReceiptHandle, ReceivedMessage, SessionId, Timestamp,
53};
54use crate::provider::{AzureServiceBusConfig, ProviderType, SessionSupport};
55use async_trait::async_trait;
56use azure_core::credentials::Secret as AzureSecret;
57use azure_core::credentials::TokenCredential;
58use azure_identity::{
59    ClientSecretCredential, ClientSecretCredentialOptions, DeveloperToolsCredential,
60    ManagedIdentityCredential,
61};
62use chrono::{Duration, Utc};
63use reqwest::{header, Client as HttpClient, StatusCode};
64use serde::{Deserialize, Serialize};
65use std::collections::{HashMap, HashSet};
66use std::fmt;
67use std::str::FromStr;
68use std::sync::Arc;
69use tokio::sync::RwLock;
70
71#[cfg(test)]
72#[path = "azure_tests.rs"]
73mod tests;
74
75// ============================================================================
76// Shared Auth Helper
77// ============================================================================
78
79/// Acquire a bearer token from an AAD [`TokenCredential`] for the Azure Service Bus scope.
80///
81/// # Errors
82///
83/// Returns [`AzureError::AuthenticationError`] when the credential provider fails.
84async fn get_bearer_token(
85    cred: &(dyn TokenCredential + Send + Sync),
86) -> Result<String, AzureError> {
87    let scopes = &["https://servicebus.azure.net/.default"];
88    let token = cred
89        .get_token(scopes, None)
90        .await
91        .map_err(|e| AzureError::AuthenticationError(format!("Failed to get token: {}", e)))?;
92    // token is an AccessToken (outer struct); .token is its Secret<String> field; .secret() extracts the raw string.
93    Ok(token.token.secret().to_string())
94}
95
96/// Generate a Shared Access Signature (SAS) token for Azure Service Bus.
97///
98/// Parses `SharedAccessKeyName` and `SharedAccessKey` from `conn_str`, then
99/// produces an HMAC-SHA256 signature over `namespace_url` and the expiry.
100/// The resulting token is valid for one hour from the moment of generation.
101///
102/// # Errors
103///
104/// Returns [`AzureError::AuthenticationError`] if the connection string is
105/// missing the required fields or if the key cannot be decoded.
106fn generate_sas_token(namespace_url: &str, conn_str: &str) -> Result<String, AzureError> {
107    let mut key_name = None;
108    let mut key = None;
109
110    for part in conn_str.split(';') {
111        if let Some(value) = part.strip_prefix("SharedAccessKeyName=") {
112            key_name = Some(value.to_string());
113        } else if let Some(value) = part.strip_prefix("SharedAccessKey=") {
114            key = Some(value.to_string());
115        }
116    }
117
118    let key_name = key_name.ok_or_else(|| {
119        AzureError::AuthenticationError(
120            "Missing SharedAccessKeyName in connection string".to_string(),
121        )
122    })?;
123    let key = key.ok_or_else(|| {
124        AzureError::AuthenticationError("Missing SharedAccessKey in connection string".to_string())
125    })?;
126
127    let expiry = (Utc::now() + Duration::hours(1)).timestamp();
128    let string_to_sign = format!("{}\n{}", urlencoding::encode(namespace_url), expiry);
129
130    use base64::{engine::general_purpose::STANDARD, Engine};
131    use hmac::{Hmac, KeyInit, Mac};
132    use sha2::Sha256;
133    type HmacSha256 = Hmac<Sha256>;
134
135    let key_bytes = STANDARD
136        .decode(&key)
137        .map_err(|e| AzureError::AuthenticationError(format!("Invalid SharedAccessKey: {}", e)))?;
138    let mut mac = HmacSha256::new_from_slice(&key_bytes)
139        .map_err(|e| AzureError::AuthenticationError(format!("Failed to create HMAC: {}", e)))?;
140    mac.update(string_to_sign.as_bytes());
141    let signature = STANDARD.encode(mac.finalize().into_bytes());
142
143    Ok(format!(
144        "SharedAccessSignature sr={}&sig={}&se={}&skn={}",
145        urlencoding::encode(namespace_url),
146        urlencoding::encode(&signature),
147        expiry,
148        urlencoding::encode(&key_name)
149    ))
150}
151
152// ============================================================================
153// Authentication Types
154// ============================================================================
155
156/// Authentication method for Azure Service Bus
157#[derive(Clone, Serialize, Deserialize)]
158pub enum AzureAuthMethod {
159    /// Connection string with embedded credentials
160    ConnectionString,
161    /// Azure Managed Identity (for serverless environments)
162    ManagedIdentity,
163    /// Service principal with client secret
164    ClientSecret {
165        tenant_id: String,
166        client_id: String,
167        client_secret: String,
168    },
169    /// Default Azure credential chain (for development)
170    DefaultCredential,
171}
172
173impl fmt::Debug for AzureAuthMethod {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        match self {
176            Self::ConnectionString => f.debug_struct("ConnectionString").finish(),
177            Self::ManagedIdentity => f.debug_struct("ManagedIdentity").finish(),
178            Self::ClientSecret {
179                tenant_id,
180                client_id,
181                ..
182            } => f
183                .debug_struct("ClientSecret")
184                .field("tenant_id", tenant_id)
185                .field("client_id", client_id)
186                .field("client_secret", &"<REDACTED>")
187                .finish(),
188            Self::DefaultCredential => f.debug_struct("DefaultCredential").finish(),
189        }
190    }
191}
192
193impl fmt::Display for AzureAuthMethod {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        match self {
196            Self::ConnectionString => write!(f, "ConnectionString"),
197            Self::ManagedIdentity => write!(f, "ManagedIdentity"),
198            Self::ClientSecret { .. } => write!(f, "ClientSecret"),
199            Self::DefaultCredential => write!(f, "DefaultCredential"),
200        }
201    }
202}
203
204// ============================================================================
205// Error Types
206// ============================================================================
207
208/// Azure Service Bus specific errors
209#[derive(Debug, thiserror::Error)]
210pub enum AzureError {
211    #[error("Authentication failed: {0}")]
212    AuthenticationError(String),
213
214    #[error("Network error: {0}")]
215    NetworkError(String),
216
217    #[error("Service Bus error: {0}")]
218    ServiceBusError(String),
219
220    #[error("Message lock lost: {0}")]
221    MessageLockLost(String),
222
223    #[error("Session lock lost: {0}")]
224    SessionLockLost(String),
225
226    #[error("Invalid configuration: {0}")]
227    ConfigurationError(String),
228
229    #[error("Serialization error: {0}")]
230    SerializationError(String),
231}
232
233impl AzureError {
234    /// Check if error is transient and should be retried
235    pub fn is_transient(&self) -> bool {
236        match self {
237            Self::AuthenticationError(_) => false,
238            Self::NetworkError(_) => true,
239            Self::ServiceBusError(_) => true, // Most Service Bus errors are transient
240            Self::MessageLockLost(_) => false,
241            Self::SessionLockLost(_) => false,
242            Self::ConfigurationError(_) => false,
243            Self::SerializationError(_) => false,
244        }
245    }
246
247    /// Map Azure error to QueueError
248    pub fn to_queue_error(self) -> QueueError {
249        match self {
250            Self::AuthenticationError(msg) => QueueError::AuthenticationFailed { message: msg },
251            Self::NetworkError(msg) => QueueError::ConnectionFailed { message: msg },
252            Self::ServiceBusError(msg) => QueueError::ProviderError {
253                provider: "AzureServiceBus".to_string(),
254                code: "ServiceBusError".to_string(),
255                message: msg,
256            },
257            Self::MessageLockLost(msg) => QueueError::InvalidReceipt { receipt: msg },
258            Self::SessionLockLost(session_id) => QueueError::SessionNotFound { session_id },
259            Self::ConfigurationError(msg) => {
260                QueueError::ConfigurationError(ConfigurationError::Invalid { message: msg })
261            }
262            Self::SerializationError(msg) => QueueError::SerializationError(
263                SerializationError::JsonError(serde_json::Error::io(std::io::Error::new(
264                    std::io::ErrorKind::InvalidData,
265                    msg,
266                ))),
267            ),
268        }
269    }
270}
271
272// ============================================================================
273// Azure Service Bus Provider
274// ============================================================================
275
276/// Azure Service Bus queue provider implementation using REST API
277///
278/// This provider implements the QueueProvider trait using Azure Service Bus REST API.
279/// It supports:
280/// - Multiple authentication methods (connection string, managed identity, service principal)
281/// - HTTP-based message operations (send, receive, complete, abandon, dead-letter)
282/// - Session support for ordered processing
283/// - Lock token management for PeekLock receive mode
284/// - Comprehensive error classification with retry logic
285pub struct AzureServiceBusProvider {
286    config: AzureServiceBusConfig,
287    http_client: HttpClient,
288    namespace_url: String,
289    credential: Option<Arc<dyn TokenCredential + Send + Sync>>,
290    // Cached lock tokens: receipt_handle -> (lock_token, queue_name)
291    lock_tokens: Arc<RwLock<HashMap<String, (String, String)>>>,
292}
293
294impl fmt::Debug for AzureServiceBusProvider {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        f.debug_struct("AzureServiceBusProvider")
297            .field("config", &self.config)
298            .field("namespace_url", &self.namespace_url)
299            .field(
300                "credential",
301                &self.credential.as_ref().map(|_| "<TokenCredential>"),
302            )
303            .field("lock_tokens", &self.lock_tokens)
304            .finish()
305    }
306}
307
308impl AzureServiceBusProvider {
309    /// Create new Azure Service Bus provider
310    ///
311    /// # Arguments
312    ///
313    /// * `config` - Azure Service Bus configuration with authentication details
314    ///
315    /// # Errors
316    ///
317    /// Returns error if:
318    /// - Connection string is invalid
319    /// - Authentication fails
320    /// - Namespace is not accessible
321    ///
322    /// # Example
323    ///
324    /// ```no_run
325    /// use queue_runtime::{AzureServiceBusConfig, AzureAuthMethod};
326    /// use queue_runtime::providers::AzureServiceBusProvider;
327    /// use chrono::Duration;
328    ///
329    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
330    /// let config = AzureServiceBusConfig {
331    ///     connection_string: Some("Endpoint=sb://...".to_string()),
332    ///     namespace: None,
333    ///     auth_method: AzureAuthMethod::ConnectionString,
334    ///     use_sessions: true,
335    ///     session_timeout: Duration::minutes(5),
336    /// };
337    ///
338    /// let provider = AzureServiceBusProvider::new(config).await?;
339    /// # Ok(())
340    /// # }
341    /// ```
342    pub async fn new(config: AzureServiceBusConfig) -> Result<Self, AzureError> {
343        // Validate configuration
344        Self::validate_config(&config)?;
345
346        // Extract namespace URL and setup authentication
347        let (namespace_url, credential) = match &config.auth_method {
348            AzureAuthMethod::ConnectionString => {
349                let conn_str = config.connection_string.as_ref().ok_or_else(|| {
350                    AzureError::ConfigurationError(
351                        "Connection string required for ConnectionString auth".to_string(),
352                    )
353                })?;
354
355                let namespace_url = Self::parse_connection_string_endpoint(conn_str)?;
356                (namespace_url, None)
357            }
358            AzureAuthMethod::ManagedIdentity => {
359                let namespace = config.namespace.as_ref().ok_or_else(|| {
360                    AzureError::ConfigurationError(
361                        "Namespace required for ManagedIdentity auth".to_string(),
362                    )
363                })?;
364
365                let credential = ManagedIdentityCredential::new(None).map_err(|e| {
366                    AzureError::ConfigurationError(format!(
367                        "Failed to create managed identity credential: {}",
368                        e
369                    ))
370                })?;
371                let namespace_url = format!("https://{}.servicebus.windows.net", namespace);
372                (
373                    namespace_url,
374                    Some(credential as Arc<dyn TokenCredential + Send + Sync>),
375                )
376            }
377            AzureAuthMethod::ClientSecret {
378                ref tenant_id,
379                ref client_id,
380                ref client_secret,
381            } => {
382                let namespace = config.namespace.as_ref().ok_or_else(|| {
383                    AzureError::ConfigurationError(
384                        "Namespace required for ClientSecret auth".to_string(),
385                    )
386                })?;
387
388                // Create ClientSecretCredential with new API
389                let credential = ClientSecretCredential::new(
390                    tenant_id,
391                    client_id.clone(),
392                    AzureSecret::from(client_secret.clone()),
393                    None::<ClientSecretCredentialOptions>,
394                )
395                .map_err(|e| {
396                    AzureError::ConfigurationError(format!("Failed to create credential: {}", e))
397                })?;
398                let namespace_url = format!("https://{}.servicebus.windows.net", namespace);
399                (
400                    namespace_url,
401                    Some(credential as Arc<dyn TokenCredential + Send + Sync>),
402                )
403            }
404            AzureAuthMethod::DefaultCredential => {
405                let namespace = config.namespace.as_ref().ok_or_else(|| {
406                    AzureError::ConfigurationError(
407                        "Namespace required for DefaultCredential auth".to_string(),
408                    )
409                })?;
410
411                // Use DeveloperToolsCredential (Azure CLI → azd chain) for local development.
412                // In production workloads, prefer the explicit ManagedIdentity variant.
413                let credential = DeveloperToolsCredential::new(None).map_err(|e| {
414                    AzureError::ConfigurationError(format!(
415                        "Failed to create developer tools credential: {}",
416                        e
417                    ))
418                })?;
419                let namespace_url = format!("https://{}.servicebus.windows.net", namespace);
420                (
421                    namespace_url,
422                    Some(credential as Arc<dyn TokenCredential + Send + Sync>),
423                )
424            }
425        };
426
427        // Create HTTP client
428        let http_client = HttpClient::builder()
429            .timeout(std::time::Duration::from_secs(30))
430            .build()
431            .map_err(|e| {
432                AzureError::NetworkError(format!("Failed to create HTTP client: {}", e))
433            })?;
434
435        Ok(Self {
436            config,
437            http_client,
438            namespace_url,
439            credential,
440            lock_tokens: Arc::new(RwLock::new(HashMap::new())),
441        })
442    }
443
444    /// Parse endpoint from connection string
445    fn parse_connection_string_endpoint(conn_str: &str) -> Result<String, AzureError> {
446        for part in conn_str.split(';') {
447            if let Some(endpoint) = part.strip_prefix("Endpoint=") {
448                return Ok(endpoint.trim_end_matches('/').to_string());
449            }
450        }
451        Err(AzureError::ConfigurationError(
452            "Invalid connection string: missing Endpoint".to_string(),
453        ))
454    }
455
456    /// Validate Azure Service Bus configuration
457    fn validate_config(config: &AzureServiceBusConfig) -> Result<(), AzureError> {
458        match &config.auth_method {
459            AzureAuthMethod::ConnectionString => {
460                if config.connection_string.is_none() {
461                    return Err(AzureError::ConfigurationError(
462                        "Connection string required for ConnectionString auth method".to_string(),
463                    ));
464                }
465            }
466            AzureAuthMethod::ManagedIdentity | AzureAuthMethod::DefaultCredential => {
467                if config.namespace.is_none() {
468                    return Err(AzureError::ConfigurationError(
469                        "Namespace required for ManagedIdentity/DefaultCredential auth".to_string(),
470                    ));
471                }
472            }
473            AzureAuthMethod::ClientSecret {
474                tenant_id,
475                client_id,
476                client_secret,
477            } => {
478                if config.namespace.is_none() {
479                    return Err(AzureError::ConfigurationError(
480                        "Namespace required for ClientSecret auth".to_string(),
481                    ));
482                }
483                if tenant_id.is_empty() || client_id.is_empty() || client_secret.is_empty() {
484                    return Err(AzureError::ConfigurationError(
485                        "Tenant ID, Client ID, and Client Secret required for ClientSecret auth"
486                            .to_string(),
487                    ));
488                }
489            }
490        }
491
492        Ok(())
493    }
494
495    /// Get authentication token for Service Bus operations
496    async fn get_auth_token(&self) -> Result<String, AzureError> {
497        match &self.credential {
498            Some(cred) => get_bearer_token(cred.as_ref()).await,
499            None => {
500                // Connection string auth - parse SharedAccessSignature
501                self.get_sas_token()
502            }
503        }
504    }
505
506    /// Extract SAS token from connection string.
507    ///
508    /// Delegates to the module-level [`generate_sas_token`] helper.
509    fn get_sas_token(&self) -> Result<String, AzureError> {
510        let conn_str = self.config.connection_string.as_ref().ok_or_else(|| {
511            AzureError::AuthenticationError("No connection string available".to_string())
512        })?;
513        generate_sas_token(&self.namespace_url, conn_str)
514    }
515}
516
517// ============================================================================
518// Azure Service Bus REST API Types
519// ============================================================================
520
521/// Message body for sending messages
522#[derive(Debug, Serialize, Deserialize)]
523struct ServiceBusMessageBody {
524    #[serde(rename = "ContentType")]
525    content_type: String,
526    #[serde(rename = "Body")]
527    body: String, // Base64-encoded
528    #[serde(rename = "BrokerProperties")]
529    broker_properties: BrokerProperties,
530}
531
532#[derive(Debug, Serialize, Deserialize)]
533struct BrokerProperties {
534    #[serde(rename = "MessageId")]
535    message_id: String,
536    #[serde(rename = "SessionId", skip_serializing_if = "Option::is_none")]
537    session_id: Option<String>,
538    #[serde(rename = "TimeToLive", skip_serializing_if = "Option::is_none")]
539    time_to_live: Option<u64>,
540}
541
542/// Batch receive response structure
543#[derive(Debug, Deserialize)]
544struct ServiceBusMessageResponse {
545    #[serde(rename = "Body")]
546    body: String,
547    #[serde(rename = "BrokerProperties")]
548    broker_properties: ReceivedBrokerProperties,
549}
550
551#[allow(dead_code)] // Used when receive operations are implemented
552#[derive(Debug, Deserialize)]
553struct ReceivedServiceBusMessage {
554    #[serde(rename = "Body")]
555    body: String,
556    #[serde(rename = "BrokerProperties")]
557    broker_properties: ReceivedBrokerProperties,
558}
559
560#[allow(dead_code)] // Used when receive operations are implemented
561#[derive(Debug, Deserialize)]
562struct ReceivedBrokerProperties {
563    #[serde(rename = "MessageId")]
564    message_id: String,
565    #[serde(rename = "SessionId")]
566    session_id: Option<String>,
567    #[serde(rename = "LockToken")]
568    lock_token: String,
569    #[serde(rename = "DeliveryCount")]
570    delivery_count: u32,
571    #[serde(rename = "EnqueuedTimeUtc")]
572    enqueued_time_utc: String,
573}
574
575// ============================================================================
576// QueueProvider Implementation
577// ============================================================================
578
579#[async_trait]
580impl QueueProvider for AzureServiceBusProvider {
581    async fn send_message(
582        &self,
583        queue: &QueueName,
584        message: &Message,
585    ) -> Result<MessageId, QueueError> {
586        // Generate message ID
587        let message_id = MessageId::new();
588
589        // Serialize message body (it's already Bytes, just base64 encode it)
590        use base64::{engine::general_purpose::STANDARD, Engine};
591        let body_base64 = STANDARD.encode(&message.body);
592
593        // Build broker properties
594        let broker_props = BrokerProperties {
595            message_id: message_id.to_string(),
596            session_id: message.session_id.as_ref().map(|s| s.to_string()),
597            time_to_live: message
598                .time_to_live
599                .as_ref()
600                .map(|ttl| ttl.num_seconds() as u64),
601        };
602
603        // Build URL: {namespace}/{queue}/messages
604        let url = format!("{}/{}/messages", self.namespace_url, queue.as_str());
605
606        // Get auth token
607        let auth_token = self
608            .get_auth_token()
609            .await
610            .map_err(|e| e.to_queue_error())?;
611
612        // Send HTTP POST request
613        let response = self
614            .http_client
615            .post(&url)
616            .header(header::AUTHORIZATION, auth_token)
617            .header(
618                header::CONTENT_TYPE,
619                "application/atom+xml;type=entry;charset=utf-8",
620            )
621            .header(
622                "BrokerProperties",
623                serde_json::to_string(&broker_props).unwrap(),
624            )
625            .body(body_base64)
626            .send()
627            .await
628            .map_err(|e| {
629                AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
630            })?;
631
632        // Check response status
633        match response.status() {
634            StatusCode::CREATED | StatusCode::OK => Ok(message_id),
635            status => {
636                let error_body = response.text().await.unwrap_or_default();
637                Err(QueueError::ProviderError {
638                    provider: "AzureServiceBus".to_string(),
639                    code: status.as_str().to_string(),
640                    message: format!("Send failed: {}", error_body),
641                })
642            }
643        }
644    }
645
646    async fn send_messages(
647        &self,
648        queue: &QueueName,
649        messages: &[Message],
650    ) -> Result<Vec<MessageId>, QueueError> {
651        // Azure Service Bus supports batch send (max 100 messages)
652        if messages.len() > 100 {
653            return Err(QueueError::BatchTooLarge {
654                size: messages.len(),
655                max_size: 100,
656            });
657        }
658
659        if messages.is_empty() {
660            return Ok(Vec::new());
661        }
662
663        // Build batch request body - array of messages
664        let mut batch_messages = Vec::with_capacity(messages.len());
665        let mut message_ids = Vec::with_capacity(messages.len());
666
667        use base64::{engine::general_purpose::STANDARD, Engine};
668
669        for message in messages {
670            let message_id = MessageId::new();
671            let body_base64 = STANDARD.encode(&message.body);
672
673            let broker_props = BrokerProperties {
674                message_id: message_id.to_string(),
675                session_id: message.session_id.as_ref().map(|s| s.to_string()),
676                time_to_live: message
677                    .time_to_live
678                    .as_ref()
679                    .map(|ttl| ttl.num_seconds() as u64),
680            };
681
682            batch_messages.push(ServiceBusMessageBody {
683                content_type: "application/octet-stream".to_string(),
684                body: body_base64,
685                broker_properties: broker_props,
686            });
687
688            message_ids.push(message_id);
689        }
690
691        // Build URL: {namespace}/{queue}/messages
692        let url = format!("{}/{}/messages", self.namespace_url, queue.as_str());
693
694        // Get auth token
695        let auth_token = self
696            .get_auth_token()
697            .await
698            .map_err(|e| e.to_queue_error())?;
699
700        // Send batch HTTP POST request with JSON array
701        let response = self
702            .http_client
703            .post(&url)
704            .header(header::AUTHORIZATION, auth_token)
705            .header(header::CONTENT_TYPE, "application/json")
706            .json(&batch_messages)
707            .send()
708            .await
709            .map_err(|e| {
710                AzureError::NetworkError(format!("Batch send HTTP request failed: {}", e))
711                    .to_queue_error()
712            })?;
713
714        // Check response status
715        match response.status() {
716            StatusCode::CREATED | StatusCode::OK => Ok(message_ids),
717            StatusCode::PAYLOAD_TOO_LARGE => Err(QueueError::BatchTooLarge {
718                size: messages.len(),
719                max_size: 100,
720            }),
721            StatusCode::TOO_MANY_REQUESTS => {
722                let retry_after = response
723                    .headers()
724                    .get("Retry-After")
725                    .and_then(|v| v.to_str().ok())
726                    .and_then(|s| s.parse::<u64>().ok())
727                    .unwrap_or(30);
728
729                Err(QueueError::ProviderError {
730                    provider: "AzureServiceBus".to_string(),
731                    code: "ThrottlingError".to_string(),
732                    message: format!("Request throttled, retry after {} seconds", retry_after),
733                })
734            }
735            StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
736                let error_body = response.text().await.unwrap_or_default();
737                Err(QueueError::AuthenticationFailed {
738                    message: format!("Authentication failed: {}", error_body),
739                })
740            }
741            status => {
742                let error_body = response.text().await.unwrap_or_default();
743                Err(QueueError::ProviderError {
744                    provider: "AzureServiceBus".to_string(),
745                    code: status.as_str().to_string(),
746                    message: format!("Batch send failed: {}", error_body),
747                })
748            }
749        }
750    }
751
752    async fn receive_message(
753        &self,
754        queue: &QueueName,
755        timeout: Duration,
756    ) -> Result<Option<ReceivedMessage>, QueueError> {
757        // Azure Service Bus receive uses HTTP DELETE with peek-lock
758        // URL: {namespace}/{queue}/messages/head?timeout={seconds}
759        let url = format!(
760            "{}/{}/messages/head?timeout={}",
761            self.namespace_url,
762            queue.as_str(),
763            timeout.num_seconds()
764        );
765
766        // Get auth token
767        let auth_token = self
768            .get_auth_token()
769            .await
770            .map_err(|e| e.to_queue_error())?;
771
772        // Send HTTP DELETE request (peek-lock mode)
773        let response = self
774            .http_client
775            .delete(&url)
776            .header(header::AUTHORIZATION, auth_token)
777            .send()
778            .await
779            .map_err(|e| {
780                AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
781            })?;
782
783        // Check response status
784        match response.status() {
785            StatusCode::OK | StatusCode::CREATED => {
786                // Parse BrokerProperties from response header
787                let broker_props = response
788                    .headers()
789                    .get("BrokerProperties")
790                    .and_then(|v| v.to_str().ok())
791                    .and_then(|s| serde_json::from_str::<ReceivedBrokerProperties>(s).ok())
792                    .ok_or_else(|| QueueError::ProviderError {
793                        provider: "AzureServiceBus".to_string(),
794                        code: "InvalidResponse".to_string(),
795                        message: "Missing or invalid BrokerProperties header".to_string(),
796                    })?;
797
798                // Get message body (base64 encoded)
799                let body_base64 = response.text().await.map_err(|e| {
800                    AzureError::NetworkError(format!("Failed to read response body: {}", e))
801                        .to_queue_error()
802                })?;
803
804                // Decode base64 body
805                use base64::{engine::general_purpose::STANDARD, Engine};
806                let body =
807                    STANDARD
808                        .decode(&body_base64)
809                        .map_err(|e| QueueError::ProviderError {
810                            provider: "AzureServiceBus".to_string(),
811                            code: "DecodingError".to_string(),
812                            message: format!("Failed to decode message body: {}", e),
813                        })?;
814
815                // Parse enqueued time
816                let first_delivered_at =
817                    chrono::DateTime::parse_from_rfc3339(&broker_props.enqueued_time_utc)
818                        .map(|dt| Timestamp::from_datetime(dt.with_timezone(&chrono::Utc)))
819                        .unwrap_or_else(|_| Timestamp::now());
820
821                // Create receipt handle combining lock token and queue name
822                // Lock expires in 30 seconds by default (Azure Service Bus default)
823                let expires_at = Timestamp::from_datetime(Utc::now() + Duration::seconds(30));
824                let receipt_str = format!("{}::{}", broker_props.lock_token, queue.as_str());
825                let receipt = ReceiptHandle::new(
826                    receipt_str.clone(),
827                    expires_at,
828                    ProviderType::AzureServiceBus,
829                );
830
831                // Store lock token for later acknowledgment
832                self.lock_tokens.write().await.insert(
833                    receipt_str,
834                    (broker_props.lock_token.clone(), queue.as_str().to_string()),
835                );
836
837                // Parse Azure message ID
838                let message_id = MessageId::from_str(&broker_props.message_id)
839                    .unwrap_or_else(|_| MessageId::new());
840
841                // Create received message
842                let received_message = ReceivedMessage {
843                    message_id,
844                    body: bytes::Bytes::from(body),
845                    attributes: HashMap::new(),
846                    session_id: broker_props.session_id.map(SessionId::new).transpose()?,
847                    correlation_id: None,
848                    receipt_handle: receipt,
849                    delivery_count: broker_props.delivery_count,
850                    first_delivered_at,
851                    delivered_at: Timestamp::now(),
852                };
853
854                Ok(Some(received_message))
855            }
856            StatusCode::NO_CONTENT => {
857                // No messages available
858                Ok(None)
859            }
860            status => {
861                let error_body = response.text().await.unwrap_or_default();
862                Err(QueueError::ProviderError {
863                    provider: "AzureServiceBus".to_string(),
864                    code: status.as_str().to_string(),
865                    message: format!("Receive failed: {}", error_body),
866                })
867            }
868        }
869    }
870
871    async fn receive_messages(
872        &self,
873        queue: &QueueName,
874        max_messages: u32,
875        timeout: Duration,
876    ) -> Result<Vec<ReceivedMessage>, QueueError> {
877        // Azure Service Bus max batch receive is 32 messages
878        if max_messages > 32 {
879            return Err(QueueError::BatchTooLarge {
880                size: max_messages as usize,
881                max_size: 32,
882            });
883        }
884
885        if max_messages == 0 {
886            return Ok(Vec::new());
887        }
888
889        // Build URL with maxMessageCount parameter for batch receive
890        // {namespace}/{queue}/messages/head?timeout={seconds}&maxMessageCount={count}
891        let url = format!(
892            "{}/{}/messages/head?timeout={}&maxMessageCount={}",
893            self.namespace_url,
894            queue.as_str(),
895            timeout.num_seconds(),
896            max_messages
897        );
898
899        // Get auth token
900        let auth_token = self
901            .get_auth_token()
902            .await
903            .map_err(|e| e.to_queue_error())?;
904
905        // Receive messages using HTTP DELETE (PeekLock mode)
906        let response = self
907            .http_client
908            .delete(&url)
909            .header(header::AUTHORIZATION, auth_token)
910            .send()
911            .await
912            .map_err(|e| {
913                AzureError::NetworkError(format!("Batch receive HTTP request failed: {}", e))
914                    .to_queue_error()
915            })?;
916
917        // Parse response
918        match response.status() {
919            StatusCode::OK | StatusCode::CREATED => {
920                // Parse JSON array response
921                let messages_data: Vec<ServiceBusMessageResponse> =
922                    response.json().await.map_err(|e| {
923                        AzureError::SerializationError(format!(
924                            "Failed to parse batch receive response: {}",
925                            e
926                        ))
927                        .to_queue_error()
928                    })?;
929
930                let mut received_messages = Vec::with_capacity(messages_data.len());
931
932                use base64::{engine::general_purpose::STANDARD, Engine};
933
934                for msg_data in messages_data {
935                    let broker_props = msg_data.broker_properties;
936
937                    // Decode base64 body
938                    let body = STANDARD.decode(&msg_data.body).map_err(|e| {
939                        AzureError::SerializationError(format!(
940                            "Failed to decode message body: {}",
941                            e
942                        ))
943                        .to_queue_error()
944                    })?;
945
946                    // Parse enqueued time
947                    let enqueued_time =
948                        chrono::DateTime::parse_from_rfc3339(&broker_props.enqueued_time_utc)
949                            .map_err(|e| {
950                                AzureError::SerializationError(format!(
951                                    "Failed to parse enqueued time: {}",
952                                    e
953                                ))
954                                .to_queue_error()
955                            })?;
956                    let first_delivered_at =
957                        Timestamp::from_datetime(enqueued_time.with_timezone(&Utc));
958
959                    // Create receipt handle with lock expiration (30s default)
960                    let expires_at = Timestamp::from_datetime(Utc::now() + Duration::seconds(30));
961                    let receipt_str = format!("{}::{}", broker_props.lock_token, queue.as_str());
962                    let receipt = ReceiptHandle::new(
963                        receipt_str.clone(),
964                        expires_at,
965                        ProviderType::AzureServiceBus,
966                    );
967
968                    // Store lock token for acknowledgment
969                    self.lock_tokens.write().await.insert(
970                        receipt_str,
971                        (broker_props.lock_token.clone(), queue.as_str().to_string()),
972                    );
973
974                    // Parse Azure message ID
975                    let message_id = MessageId::from_str(&broker_props.message_id)
976                        .unwrap_or_else(|_| MessageId::new());
977
978                    // Create received message
979                    let received_message = ReceivedMessage {
980                        message_id,
981                        body: bytes::Bytes::from(body),
982                        attributes: HashMap::new(),
983                        session_id: broker_props.session_id.map(SessionId::new).transpose()?,
984                        correlation_id: None,
985                        receipt_handle: receipt,
986                        delivery_count: broker_props.delivery_count,
987                        first_delivered_at,
988                        delivered_at: Timestamp::now(),
989                    };
990
991                    received_messages.push(received_message);
992                }
993
994                Ok(received_messages)
995            }
996            StatusCode::NO_CONTENT => {
997                // No messages available
998                Ok(Vec::new())
999            }
1000            StatusCode::TOO_MANY_REQUESTS => {
1001                let retry_after = response
1002                    .headers()
1003                    .get("Retry-After")
1004                    .and_then(|v| v.to_str().ok())
1005                    .and_then(|s| s.parse::<u64>().ok())
1006                    .unwrap_or(30);
1007
1008                Err(QueueError::ProviderError {
1009                    provider: "AzureServiceBus".to_string(),
1010                    code: "ThrottlingError".to_string(),
1011                    message: format!("Request throttled, retry after {} seconds", retry_after),
1012                })
1013            }
1014            StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
1015                let error_body = response.text().await.unwrap_or_default();
1016                Err(QueueError::AuthenticationFailed {
1017                    message: format!("Authentication failed: {}", error_body),
1018                })
1019            }
1020            status => {
1021                let error_body = response.text().await.unwrap_or_default();
1022                Err(QueueError::ProviderError {
1023                    provider: "AzureServiceBus".to_string(),
1024                    code: status.as_str().to_string(),
1025                    message: format!("Batch receive failed: {}", error_body),
1026                })
1027            }
1028        }
1029    }
1030
1031    async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1032        // Extract lock token and queue name from receipt handle
1033        let lock_tokens = self.lock_tokens.read().await;
1034        let (lock_token, queue_name) =
1035            lock_tokens
1036                .get(receipt.handle())
1037                .ok_or_else(|| QueueError::InvalidReceipt {
1038                    receipt: receipt.handle().to_string(),
1039                })?;
1040
1041        // Azure Service Bus complete uses HTTP DELETE to {namespace}/{queue}/messages/{messageId}/{lockToken}
1042        let url = format!(
1043            "{}/{}/messages/head/{}",
1044            self.namespace_url,
1045            queue_name,
1046            urlencoding::encode(lock_token)
1047        );
1048
1049        // Get auth token
1050        let auth_token = self
1051            .get_auth_token()
1052            .await
1053            .map_err(|e| e.to_queue_error())?;
1054
1055        // Send HTTP DELETE request
1056        let response = self
1057            .http_client
1058            .delete(&url)
1059            .header(header::AUTHORIZATION, auth_token)
1060            .send()
1061            .await
1062            .map_err(|e| {
1063                AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
1064            })?;
1065
1066        // Check response status
1067        match response.status() {
1068            StatusCode::OK | StatusCode::NO_CONTENT => {
1069                // Remove lock token from cache
1070                drop(lock_tokens);
1071                self.lock_tokens.write().await.remove(receipt.handle());
1072                Ok(())
1073            }
1074            StatusCode::GONE | StatusCode::NOT_FOUND => {
1075                // Lock expired or message already processed
1076                Err(QueueError::InvalidReceipt {
1077                    receipt: receipt.handle().to_string(),
1078                })
1079            }
1080            status => {
1081                let error_body = response.text().await.unwrap_or_default();
1082                Err(QueueError::ProviderError {
1083                    provider: "AzureServiceBus".to_string(),
1084                    code: status.as_str().to_string(),
1085                    message: format!("Complete failed: {}", error_body),
1086                })
1087            }
1088        }
1089    }
1090
1091    async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1092        // Extract lock token and queue name from receipt handle
1093        let lock_tokens = self.lock_tokens.read().await;
1094        let (lock_token, queue_name) =
1095            lock_tokens
1096                .get(receipt.handle())
1097                .ok_or_else(|| QueueError::InvalidReceipt {
1098                    receipt: receipt.handle().to_string(),
1099                })?;
1100
1101        // Azure Service Bus abandon uses HTTP PUT to {namespace}/{queue}/messages/{messageId}/{lockToken}
1102        // with empty body to unlock the message
1103        let url = format!(
1104            "{}/{}/messages/head/{}",
1105            self.namespace_url,
1106            queue_name,
1107            urlencoding::encode(lock_token)
1108        );
1109
1110        // Get auth token
1111        let auth_token = self
1112            .get_auth_token()
1113            .await
1114            .map_err(|e| e.to_queue_error())?;
1115
1116        // Send HTTP PUT request with empty body to abandon
1117        let response = self
1118            .http_client
1119            .put(&url)
1120            .header(header::AUTHORIZATION, auth_token)
1121            .header(header::CONTENT_LENGTH, "0")
1122            .send()
1123            .await
1124            .map_err(|e| {
1125                AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
1126            })?;
1127
1128        // Check response status
1129        match response.status() {
1130            StatusCode::OK | StatusCode::NO_CONTENT => {
1131                // Remove lock token from cache
1132                drop(lock_tokens);
1133                self.lock_tokens.write().await.remove(receipt.handle());
1134                Ok(())
1135            }
1136            StatusCode::GONE | StatusCode::NOT_FOUND => {
1137                // Lock expired or message already processed
1138                Err(QueueError::InvalidReceipt {
1139                    receipt: receipt.handle().to_string(),
1140                })
1141            }
1142            status => {
1143                let error_body = response.text().await.unwrap_or_default();
1144                Err(QueueError::ProviderError {
1145                    provider: "AzureServiceBus".to_string(),
1146                    code: status.as_str().to_string(),
1147                    message: format!("Abandon failed: {}", error_body),
1148                })
1149            }
1150        }
1151    }
1152
1153    async fn dead_letter_message(
1154        &self,
1155        receipt: &ReceiptHandle,
1156        reason: &str,
1157    ) -> Result<(), QueueError> {
1158        // Extract lock token and queue name from receipt handle
1159        let lock_tokens = self.lock_tokens.read().await;
1160        let (lock_token, queue_name) =
1161            lock_tokens
1162                .get(receipt.handle())
1163                .ok_or_else(|| QueueError::InvalidReceipt {
1164                    receipt: receipt.handle().to_string(),
1165                })?;
1166
1167        // Azure Service Bus dead letter uses HTTP DELETE to {namespace}/{queue}/messages/{messageId}/{lockToken}
1168        // with custom properties in the DeadLetterReason header
1169        let url = format!(
1170            "{}/{}/messages/head/{}/$deadletter",
1171            self.namespace_url,
1172            queue_name,
1173            urlencoding::encode(lock_token)
1174        );
1175
1176        // Get auth token
1177        let auth_token = self
1178            .get_auth_token()
1179            .await
1180            .map_err(|e| e.to_queue_error())?;
1181
1182        // Build dead letter properties as JSON
1183        let properties = serde_json::json!({
1184            "DeadLetterReason": reason,
1185            "DeadLetterErrorDescription": "Message processing failed"
1186        });
1187
1188        // Send HTTP POST request to dead letter
1189        let response = self
1190            .http_client
1191            .post(&url)
1192            .header(header::AUTHORIZATION, auth_token)
1193            .header(header::CONTENT_TYPE, "application/json")
1194            .json(&properties)
1195            .send()
1196            .await
1197            .map_err(|e| {
1198                AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
1199            })?;
1200
1201        // Check response status
1202        match response.status() {
1203            StatusCode::OK | StatusCode::NO_CONTENT | StatusCode::CREATED => {
1204                // Remove lock token from cache
1205                drop(lock_tokens);
1206                self.lock_tokens.write().await.remove(receipt.handle());
1207                Ok(())
1208            }
1209            StatusCode::GONE | StatusCode::NOT_FOUND => {
1210                // Lock expired or message already processed
1211                Err(QueueError::InvalidReceipt {
1212                    receipt: receipt.handle().to_string(),
1213                })
1214            }
1215            status => {
1216                let error_body = response.text().await.unwrap_or_default();
1217                Err(QueueError::ProviderError {
1218                    provider: "AzureServiceBus".to_string(),
1219                    code: status.as_str().to_string(),
1220                    message: format!("Dead letter failed: {}", error_body),
1221                })
1222            }
1223        }
1224    }
1225
1226    async fn create_session_client(
1227        &self,
1228        queue: &QueueName,
1229        session_id: Option<SessionId>,
1230    ) -> Result<Box<dyn SessionProvider>, QueueError> {
1231        let resolved_id = match session_id {
1232            Some(id) => id,
1233            None => self.accept_next_available_session(queue).await?,
1234        };
1235
1236        Ok(Box::new(AzureSessionProvider::new(
1237            resolved_id,
1238            queue.clone(),
1239            self.config.session_timeout,
1240            self.http_client.clone(),
1241            self.namespace_url.clone(),
1242            self.config.clone(),
1243            self.credential.clone(),
1244        )))
1245    }
1246
1247    fn provider_type(&self) -> ProviderType {
1248        ProviderType::AzureServiceBus
1249    }
1250
1251    fn supports_sessions(&self) -> SessionSupport {
1252        SessionSupport::Native
1253    }
1254
1255    fn supports_batching(&self) -> bool {
1256        true
1257    }
1258
1259    fn max_batch_size(&self) -> u32 {
1260        100 // Azure Service Bus max batch send
1261    }
1262}
1263
1264impl AzureServiceBusProvider {
1265    /// Accept the next available session by receiving the first available message
1266    /// from the queue and deriving the session ID from its broker properties.
1267    ///
1268    /// The Azure Service Bus REST API does **not** have an atomic
1269    /// "accept-next-session" endpoint (unlike the AMQP SDK). Enumerating
1270    /// sessions via GET and then receiving from one introduces a TOCTOU race:
1271    /// two concurrent consumers can read the same session ID and collide.
1272    ///
1273    /// To avoid the race this implementation calls
1274    /// `DELETE {namespace}/{queue}/sessions/$acceptnext/messages/head`, which is
1275    /// the undocumented but well-established REST shorthand supported by the
1276    /// Azure Service Bus broker for atomically accepting the next available
1277    /// session. The session ID is taken from the `BrokerProperties.SessionId`
1278    /// header in the response.
1279    ///
1280    /// # Errors
1281    ///
1282    /// - `QueueError::ProviderError { code: "NoSessionsAvailable" }` when no
1283    ///   session has pending messages or the timeout expires.
1284    /// - `QueueError::QueueNotFound` when the queue does not exist.
1285    /// - Network or auth errors on failure.
1286    async fn accept_next_available_session(
1287        &self,
1288        queue: &QueueName,
1289    ) -> Result<SessionId, QueueError> {
1290        // `$acceptnext` is the REST equivalent of AcceptNextSessionAsync in the SDK:
1291        // the broker atomically picks and locks the next session with pending messages.
1292        let timeout_secs = self.config.session_timeout.num_seconds().max(1);
1293        let url = format!(
1294            "{}/{}/sessions/$acceptnext/messages/head?timeout={}",
1295            self.namespace_url,
1296            queue.as_str(),
1297            timeout_secs
1298        );
1299
1300        let auth_token = self
1301            .get_auth_token()
1302            .await
1303            .map_err(|e| e.to_queue_error())?;
1304
1305        let response = self
1306            .http_client
1307            .delete(&url)
1308            .header(header::AUTHORIZATION, auth_token)
1309            .send()
1310            .await
1311            .map_err(|e| {
1312                AzureError::NetworkError(format!("Failed to accept next session: {}", e))
1313                    .to_queue_error()
1314            })?;
1315
1316        match response.status() {
1317            StatusCode::OK | StatusCode::CREATED => {
1318                // Parse BrokerProperties from response header to get the session ID.
1319                let broker_props = response
1320                    .headers()
1321                    .get("BrokerProperties")
1322                    .and_then(|v| v.to_str().ok())
1323                    .and_then(|s| serde_json::from_str::<ReceivedBrokerProperties>(s).ok())
1324                    .ok_or_else(|| QueueError::ProviderError {
1325                        provider: "AzureServiceBus".to_string(),
1326                        code: "InvalidResponse".to_string(),
1327                        message: "Missing BrokerProperties in accept-next-session response"
1328                            .to_string(),
1329                    })?;
1330
1331                let session_id_str =
1332                    broker_props
1333                        .session_id
1334                        .ok_or_else(|| QueueError::ProviderError {
1335                            provider: "AzureServiceBus".to_string(),
1336                            code: "NoSessionId".to_string(),
1337                            message: "Accepted message has no SessionId".to_string(),
1338                        })?;
1339
1340                SessionId::new(session_id_str).map_err(|e| QueueError::ProviderError {
1341                    provider: "AzureServiceBus".to_string(),
1342                    code: "InvalidSessionId".to_string(),
1343                    message: format!("Invalid session ID returned by broker: {}", e),
1344                })
1345            }
1346            StatusCode::NO_CONTENT => Err(QueueError::ProviderError {
1347                provider: "AzureServiceBus".to_string(),
1348                code: "NoSessionsAvailable".to_string(),
1349                message: "No sessions with pending messages are available".to_string(),
1350            }),
1351            StatusCode::NOT_FOUND => Err(QueueError::QueueNotFound {
1352                queue_name: queue.to_string(),
1353            }),
1354            status => {
1355                let error_body = response.text().await.unwrap_or_default();
1356                Err(QueueError::ProviderError {
1357                    provider: "AzureServiceBus".to_string(),
1358                    code: status.as_str().to_string(),
1359                    message: format!("Accept next session failed: {}", error_body),
1360                })
1361            }
1362        }
1363    }
1364}
1365
1366// ============================================================================
1367// Azure Session Provider
1368// ============================================================================
1369
1370/// Azure Service Bus session provider for ordered message processing.
1371///
1372/// Implements the [`SessionProvider`] trait using the Azure Service Bus REST API,
1373/// providing exclusive session-locked access to messages within a single session.
1374/// All messages within a session are delivered in strict FIFO order.
1375///
1376/// ## Session Lifecycle
1377///
1378/// 1. Obtain via [`AzureServiceBusProvider::create_session_client`].
1379/// 2. Call [`receive_message`](SessionProvider::receive_message) to fetch the next message.
1380/// 3. Process the message, then call [`complete_message`](SessionProvider::complete_message)
1381///    or [`abandon_message`](SessionProvider::abandon_message).
1382/// 4. Call [`renew_session_lock`](SessionProvider::renew_session_lock) periodically for
1383///    long-running processing to prevent session lock expiry.
1384/// 5. Call [`close_session`](SessionProvider::close_session) when finished.
1385pub struct AzureSessionProvider {
1386    session_id: SessionId,
1387    queue_name: QueueName,
1388    /// Session lock expiry. Uses `std::sync::RwLock` so the synchronous
1389    /// `session_expires_at()` trait method can read without async.
1390    session_expires_at: Arc<std::sync::RwLock<Timestamp>>,
1391    http_client: HttpClient,
1392    namespace_url: String,
1393    config: AzureServiceBusConfig,
1394    credential: Option<Arc<dyn TokenCredential + Send + Sync>>,
1395    /// Lock tokens for in-flight messages. The receipt handle IS the lock token
1396    /// for session messages, so a set suffices (no separate value).
1397    lock_tokens: Arc<RwLock<HashSet<String>>>,
1398}
1399
1400impl fmt::Debug for AzureSessionProvider {
1401    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1402        f.debug_struct("AzureSessionProvider")
1403            .field("session_id", &self.session_id)
1404            .field("queue_name", &self.queue_name)
1405            .field("namespace_url", &self.namespace_url)
1406            .field(
1407                "credential",
1408                &self.credential.as_ref().map(|_| "<TokenCredential>"),
1409            )
1410            .finish()
1411    }
1412}
1413
1414impl AzureSessionProvider {
1415    /// Create a new session provider.
1416    ///
1417    /// Normally obtained through [`AzureServiceBusProvider::create_session_client`]
1418    /// rather than constructed directly.
1419    ///
1420    /// # Arguments
1421    ///
1422    /// * `session_id` - The session to operate on.
1423    /// * `queue_name` - The queue containing the session.
1424    /// * `session_timeout` - How long the session lock is expected to be held; used to
1425    ///   compute `session_expires_at` and refreshed on each receive and lock renewal.
1426    /// * `http_client` - Shared HTTP client (cloned from the parent provider).
1427    /// * `namespace_url` - Base URL of the Service Bus namespace.
1428    /// * `config` - Provider configuration (used for SAS token generation).
1429    /// * `credential` - Optional token credential for AAD-based auth.
1430    pub fn new(
1431        session_id: SessionId,
1432        queue_name: QueueName,
1433        session_timeout: Duration,
1434        http_client: HttpClient,
1435        namespace_url: String,
1436        config: AzureServiceBusConfig,
1437        credential: Option<Arc<dyn TokenCredential + Send + Sync>>,
1438    ) -> Self {
1439        let session_expires_at = Timestamp::from_datetime(Utc::now() + session_timeout);
1440        Self {
1441            session_id,
1442            queue_name,
1443            session_expires_at: Arc::new(std::sync::RwLock::new(session_expires_at)),
1444            http_client,
1445            namespace_url,
1446            config,
1447            credential,
1448            lock_tokens: Arc::new(RwLock::new(HashSet::new())),
1449        }
1450    }
1451
1452    /// Get an authentication token for Service Bus REST operations.
1453    ///
1454    /// Delegates to [`get_bearer_token`] for AAD credentials and [`generate_sas_token`] for SAS.
1455    async fn get_auth_token(&self) -> Result<String, AzureError> {
1456        match &self.credential {
1457            Some(cred) => get_bearer_token(cred.as_ref()).await,
1458            None => {
1459                let conn_str = self.config.connection_string.as_ref().ok_or_else(|| {
1460                    AzureError::AuthenticationError("No connection string available".to_string())
1461                })?;
1462                generate_sas_token(&self.namespace_url, conn_str)
1463            }
1464        }
1465    }
1466
1467    /// Refresh the local session expiry to `now + session_timeout`.
1468    fn refresh_session_expiry(&self) {
1469        if let Ok(mut expiry) = self.session_expires_at.write() {
1470            *expiry = Timestamp::from_datetime(Utc::now() + self.config.session_timeout);
1471        }
1472    }
1473}
1474
1475#[async_trait]
1476impl SessionProvider for AzureSessionProvider {
1477    /// Receive the next message from the session using PeekLock mode.
1478    ///
1479    /// Calls `DELETE {namespace}/{queue}/sessions/{sessionId}/messages/head?timeout={t}`.
1480    /// On success the session lock expiry is refreshed and the message lock token is
1481    /// stored internally so that [`complete_message`](Self::complete_message),
1482    /// [`abandon_message`](Self::abandon_message), and
1483    /// [`dead_letter_message`](Self::dead_letter_message) can resolve the token by
1484    /// receipt handle.
1485    ///
1486    /// # Errors
1487    ///
1488    /// - `QueueError::SessionNotFound` – the session no longer exists or the lock expired.
1489    /// - `QueueError::ProviderError` – network or broker error.
1490    async fn receive_message(
1491        &self,
1492        timeout: Duration,
1493    ) -> Result<Option<ReceivedMessage>, QueueError> {
1494        let url = format!(
1495            "{}/{}/sessions/{}/messages/head?timeout={}",
1496            self.namespace_url,
1497            self.queue_name.as_str(),
1498            urlencoding::encode(self.session_id.as_str()),
1499            timeout.num_seconds()
1500        );
1501
1502        let auth_token = self
1503            .get_auth_token()
1504            .await
1505            .map_err(|e| e.to_queue_error())?;
1506
1507        let response = self
1508            .http_client
1509            .delete(&url)
1510            .header(header::AUTHORIZATION, auth_token)
1511            .send()
1512            .await
1513            .map_err(|e| {
1514                AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
1515            })?;
1516
1517        match response.status() {
1518            StatusCode::OK | StatusCode::CREATED => {
1519                let broker_props = response
1520                    .headers()
1521                    .get("BrokerProperties")
1522                    .and_then(|v| v.to_str().ok())
1523                    .and_then(|s| serde_json::from_str::<ReceivedBrokerProperties>(s).ok())
1524                    .ok_or_else(|| QueueError::ProviderError {
1525                        provider: "AzureServiceBus".to_string(),
1526                        code: "InvalidResponse".to_string(),
1527                        message: "Missing or invalid BrokerProperties header".to_string(),
1528                    })?;
1529
1530                let body_base64 = response.text().await.map_err(|e| {
1531                    AzureError::NetworkError(format!("Failed to read response body: {}", e))
1532                        .to_queue_error()
1533                })?;
1534
1535                use base64::{engine::general_purpose::STANDARD, Engine};
1536                let body =
1537                    STANDARD
1538                        .decode(&body_base64)
1539                        .map_err(|e| QueueError::ProviderError {
1540                            provider: "AzureServiceBus".to_string(),
1541                            code: "DecodingError".to_string(),
1542                            message: format!("Failed to decode message body: {}", e),
1543                        })?;
1544
1545                let first_delivered_at =
1546                    chrono::DateTime::parse_from_rfc3339(&broker_props.enqueued_time_utc)
1547                        .map(|dt| Timestamp::from_datetime(dt.with_timezone(&chrono::Utc)))
1548                        .unwrap_or_else(|_| Timestamp::now());
1549
1550                // Receipt handle is the lock token; store it for later acknowledgement.
1551                // Use config.session_timeout as the ReceiptHandle local expiry to match
1552                // the session lock duration configured on the provider.
1553                let expires_at = Timestamp::from_datetime(Utc::now() + self.config.session_timeout);
1554                let lock_token = broker_props.lock_token.clone();
1555                let receipt = ReceiptHandle::new(
1556                    lock_token.clone(),
1557                    expires_at,
1558                    ProviderType::AzureServiceBus,
1559                );
1560
1561                self.lock_tokens.write().await.insert(lock_token);
1562
1563                let message_id = MessageId::from_str(&broker_props.message_id)
1564                    .unwrap_or_else(|_| MessageId::new());
1565
1566                // Keep session lock alive.
1567                self.refresh_session_expiry();
1568
1569                Ok(Some(ReceivedMessage {
1570                    message_id,
1571                    body: bytes::Bytes::from(body),
1572                    attributes: HashMap::new(),
1573                    session_id: Some(self.session_id.clone()),
1574                    correlation_id: None,
1575                    receipt_handle: receipt,
1576                    delivery_count: broker_props.delivery_count,
1577                    first_delivered_at,
1578                    delivered_at: Timestamp::now(),
1579                }))
1580            }
1581            StatusCode::NO_CONTENT => Ok(None),
1582            StatusCode::GONE | StatusCode::NOT_FOUND => Err(QueueError::SessionNotFound {
1583                session_id: self.session_id.to_string(),
1584            }),
1585            status => {
1586                let error_body = response.text().await.unwrap_or_default();
1587                Err(QueueError::ProviderError {
1588                    provider: "AzureServiceBus".to_string(),
1589                    code: status.as_str().to_string(),
1590                    message: format!("Session receive failed: {}", error_body),
1591                })
1592            }
1593        }
1594    }
1595
1596    /// Complete (delete) a session message using its lock token.
1597    ///
1598    /// Calls `DELETE {namespace}/{queue}/sessions/{sessionId}/messages/{lockToken}`.
1599    ///
1600    /// # Errors
1601    ///
1602    /// - `QueueError::InvalidReceipt` – receipt not found locally or lock expired on broker.
1603    async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1604        if !self.lock_tokens.read().await.contains(receipt.handle()) {
1605            return Err(QueueError::InvalidReceipt {
1606                receipt: receipt.handle().to_string(),
1607            });
1608        }
1609        let lock_token = receipt.handle().to_string();
1610
1611        let url = format!(
1612            "{}/{}/sessions/{}/messages/{}",
1613            self.namespace_url,
1614            self.queue_name.as_str(),
1615            urlencoding::encode(self.session_id.as_str()),
1616            urlencoding::encode(&lock_token)
1617        );
1618
1619        let auth_token = self
1620            .get_auth_token()
1621            .await
1622            .map_err(|e| e.to_queue_error())?;
1623
1624        let response = self
1625            .http_client
1626            .delete(&url)
1627            .header(header::AUTHORIZATION, auth_token)
1628            .send()
1629            .await
1630            .map_err(|e| {
1631                AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
1632            })?;
1633
1634        match response.status() {
1635            StatusCode::OK | StatusCode::NO_CONTENT => {
1636                self.lock_tokens.write().await.remove(receipt.handle());
1637                Ok(())
1638            }
1639            StatusCode::GONE | StatusCode::NOT_FOUND => Err(QueueError::InvalidReceipt {
1640                receipt: receipt.handle().to_string(),
1641            }),
1642            status => {
1643                let error_body = response.text().await.unwrap_or_default();
1644                Err(QueueError::ProviderError {
1645                    provider: "AzureServiceBus".to_string(),
1646                    code: status.as_str().to_string(),
1647                    message: format!("Session complete failed: {}", error_body),
1648                })
1649            }
1650        }
1651    }
1652
1653    /// Abandon a session message and make it available for re-delivery.
1654    ///
1655    /// Calls `PUT {namespace}/{queue}/sessions/{sessionId}/messages/{lockToken}`.
1656    ///
1657    /// # Errors
1658    ///
1659    /// - `QueueError::InvalidReceipt` – receipt not found locally or lock expired.
1660    async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1661        if !self.lock_tokens.read().await.contains(receipt.handle()) {
1662            return Err(QueueError::InvalidReceipt {
1663                receipt: receipt.handle().to_string(),
1664            });
1665        }
1666        let lock_token = receipt.handle().to_string();
1667
1668        let url = format!(
1669            "{}/{}/sessions/{}/messages/{}",
1670            self.namespace_url,
1671            self.queue_name.as_str(),
1672            urlencoding::encode(self.session_id.as_str()),
1673            urlencoding::encode(&lock_token)
1674        );
1675
1676        let auth_token = self
1677            .get_auth_token()
1678            .await
1679            .map_err(|e| e.to_queue_error())?;
1680
1681        let response = self
1682            .http_client
1683            .put(&url)
1684            .header(header::AUTHORIZATION, auth_token)
1685            .header(header::CONTENT_LENGTH, "0")
1686            .send()
1687            .await
1688            .map_err(|e| {
1689                AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
1690            })?;
1691
1692        match response.status() {
1693            StatusCode::OK | StatusCode::NO_CONTENT => {
1694                self.lock_tokens.write().await.remove(receipt.handle());
1695                Ok(())
1696            }
1697            StatusCode::GONE | StatusCode::NOT_FOUND => Err(QueueError::InvalidReceipt {
1698                receipt: receipt.handle().to_string(),
1699            }),
1700            status => {
1701                let error_body = response.text().await.unwrap_or_default();
1702                Err(QueueError::ProviderError {
1703                    provider: "AzureServiceBus".to_string(),
1704                    code: status.as_str().to_string(),
1705                    message: format!("Session abandon failed: {}", error_body),
1706                })
1707            }
1708        }
1709    }
1710
1711    /// Dead-letter a session message.
1712    ///
1713    /// Calls `POST {namespace}/{queue}/sessions/{sessionId}/messages/{lockToken}/$deadletter`
1714    /// with a JSON body containing `DeadLetterReason`.
1715    ///
1716    /// # Errors
1717    ///
1718    /// - `QueueError::InvalidReceipt` – receipt not found locally or lock expired.
1719    async fn dead_letter_message(
1720        &self,
1721        receipt: &ReceiptHandle,
1722        reason: &str,
1723    ) -> Result<(), QueueError> {
1724        if !self.lock_tokens.read().await.contains(receipt.handle()) {
1725            return Err(QueueError::InvalidReceipt {
1726                receipt: receipt.handle().to_string(),
1727            });
1728        }
1729        let lock_token = receipt.handle().to_string();
1730
1731        let url = format!(
1732            "{}/{}/sessions/{}/messages/{}/$deadletter",
1733            self.namespace_url,
1734            self.queue_name.as_str(),
1735            urlencoding::encode(self.session_id.as_str()),
1736            urlencoding::encode(&lock_token)
1737        );
1738
1739        let auth_token = self
1740            .get_auth_token()
1741            .await
1742            .map_err(|e| e.to_queue_error())?;
1743
1744        let properties = serde_json::json!({
1745            "DeadLetterReason": reason,
1746            "DeadLetterErrorDescription": "Message processing failed"
1747        });
1748
1749        let response = self
1750            .http_client
1751            .post(&url)
1752            .header(header::AUTHORIZATION, auth_token)
1753            .header(header::CONTENT_TYPE, "application/json")
1754            .json(&properties)
1755            .send()
1756            .await
1757            .map_err(|e| {
1758                AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
1759            })?;
1760
1761        match response.status() {
1762            StatusCode::OK | StatusCode::NO_CONTENT | StatusCode::CREATED => {
1763                self.lock_tokens.write().await.remove(receipt.handle());
1764                Ok(())
1765            }
1766            StatusCode::GONE | StatusCode::NOT_FOUND => Err(QueueError::InvalidReceipt {
1767                receipt: receipt.handle().to_string(),
1768            }),
1769            status => {
1770                let error_body = response.text().await.unwrap_or_default();
1771                Err(QueueError::ProviderError {
1772                    provider: "AzureServiceBus".to_string(),
1773                    code: status.as_str().to_string(),
1774                    message: format!("Session dead letter failed: {}", error_body),
1775                })
1776            }
1777        }
1778    }
1779
1780    /// Renew the session lock to extend the exclusive hold on the session.
1781    ///
1782    /// Calls `POST {namespace}/{queue}/sessions/{sessionId}/renewlock`.
1783    /// On success the local `session_expires_at` is refreshed.
1784    ///
1785    /// # Errors
1786    ///
1787    /// - `QueueError::SessionNotFound` – the session lock has already expired.
1788    async fn renew_session_lock(&self) -> Result<(), QueueError> {
1789        let url = format!(
1790            "{}/{}/sessions/{}/renewlock",
1791            self.namespace_url,
1792            self.queue_name.as_str(),
1793            urlencoding::encode(self.session_id.as_str())
1794        );
1795
1796        let auth_token = self
1797            .get_auth_token()
1798            .await
1799            .map_err(|e| e.to_queue_error())?;
1800
1801        let response = self
1802            .http_client
1803            .post(&url)
1804            .header(header::AUTHORIZATION, auth_token)
1805            .header(header::CONTENT_LENGTH, "0")
1806            .send()
1807            .await
1808            .map_err(|e| {
1809                AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
1810            })?;
1811
1812        match response.status() {
1813            StatusCode::OK | StatusCode::NO_CONTENT => {
1814                self.refresh_session_expiry();
1815                Ok(())
1816            }
1817            StatusCode::GONE | StatusCode::NOT_FOUND => Err(QueueError::SessionNotFound {
1818                session_id: self.session_id.to_string(),
1819            }),
1820            status => {
1821                let error_body = response.text().await.unwrap_or_default();
1822                Err(QueueError::ProviderError {
1823                    provider: "AzureServiceBus".to_string(),
1824                    code: status.as_str().to_string(),
1825                    message: format!("Session lock renewal failed: {}", error_body),
1826                })
1827            }
1828        }
1829    }
1830
1831    /// Release local session state.
1832    ///
1833    /// Clears all locally cached message lock tokens. The Azure Service Bus
1834    /// REST API has no endpoint to release a session lock before it expires;
1835    /// the broker releases the lock automatically after the session timeout
1836    /// configured on the queue entity (typically 30 s – 5 min). For workloads
1837    /// that need immediate hand-off, configure a shorter session lock duration
1838    /// on the queue entity or use the AMQP-based SDK which supports explicit
1839    /// session release.
1840    async fn close_session(&self) -> Result<(), QueueError> {
1841        self.lock_tokens.write().await.clear();
1842        Ok(())
1843    }
1844
1845    fn session_id(&self) -> &SessionId {
1846        &self.session_id
1847    }
1848
1849    fn session_expires_at(&self) -> Timestamp {
1850        self.session_expires_at
1851            .read()
1852            .map(|guard| *guard)
1853            .unwrap_or_else(|_| {
1854                // Lock is poisoned (a writer panicked). Return an already-expired
1855                // sentinel so callers treat the session as invalid rather than
1856                // silently assuming it just started.
1857                Timestamp::from_datetime(Utc::now() - Duration::seconds(1))
1858            })
1859    }
1860}