use super::{AzureAdConfig, ServiceBusError};
use crate::common::HttpError;
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
const AZURE_MANAGEMENT_URL: &str = "https://management.azure.com";
const API_VERSION_SUBSCRIPTIONS: &str = "2022-12-01";
const API_VERSION_RESOURCE_GROUPS: &str = "2021-04-01";
const API_VERSION_SERVICE_BUS: &str = "2021-11-01";
#[derive(Debug, Clone)]
pub struct AzureManagementClient {
client: reqwest::Client,
azure_ad_config: Option<AzureAdConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Subscription {
pub id: String,
#[serde(rename = "subscriptionId")]
pub subscription_id: String,
#[serde(rename = "displayName")]
pub display_name: String,
pub state: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ResourceGroup {
pub id: String,
pub name: String,
pub location: String,
#[serde(default)]
pub tags: std::collections::HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ServiceBusNamespace {
pub id: String,
pub name: String,
pub location: String,
#[serde(rename = "type")]
pub resource_type: String,
pub properties: NamespaceProperties,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct NamespaceProperties {
#[serde(rename = "serviceBusEndpoint")]
pub service_bus_endpoint: String,
pub status: Option<String>,
#[serde(rename = "createdAt")]
pub created_at: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AccessKeys {
#[serde(rename = "primaryConnectionString")]
pub primary_connection_string: String,
#[serde(rename = "secondaryConnectionString")]
pub secondary_connection_string: String,
#[serde(rename = "primaryKey")]
pub primary_key: String,
#[serde(rename = "secondaryKey")]
pub secondary_key: String,
}
#[derive(Debug, Deserialize)]
struct QueuePropertiesResponse {
properties: QueueProperties,
}
#[derive(Debug, Deserialize)]
struct QueueProperties {
#[serde(rename = "countDetails")]
count_details: CountDetails,
}
#[derive(Debug, Deserialize)]
struct CountDetails {
#[serde(rename = "activeMessageCount")]
active_message_count: i64,
#[serde(rename = "deadLetterMessageCount")]
dead_letter_message_count: i64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ListResponse<T> {
pub value: Vec<T>,
#[serde(rename = "nextLink")]
pub next_link: Option<String>,
}
impl AzureManagementClient {
pub fn new(client: reqwest::Client) -> Self {
Self {
client,
azure_ad_config: None,
}
}
pub fn with_config(client: reqwest::Client, azure_ad_config: AzureAdConfig) -> Self {
Self {
client,
azure_ad_config: Some(azure_ad_config),
}
}
pub fn from_config(
client: reqwest::Client,
azure_ad_config: AzureAdConfig,
) -> Result<Self, ServiceBusError> {
azure_ad_config.subscription_id()?;
azure_ad_config.resource_group()?;
azure_ad_config.namespace()?;
Ok(Self::with_config(client, azure_ad_config))
}
async fn get_management_api_token(&self) -> Result<String, ServiceBusError> {
match &self.azure_ad_config {
Some(config) => config
.get_azure_ad_token(&self.client)
.await
.map_err(|e| ServiceBusError::AuthenticationError(e.to_string())),
None => Err(ServiceBusError::ConfigurationError(
"Azure AD configuration not available for this operation".to_string(),
)),
}
}
pub async fn list_subscriptions(
&self,
token: &str,
) -> Result<Vec<Subscription>, ServiceBusError> {
self.list_subscriptions_paginated(token, None)
.await
.map(|(subs, _)| subs)
}
async fn list_subscriptions_paginated(
&self,
token: &str,
continuation_token: Option<String>,
) -> Result<(Vec<Subscription>, Option<String>), ServiceBusError> {
let url = match continuation_token {
Some(next_link) => next_link,
None => format!(
"{AZURE_MANAGEMENT_URL}/subscriptions?api-version={API_VERSION_SUBSCRIPTIONS}"
),
};
let client = self.client.clone();
let token = token.to_string();
let request = client
.get(&url)
.header(AUTHORIZATION, format!("Bearer {token}"));
let response = request
.send()
.await
.map_err(|e| ServiceBusError::ConnectionFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(ServiceBusError::from_azure_response(response, "list_subscriptions").await);
}
let list_response: ListResponse<Subscription> = response
.json()
.await
.map_err(|e| ServiceBusError::ConfigurationError(e.to_string()))?;
Ok((list_response.value, list_response.next_link))
}
pub async fn list_all_subscriptions(
&self,
token: &str,
) -> Result<Vec<Subscription>, ServiceBusError> {
let mut all_subscriptions = Vec::new();
let mut continuation_token = None;
loop {
let (mut page_subscriptions, next_token) = self
.list_subscriptions_paginated(token, continuation_token)
.await?;
all_subscriptions.append(&mut page_subscriptions);
match next_token {
Some(token) => continuation_token = Some(token),
None => break,
}
}
Ok(all_subscriptions)
}
pub async fn list_resource_groups(
&self,
token: &str,
subscription_id: &str,
) -> Result<Vec<ResourceGroup>, ServiceBusError> {
self.list_resource_groups_paginated(token, subscription_id, None)
.await
.map(|(groups, _)| groups)
}
pub async fn list_resource_groups_paginated(
&self,
token: &str,
subscription_id: &str,
continuation_token: Option<String>,
) -> Result<(Vec<ResourceGroup>, Option<String>), ServiceBusError> {
let url = match continuation_token {
Some(next_link) => next_link,
None => format!(
"{AZURE_MANAGEMENT_URL}/subscriptions/{subscription_id}/resourcegroups?api-version={API_VERSION_RESOURCE_GROUPS}"
),
};
let request = self
.client
.get(&url)
.header(AUTHORIZATION, format!("Bearer {token}"));
let response = request
.send()
.await
.map_err(|e| ServiceBusError::ConnectionFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(
ServiceBusError::from_azure_response(response, "list_resource_groups").await,
);
}
let list_response: ListResponse<ResourceGroup> = response
.json()
.await
.map_err(|e| ServiceBusError::ConfigurationError(e.to_string()))?;
Ok((list_response.value, list_response.next_link))
}
pub async fn list_all_resource_groups(
&self,
token: &str,
subscription_id: &str,
) -> Result<Vec<ResourceGroup>, ServiceBusError> {
let mut all_groups = Vec::new();
let mut continuation_token = None;
loop {
let (mut page_groups, next_token) = self
.list_resource_groups_paginated(token, subscription_id, continuation_token)
.await?;
all_groups.append(&mut page_groups);
match next_token {
Some(token) => continuation_token = Some(token),
None => break,
}
}
Ok(all_groups)
}
pub async fn list_service_bus_namespaces(
&self,
token: &str,
subscription_id: &str,
) -> Result<Vec<ServiceBusNamespace>, ServiceBusError> {
self.list_service_bus_namespaces_paginated(token, subscription_id, None)
.await
.map(|(namespaces, _)| namespaces)
}
pub async fn list_service_bus_namespaces_paginated(
&self,
token: &str,
subscription_id: &str,
continuation_token: Option<String>,
) -> Result<(Vec<ServiceBusNamespace>, Option<String>), ServiceBusError> {
let url = match continuation_token {
Some(next_link) => next_link,
None => format!(
"{AZURE_MANAGEMENT_URL}/subscriptions/{subscription_id}/providers/Microsoft.ServiceBus/namespaces?api-version={API_VERSION_SERVICE_BUS}"
),
};
let request = self
.client
.get(&url)
.header(AUTHORIZATION, format!("Bearer {token}"));
let response = request
.send()
.await
.map_err(|e| ServiceBusError::ConnectionFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(ServiceBusError::from_azure_response(
response,
"list_service_bus_namespaces",
)
.await);
}
let list_response: ListResponse<ServiceBusNamespace> = response
.json()
.await
.map_err(|e| ServiceBusError::ConfigurationError(e.to_string()))?;
Ok((list_response.value, list_response.next_link))
}
pub async fn list_all_service_bus_namespaces(
&self,
token: &str,
subscription_id: &str,
) -> Result<Vec<ServiceBusNamespace>, ServiceBusError> {
let mut all_namespaces = Vec::new();
let mut continuation_token = None;
loop {
let (mut page_namespaces, next_token) = self
.list_service_bus_namespaces_paginated(token, subscription_id, continuation_token)
.await?;
all_namespaces.append(&mut page_namespaces);
match next_token {
Some(token) => continuation_token = Some(token),
None => break,
}
}
Ok(all_namespaces)
}
pub async fn get_namespace_connection_string(
&self,
token: &str,
subscription_id: &str,
resource_group: &str,
namespace: &str,
) -> Result<String, ServiceBusError> {
let url = format!(
"{AZURE_MANAGEMENT_URL}/subscriptions/{subscription_id}/resourceGroups/{resource_group}/providers/Microsoft.ServiceBus/namespaces/{namespace}/authorizationRules/RootManageSharedAccessKey/listKeys?api-version={API_VERSION_SERVICE_BUS}"
);
let request = self
.client
.post(&url)
.header(AUTHORIZATION, format!("Bearer {token}"))
.header(CONTENT_TYPE, "application/json")
.body("{}");
let response = request
.send()
.await
.map_err(|e| ServiceBusError::ConnectionFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(ServiceBusError::from_azure_response(
response,
"get_namespace_connection_string",
)
.await);
}
let keys: AccessKeys = response
.json()
.await
.map_err(|e| ServiceBusError::ConfigurationError(e.to_string()))?;
Ok(keys.primary_connection_string)
}
pub async fn list_queues(
&self,
token: &str,
subscription_id: &str,
resource_group: &str,
namespace: &str,
) -> Result<Vec<String>, ServiceBusError> {
self.list_queues_paginated(token, subscription_id, resource_group, namespace, None)
.await
.map(|(queues, _)| queues)
}
pub async fn list_queues_paginated(
&self,
token: &str,
subscription_id: &str,
resource_group: &str,
namespace: &str,
continuation_token: Option<String>,
) -> Result<(Vec<String>, Option<String>), ServiceBusError> {
let url = match continuation_token {
Some(next_link) => next_link,
None => format!(
"{AZURE_MANAGEMENT_URL}/subscriptions/{subscription_id}/resourceGroups/{resource_group}/providers/Microsoft.ServiceBus/namespaces/{namespace}/queues?api-version={API_VERSION_SERVICE_BUS}"
),
};
let request = self
.client
.get(&url)
.header(AUTHORIZATION, format!("Bearer {token}"));
let response = request
.send()
.await
.map_err(|e| ServiceBusError::ConnectionFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(ServiceBusError::from_azure_response(response, "list_queues").await);
}
let list_response: ListResponse<serde_json::Value> = response
.json()
.await
.map_err(|e| ServiceBusError::ConfigurationError(e.to_string()))?;
let queue_names: Vec<String> = list_response
.value
.iter()
.filter_map(|queue| queue["name"].as_str().map(|s| s.to_string()))
.collect();
Ok((queue_names, list_response.next_link))
}
pub async fn list_all_queues(
&self,
token: &str,
subscription_id: &str,
resource_group: &str,
namespace: &str,
) -> Result<Vec<String>, ServiceBusError> {
let mut all_queues = Vec::new();
let mut continuation_token = None;
loop {
let (mut page_queues, next_token) = self
.list_queues_paginated(
token,
subscription_id,
resource_group,
namespace,
continuation_token,
)
.await?;
all_queues.append(&mut page_queues);
match next_token {
Some(token) => continuation_token = Some(token),
None => break,
}
}
Ok(all_queues)
}
pub async fn get_queue_message_count(&self, queue_name: &str) -> Result<u64, ServiceBusError> {
let (active_count, _) = self.get_queue_counts(queue_name).await?;
Ok(active_count)
}
pub async fn get_queue_counts(&self, queue_name: &str) -> Result<(u64, u64), ServiceBusError> {
self.get_queue_counts_with_retry(queue_name, 3).await
}
async fn get_queue_counts_with_retry(
&self,
queue_name: &str,
max_retries: u32,
) -> Result<(u64, u64), ServiceBusError> {
let mut last_error = None;
for attempt in 0..=max_retries {
match self.get_queue_counts_internal(queue_name).await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
if let Some(ref err) = last_error {
match err {
ServiceBusError::ConfigurationError(_)
| ServiceBusError::AuthenticationError(_) => {
log::debug!("Non-retryable error, failing immediately: {err}");
return Err(last_error.unwrap());
}
ServiceBusError::InternalError(msg) if msg.contains("404") => {
return Err(HttpError::InvalidResponse {
expected: "2xx status".to_string(),
actual: format!("Queue not found: {queue_name}"),
}
.into());
}
_ => {}
}
}
if attempt < max_retries {
let delay = Duration::from_millis(100 * (2_u64.pow(attempt))); log::debug!(
"Attempt {} failed, retrying in {:?}: {}",
attempt + 1,
delay,
last_error.as_ref().unwrap()
);
tokio::time::sleep(delay).await;
}
}
}
}
Err(last_error.unwrap())
}
async fn get_queue_counts_internal(
&self,
queue_name: &str,
) -> Result<(u64, u64), ServiceBusError> {
log::debug!("Getting queue counts for: {queue_name}");
let config = self.azure_ad_config.as_ref().ok_or_else(|| {
ServiceBusError::ConfigurationError(
"Azure AD configuration required for queue statistics".to_string(),
)
})?;
let subscription_id = config.subscription_id()?;
let resource_group = config.resource_group()?;
let namespace = config.namespace()?;
let access_token = self.get_management_api_token().await?;
let encoded_queue_name = urlencoding::encode(queue_name);
let url = format!(
"{AZURE_MANAGEMENT_URL}/subscriptions/{subscription_id}/resourceGroups/{resource_group}/providers/Microsoft.ServiceBus/namespaces/{namespace}/queues/{encoded_queue_name}?api-version={API_VERSION_SERVICE_BUS}"
);
log::debug!("Requesting queue properties from Azure Management API: {url}");
let request = self
.client
.get(&url)
.header(AUTHORIZATION, format!("Bearer {access_token}"))
.header(CONTENT_TYPE, "application/json");
let response = request
.send()
.await
.map_err(|e| ServiceBusError::ConnectionFailed(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
if status == 404 {
return Err(ServiceBusError::azure_api_error(
"get_queue_counts",
"QueueNotFound",
404,
format!("Queue not found: {queue_name}"),
));
}
return Err(ServiceBusError::from_azure_response(response, "get_queue_counts").await);
}
let response_text = response
.text()
.await
.map_err(|e| ServiceBusError::InternalError(format!("Failed to read response: {e}")))?;
let queue_response: QueuePropertiesResponse = serde_json::from_str(&response_text)
.map_err(|e| {
ServiceBusError::ConfigurationError(format!("Failed to parse JSON: {e}"))
})?;
let active_raw = queue_response.properties.count_details.active_message_count;
let dlq_raw = queue_response
.properties
.count_details
.dead_letter_message_count;
let active = if active_raw < 0 { 0 } else { active_raw as u64 };
let dlq = if dlq_raw < 0 { 0 } else { dlq_raw as u64 };
Ok((active, dlq))
}
}
#[derive(Debug, Clone)]
struct CacheEntry<T> {
data: T,
cached_at: Instant,
}
impl<T> CacheEntry<T> {
fn new(data: T) -> Self {
Self {
data,
cached_at: Instant::now(),
}
}
fn is_expired(&self, ttl: Duration) -> bool {
self.cached_at.elapsed() > ttl
}
}
#[derive(Debug, Clone)]
pub struct AzureResourceCache {
subscriptions: Option<CacheEntry<Vec<Subscription>>>,
resource_groups: std::collections::HashMap<String, CacheEntry<Vec<ResourceGroup>>>,
namespaces: std::collections::HashMap<String, CacheEntry<Vec<ServiceBusNamespace>>>,
connection_strings: std::collections::HashMap<String, CacheEntry<String>>,
cache_ttl: Duration,
max_entries_per_cache: usize,
}
impl AzureResourceCache {
pub fn new() -> Self {
Self::with_config(Duration::from_secs(300), 100) }
pub fn with_config(cache_ttl: Duration, max_entries: usize) -> Self {
Self {
subscriptions: None,
resource_groups: std::collections::HashMap::new(),
namespaces: std::collections::HashMap::new(),
connection_strings: std::collections::HashMap::new(),
cache_ttl,
max_entries_per_cache: max_entries,
}
}
pub fn cache_subscriptions(&mut self, subscriptions: Vec<Subscription>) {
self.subscriptions = Some(CacheEntry::new(subscriptions));
}
pub fn cache_resource_groups(&mut self, subscription_id: String, groups: Vec<ResourceGroup>) {
if self.resource_groups.len() >= self.max_entries_per_cache
&& !self.resource_groups.contains_key(&subscription_id)
{
if let Some(oldest_key) = self.find_oldest_entry(&self.resource_groups) {
self.resource_groups.remove(&oldest_key);
}
}
self.resource_groups
.insert(subscription_id, CacheEntry::new(groups));
}
pub fn cache_namespaces(
&mut self,
subscription_id: String,
namespaces: Vec<ServiceBusNamespace>,
) {
if self.namespaces.len() >= self.max_entries_per_cache
&& !self.namespaces.contains_key(&subscription_id)
{
if let Some(oldest_key) = self.find_oldest_entry(&self.namespaces) {
self.namespaces.remove(&oldest_key);
}
}
self.namespaces
.insert(subscription_id, CacheEntry::new(namespaces));
}
pub fn cache_connection_string(&mut self, namespace_id: String, connection_string: String) {
if self.connection_strings.len() >= self.max_entries_per_cache
&& !self.connection_strings.contains_key(&namespace_id)
{
if let Some(oldest_key) = self.find_oldest_entry(&self.connection_strings) {
self.connection_strings.remove(&oldest_key);
}
}
self.connection_strings
.insert(namespace_id, CacheEntry::new(connection_string));
}
pub fn get_cached_connection_string(&self, namespace_id: &str) -> Option<String> {
self.connection_strings
.get(namespace_id)
.filter(|entry| !entry.is_expired(self.cache_ttl))
.map(|entry| entry.data.clone())
}
pub fn get_cached_subscriptions(&self) -> Option<Vec<Subscription>> {
self.subscriptions
.as_ref()
.filter(|entry| !entry.is_expired(self.cache_ttl))
.map(|entry| entry.data.clone())
}
pub fn get_cached_resource_groups(&self, subscription_id: &str) -> Option<Vec<ResourceGroup>> {
self.resource_groups
.get(subscription_id)
.filter(|entry| !entry.is_expired(self.cache_ttl))
.map(|entry| entry.data.clone())
}
pub fn get_cached_namespaces(&self, subscription_id: &str) -> Option<Vec<ServiceBusNamespace>> {
self.namespaces
.get(subscription_id)
.filter(|entry| !entry.is_expired(self.cache_ttl))
.map(|entry| entry.data.clone())
}
pub fn is_empty(&self) -> bool {
self.subscriptions.is_none()
&& self.resource_groups.is_empty()
&& self.namespaces.is_empty()
&& self.connection_strings.is_empty()
}
pub fn clear(&mut self) {
self.subscriptions = None;
self.resource_groups.clear();
self.namespaces.clear();
self.connection_strings.clear();
}
pub fn clean_expired(&mut self) {
if let Some(ref entry) = self.subscriptions {
if entry.is_expired(self.cache_ttl) {
self.subscriptions = None;
}
}
self.resource_groups
.retain(|_, entry| !entry.is_expired(self.cache_ttl));
self.namespaces
.retain(|_, entry| !entry.is_expired(self.cache_ttl));
self.connection_strings
.retain(|_, entry| !entry.is_expired(self.cache_ttl));
}
fn find_oldest_entry<T>(
&self,
cache: &std::collections::HashMap<String, CacheEntry<T>>,
) -> Option<String> {
cache
.iter()
.min_by_key(|(_, entry)| entry.cached_at)
.map(|(key, _)| key.clone())
}
}
impl Default for AzureResourceCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct StatisticsConfig {
pub display_enabled: bool,
pub cache_ttl_seconds: u64,
pub use_management_api: bool,
}
impl StatisticsConfig {
pub fn new(display_enabled: bool, cache_ttl_seconds: u64, use_management_api: bool) -> Self {
Self {
display_enabled,
cache_ttl_seconds,
use_management_api,
}
}
}