use dactor::{ClusterDiscovery, DiscoveryError};
use std::fmt;
#[derive(Debug)]
pub enum AzureDiscoveryError {
ImdsError(String),
ArmApiError(String),
HttpError(reqwest::Error),
ParseError(String),
Config(String),
}
impl fmt::Display for AzureDiscoveryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AzureDiscoveryError::ImdsError(e) => write!(f, "IMDS error: {e}"),
AzureDiscoveryError::ArmApiError(e) => write!(f, "ARM API error: {e}"),
AzureDiscoveryError::HttpError(e) => write!(f, "HTTP error: {e}"),
AzureDiscoveryError::ParseError(e) => write!(f, "parse error: {e}"),
AzureDiscoveryError::Config(e) => write!(f, "configuration error: {e}"),
}
}
}
impl std::error::Error for AzureDiscoveryError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
AzureDiscoveryError::HttpError(e) => Some(e),
_ => None,
}
}
}
impl From<reqwest::Error> for AzureDiscoveryError {
fn from(e: reqwest::Error) -> Self {
AzureDiscoveryError::HttpError(e)
}
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct ImdsResponse {
compute: ImdsCompute,
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct ImdsCompute {
subscription_id: String,
resource_group_name: String,
#[serde(default, rename = "vmScaleSetName")]
vmss_name: Option<String>,
}
#[derive(Debug, Clone, serde::Deserialize)]
struct ArmListResponse<T> {
value: Vec<T>,
#[serde(default, rename = "nextLink")]
next_link: Option<String>,
}
#[derive(Debug, Clone, serde::Deserialize)]
struct ArmNetworkInterface {
properties: ArmNicProperties,
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct ArmNicProperties {
ip_configurations: Vec<ArmIpConfiguration>,
}
#[derive(Debug, Clone, serde::Deserialize)]
struct ArmIpConfiguration {
properties: ArmIpConfigProperties,
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct ArmIpConfigProperties {
private_ip_address: Option<String>,
}
#[derive(Debug, Clone, serde::Deserialize)]
struct TokenResponse {
access_token: String,
}
async fn acquire_managed_identity_token(
client: &reqwest::Client,
) -> Result<String, AzureDiscoveryError> {
let resp = client
.get("http://169.254.169.254/metadata/identity/oauth2/token")
.header("Metadata", "true")
.query(&[
("api-version", "2019-08-01"),
("resource", "https://management.azure.com/"),
])
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AzureDiscoveryError::ImdsError(format!(
"token request failed ({status}): {body}"
)));
}
let token: TokenResponse = resp
.json()
.await
.map_err(|e| AzureDiscoveryError::ParseError(format!("token response: {e}")))?;
Ok(token.access_token)
}
const IMDS_BASE: &str = "http://169.254.169.254";
const IMDS_API_VERSION: &str = "2021-02-01";
const ARM_API_VERSION_NIC: &str = "2023-09-01";
const ARM_API_VERSION_VM: &str = "2023-09-01";
async fn query_imds(client: &reqwest::Client) -> Result<ImdsResponse, AzureDiscoveryError> {
let resp = client
.get(format!("{IMDS_BASE}/metadata/instance"))
.header("Metadata", "true")
.query(&[("api-version", IMDS_API_VERSION)])
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AzureDiscoveryError::ImdsError(format!(
"IMDS returned {status}: {body}"
)));
}
resp.json()
.await
.map_err(|e| AzureDiscoveryError::ParseError(format!("IMDS response: {e}")))
}
pub async fn current_subscription_id() -> Option<String> {
let client = reqwest::Client::builder().timeout(std::time::Duration::from_secs(10)).build().unwrap_or_default();
query_imds(&client)
.await
.ok()
.map(|r| r.compute.subscription_id)
}
pub async fn current_resource_group() -> Option<String> {
let client = reqwest::Client::builder().timeout(std::time::Duration::from_secs(10)).build().unwrap_or_default();
query_imds(&client)
.await
.ok()
.map(|r| r.compute.resource_group_name)
}
pub fn imds_instance_url() -> String {
format!(
"{IMDS_BASE}/metadata/instance?api-version={IMDS_API_VERSION}"
)
}
#[derive(Debug, Clone)]
pub struct VmssDiscoveryConfig {
pub port: u16,
pub use_imds: bool,
pub subscription_id: Option<String>,
pub resource_group: Option<String>,
pub vmss_name: Option<String>,
}
impl Default for VmssDiscoveryConfig {
fn default() -> Self {
Self {
port: 9000,
use_imds: true,
subscription_id: None,
resource_group: None,
vmss_name: None,
}
}
}
pub struct VmssDiscovery {
config: VmssDiscoveryConfig,
client: reqwest::Client,
}
impl VmssDiscovery {
pub fn builder() -> VmssDiscoveryBuilder {
VmssDiscoveryBuilder {
config: VmssDiscoveryConfig::default(),
}
}
pub fn config(&self) -> &VmssDiscoveryConfig {
&self.config
}
async fn resolve_vmss_info(
&self,
) -> Result<(String, String, String), AzureDiscoveryError> {
if let (Some(sub), Some(rg), Some(vmss)) = (
self.config.subscription_id.clone(),
self.config.resource_group.clone(),
self.config.vmss_name.clone(),
) {
return Ok((sub, rg, vmss));
}
if !self.config.use_imds {
return Err(AzureDiscoveryError::Config(
"use_imds is false but subscription_id, resource_group, or vmss_name is missing"
.to_string(),
));
}
let imds = query_imds(&self.client).await?;
let sub = self
.config
.subscription_id
.clone()
.unwrap_or(imds.compute.subscription_id);
let rg = self
.config
.resource_group
.clone()
.unwrap_or(imds.compute.resource_group_name);
let vmss = self.config.vmss_name.clone().or(imds.compute.vmss_name).ok_or_else(
|| {
AzureDiscoveryError::ImdsError(
"current VM is not part of a VMSS".to_string(),
)
},
)?;
Ok((sub, rg, vmss))
}
pub async fn discover_instances(&self) -> Result<Vec<String>, AzureDiscoveryError> {
let (subscription_id, resource_group, vmss_name) =
self.resolve_vmss_info().await?;
let token = acquire_managed_identity_token(&self.client).await?;
let url = format!(
"https://management.azure.com/subscriptions/{subscription_id}\
/resourceGroups/{resource_group}\
/providers/Microsoft.Compute/virtualMachineScaleSets/{vmss_name}\
/networkInterfaces?api-version={ARM_API_VERSION_NIC}"
);
let mut addresses = Vec::new();
let mut next_url: Option<String> = Some(url);
while let Some(page_url) = next_url.take() {
let resp = self
.client
.get(&page_url)
.bearer_auth(&token)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AzureDiscoveryError::ArmApiError(format!(
"list NICs failed ({status}): {body}"
)));
}
let page: ArmListResponse<ArmNetworkInterface> = resp
.json()
.await
.map_err(|e| AzureDiscoveryError::ParseError(format!("NIC list: {e}")))?;
for nic in &page.value {
for ip_config in &nic.properties.ip_configurations {
if let Some(ip) = &ip_config.properties.private_ip_address {
addresses.push(format!("{ip}:{}", self.config.port));
}
}
}
next_url = page.next_link;
}
tracing::debug!(count = addresses.len(), "VMSS discovery complete");
Ok(addresses)
}
}
#[async_trait::async_trait]
impl ClusterDiscovery for VmssDiscovery {
async fn discover(&self) -> Result<Vec<dactor::DiscoveredPeer>, DiscoveryError> {
self.discover_instances()
.await
.map(|addrs| addrs.into_iter().map(dactor::DiscoveredPeer::from_address).collect())
.map_err(|e| DiscoveryError::new(e.to_string()))
}
}
pub struct VmssDiscoveryBuilder {
config: VmssDiscoveryConfig,
}
impl VmssDiscoveryBuilder {
pub fn port(mut self, port: u16) -> Self {
self.config.port = port;
self
}
pub fn use_imds(mut self, yes: bool) -> Self {
self.config.use_imds = yes;
self
}
pub fn subscription_id(mut self, id: &str) -> Self {
self.config.subscription_id = Some(id.to_string());
self
}
pub fn resource_group(mut self, rg: &str) -> Self {
self.config.resource_group = Some(rg.to_string());
self
}
pub fn vmss_name(mut self, name: &str) -> Self {
self.config.vmss_name = Some(name.to_string());
self
}
pub fn build(self) -> VmssDiscovery {
VmssDiscovery {
config: self.config,
client: reqwest::Client::builder().timeout(std::time::Duration::from_secs(10)).build().unwrap_or_default(),
}
}
}
#[derive(Debug, Clone)]
pub struct AzureTagConfig {
pub tag_key: String,
pub tag_value: String,
pub port: u16,
pub subscription_id: Option<String>,
pub resource_group: Option<String>,
}
impl Default for AzureTagConfig {
fn default() -> Self {
Self {
tag_key: String::new(),
tag_value: String::new(),
port: 9000,
subscription_id: None,
resource_group: None,
}
}
}
pub struct AzureTagDiscovery {
config: AzureTagConfig,
client: reqwest::Client,
}
impl AzureTagDiscovery {
pub fn builder() -> AzureTagDiscoveryBuilder {
AzureTagDiscoveryBuilder {
config: AzureTagConfig::default(),
}
}
pub fn config(&self) -> &AzureTagConfig {
&self.config
}
async fn resolve_subscription(&self) -> Result<String, AzureDiscoveryError> {
if let Some(sub) = &self.config.subscription_id {
return Ok(sub.clone());
}
let imds = query_imds(&self.client).await?;
Ok(imds.compute.subscription_id)
}
pub async fn discover_by_tag(&self) -> Result<Vec<String>, AzureDiscoveryError> {
if self.config.tag_key.is_empty() {
return Err(AzureDiscoveryError::Config(
"tag_key must not be empty".to_string(),
));
}
let subscription_id = self.resolve_subscription().await?;
let token = acquire_managed_identity_token(&self.client).await?;
let base_url = if let Some(rg) = &self.config.resource_group {
format!(
"https://management.azure.com/subscriptions/{subscription_id}\
/resourceGroups/{rg}\
/providers/Microsoft.Compute/virtualMachines\
?api-version={ARM_API_VERSION_VM}"
)
} else {
format!(
"https://management.azure.com/subscriptions/{subscription_id}\
/providers/Microsoft.Compute/virtualMachines\
?api-version={ARM_API_VERSION_VM}"
)
};
let mut addresses = Vec::new();
let mut next_url: Option<String> = Some(base_url);
while let Some(page_url) = next_url.take() {
let resp = self
.client
.get(&page_url)
.bearer_auth(&token)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AzureDiscoveryError::ArmApiError(format!(
"list VMs failed ({status}): {body}"
)));
}
let page: ArmListResponse<serde_json::Value> = resp
.json()
.await
.map_err(|e| AzureDiscoveryError::ParseError(format!("VM list: {e}")))?;
for vm in &page.value {
let tags = vm.get("tags").and_then(|t| t.as_object());
let matches = tags
.and_then(|t| t.get(&self.config.tag_key))
.and_then(|v| v.as_str())
.map(|v| v == self.config.tag_value)
.unwrap_or(false);
if !matches {
continue;
}
if let Some(nic_id) = vm
.pointer("/properties/networkProfile/networkInterfaces/0/id")
.and_then(|v| v.as_str())
{
if let Ok(ip) = self.fetch_nic_private_ip(nic_id, &token).await {
addresses.push(format!("{ip}:{}", self.config.port));
}
}
}
next_url = page.next_link;
}
tracing::debug!(count = addresses.len(), "Azure tag discovery complete");
Ok(addresses)
}
async fn fetch_nic_private_ip(
&self,
nic_id: &str,
token: &str,
) -> Result<String, AzureDiscoveryError> {
let url = format!(
"https://management.azure.com{nic_id}?api-version={ARM_API_VERSION_NIC}"
);
let resp = self
.client
.get(&url)
.bearer_auth(token)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AzureDiscoveryError::ArmApiError(format!(
"get NIC failed ({status}): {body}"
)));
}
let nic: ArmNetworkInterface = resp
.json()
.await
.map_err(|e| AzureDiscoveryError::ParseError(format!("NIC details: {e}")))?;
nic.properties
.ip_configurations
.first()
.and_then(|c| c.properties.private_ip_address.clone())
.ok_or_else(|| {
AzureDiscoveryError::ArmApiError(
"NIC has no private IP configuration".to_string(),
)
})
}
}
#[async_trait::async_trait]
impl ClusterDiscovery for AzureTagDiscovery {
async fn discover(&self) -> Result<Vec<dactor::DiscoveredPeer>, DiscoveryError> {
self.discover_by_tag()
.await
.map(|addrs| addrs.into_iter().map(dactor::DiscoveredPeer::from_address).collect())
.map_err(|e| DiscoveryError::new(e.to_string()))
}
}
pub struct AzureTagDiscoveryBuilder {
config: AzureTagConfig,
}
impl AzureTagDiscoveryBuilder {
pub fn tag_key(mut self, key: &str) -> Self {
self.config.tag_key = key.to_string();
self
}
pub fn tag_value(mut self, value: &str) -> Self {
self.config.tag_value = value.to_string();
self
}
pub fn port(mut self, port: u16) -> Self {
self.config.port = port;
self
}
pub fn subscription_id(mut self, id: &str) -> Self {
self.config.subscription_id = Some(id.to_string());
self
}
pub fn resource_group(mut self, rg: &str) -> Self {
self.config.resource_group = Some(rg.to_string());
self
}
pub fn build(self) -> AzureTagDiscovery {
AzureTagDiscovery {
config: self.config,
client: reqwest::Client::builder().timeout(std::time::Duration::from_secs(10)).build().unwrap_or_default(),
}
}
}
pub fn vm_private_ip() -> Option<String> {
std::env::var("DACTOR_VM_IP").ok()
}
pub fn subscription_id() -> Option<String> {
std::env::var("AZURE_SUBSCRIPTION_ID").ok()
}
pub fn resource_group() -> Option<String> {
std::env::var("AZURE_RESOURCE_GROUP").ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vmss_builder_creates_valid_config() {
let discovery = VmssDiscovery::builder()
.port(8080)
.use_imds(false)
.subscription_id("sub-123")
.resource_group("my-rg")
.vmss_name("my-vmss")
.build();
assert_eq!(discovery.config().port, 8080);
assert!(!discovery.config().use_imds);
assert_eq!(
discovery.config().subscription_id.as_deref(),
Some("sub-123")
);
assert_eq!(
discovery.config().resource_group.as_deref(),
Some("my-rg")
);
assert_eq!(
discovery.config().vmss_name.as_deref(),
Some("my-vmss")
);
}
#[test]
fn vmss_builder_default_values() {
let discovery = VmssDiscovery::builder().build();
assert_eq!(discovery.config().port, 9000);
assert!(discovery.config().use_imds);
assert!(discovery.config().subscription_id.is_none());
assert!(discovery.config().resource_group.is_none());
assert!(discovery.config().vmss_name.is_none());
}
#[test]
fn vmss_default_config() {
let cfg = VmssDiscoveryConfig::default();
assert_eq!(cfg.port, 9000);
assert!(cfg.use_imds);
assert!(cfg.subscription_id.is_none());
assert!(cfg.resource_group.is_none());
assert!(cfg.vmss_name.is_none());
}
#[test]
fn tag_builder_creates_valid_config() {
let discovery = AzureTagDiscovery::builder()
.tag_key("dactor-cluster")
.tag_value("production")
.port(7000)
.subscription_id("sub-456")
.resource_group("prod-rg")
.build();
assert_eq!(discovery.config().tag_key, "dactor-cluster");
assert_eq!(discovery.config().tag_value, "production");
assert_eq!(discovery.config().port, 7000);
assert_eq!(
discovery.config().subscription_id.as_deref(),
Some("sub-456")
);
assert_eq!(
discovery.config().resource_group.as_deref(),
Some("prod-rg")
);
}
#[test]
fn tag_builder_default_values() {
let discovery = AzureTagDiscovery::builder()
.tag_key("cluster")
.tag_value("dev")
.build();
assert_eq!(discovery.config().port, 9000);
assert!(discovery.config().subscription_id.is_none());
assert!(discovery.config().resource_group.is_none());
}
#[test]
fn tag_default_config() {
let cfg = AzureTagConfig::default();
assert!(cfg.tag_key.is_empty());
assert!(cfg.tag_value.is_empty());
assert_eq!(cfg.port, 9000);
assert!(cfg.subscription_id.is_none());
assert!(cfg.resource_group.is_none());
}
#[test]
fn vm_private_ip_returns_none_outside_azure() {
std::env::remove_var("DACTOR_VM_IP");
assert!(vm_private_ip().is_none());
}
#[test]
fn subscription_id_returns_none_outside_azure() {
std::env::remove_var("AZURE_SUBSCRIPTION_ID");
assert!(subscription_id().is_none());
}
#[test]
fn resource_group_returns_none_outside_azure() {
std::env::remove_var("AZURE_RESOURCE_GROUP");
assert!(resource_group().is_none());
}
#[test]
fn error_display_imds() {
let err = AzureDiscoveryError::ImdsError("timeout".to_string());
assert_eq!(err.to_string(), "IMDS error: timeout");
}
#[test]
fn error_display_arm_api() {
let err = AzureDiscoveryError::ArmApiError("403 forbidden".to_string());
assert_eq!(err.to_string(), "ARM API error: 403 forbidden");
}
#[test]
fn error_display_parse() {
let err = AzureDiscoveryError::ParseError("invalid json".to_string());
assert_eq!(err.to_string(), "parse error: invalid json");
}
#[test]
fn error_display_config() {
let err = AzureDiscoveryError::Config("missing subscription".to_string());
assert_eq!(err.to_string(), "configuration error: missing subscription");
}
#[test]
fn imds_url_contains_api_version() {
let url = imds_instance_url();
assert!(url.starts_with("http://169.254.169.254/metadata/instance"));
assert!(url.contains("api-version=2021-02-01"));
}
}