use crate::client::{QueueProvider, SessionProvider};
use crate::error::{ConfigurationError, QueueError, SerializationError};
use crate::message::{
Message, MessageId, QueueName, ReceiptHandle, ReceivedMessage, SessionId, Timestamp,
};
use crate::provider::{AzureServiceBusConfig, ProviderType, SessionSupport};
use async_trait::async_trait;
use azure_core::credentials::Secret as AzureSecret;
use azure_core::credentials::TokenCredential;
use azure_identity::{
ClientSecretCredential, ClientSecretCredentialOptions, DeveloperToolsCredential,
ManagedIdentityCredential,
};
use chrono::{Duration, Utc};
use reqwest::{header, Client as HttpClient, StatusCode};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::RwLock;
#[cfg(test)]
#[path = "azure_tests.rs"]
mod tests;
async fn get_bearer_token(
cred: &(dyn TokenCredential + Send + Sync),
) -> Result<String, AzureError> {
let scopes = &["https://servicebus.azure.net/.default"];
let token = cred
.get_token(scopes, None)
.await
.map_err(|e| AzureError::AuthenticationError(format!("Failed to get token: {}", e)))?;
Ok(token.token.secret().to_string())
}
fn generate_sas_token(namespace_url: &str, conn_str: &str) -> Result<String, AzureError> {
let mut key_name = None;
let mut key = None;
for part in conn_str.split(';') {
if let Some(value) = part.strip_prefix("SharedAccessKeyName=") {
key_name = Some(value.to_string());
} else if let Some(value) = part.strip_prefix("SharedAccessKey=") {
key = Some(value.to_string());
}
}
let key_name = key_name.ok_or_else(|| {
AzureError::AuthenticationError(
"Missing SharedAccessKeyName in connection string".to_string(),
)
})?;
let key = key.ok_or_else(|| {
AzureError::AuthenticationError("Missing SharedAccessKey in connection string".to_string())
})?;
let expiry = (Utc::now() + Duration::hours(1)).timestamp();
let string_to_sign = format!("{}\n{}", urlencoding::encode(namespace_url), expiry);
use base64::{engine::general_purpose::STANDARD, Engine};
use hmac::{Hmac, KeyInit, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
let key_bytes = STANDARD
.decode(&key)
.map_err(|e| AzureError::AuthenticationError(format!("Invalid SharedAccessKey: {}", e)))?;
let mut mac = HmacSha256::new_from_slice(&key_bytes)
.map_err(|e| AzureError::AuthenticationError(format!("Failed to create HMAC: {}", e)))?;
mac.update(string_to_sign.as_bytes());
let signature = STANDARD.encode(mac.finalize().into_bytes());
Ok(format!(
"SharedAccessSignature sr={}&sig={}&se={}&skn={}",
urlencoding::encode(namespace_url),
urlencoding::encode(&signature),
expiry,
urlencoding::encode(&key_name)
))
}
#[derive(Clone, Serialize, Deserialize)]
pub enum AzureAuthMethod {
ConnectionString,
ManagedIdentity,
ClientSecret {
tenant_id: String,
client_id: String,
client_secret: String,
},
DefaultCredential,
}
impl fmt::Debug for AzureAuthMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ConnectionString => f.debug_struct("ConnectionString").finish(),
Self::ManagedIdentity => f.debug_struct("ManagedIdentity").finish(),
Self::ClientSecret {
tenant_id,
client_id,
..
} => f
.debug_struct("ClientSecret")
.field("tenant_id", tenant_id)
.field("client_id", client_id)
.field("client_secret", &"<REDACTED>")
.finish(),
Self::DefaultCredential => f.debug_struct("DefaultCredential").finish(),
}
}
}
impl fmt::Display for AzureAuthMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ConnectionString => write!(f, "ConnectionString"),
Self::ManagedIdentity => write!(f, "ManagedIdentity"),
Self::ClientSecret { .. } => write!(f, "ClientSecret"),
Self::DefaultCredential => write!(f, "DefaultCredential"),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum AzureError {
#[error("Authentication failed: {0}")]
AuthenticationError(String),
#[error("Network error: {0}")]
NetworkError(String),
#[error("Service Bus error: {0}")]
ServiceBusError(String),
#[error("Message lock lost: {0}")]
MessageLockLost(String),
#[error("Session lock lost: {0}")]
SessionLockLost(String),
#[error("Invalid configuration: {0}")]
ConfigurationError(String),
#[error("Serialization error: {0}")]
SerializationError(String),
}
impl AzureError {
pub fn is_transient(&self) -> bool {
match self {
Self::AuthenticationError(_) => false,
Self::NetworkError(_) => true,
Self::ServiceBusError(_) => true, Self::MessageLockLost(_) => false,
Self::SessionLockLost(_) => false,
Self::ConfigurationError(_) => false,
Self::SerializationError(_) => false,
}
}
pub fn to_queue_error(self) -> QueueError {
match self {
Self::AuthenticationError(msg) => QueueError::AuthenticationFailed { message: msg },
Self::NetworkError(msg) => QueueError::ConnectionFailed { message: msg },
Self::ServiceBusError(msg) => QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "ServiceBusError".to_string(),
message: msg,
},
Self::MessageLockLost(msg) => QueueError::InvalidReceipt { receipt: msg },
Self::SessionLockLost(session_id) => QueueError::SessionNotFound { session_id },
Self::ConfigurationError(msg) => {
QueueError::ConfigurationError(ConfigurationError::Invalid { message: msg })
}
Self::SerializationError(msg) => QueueError::SerializationError(
SerializationError::JsonError(serde_json::Error::io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
msg,
))),
),
}
}
}
pub struct AzureServiceBusProvider {
config: AzureServiceBusConfig,
http_client: HttpClient,
namespace_url: String,
credential: Option<Arc<dyn TokenCredential + Send + Sync>>,
lock_tokens: Arc<RwLock<HashMap<String, (String, String)>>>,
}
impl fmt::Debug for AzureServiceBusProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AzureServiceBusProvider")
.field("config", &self.config)
.field("namespace_url", &self.namespace_url)
.field(
"credential",
&self.credential.as_ref().map(|_| "<TokenCredential>"),
)
.field("lock_tokens", &self.lock_tokens)
.finish()
}
}
impl AzureServiceBusProvider {
pub async fn new(config: AzureServiceBusConfig) -> Result<Self, AzureError> {
Self::validate_config(&config)?;
let (namespace_url, credential) = match &config.auth_method {
AzureAuthMethod::ConnectionString => {
let conn_str = config.connection_string.as_ref().ok_or_else(|| {
AzureError::ConfigurationError(
"Connection string required for ConnectionString auth".to_string(),
)
})?;
let namespace_url = Self::parse_connection_string_endpoint(conn_str)?;
(namespace_url, None)
}
AzureAuthMethod::ManagedIdentity => {
let namespace = config.namespace.as_ref().ok_or_else(|| {
AzureError::ConfigurationError(
"Namespace required for ManagedIdentity auth".to_string(),
)
})?;
let credential = ManagedIdentityCredential::new(None).map_err(|e| {
AzureError::ConfigurationError(format!(
"Failed to create managed identity credential: {}",
e
))
})?;
let namespace_url = format!("https://{}.servicebus.windows.net", namespace);
(
namespace_url,
Some(credential as Arc<dyn TokenCredential + Send + Sync>),
)
}
AzureAuthMethod::ClientSecret {
ref tenant_id,
ref client_id,
ref client_secret,
} => {
let namespace = config.namespace.as_ref().ok_or_else(|| {
AzureError::ConfigurationError(
"Namespace required for ClientSecret auth".to_string(),
)
})?;
let credential = ClientSecretCredential::new(
tenant_id,
client_id.clone(),
AzureSecret::from(client_secret.clone()),
None::<ClientSecretCredentialOptions>,
)
.map_err(|e| {
AzureError::ConfigurationError(format!("Failed to create credential: {}", e))
})?;
let namespace_url = format!("https://{}.servicebus.windows.net", namespace);
(
namespace_url,
Some(credential as Arc<dyn TokenCredential + Send + Sync>),
)
}
AzureAuthMethod::DefaultCredential => {
let namespace = config.namespace.as_ref().ok_or_else(|| {
AzureError::ConfigurationError(
"Namespace required for DefaultCredential auth".to_string(),
)
})?;
let credential = DeveloperToolsCredential::new(None).map_err(|e| {
AzureError::ConfigurationError(format!(
"Failed to create developer tools credential: {}",
e
))
})?;
let namespace_url = format!("https://{}.servicebus.windows.net", namespace);
(
namespace_url,
Some(credential as Arc<dyn TokenCredential + Send + Sync>),
)
}
};
let http_client = HttpClient::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| {
AzureError::NetworkError(format!("Failed to create HTTP client: {}", e))
})?;
Ok(Self {
config,
http_client,
namespace_url,
credential,
lock_tokens: Arc::new(RwLock::new(HashMap::new())),
})
}
fn parse_connection_string_endpoint(conn_str: &str) -> Result<String, AzureError> {
for part in conn_str.split(';') {
if let Some(endpoint) = part.strip_prefix("Endpoint=") {
return Ok(endpoint.trim_end_matches('/').to_string());
}
}
Err(AzureError::ConfigurationError(
"Invalid connection string: missing Endpoint".to_string(),
))
}
fn validate_config(config: &AzureServiceBusConfig) -> Result<(), AzureError> {
match &config.auth_method {
AzureAuthMethod::ConnectionString => {
if config.connection_string.is_none() {
return Err(AzureError::ConfigurationError(
"Connection string required for ConnectionString auth method".to_string(),
));
}
}
AzureAuthMethod::ManagedIdentity | AzureAuthMethod::DefaultCredential => {
if config.namespace.is_none() {
return Err(AzureError::ConfigurationError(
"Namespace required for ManagedIdentity/DefaultCredential auth".to_string(),
));
}
}
AzureAuthMethod::ClientSecret {
tenant_id,
client_id,
client_secret,
} => {
if config.namespace.is_none() {
return Err(AzureError::ConfigurationError(
"Namespace required for ClientSecret auth".to_string(),
));
}
if tenant_id.is_empty() || client_id.is_empty() || client_secret.is_empty() {
return Err(AzureError::ConfigurationError(
"Tenant ID, Client ID, and Client Secret required for ClientSecret auth"
.to_string(),
));
}
}
}
Ok(())
}
async fn get_auth_token(&self) -> Result<String, AzureError> {
match &self.credential {
Some(cred) => get_bearer_token(cred.as_ref()).await,
None => {
self.get_sas_token()
}
}
}
fn get_sas_token(&self) -> Result<String, AzureError> {
let conn_str = self.config.connection_string.as_ref().ok_or_else(|| {
AzureError::AuthenticationError("No connection string available".to_string())
})?;
generate_sas_token(&self.namespace_url, conn_str)
}
}
#[derive(Debug, Serialize, Deserialize)]
struct ServiceBusMessageBody {
#[serde(rename = "ContentType")]
content_type: String,
#[serde(rename = "Body")]
body: String, #[serde(rename = "BrokerProperties")]
broker_properties: BrokerProperties,
}
#[derive(Debug, Serialize, Deserialize)]
struct BrokerProperties {
#[serde(rename = "MessageId")]
message_id: String,
#[serde(rename = "SessionId", skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
#[serde(rename = "TimeToLive", skip_serializing_if = "Option::is_none")]
time_to_live: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct ServiceBusMessageResponse {
#[serde(rename = "Body")]
body: String,
#[serde(rename = "BrokerProperties")]
broker_properties: ReceivedBrokerProperties,
}
#[allow(dead_code)] #[derive(Debug, Deserialize)]
struct ReceivedServiceBusMessage {
#[serde(rename = "Body")]
body: String,
#[serde(rename = "BrokerProperties")]
broker_properties: ReceivedBrokerProperties,
}
#[allow(dead_code)] #[derive(Debug, Deserialize)]
struct ReceivedBrokerProperties {
#[serde(rename = "MessageId")]
message_id: String,
#[serde(rename = "SessionId")]
session_id: Option<String>,
#[serde(rename = "LockToken")]
lock_token: String,
#[serde(rename = "DeliveryCount")]
delivery_count: u32,
#[serde(rename = "EnqueuedTimeUtc")]
enqueued_time_utc: String,
}
#[async_trait]
impl QueueProvider for AzureServiceBusProvider {
async fn send_message(
&self,
queue: &QueueName,
message: &Message,
) -> Result<MessageId, QueueError> {
let message_id = MessageId::new();
use base64::{engine::general_purpose::STANDARD, Engine};
let body_base64 = STANDARD.encode(&message.body);
let broker_props = BrokerProperties {
message_id: message_id.to_string(),
session_id: message.session_id.as_ref().map(|s| s.to_string()),
time_to_live: message
.time_to_live
.as_ref()
.map(|ttl| ttl.num_seconds() as u64),
};
let url = format!("{}/{}/messages", self.namespace_url, queue.as_str());
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.post(&url)
.header(header::AUTHORIZATION, auth_token)
.header(
header::CONTENT_TYPE,
"application/atom+xml;type=entry;charset=utf-8",
)
.header(
"BrokerProperties",
serde_json::to_string(&broker_props).unwrap(),
)
.body(body_base64)
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
})?;
match response.status() {
StatusCode::CREATED | StatusCode::OK => Ok(message_id),
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Send failed: {}", error_body),
})
}
}
}
async fn send_messages(
&self,
queue: &QueueName,
messages: &[Message],
) -> Result<Vec<MessageId>, QueueError> {
if messages.len() > 100 {
return Err(QueueError::BatchTooLarge {
size: messages.len(),
max_size: 100,
});
}
if messages.is_empty() {
return Ok(Vec::new());
}
let mut batch_messages = Vec::with_capacity(messages.len());
let mut message_ids = Vec::with_capacity(messages.len());
use base64::{engine::general_purpose::STANDARD, Engine};
for message in messages {
let message_id = MessageId::new();
let body_base64 = STANDARD.encode(&message.body);
let broker_props = BrokerProperties {
message_id: message_id.to_string(),
session_id: message.session_id.as_ref().map(|s| s.to_string()),
time_to_live: message
.time_to_live
.as_ref()
.map(|ttl| ttl.num_seconds() as u64),
};
batch_messages.push(ServiceBusMessageBody {
content_type: "application/octet-stream".to_string(),
body: body_base64,
broker_properties: broker_props,
});
message_ids.push(message_id);
}
let url = format!("{}/{}/messages", self.namespace_url, queue.as_str());
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.post(&url)
.header(header::AUTHORIZATION, auth_token)
.header(header::CONTENT_TYPE, "application/json")
.json(&batch_messages)
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("Batch send HTTP request failed: {}", e))
.to_queue_error()
})?;
match response.status() {
StatusCode::CREATED | StatusCode::OK => Ok(message_ids),
StatusCode::PAYLOAD_TOO_LARGE => Err(QueueError::BatchTooLarge {
size: messages.len(),
max_size: 100,
}),
StatusCode::TOO_MANY_REQUESTS => {
let retry_after = response
.headers()
.get("Retry-After")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(30);
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "ThrottlingError".to_string(),
message: format!("Request throttled, retry after {} seconds", retry_after),
})
}
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::AuthenticationFailed {
message: format!("Authentication failed: {}", error_body),
})
}
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Batch send failed: {}", error_body),
})
}
}
}
async fn receive_message(
&self,
queue: &QueueName,
timeout: Duration,
) -> Result<Option<ReceivedMessage>, QueueError> {
let url = format!(
"{}/{}/messages/head?timeout={}",
self.namespace_url,
queue.as_str(),
timeout.num_seconds()
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.delete(&url)
.header(header::AUTHORIZATION, auth_token)
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::CREATED => {
let broker_props = response
.headers()
.get("BrokerProperties")
.and_then(|v| v.to_str().ok())
.and_then(|s| serde_json::from_str::<ReceivedBrokerProperties>(s).ok())
.ok_or_else(|| QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "InvalidResponse".to_string(),
message: "Missing or invalid BrokerProperties header".to_string(),
})?;
let body_base64 = response.text().await.map_err(|e| {
AzureError::NetworkError(format!("Failed to read response body: {}", e))
.to_queue_error()
})?;
use base64::{engine::general_purpose::STANDARD, Engine};
let body =
STANDARD
.decode(&body_base64)
.map_err(|e| QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "DecodingError".to_string(),
message: format!("Failed to decode message body: {}", e),
})?;
let first_delivered_at =
chrono::DateTime::parse_from_rfc3339(&broker_props.enqueued_time_utc)
.map(|dt| Timestamp::from_datetime(dt.with_timezone(&chrono::Utc)))
.unwrap_or_else(|_| Timestamp::now());
let expires_at = Timestamp::from_datetime(Utc::now() + Duration::seconds(30));
let receipt_str = format!("{}::{}", broker_props.lock_token, queue.as_str());
let receipt = ReceiptHandle::new(
receipt_str.clone(),
expires_at,
ProviderType::AzureServiceBus,
);
self.lock_tokens.write().await.insert(
receipt_str,
(broker_props.lock_token.clone(), queue.as_str().to_string()),
);
let message_id = MessageId::from_str(&broker_props.message_id)
.unwrap_or_else(|_| MessageId::new());
let received_message = ReceivedMessage {
message_id,
body: bytes::Bytes::from(body),
attributes: HashMap::new(),
session_id: broker_props.session_id.map(SessionId::new).transpose()?,
correlation_id: None,
receipt_handle: receipt,
delivery_count: broker_props.delivery_count,
first_delivered_at,
delivered_at: Timestamp::now(),
};
Ok(Some(received_message))
}
StatusCode::NO_CONTENT => {
Ok(None)
}
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Receive failed: {}", error_body),
})
}
}
}
async fn receive_messages(
&self,
queue: &QueueName,
max_messages: u32,
timeout: Duration,
) -> Result<Vec<ReceivedMessage>, QueueError> {
if max_messages > 32 {
return Err(QueueError::BatchTooLarge {
size: max_messages as usize,
max_size: 32,
});
}
if max_messages == 0 {
return Ok(Vec::new());
}
let url = format!(
"{}/{}/messages/head?timeout={}&maxMessageCount={}",
self.namespace_url,
queue.as_str(),
timeout.num_seconds(),
max_messages
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.delete(&url)
.header(header::AUTHORIZATION, auth_token)
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("Batch receive HTTP request failed: {}", e))
.to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::CREATED => {
let messages_data: Vec<ServiceBusMessageResponse> =
response.json().await.map_err(|e| {
AzureError::SerializationError(format!(
"Failed to parse batch receive response: {}",
e
))
.to_queue_error()
})?;
let mut received_messages = Vec::with_capacity(messages_data.len());
use base64::{engine::general_purpose::STANDARD, Engine};
for msg_data in messages_data {
let broker_props = msg_data.broker_properties;
let body = STANDARD.decode(&msg_data.body).map_err(|e| {
AzureError::SerializationError(format!(
"Failed to decode message body: {}",
e
))
.to_queue_error()
})?;
let enqueued_time =
chrono::DateTime::parse_from_rfc3339(&broker_props.enqueued_time_utc)
.map_err(|e| {
AzureError::SerializationError(format!(
"Failed to parse enqueued time: {}",
e
))
.to_queue_error()
})?;
let first_delivered_at =
Timestamp::from_datetime(enqueued_time.with_timezone(&Utc));
let expires_at = Timestamp::from_datetime(Utc::now() + Duration::seconds(30));
let receipt_str = format!("{}::{}", broker_props.lock_token, queue.as_str());
let receipt = ReceiptHandle::new(
receipt_str.clone(),
expires_at,
ProviderType::AzureServiceBus,
);
self.lock_tokens.write().await.insert(
receipt_str,
(broker_props.lock_token.clone(), queue.as_str().to_string()),
);
let message_id = MessageId::from_str(&broker_props.message_id)
.unwrap_or_else(|_| MessageId::new());
let received_message = ReceivedMessage {
message_id,
body: bytes::Bytes::from(body),
attributes: HashMap::new(),
session_id: broker_props.session_id.map(SessionId::new).transpose()?,
correlation_id: None,
receipt_handle: receipt,
delivery_count: broker_props.delivery_count,
first_delivered_at,
delivered_at: Timestamp::now(),
};
received_messages.push(received_message);
}
Ok(received_messages)
}
StatusCode::NO_CONTENT => {
Ok(Vec::new())
}
StatusCode::TOO_MANY_REQUESTS => {
let retry_after = response
.headers()
.get("Retry-After")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(30);
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "ThrottlingError".to_string(),
message: format!("Request throttled, retry after {} seconds", retry_after),
})
}
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::AuthenticationFailed {
message: format!("Authentication failed: {}", error_body),
})
}
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Batch receive failed: {}", error_body),
})
}
}
}
async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
let lock_tokens = self.lock_tokens.read().await;
let (lock_token, queue_name) =
lock_tokens
.get(receipt.handle())
.ok_or_else(|| QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
})?;
let url = format!(
"{}/{}/messages/head/{}",
self.namespace_url,
queue_name,
urlencoding::encode(lock_token)
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.delete(&url)
.header(header::AUTHORIZATION, auth_token)
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::NO_CONTENT => {
drop(lock_tokens);
self.lock_tokens.write().await.remove(receipt.handle());
Ok(())
}
StatusCode::GONE | StatusCode::NOT_FOUND => {
Err(QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
})
}
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Complete failed: {}", error_body),
})
}
}
}
async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
let lock_tokens = self.lock_tokens.read().await;
let (lock_token, queue_name) =
lock_tokens
.get(receipt.handle())
.ok_or_else(|| QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
})?;
let url = format!(
"{}/{}/messages/head/{}",
self.namespace_url,
queue_name,
urlencoding::encode(lock_token)
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.put(&url)
.header(header::AUTHORIZATION, auth_token)
.header(header::CONTENT_LENGTH, "0")
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::NO_CONTENT => {
drop(lock_tokens);
self.lock_tokens.write().await.remove(receipt.handle());
Ok(())
}
StatusCode::GONE | StatusCode::NOT_FOUND => {
Err(QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
})
}
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Abandon failed: {}", error_body),
})
}
}
}
async fn dead_letter_message(
&self,
receipt: &ReceiptHandle,
reason: &str,
) -> Result<(), QueueError> {
let lock_tokens = self.lock_tokens.read().await;
let (lock_token, queue_name) =
lock_tokens
.get(receipt.handle())
.ok_or_else(|| QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
})?;
let url = format!(
"{}/{}/messages/head/{}/$deadletter",
self.namespace_url,
queue_name,
urlencoding::encode(lock_token)
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let properties = serde_json::json!({
"DeadLetterReason": reason,
"DeadLetterErrorDescription": "Message processing failed"
});
let response = self
.http_client
.post(&url)
.header(header::AUTHORIZATION, auth_token)
.header(header::CONTENT_TYPE, "application/json")
.json(&properties)
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::NO_CONTENT | StatusCode::CREATED => {
drop(lock_tokens);
self.lock_tokens.write().await.remove(receipt.handle());
Ok(())
}
StatusCode::GONE | StatusCode::NOT_FOUND => {
Err(QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
})
}
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Dead letter failed: {}", error_body),
})
}
}
}
async fn create_session_client(
&self,
queue: &QueueName,
session_id: Option<SessionId>,
) -> Result<Box<dyn SessionProvider>, QueueError> {
let resolved_id = match session_id {
Some(id) => id,
None => self.accept_next_available_session(queue).await?,
};
Ok(Box::new(AzureSessionProvider::new(
resolved_id,
queue.clone(),
self.config.session_timeout,
self.http_client.clone(),
self.namespace_url.clone(),
self.config.clone(),
self.credential.clone(),
)))
}
fn provider_type(&self) -> ProviderType {
ProviderType::AzureServiceBus
}
fn supports_sessions(&self) -> SessionSupport {
SessionSupport::Native
}
fn supports_batching(&self) -> bool {
true
}
fn max_batch_size(&self) -> u32 {
100 }
}
impl AzureServiceBusProvider {
async fn accept_next_available_session(
&self,
queue: &QueueName,
) -> Result<SessionId, QueueError> {
let timeout_secs = self.config.session_timeout.num_seconds().max(1);
let url = format!(
"{}/{}/sessions/$acceptnext/messages/head?timeout={}",
self.namespace_url,
queue.as_str(),
timeout_secs
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.delete(&url)
.header(header::AUTHORIZATION, auth_token)
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("Failed to accept next session: {}", e))
.to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::CREATED => {
let broker_props = response
.headers()
.get("BrokerProperties")
.and_then(|v| v.to_str().ok())
.and_then(|s| serde_json::from_str::<ReceivedBrokerProperties>(s).ok())
.ok_or_else(|| QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "InvalidResponse".to_string(),
message: "Missing BrokerProperties in accept-next-session response"
.to_string(),
})?;
let session_id_str =
broker_props
.session_id
.ok_or_else(|| QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "NoSessionId".to_string(),
message: "Accepted message has no SessionId".to_string(),
})?;
SessionId::new(session_id_str).map_err(|e| QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "InvalidSessionId".to_string(),
message: format!("Invalid session ID returned by broker: {}", e),
})
}
StatusCode::NO_CONTENT => Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "NoSessionsAvailable".to_string(),
message: "No sessions with pending messages are available".to_string(),
}),
StatusCode::NOT_FOUND => Err(QueueError::QueueNotFound {
queue_name: queue.to_string(),
}),
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Accept next session failed: {}", error_body),
})
}
}
}
}
pub struct AzureSessionProvider {
session_id: SessionId,
queue_name: QueueName,
session_expires_at: Arc<std::sync::RwLock<Timestamp>>,
http_client: HttpClient,
namespace_url: String,
config: AzureServiceBusConfig,
credential: Option<Arc<dyn TokenCredential + Send + Sync>>,
lock_tokens: Arc<RwLock<HashSet<String>>>,
}
impl fmt::Debug for AzureSessionProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AzureSessionProvider")
.field("session_id", &self.session_id)
.field("queue_name", &self.queue_name)
.field("namespace_url", &self.namespace_url)
.field(
"credential",
&self.credential.as_ref().map(|_| "<TokenCredential>"),
)
.finish()
}
}
impl AzureSessionProvider {
pub fn new(
session_id: SessionId,
queue_name: QueueName,
session_timeout: Duration,
http_client: HttpClient,
namespace_url: String,
config: AzureServiceBusConfig,
credential: Option<Arc<dyn TokenCredential + Send + Sync>>,
) -> Self {
let session_expires_at = Timestamp::from_datetime(Utc::now() + session_timeout);
Self {
session_id,
queue_name,
session_expires_at: Arc::new(std::sync::RwLock::new(session_expires_at)),
http_client,
namespace_url,
config,
credential,
lock_tokens: Arc::new(RwLock::new(HashSet::new())),
}
}
async fn get_auth_token(&self) -> Result<String, AzureError> {
match &self.credential {
Some(cred) => get_bearer_token(cred.as_ref()).await,
None => {
let conn_str = self.config.connection_string.as_ref().ok_or_else(|| {
AzureError::AuthenticationError("No connection string available".to_string())
})?;
generate_sas_token(&self.namespace_url, conn_str)
}
}
}
fn refresh_session_expiry(&self) {
if let Ok(mut expiry) = self.session_expires_at.write() {
*expiry = Timestamp::from_datetime(Utc::now() + self.config.session_timeout);
}
}
}
#[async_trait]
impl SessionProvider for AzureSessionProvider {
async fn receive_message(
&self,
timeout: Duration,
) -> Result<Option<ReceivedMessage>, QueueError> {
let url = format!(
"{}/{}/sessions/{}/messages/head?timeout={}",
self.namespace_url,
self.queue_name.as_str(),
urlencoding::encode(self.session_id.as_str()),
timeout.num_seconds()
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.delete(&url)
.header(header::AUTHORIZATION, auth_token)
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::CREATED => {
let broker_props = response
.headers()
.get("BrokerProperties")
.and_then(|v| v.to_str().ok())
.and_then(|s| serde_json::from_str::<ReceivedBrokerProperties>(s).ok())
.ok_or_else(|| QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "InvalidResponse".to_string(),
message: "Missing or invalid BrokerProperties header".to_string(),
})?;
let body_base64 = response.text().await.map_err(|e| {
AzureError::NetworkError(format!("Failed to read response body: {}", e))
.to_queue_error()
})?;
use base64::{engine::general_purpose::STANDARD, Engine};
let body =
STANDARD
.decode(&body_base64)
.map_err(|e| QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: "DecodingError".to_string(),
message: format!("Failed to decode message body: {}", e),
})?;
let first_delivered_at =
chrono::DateTime::parse_from_rfc3339(&broker_props.enqueued_time_utc)
.map(|dt| Timestamp::from_datetime(dt.with_timezone(&chrono::Utc)))
.unwrap_or_else(|_| Timestamp::now());
let expires_at = Timestamp::from_datetime(Utc::now() + self.config.session_timeout);
let lock_token = broker_props.lock_token.clone();
let receipt = ReceiptHandle::new(
lock_token.clone(),
expires_at,
ProviderType::AzureServiceBus,
);
self.lock_tokens.write().await.insert(lock_token);
let message_id = MessageId::from_str(&broker_props.message_id)
.unwrap_or_else(|_| MessageId::new());
self.refresh_session_expiry();
Ok(Some(ReceivedMessage {
message_id,
body: bytes::Bytes::from(body),
attributes: HashMap::new(),
session_id: Some(self.session_id.clone()),
correlation_id: None,
receipt_handle: receipt,
delivery_count: broker_props.delivery_count,
first_delivered_at,
delivered_at: Timestamp::now(),
}))
}
StatusCode::NO_CONTENT => Ok(None),
StatusCode::GONE | StatusCode::NOT_FOUND => Err(QueueError::SessionNotFound {
session_id: self.session_id.to_string(),
}),
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Session receive failed: {}", error_body),
})
}
}
}
async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
if !self.lock_tokens.read().await.contains(receipt.handle()) {
return Err(QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
});
}
let lock_token = receipt.handle().to_string();
let url = format!(
"{}/{}/sessions/{}/messages/{}",
self.namespace_url,
self.queue_name.as_str(),
urlencoding::encode(self.session_id.as_str()),
urlencoding::encode(&lock_token)
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.delete(&url)
.header(header::AUTHORIZATION, auth_token)
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::NO_CONTENT => {
self.lock_tokens.write().await.remove(receipt.handle());
Ok(())
}
StatusCode::GONE | StatusCode::NOT_FOUND => Err(QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
}),
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Session complete failed: {}", error_body),
})
}
}
}
async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
if !self.lock_tokens.read().await.contains(receipt.handle()) {
return Err(QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
});
}
let lock_token = receipt.handle().to_string();
let url = format!(
"{}/{}/sessions/{}/messages/{}",
self.namespace_url,
self.queue_name.as_str(),
urlencoding::encode(self.session_id.as_str()),
urlencoding::encode(&lock_token)
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.put(&url)
.header(header::AUTHORIZATION, auth_token)
.header(header::CONTENT_LENGTH, "0")
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::NO_CONTENT => {
self.lock_tokens.write().await.remove(receipt.handle());
Ok(())
}
StatusCode::GONE | StatusCode::NOT_FOUND => Err(QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
}),
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Session abandon failed: {}", error_body),
})
}
}
}
async fn dead_letter_message(
&self,
receipt: &ReceiptHandle,
reason: &str,
) -> Result<(), QueueError> {
if !self.lock_tokens.read().await.contains(receipt.handle()) {
return Err(QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
});
}
let lock_token = receipt.handle().to_string();
let url = format!(
"{}/{}/sessions/{}/messages/{}/$deadletter",
self.namespace_url,
self.queue_name.as_str(),
urlencoding::encode(self.session_id.as_str()),
urlencoding::encode(&lock_token)
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let properties = serde_json::json!({
"DeadLetterReason": reason,
"DeadLetterErrorDescription": "Message processing failed"
});
let response = self
.http_client
.post(&url)
.header(header::AUTHORIZATION, auth_token)
.header(header::CONTENT_TYPE, "application/json")
.json(&properties)
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::NO_CONTENT | StatusCode::CREATED => {
self.lock_tokens.write().await.remove(receipt.handle());
Ok(())
}
StatusCode::GONE | StatusCode::NOT_FOUND => Err(QueueError::InvalidReceipt {
receipt: receipt.handle().to_string(),
}),
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Session dead letter failed: {}", error_body),
})
}
}
}
async fn renew_session_lock(&self) -> Result<(), QueueError> {
let url = format!(
"{}/{}/sessions/{}/renewlock",
self.namespace_url,
self.queue_name.as_str(),
urlencoding::encode(self.session_id.as_str())
);
let auth_token = self
.get_auth_token()
.await
.map_err(|e| e.to_queue_error())?;
let response = self
.http_client
.post(&url)
.header(header::AUTHORIZATION, auth_token)
.header(header::CONTENT_LENGTH, "0")
.send()
.await
.map_err(|e| {
AzureError::NetworkError(format!("HTTP request failed: {}", e)).to_queue_error()
})?;
match response.status() {
StatusCode::OK | StatusCode::NO_CONTENT => {
self.refresh_session_expiry();
Ok(())
}
StatusCode::GONE | StatusCode::NOT_FOUND => Err(QueueError::SessionNotFound {
session_id: self.session_id.to_string(),
}),
status => {
let error_body = response.text().await.unwrap_or_default();
Err(QueueError::ProviderError {
provider: "AzureServiceBus".to_string(),
code: status.as_str().to_string(),
message: format!("Session lock renewal failed: {}", error_body),
})
}
}
}
async fn close_session(&self) -> Result<(), QueueError> {
self.lock_tokens.write().await.clear();
Ok(())
}
fn session_id(&self) -> &SessionId {
&self.session_id
}
fn session_expires_at(&self) -> Timestamp {
self.session_expires_at
.read()
.map(|guard| *guard)
.unwrap_or_else(|_| {
Timestamp::from_datetime(Utc::now() - Duration::seconds(1))
})
}
}