use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use torsh_core::{Result as TorshResult, TorshError};
pub const PROFILER_NAMESPACE: &str = "torsh-profiler";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfilingJob {
pub api_version: String,
pub kind: String,
pub metadata: ProfilingJobMetadata,
pub spec: ProfilingJobSpec,
pub status: Option<ProfilingJobStatus>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfilingJobMetadata {
pub name: String,
pub namespace: String,
pub labels: HashMap<String, String>,
pub annotations: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfilingJobSpec {
pub selector: PodSelector,
pub profiling_config: ProfilingConfig,
pub export_config: ExportConfig,
pub duration_seconds: u64,
pub sampling_rate: f64,
pub distributed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PodSelector {
pub match_labels: HashMap<String, String>,
pub match_expressions: Vec<LabelSelectorRequirement>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LabelSelectorRequirement {
pub key: String,
pub operator: String,
pub values: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfilingConfig {
pub enable_cpu: bool,
pub enable_gpu: bool,
pub enable_memory: bool,
pub enable_network: bool,
pub enable_distributed: bool,
pub stack_trace_depth: usize,
pub max_overhead_percent: f64,
pub custom_params: HashMap<String, String>,
}
impl Default for ProfilingConfig {
fn default() -> Self {
Self {
enable_cpu: true,
enable_gpu: true,
enable_memory: true,
enable_network: false,
enable_distributed: false,
stack_trace_depth: 10,
max_overhead_percent: 5.0,
custom_params: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportConfig {
pub prometheus: Option<PrometheusExportConfig>,
pub cloudwatch: Option<CloudWatchExportConfig>,
pub grafana: Option<GrafanaExportConfig>,
pub object_storage: Option<ObjectStorageConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrometheusExportConfig {
pub pushgateway_url: String,
pub job_name: String,
pub labels: HashMap<String, String>,
pub push_interval_seconds: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CloudWatchExportConfig {
pub region: String,
pub namespace: String,
pub dimensions: HashMap<String, String>,
pub publish_interval_seconds: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrafanaExportConfig {
pub api_url: String,
pub api_token: String,
pub dashboard_uid: String,
pub auto_update: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ObjectStorageConfig {
pub provider: String,
pub bucket: String,
pub prefix: String,
pub credentials_secret: String,
pub upload_interval_seconds: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfilingJobStatus {
pub phase: String,
pub start_time: Option<String>,
pub completion_time: Option<String>,
pub profiled_pods: usize,
pub total_events: u64,
pub total_overhead_percent: f64,
pub message: Option<String>,
pub conditions: Vec<ProfilingCondition>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfilingCondition {
pub condition_type: String,
pub status: String,
pub last_transition_time: String,
pub reason: String,
pub message: String,
}
pub struct KubernetesProfilerOperator {
namespace: String,
jobs: HashMap<String, ProfilingJob>,
pod_profilers: HashMap<String, PodProfilerInstance>,
}
#[derive(Debug, Clone)]
pub struct PodProfilerInstance {
pub pod_name: String,
pub pod_namespace: String,
pub node_name: String,
pub active: bool,
pub events_collected: u64,
pub start_time: chrono::DateTime<chrono::Utc>,
pub last_export: Option<chrono::DateTime<chrono::Utc>>,
}
impl KubernetesProfilerOperator {
pub fn new(namespace: String) -> Self {
Self {
namespace,
jobs: HashMap::new(),
pod_profilers: HashMap::new(),
}
}
pub fn create_job(&mut self, job: ProfilingJob) -> TorshResult<()> {
let job_name = job.metadata.name.clone();
if self.jobs.contains_key(&job_name) {
return Err(TorshError::InvalidArgument(format!(
"ProfilingJob {} already exists",
job_name
)));
}
self.jobs.insert(job_name, job);
Ok(())
}
pub fn delete_job(&mut self, job_name: &str) -> TorshResult<()> {
self.jobs.remove(job_name).ok_or_else(|| {
TorshError::operation_error(&format!("ProfilingJob {} not found", job_name))
})?;
Ok(())
}
pub fn get_job_status(&self, job_name: &str) -> TorshResult<ProfilingJobStatus> {
let job = self.jobs.get(job_name).ok_or_else(|| {
TorshError::operation_error(&format!("ProfilingJob {} not found", job_name))
})?;
job.status
.clone()
.ok_or_else(|| TorshError::operation_error("Job status not available"))
}
pub fn list_jobs(&self) -> Vec<&ProfilingJob> {
self.jobs.values().collect()
}
pub fn register_pod(&mut self, instance: PodProfilerInstance) {
let key = format!("{}/{}", instance.pod_namespace, instance.pod_name);
self.pod_profilers.insert(key, instance);
}
pub fn unregister_pod(&mut self, pod_namespace: &str, pod_name: &str) {
let key = format!("{}/{}", pod_namespace, pod_name);
self.pod_profilers.remove(&key);
}
pub fn active_pods(&self) -> Vec<&PodProfilerInstance> {
self.pod_profilers.values().filter(|p| p.active).collect()
}
pub fn generate_configmap(&self, job_name: &str) -> TorshResult<String> {
let job = self.jobs.get(job_name).ok_or_else(|| {
TorshError::operation_error(&format!("ProfilingJob {} not found", job_name))
})?;
let configmap = format!(
r#"apiVersion: v1
kind: ConfigMap
metadata:
name: {}-config
namespace: {}
data:
profiling.json: |
{}
"#,
job_name,
job.metadata.namespace,
serde_json::to_string_pretty(&job.spec.profiling_config).map_err(|e| {
TorshError::operation_error(&format!("JSON serialization failed: {}", e))
})?
);
Ok(configmap)
}
pub fn generate_metrics_service(&self, job_name: &str) -> TorshResult<String> {
let job = self.jobs.get(job_name).ok_or_else(|| {
TorshError::operation_error(&format!("ProfilingJob {} not found", job_name))
})?;
let service = format!(
r#"apiVersion: v1
kind: Service
metadata:
name: {}-metrics
namespace: {}
labels:
app: torsh-profiler
job: {}
spec:
type: ClusterIP
selector:
app: torsh-profiler
job: {}
ports:
- name: metrics
port: 9090
targetPort: 9090
protocol: TCP
- name: http
port: 8080
targetPort: 8080
protocol: TCP
"#,
job_name, job.metadata.namespace, job_name, job_name
);
Ok(service)
}
pub fn generate_service_monitor(&self, job_name: &str) -> TorshResult<String> {
let job = self.jobs.get(job_name).ok_or_else(|| {
TorshError::operation_error(&format!("ProfilingJob {} not found", job_name))
})?;
let service_monitor = format!(
r#"apiVersion: monitoring.coreos.com/v1
kind: ServiceMonitor
metadata:
name: {}-monitor
namespace: {}
labels:
app: torsh-profiler
job: {}
spec:
selector:
matchLabels:
app: torsh-profiler
job: {}
endpoints:
- port: metrics
interval: 30s
path: /metrics
"#,
job_name, job.metadata.namespace, job_name, job_name
);
Ok(service_monitor)
}
pub fn export_state(&self) -> TorshResult<String> {
#[derive(Serialize)]
struct OperatorState<'a> {
namespace: &'a str,
active_jobs: usize,
active_pods: usize,
total_events: u64,
}
let total_events: u64 = self
.pod_profilers
.values()
.map(|p| p.events_collected)
.sum();
let state = OperatorState {
namespace: &self.namespace,
active_jobs: self.jobs.len(),
active_pods: self.active_pods().len(),
total_events,
};
serde_json::to_string_pretty(&state)
.map_err(|e| TorshError::operation_error(&format!("JSON export failed: {}", e)))
}
}
impl Default for KubernetesProfilerOperator {
fn default() -> Self {
Self::new(PROFILER_NAMESPACE.to_string())
}
}
pub struct HelmChartGenerator;
impl HelmChartGenerator {
pub fn generate_values_yaml(job_name: &str, config: &ProfilingConfig) -> String {
format!(
r#"# ToRSh Profiler Helm Chart Values
replicaCount: 1
image:
repository: torsh/profiler
pullPolicy: IfNotPresent
tag: "latest"
nameOverride: "{}"
fullnameOverride: "{}-profiler"
serviceAccount:
create: true
name: torsh-profiler
profiling:
enabled: true
cpuProfiling: {}
gpuProfiling: {}
memoryProfiling: {}
networkProfiling: {}
distributedProfiling: {}
stackTraceDepth: {}
maxOverheadPercent: {}
metrics:
enabled: true
port: 9090
serviceMonitor:
enabled: true
interval: 30s
resources:
limits:
cpu: 500m
memory: 512Mi
requests:
cpu: 100m
memory: 128Mi
nodeSelector: {{}}
tolerations: []
affinity: {{}}
"#,
job_name,
job_name,
config.enable_cpu,
config.enable_gpu,
config.enable_memory,
config.enable_network,
config.enable_distributed,
config.stack_trace_depth,
config.max_overhead_percent
)
}
pub fn generate_chart_yaml() -> String {
r#"apiVersion: v2
name: torsh-profiler
description: A Helm chart for ToRSh Profiler Operator
type: application
version: 0.1.0
appVersion: "0.1.0-alpha.2"
keywords:
- machine-learning
- profiling
- pytorch
- rust
maintainers:
- name: ToRSh Team
email: torsh@example.com
"#
.to_string()
}
pub fn generate_deployment_template() -> String {
r#"apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "torsh-profiler.fullname" . }}
labels:
{{- include "torsh-profiler.labels" . | nindent 4 }}
spec:
replicas: {{ .Values.replicaCount }}
selector:
matchLabels:
{{- include "torsh-profiler.selectorLabels" . | nindent 6 }}
template:
metadata:
labels:
{{- include "torsh-profiler.selectorLabels" . | nindent 8 }}
spec:
serviceAccountName: {{ include "torsh-profiler.serviceAccountName" . }}
containers:
- name: profiler
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
imagePullPolicy: {{ .Values.image.pullPolicy }}
ports:
- name: metrics
containerPort: 9090
protocol: TCP
- name: http
containerPort: 8080
protocol: TCP
env:
- name: PROFILER_NAMESPACE
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: POD_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: NODE_NAME
valueFrom:
fieldRef:
fieldPath: spec.nodeName
resources:
{{- toYaml .Values.resources | nindent 12 }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
"#
.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_operator_creation() {
let operator = KubernetesProfilerOperator::new("default".to_string());
assert_eq!(operator.namespace, "default");
assert_eq!(operator.jobs.len(), 0);
assert_eq!(operator.pod_profilers.len(), 0);
}
#[test]
fn test_job_creation() {
let mut operator = KubernetesProfilerOperator::default();
let job = ProfilingJob {
api_version: "profiler.torsh.dev/v1".to_string(),
kind: "ProfilingJob".to_string(),
metadata: ProfilingJobMetadata {
name: "test-job".to_string(),
namespace: "default".to_string(),
labels: HashMap::new(),
annotations: HashMap::new(),
},
spec: ProfilingJobSpec {
selector: PodSelector {
match_labels: vec![("app".to_string(), "training".to_string())]
.into_iter()
.collect(),
match_expressions: vec![],
},
profiling_config: ProfilingConfig::default(),
export_config: ExportConfig {
prometheus: None,
cloudwatch: None,
grafana: None,
object_storage: None,
},
duration_seconds: 3600,
sampling_rate: 1.0,
distributed: false,
},
status: None,
};
operator.create_job(job).unwrap();
assert_eq!(operator.jobs.len(), 1);
assert!(operator.jobs.contains_key("test-job"));
}
#[test]
fn test_pod_registration() {
let mut operator = KubernetesProfilerOperator::default();
let instance = PodProfilerInstance {
pod_name: "training-pod-1".to_string(),
pod_namespace: "default".to_string(),
node_name: "node-1".to_string(),
active: true,
events_collected: 0,
start_time: chrono::Utc::now(),
last_export: None,
};
operator.register_pod(instance);
assert_eq!(operator.pod_profilers.len(), 1);
assert_eq!(operator.active_pods().len(), 1);
}
#[test]
fn test_configmap_generation() {
let mut operator = KubernetesProfilerOperator::default();
let job = ProfilingJob {
api_version: "profiler.torsh.dev/v1".to_string(),
kind: "ProfilingJob".to_string(),
metadata: ProfilingJobMetadata {
name: "test-job".to_string(),
namespace: "default".to_string(),
labels: HashMap::new(),
annotations: HashMap::new(),
},
spec: ProfilingJobSpec {
selector: PodSelector {
match_labels: HashMap::new(),
match_expressions: vec![],
},
profiling_config: ProfilingConfig::default(),
export_config: ExportConfig {
prometheus: None,
cloudwatch: None,
grafana: None,
object_storage: None,
},
duration_seconds: 3600,
sampling_rate: 1.0,
distributed: false,
},
status: None,
};
operator.create_job(job).unwrap();
let configmap = operator.generate_configmap("test-job").unwrap();
assert!(configmap.contains("kind: ConfigMap"));
assert!(configmap.contains("test-job-config"));
}
#[test]
fn test_service_generation() {
let mut operator = KubernetesProfilerOperator::default();
let job = ProfilingJob {
api_version: "profiler.torsh.dev/v1".to_string(),
kind: "ProfilingJob".to_string(),
metadata: ProfilingJobMetadata {
name: "test-job".to_string(),
namespace: "default".to_string(),
labels: HashMap::new(),
annotations: HashMap::new(),
},
spec: ProfilingJobSpec {
selector: PodSelector {
match_labels: HashMap::new(),
match_expressions: vec![],
},
profiling_config: ProfilingConfig::default(),
export_config: ExportConfig {
prometheus: None,
cloudwatch: None,
grafana: None,
object_storage: None,
},
duration_seconds: 3600,
sampling_rate: 1.0,
distributed: false,
},
status: None,
};
operator.create_job(job).unwrap();
let service = operator.generate_metrics_service("test-job").unwrap();
assert!(service.contains("kind: Service"));
assert!(service.contains("test-job-metrics"));
}
#[test]
fn test_helm_chart_generation() {
let values =
HelmChartGenerator::generate_values_yaml("my-job", &ProfilingConfig::default());
assert!(values.contains("my-job"));
assert!(values.contains("cpuProfiling: true"));
let chart = HelmChartGenerator::generate_chart_yaml();
assert!(chart.contains("torsh-profiler"));
let deployment = HelmChartGenerator::generate_deployment_template();
assert!(deployment.contains("Deployment"));
}
#[test]
fn test_state_export() {
let operator = KubernetesProfilerOperator::default();
let state = operator.export_state().unwrap();
assert!(state.contains("namespace"));
assert!(state.contains("active_jobs"));
}
}