cloud_terrastodon_azure 0.35.1

Helpers for interacting with Azure for the Cloud Terrastodon project
use cloud_terrastodon_azure_types::AzureTenantId;
use cloud_terrastodon_azure_types::ResourceGraphQueryResponse;
use cloud_terrastodon_command::CacheKey;
use cloud_terrastodon_command::CommandBuilder;
use cloud_terrastodon_command::CommandKind;
use cloud_terrastodon_command::FromCommandOutput;
use cloud_terrastodon_relative_location::RelativeLocation;
use eyre::Context;
use eyre::Result;
#[cfg(debug_assertions)]
use eyre::bail;
use serde::Deserialize;
use serde::Serialize;
#[cfg(debug_assertions)]
use std::collections::HashSet;
use std::future::Future;
use std::panic::Location;
use std::path::PathBuf;
use tracing::debug;

pub struct ResourceGraphHelper {
    query: String,
    cache_behaviour: Option<CacheKey>,
    tenant_id: AzureTenantId,
    skip: Option<(u64, String)>,
    index: usize,
    #[cfg(debug_assertions)]
    seen_skip_tokens: HashSet<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ResourceGraphQueryRestOptions {
    #[serde(rename = "$skip")]
    skip: u64,
    #[serde(rename = "$top")]
    top: u64,
    #[serde(rename = "$skipToken")]
    skip_token: Option<String>,
    #[serde(rename = "authorizationScopeFilter")]
    authorization_scope_filter: ResourceGraphQueryRestScopeFilterOption,
    #[serde(rename = "resultFormat")]
    result_format: QueryRestResultFormat,
}

#[derive(Debug, Serialize, Deserialize)]
pub enum ResourceGraphQueryRestScopeFilterOption {
    AtScopeAboveAndBelow,
}

#[derive(Debug, Serialize, Deserialize)]
pub enum QueryRestResultFormat {
    #[serde(rename = "table")]
    Table,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ResourceGraphQueryRestBody {
    query: String,
    options: ResourceGraphQueryRestOptions,
}

impl ResourceGraphHelper {
    pub fn new(
        tenant_id: AzureTenantId,
        query: impl Into<String>,
        cache_behaviour: Option<CacheKey>,
    ) -> Self {
        Self {
            query: query.into(),
            cache_behaviour,
            tenant_id,
            skip: None,
            index: 0,
            #[cfg(debug_assertions)]
            seen_skip_tokens: Default::default(),
        }
    }

    fn get_command(&self, body: String) -> CommandBuilder {
        let mut cmd = CommandBuilder::new(CommandKind::CloudTerrastodon);
        cmd.args(["rest", "--method", "POST", "--url", "https://management.azure.com/providers/Microsoft.ResourceGraph/resources?api-version=2022-10-01"]);
        cmd.cache(self.cache_behaviour.clone().unwrap_or_else(|| {
            CacheKey::new(PathBuf::from_iter(["az", "resource_graph", "query"]))
        }));
        cmd.arg("--body");
        cmd.azure_file_arg("body.json", body);
        let tenant_id = self.tenant_id.to_string();
        cmd.args(["--tenant", tenant_id.as_str()]);
        cmd.use_cache(self.cache_behaviour.clone());
        cmd
    }

    #[track_caller]
    pub fn fetch<T: FromCommandOutput>(
        &mut self,
    ) -> impl Future<Output = Result<Option<ResourceGraphQueryResponse<T>>>> + '_ {
        self.fetch_from(Location::caller())
    }

    async fn fetch_from<T: FromCommandOutput>(
        &mut self,
        caller: &'static Location<'static>,
    ) -> Result<Option<ResourceGraphQueryResponse<T>>> {
        async {
            #[cfg(debug_assertions)]
            if let Some((_, token)) = &self.skip
                && !self.seen_skip_tokens.insert(token.to_owned())
            {
                bail!("Saw the same skip token twice, infinite loop detected");
            }

            // Previously tried using `az graph query` but hit issues with scopes.
            // We use the REST endpoint so we can pass authorizationScopeFilter.
            let batch_size = 1000;
            let (skip, skip_token) = match &self.skip {
                Some((skip, token)) => (*skip, Some(token.to_owned())),
                None => (0u64, None),
            };
            let body = serde_json::to_string_pretty(&ResourceGraphQueryRestBody {
                query: self.query.to_string(),
                options: ResourceGraphQueryRestOptions {
                    skip,
                    top: batch_size,
                    skip_token,
                    authorization_scope_filter:
                        ResourceGraphQueryRestScopeFilterOption::AtScopeAboveAndBelow,
                    result_format: QueryRestResultFormat::Table,
                },
            })?;
            let mut cmd = self.get_command(body);

            // Set up caching
            if let Some(CacheKey {
                ref path,
                ref valid_for,
            }) = self.cache_behaviour
            {
                cmd.cache(CacheKey {
                    path: path.join(self.index.to_string()),
                    valid_for: *valid_for,
                });
            }

            debug!(
                batch_index=self.index,
                batch_size,
                skip,
                ?self.tenant_id,
                ?self.cache_behaviour,
                "Fetching resource graph batch",
            );

            // Run command
            // TODO: handle throttling
            // https://learn.microsoft.com/en-us/azure/governance/resource-graph/overview#throttling
            // https://learn.microsoft.com/en-us/azure/governance/resource-graph/concepts/guidance-for-throttled-requests
            let results = cmd.run::<ResourceGraphQueryResponse<T>>().await?;

            // Increment index for the next potential query
            self.index += 1;

            // Update skip token
            if let Some(skip_token) = &results.skip_token {
                self.skip
                    .replace((skip + results.count, skip_token.to_owned()));
            } else {
                self.skip.clone_from(&None);
            }

            // // Transform results
            // let results: QueryResponse<T> = results.try_into()?;

            eyre::Ok(Some(results))
        }
        .await
        .wrap_err(format!(
            "ResourceGraphHelper::fetch failed, called from {}",
            RelativeLocation::from(caller)
        ))
    }

    #[track_caller]
    pub fn collect_all<T: FromCommandOutput>(
        &mut self,
    ) -> impl Future<Output = Result<Vec<T>>> + '_ {
        self.collect_all_from(Location::caller())
    }

    async fn collect_all_from<T: FromCommandOutput>(
        &mut self,
        caller: &'static Location<'static>,
    ) -> Result<Vec<T>> {
        let result: Result<Vec<T>> = async {
            let mut all_data = Vec::new();
            while let Some(response) = self.fetch_from(caller).await? {
                all_data.extend(response.data);

                if self.skip.is_none() {
                    break;
                }
            }

            debug!(
                total_items=all_data.len(),
                ?self.tenant_id,
                ?self.cache_behaviour,
                "Completed fetching all resource graph data",
            );

            Ok(all_data)
        }
        .await;

        result.wrap_err(format!(
            "ResourceGraphHelper::collect_all failed, called from {}",
            RelativeLocation::from(caller)
        ))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde::Deserialize;
    use std::path::PathBuf;

    #[tokio::test]
    async fn it_works() -> Result<()> {
        let query = r#"
resourcecontainers
| project name
"#;
        #[derive(Deserialize)]
        struct Row {
            name: String,
        }
        let data = ResourceGraphHelper::new(
            crate::get_test_tenant_id().await?,
            query,
            Some(CacheKey::new(PathBuf::from_iter([
                "az",
                "resource_graph",
                "resource-container-names",
            ]))),
        )
        .collect_all::<Row>()
        .await?;
        assert!(data.len() > 10);
        assert!(data.iter().all(|row| !row.name.is_empty()));
        Ok(())
    }
}