#![allow(dead_code)]
use async_lock::Mutex;
use serde::Deserialize;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use uuid::Uuid;
const IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/instance?api-version=2020-06-01";
const IMDS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
const IMDS_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
pub(crate) const VM_ID_PREFIX: &str = "vmId_";
pub(crate) const UUID_PREFIX: &str = "uuid_";
static VM_METADATA: OnceLock<Arc<VmMetadataServiceInner>> = OnceLock::new();
#[derive(Clone, Debug, Default, Deserialize)]
#[serde(default)]
pub(crate) struct AzureVmMetadata {
compute: ComputeMetadata,
}
impl AzureVmMetadata {
pub(crate) fn location(&self) -> &str {
&self.compute.location
}
pub(crate) fn sku(&self) -> &str {
&self.compute.sku
}
pub(crate) fn az_environment(&self) -> &str {
&self.compute.az_environment
}
pub(crate) fn os_type(&self) -> &str {
&self.compute.os_type
}
pub(crate) fn vm_size(&self) -> &str {
&self.compute.vm_size
}
pub(crate) fn vm_id(&self) -> &str {
&self.compute.vm_id
}
pub(crate) fn machine_id(&self) -> String {
if self.compute.vm_id.is_empty() {
String::new()
} else {
format!("{}{}", VM_ID_PREFIX, self.compute.vm_id)
}
}
pub(crate) fn host_env_info(&self) -> String {
format!(
"{}|{}|{}|{}",
self.os_type(),
self.sku(),
self.vm_size(),
self.az_environment()
)
}
}
#[derive(Clone, Debug, Default, Deserialize)]
#[serde(default, rename_all = "camelCase")]
struct ComputeMetadata {
location: String,
sku: String,
az_environment: String,
os_type: String,
vm_size: String,
vm_id: String,
}
#[non_exhaustive]
#[derive(Clone, Debug)]
pub(crate) struct VmMetadataService {
metadata: Option<Arc<AzureVmMetadata>>,
machine_id: Arc<String>,
}
impl VmMetadataService {
pub(crate) async fn get_or_init() -> Self {
let inner = VM_METADATA.get_or_init(|| Arc::new(VmMetadataServiceInner::new()));
let state = inner.ensure_initialized().await;
Self {
metadata: state.metadata.clone(),
machine_id: state
.machine_id
.clone()
.expect("machine_id is always set after initialization"),
}
}
pub(crate) fn metadata(&self) -> Option<&AzureVmMetadata> {
self.metadata.as_deref()
}
pub(crate) fn machine_id(&self) -> &str {
&self.machine_id
}
pub(crate) fn is_on_azure(&self) -> bool {
self.metadata.is_some()
}
}
#[derive(Debug)]
struct VmMetadataServiceInner {
state: Mutex<VmMetadataState>,
}
#[derive(Debug, Clone)]
struct VmMetadataState {
metadata: Option<Arc<AzureVmMetadata>>,
machine_id: Option<Arc<String>>,
fetch_complete: bool,
}
impl VmMetadataServiceInner {
fn new() -> Self {
Self {
state: Mutex::new(VmMetadataState {
metadata: None,
machine_id: None,
fetch_complete: false,
}),
}
}
async fn ensure_initialized(&self) -> VmMetadataState {
let mut state = self.state.lock().await;
if state.fetch_complete {
return state.clone();
}
Self::do_init(&mut state).await;
state.clone()
}
async fn do_init(state: &mut VmMetadataState) {
if std::env::var("COSMOS_DISABLE_IMDS").is_ok() {
tracing::info!("IMDS access disabled via COSMOS_DISABLE_IMDS");
state.machine_id = Some(Arc::new(Self::generate_fallback_machine_id()));
state.fetch_complete = true;
return;
}
match Self::do_fetch().await {
Ok(metadata) => {
tracing::debug!("Fetched Azure VM metadata: {:?}", metadata);
let vm_id = metadata.vm_id();
let machine_id = if vm_id.is_empty() {
Self::generate_fallback_machine_id()
} else {
format!("{}{}", VM_ID_PREFIX, vm_id)
};
state.machine_id = Some(Arc::new(machine_id));
state.metadata = Some(Arc::new(metadata));
}
Err(e) => {
tracing::debug!("Failed to fetch Azure VM metadata (not on Azure?): {}", e);
state.machine_id = Some(Arc::new(Self::generate_fallback_machine_id()));
}
}
state.fetch_complete = true;
}
fn generate_fallback_machine_id() -> String {
format!("{}{}", UUID_PREFIX, Uuid::new_v4())
}
#[cfg(feature = "reqwest")]
async fn do_fetch() -> azure_core::Result<AzureVmMetadata> {
let http_client = reqwest::Client::builder()
.connect_timeout(IMDS_CONNECT_TIMEOUT)
.timeout(IMDS_REQUEST_TIMEOUT)
.build()
.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Other, e))?;
let response = http_client
.get(IMDS_ENDPOINT)
.header("metadata", "true")
.send()
.await
.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Io, e))?;
let body = response
.text()
.await
.map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::Io, e))?;
let metadata: AzureVmMetadata = serde_json::from_str(&body)?;
Ok(metadata)
}
#[cfg(not(feature = "reqwest"))]
async fn do_fetch() -> azure_core::Result<AzureVmMetadata> {
Err(azure_core::Error::with_message(
azure_core::error::ErrorKind::Other,
"IMDS fetch requires the `reqwest` feature",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn azure_vm_metadata_deserialize() {
let json = r#"{
"compute": {
"location": "eastus",
"sku": "Standard",
"azEnvironment": "AzurePublicCloud",
"osType": "Linux",
"vmSize": "Standard_D2s_v3",
"vmId": "12345678-1234-1234-1234-123456789012"
}
}"#;
let metadata: AzureVmMetadata = serde_json::from_str(json).unwrap();
assert_eq!(metadata.location(), "eastus");
assert_eq!(metadata.sku(), "Standard");
assert_eq!(metadata.az_environment(), "AzurePublicCloud");
assert_eq!(metadata.os_type(), "Linux");
assert_eq!(metadata.vm_size(), "Standard_D2s_v3");
assert_eq!(metadata.vm_id(), "12345678-1234-1234-1234-123456789012");
assert_eq!(
metadata.machine_id(),
"vmId_12345678-1234-1234-1234-123456789012"
);
}
#[test]
fn azure_vm_metadata_empty() {
let metadata = AzureVmMetadata::default();
assert_eq!(metadata.location(), "");
assert_eq!(metadata.machine_id(), "");
}
#[test]
fn azure_vm_metadata_host_env_info() {
let json = r#"{
"compute": {
"osType": "Linux",
"sku": "18.04-LTS",
"vmSize": "Standard_D2s_v3",
"azEnvironment": "AzurePublicCloud"
}
}"#;
let metadata: AzureVmMetadata = serde_json::from_str(json).unwrap();
assert_eq!(
metadata.host_env_info(),
"Linux|18.04-LTS|Standard_D2s_v3|AzurePublicCloud"
);
}
}