use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use torsh_core::{Result as TorshResult, TorshError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CloudProvider {
AWS,
Azure,
GCP,
Alibaba,
Oracle,
IBM,
OnPremise,
Unknown,
}
impl CloudProvider {
pub fn detect() -> Self {
if std::env::var("AWS_REGION").is_ok() || std::env::var("AWS_DEFAULT_REGION").is_ok() {
return Self::AWS;
}
if std::env::var("AZURE_SUBSCRIPTION_ID").is_ok()
|| std::env::var("AZURE_TENANT_ID").is_ok()
{
return Self::Azure;
}
if std::env::var("GOOGLE_CLOUD_PROJECT").is_ok() || std::env::var("GCP_PROJECT").is_ok() {
return Self::GCP;
}
Self::Unknown
}
}
impl std::fmt::Display for CloudProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AWS => write!(f, "Amazon Web Services"),
Self::Azure => write!(f, "Microsoft Azure"),
Self::GCP => write!(f, "Google Cloud Platform"),
Self::Alibaba => write!(f, "Alibaba Cloud"),
Self::Oracle => write!(f, "Oracle Cloud"),
Self::IBM => write!(f, "IBM Cloud"),
Self::OnPremise => write!(f, "On-Premise"),
Self::Unknown => write!(f, "Unknown"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CloudInstanceMetadata {
pub provider: CloudProvider,
pub instance_id: String,
pub instance_type: String,
pub region: String,
pub availability_zone: Option<String>,
pub gpu_count: usize,
pub gpu_type: Option<String>,
pub cpu_count: usize,
pub memory_gb: usize,
pub is_spot: bool,
pub metadata: HashMap<String, String>,
}
impl CloudInstanceMetadata {
pub fn detect() -> TorshResult<Self> {
let provider = CloudProvider::detect();
match provider {
CloudProvider::AWS => Self::detect_aws(),
CloudProvider::Azure => Self::detect_azure(),
CloudProvider::GCP => Self::detect_gcp(),
_ => Ok(Self::default_metadata()),
}
}
fn detect_aws() -> TorshResult<Self> {
let instance_id =
std::env::var("AWS_INSTANCE_ID").unwrap_or_else(|_| "i-unknown".to_string());
let instance_type =
std::env::var("AWS_INSTANCE_TYPE").unwrap_or_else(|_| "unknown".to_string());
let region = std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string());
Ok(Self {
provider: CloudProvider::AWS,
instance_id,
instance_type,
region,
availability_zone: None,
gpu_count: 0,
gpu_type: None,
cpu_count: num_cpus::get(),
memory_gb: 0,
is_spot: false,
metadata: HashMap::new(),
})
}
fn detect_azure() -> TorshResult<Self> {
let instance_id =
std::env::var("AZURE_INSTANCE_ID").unwrap_or_else(|_| "vm-unknown".to_string());
let instance_type =
std::env::var("AZURE_VM_SIZE").unwrap_or_else(|_| "unknown".to_string());
let region = std::env::var("AZURE_REGION").unwrap_or_else(|_| "eastus".to_string());
Ok(Self {
provider: CloudProvider::Azure,
instance_id,
instance_type,
region,
availability_zone: None,
gpu_count: 0,
gpu_type: None,
cpu_count: num_cpus::get(),
memory_gb: 0,
is_spot: false,
metadata: HashMap::new(),
})
}
fn detect_gcp() -> TorshResult<Self> {
let instance_id =
std::env::var("GCP_INSTANCE_ID").unwrap_or_else(|_| "instance-unknown".to_string());
let instance_type =
std::env::var("GCP_MACHINE_TYPE").unwrap_or_else(|_| "unknown".to_string());
let region = std::env::var("GCP_REGION").unwrap_or_else(|_| "us-central1".to_string());
Ok(Self {
provider: CloudProvider::GCP,
instance_id,
instance_type,
region,
availability_zone: None,
gpu_count: 0,
gpu_type: None,
cpu_count: num_cpus::get(),
memory_gb: 0,
is_spot: false,
metadata: HashMap::new(),
})
}
fn default_metadata() -> Self {
Self {
provider: CloudProvider::Unknown,
instance_id: "unknown".to_string(),
instance_type: "unknown".to_string(),
region: "unknown".to_string(),
availability_zone: None,
gpu_count: 0,
gpu_type: None,
cpu_count: num_cpus::get(),
memory_gb: 0,
is_spot: false,
metadata: HashMap::new(),
}
}
}
pub mod aws {
use super::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SageMakerProfilingConfig {
pub job_name: String,
pub s3_bucket: String,
pub s3_prefix: String,
pub profiling_interval_seconds: u64,
pub detailed_profiling: bool,
pub framework: String,
}
impl Default for SageMakerProfilingConfig {
fn default() -> Self {
Self {
job_name: "training-job".to_string(),
s3_bucket: "sagemaker-profiling".to_string(),
s3_prefix: "profiling-data".to_string(),
profiling_interval_seconds: 60,
detailed_profiling: true,
framework: "pytorch".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ECSTaskProfilingConfig {
pub task_definition: String,
pub cluster: String,
pub service: Option<String>,
pub log_group: String,
pub container_insights: bool,
}
pub struct AWSProfiler {
metadata: CloudInstanceMetadata,
sagemaker_config: Option<SageMakerProfilingConfig>,
}
impl AWSProfiler {
pub fn new() -> TorshResult<Self> {
let metadata = CloudInstanceMetadata::detect_aws()?;
Ok(Self {
metadata,
sagemaker_config: None,
})
}
pub fn configure_sagemaker(&mut self, config: SageMakerProfilingConfig) {
self.sagemaker_config = Some(config);
}
pub fn export_to_s3(&self, bucket: &str, prefix: &str) -> TorshResult<String> {
Ok(format!("s3://{}/{}/profiling-data.json", bucket, prefix))
}
pub fn instance_metadata(&self) -> &CloudInstanceMetadata {
&self.metadata
}
}
impl Default for AWSProfiler {
fn default() -> Self {
Self {
metadata: CloudInstanceMetadata::default_metadata(),
sagemaker_config: None,
}
}
}
}
pub mod azure {
use super::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AzureMLProfilingConfig {
pub workspace_name: String,
pub resource_group: String,
pub experiment_name: String,
pub storage_account: String,
pub container_name: String,
pub application_insights: bool,
}
impl Default for AzureMLProfilingConfig {
fn default() -> Self {
Self {
workspace_name: "ml-workspace".to_string(),
resource_group: "ml-resources".to_string(),
experiment_name: "training-experiment".to_string(),
storage_account: "mlstorage".to_string(),
container_name: "profiling-data".to_string(),
application_insights: true,
}
}
}
pub struct AzureProfiler {
metadata: CloudInstanceMetadata,
azureml_config: Option<AzureMLProfilingConfig>,
}
impl AzureProfiler {
pub fn new() -> TorshResult<Self> {
let metadata = CloudInstanceMetadata::detect_azure()?;
Ok(Self {
metadata,
azureml_config: None,
})
}
pub fn configure_azureml(&mut self, config: AzureMLProfilingConfig) {
self.azureml_config = Some(config);
}
pub fn export_to_blob_storage(
&self,
container: &str,
blob_name: &str,
) -> TorshResult<String> {
Ok(format!(
"https://{}.blob.core.windows.net/{}/{}",
self.azureml_config
.as_ref()
.map(|c| c.storage_account.as_str())
.unwrap_or("storage"),
container,
blob_name
))
}
pub fn instance_metadata(&self) -> &CloudInstanceMetadata {
&self.metadata
}
}
impl Default for AzureProfiler {
fn default() -> Self {
Self {
metadata: CloudInstanceMetadata::default_metadata(),
azureml_config: None,
}
}
}
}
pub mod gcp {
use super::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VertexAIProfilingConfig {
pub project_id: String,
pub location: String,
pub pipeline_id: String,
pub gcs_bucket: String,
pub gcs_prefix: String,
pub cloud_profiler: bool,
pub tensorboard_profiling: bool,
}
impl Default for VertexAIProfilingConfig {
fn default() -> Self {
Self {
project_id: "my-project".to_string(),
location: "us-central1".to_string(),
pipeline_id: "training-pipeline".to_string(),
gcs_bucket: "vertex-profiling".to_string(),
gcs_prefix: "profiling-data".to_string(),
cloud_profiler: true,
tensorboard_profiling: true,
}
}
}
pub struct GCPProfiler {
metadata: CloudInstanceMetadata,
vertex_config: Option<VertexAIProfilingConfig>,
}
impl GCPProfiler {
pub fn new() -> TorshResult<Self> {
let metadata = CloudInstanceMetadata::detect_gcp()?;
Ok(Self {
metadata,
vertex_config: None,
})
}
pub fn configure_vertex_ai(&mut self, config: VertexAIProfilingConfig) {
self.vertex_config = Some(config);
}
pub fn export_to_gcs(&self, bucket: &str, object_name: &str) -> TorshResult<String> {
Ok(format!("gs://{}/{}", bucket, object_name))
}
pub fn instance_metadata(&self) -> &CloudInstanceMetadata {
&self.metadata
}
}
impl Default for GCPProfiler {
fn default() -> Self {
Self {
metadata: CloudInstanceMetadata::default_metadata(),
vertex_config: None,
}
}
}
}
pub struct MultiCloudProfiler {
provider: CloudProvider,
metadata: CloudInstanceMetadata,
}
impl MultiCloudProfiler {
pub fn new() -> TorshResult<Self> {
let provider = CloudProvider::detect();
let metadata = CloudInstanceMetadata::detect()?;
Ok(Self { provider, metadata })
}
pub fn provider(&self) -> CloudProvider {
self.provider
}
pub fn metadata(&self) -> &CloudInstanceMetadata {
&self.metadata
}
pub fn is_cloud(&self, provider: CloudProvider) -> bool {
self.provider == provider
}
pub fn recommended_export_destination(&self) -> String {
match self.provider {
CloudProvider::AWS => "s3://profiling-bucket/data".to_string(),
CloudProvider::Azure => {
"https://storage.blob.core.windows.net/profiling/data".to_string()
}
CloudProvider::GCP => "gs://profiling-bucket/data".to_string(),
_ => std::env::temp_dir()
.join("profiling-data")
.display()
.to_string(),
}
}
pub fn estimated_cost_per_hour(&self) -> f64 {
match self.provider {
CloudProvider::AWS => {
if self.metadata.instance_type.contains("p3") {
3.06 } else if self.metadata.instance_type.contains("p4") {
7.10 } else {
0.10 }
}
CloudProvider::Azure => {
if self.metadata.instance_type.contains("NC") {
2.50 } else {
0.10
}
}
CloudProvider::GCP => {
if self.metadata.instance_type.contains("a2") {
3.00 } else {
0.10
}
}
_ => 0.0,
}
}
pub fn recommended_tags(&self) -> HashMap<String, String> {
let mut tags = HashMap::new();
tags.insert("profiler".to_string(), "torsh".to_string());
tags.insert("framework".to_string(), "pytorch-compatible".to_string());
tags.insert("cloud".to_string(), format!("{:?}", self.provider));
tags.insert("instance".to_string(), self.metadata.instance_id.clone());
tags
}
}
impl Default for MultiCloudProfiler {
fn default() -> Self {
Self {
provider: CloudProvider::Unknown,
metadata: CloudInstanceMetadata::default_metadata(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cloud_provider_detection() {
let provider = CloudProvider::detect();
println!("Detected cloud provider: {}", provider);
}
#[test]
fn test_cloud_metadata_detection() {
let metadata = CloudInstanceMetadata::detect();
if let Ok(meta) = metadata {
println!("Cloud metadata: {:?}", meta);
}
}
#[test]
fn test_multi_cloud_profiler() {
let profiler = MultiCloudProfiler::new();
if let Ok(p) = profiler {
println!("Provider: {}", p.provider());
println!("Recommended export: {}", p.recommended_export_destination());
println!("Estimated cost: ${:.2}/hour", p.estimated_cost_per_hour());
println!("Recommended tags: {:?}", p.recommended_tags());
}
}
#[test]
fn test_aws_profiler() {
let profiler = aws::AWSProfiler::default();
println!("AWS profiler created");
assert_eq!(
profiler.instance_metadata().provider,
CloudProvider::Unknown
);
}
#[test]
fn test_azure_profiler() {
let profiler = azure::AzureProfiler::default();
println!("Azure profiler created");
assert_eq!(
profiler.instance_metadata().provider,
CloudProvider::Unknown
);
}
#[test]
fn test_gcp_profiler() {
let profiler = gcp::GCPProfiler::default();
println!("GCP profiler created");
assert_eq!(
profiler.instance_metadata().provider,
CloudProvider::Unknown
);
}
#[test]
fn test_sagemaker_config() {
let config = aws::SageMakerProfilingConfig::default();
assert_eq!(config.framework, "pytorch");
assert!(config.detailed_profiling);
}
#[test]
fn test_azureml_config() {
let config = azure::AzureMLProfilingConfig::default();
assert_eq!(config.workspace_name, "ml-workspace");
assert!(config.application_insights);
}
#[test]
fn test_vertex_ai_config() {
let config = gcp::VertexAIProfilingConfig::default();
assert_eq!(config.location, "us-central1");
assert!(config.cloud_profiler);
assert!(config.tensorboard_profiling);
}
}