use crate::{TorshDistributedError, TorshResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayConfig {
pub cluster: Option<RayClusterConfig>,
pub train: Option<RayTrainConfig>,
pub tune: Option<RayTuneConfig>,
pub serve: Option<RayServeConfig>,
pub data: Option<RayDataConfig>,
pub resources: Option<RayResourceConfig>,
pub fault_tolerance: Option<RayFaultToleranceConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayClusterConfig {
pub address: Option<String>,
pub redis_address: Option<String>,
pub num_cpus: Option<u32>,
pub num_gpus: Option<u32>,
pub memory_gb: Option<f32>,
pub object_store_memory_gb: Option<f32>,
pub namespace: Option<String>,
pub dashboard_host: Option<String>,
pub dashboard_port: Option<u16>,
pub include_dashboard: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayTrainConfig {
pub backend: RayTrainBackend,
pub num_workers: u32,
pub use_gpu: Option<bool>,
pub resources_per_worker: Option<HashMap<String, f32>>,
pub placement_group_strategy: Option<RayPlacementGroupStrategy>,
pub scaling_config: Option<RayScalingConfig>,
pub run_config: Option<RayRunConfig>,
pub checkpoint_config: Option<RayCheckpointConfig>,
pub failure_config: Option<RayFailureConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RayTrainBackend {
Torch,
TensorFlow,
Horovod,
MPI,
Custom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RayPlacementGroupStrategy {
StrictPack,
Pack,
StrictSpread,
Spread,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayScalingConfig {
pub num_workers: Option<u32>,
pub use_gpu: Option<bool>,
pub resources_per_worker: Option<HashMap<String, f32>>,
pub placement_group_strategy: Option<RayPlacementGroupStrategy>,
pub trainer_resources: Option<HashMap<String, f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayRunConfig {
pub name: Option<String>,
pub storage_path: Option<String>,
pub stop: Option<HashMap<String, f32>>,
pub checkpoint_freq: Option<u32>,
pub keep_checkpoints_num: Option<u32>,
pub checkpoint_score_attr: Option<String>,
pub checkpoint_mode: Option<RayCheckpointMode>,
pub verbose: Option<u32>,
pub progress_reporter: Option<RayProgressReporter>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RayCheckpointMode {
Max,
Min,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RayProgressReporter {
Default,
Json,
TensorBoard,
WandB,
MLflow,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayCheckpointConfig {
pub num_to_keep: Option<u32>,
pub checkpoint_frequency: Option<u32>,
pub checkpoint_at_end: Option<bool>,
pub checkpoint_score_attribute: Option<String>,
pub checkpoint_mode: Option<RayCheckpointMode>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayFailureConfig {
pub max_failures: Option<u32>,
pub failure_handling: Option<RayFailureHandling>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RayFailureHandling {
Restart,
Ignore,
Fail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayTuneConfig {
pub search_alg: Option<RaySearchAlgorithm>,
pub scheduler: Option<RayScheduler>,
pub num_samples: Option<u32>,
pub max_concurrent_trials: Option<u32>,
pub resources_per_trial: Option<HashMap<String, f32>>,
pub param_space: Option<HashMap<String, serde_json::Value>>,
pub metric: Option<String>,
pub mode: Option<String>,
pub time_budget_s: Option<f32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RaySearchAlgorithm {
BasicVariant,
Random,
Grid,
BayesOpt,
Hyperband,
BOHB,
PopulationBasedTraining,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RayScheduler {
FIFO,
Hyperband,
ASHA,
MedianStopping,
PopulationBasedTraining,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayServeConfig {
pub http_options: Option<RayServeHttpOptions>,
pub grpc_options: Option<RayServeGrpcOptions>,
pub deployments: Option<Vec<RayServeDeploymentConfig>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayServeHttpOptions {
pub host: Option<String>,
pub port: Option<u16>,
pub root_path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayServeGrpcOptions {
pub port: Option<u16>,
pub grpc_servicer_functions: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayServeDeploymentConfig {
pub name: String,
pub num_replicas: Option<u32>,
pub ray_actor_options: Option<HashMap<String, serde_json::Value>>,
pub user_config: Option<HashMap<String, serde_json::Value>>,
pub max_concurrent_queries: Option<u32>,
pub autoscaling_config: Option<RayServeAutoscalingConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayServeAutoscalingConfig {
pub min_replicas: Option<u32>,
pub max_replicas: Option<u32>,
pub target_num_ongoing_requests_per_replica: Option<f32>,
pub metrics_interval_s: Option<f32>,
pub look_back_period_s: Option<f32>,
pub smoothing_factor: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayDataConfig {
pub format: Option<RayDataFormat>,
pub parallelism: Option<u32>,
pub batch_size: Option<u32>,
pub prefetch: Option<u32>,
pub shuffle: Option<bool>,
pub shuffle_buffer_size: Option<u32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RayDataFormat {
Parquet,
CSV,
JSON,
Image,
Text,
Arrow,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayResourceConfig {
pub num_cpus: Option<f32>,
pub num_gpus: Option<f32>,
pub memory: Option<u64>,
pub object_store_memory: Option<u64>,
pub custom_resources: Option<HashMap<String, f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RayFaultToleranceConfig {
pub max_restarts: Option<u32>,
pub restart_delay_s: Option<f32>,
pub health_check_interval_s: Option<f32>,
pub enabled: Option<bool>,
}
#[derive(Debug, Clone, Default)]
pub struct RayStats {
pub training_runs: u64,
pub training_time_sec: f64,
pub tuning_trials: u64,
pub tuning_time_sec: f64,
pub served_requests: u64,
pub data_processing_tasks: u64,
pub worker_failures: u64,
pub restarts: u64,
pub resource_utilization: f64,
pub checkpoint_frequency: f64,
}
pub struct RayIntegration {
config: RayConfig,
stats: RayStats,
initialized: bool,
rank: u32,
world_size: u32,
local_rank: u32,
local_size: u32,
ray_session_active: bool,
}
impl RayIntegration {
pub fn new(config: RayConfig) -> Self {
Self {
config,
stats: RayStats::default(),
initialized: false,
rank: 0,
world_size: 1,
local_rank: 0,
local_size: 1,
ray_session_active: false,
}
}
pub fn from_file<P: AsRef<Path>>(path: P) -> TorshResult<Self> {
let content = std::fs::read_to_string(path).map_err(|e| {
TorshDistributedError::configuration_error(format!(
"Failed to read Ray config file: {}",
e
))
})?;
let config: RayConfig = serde_json::from_str(&content).map_err(|e| {
TorshDistributedError::configuration_error(format!("Failed to parse Ray config: {}", e))
})?;
Ok(Self::new(config))
}
pub fn initialize(
&mut self,
rank: u32,
world_size: u32,
local_rank: u32,
local_size: u32,
) -> TorshResult<()> {
if self.initialized {
return Err(TorshDistributedError::configuration_error(
"Ray integration already initialized",
));
}
self.rank = rank;
self.world_size = world_size;
self.local_rank = local_rank;
self.local_size = local_size;
self.validate_config()?;
self.setup_ray_cluster()?;
self.setup_ray_train()?;
self.setup_ray_tune()?;
self.setup_ray_serve()?;
self.setup_ray_data()?;
self.setup_fault_tolerance()?;
self.initialized = true;
self.ray_session_active = true;
tracing::info!(
"Ray integration initialized - rank: {}, world_size: {}, local_rank: {}",
self.rank,
self.world_size,
self.local_rank
);
Ok(())
}
fn validate_config(&self) -> TorshResult<()> {
if let Some(ref cluster) = self.config.cluster {
if let Some(num_cpus) = cluster.num_cpus {
if num_cpus == 0 {
return Err(TorshDistributedError::configuration_error(
"Ray cluster num_cpus must be greater than 0",
));
}
}
if let Some(memory_gb) = cluster.memory_gb {
if memory_gb <= 0.0 {
return Err(TorshDistributedError::configuration_error(
"Ray cluster memory_gb must be greater than 0",
));
}
}
}
if let Some(ref train) = self.config.train {
if train.num_workers == 0 {
return Err(TorshDistributedError::configuration_error(
"Ray Train num_workers must be greater than 0",
));
}
if let Some(ref scaling) = train.scaling_config {
if let Some(num_workers) = scaling.num_workers {
if num_workers == 0 {
return Err(TorshDistributedError::configuration_error(
"Ray Train scaling num_workers must be greater than 0",
));
}
}
}
}
if let Some(ref tune) = self.config.tune {
if let Some(num_samples) = tune.num_samples {
if num_samples == 0 {
return Err(TorshDistributedError::configuration_error(
"Ray Tune num_samples must be greater than 0",
));
}
}
if let Some(max_concurrent) = tune.max_concurrent_trials {
if max_concurrent == 0 {
return Err(TorshDistributedError::configuration_error(
"Ray Tune max_concurrent_trials must be greater than 0",
));
}
}
}
Ok(())
}
fn setup_ray_cluster(&self) -> TorshResult<()> {
if let Some(ref cluster) = self.config.cluster {
tracing::info!("Setting up Ray cluster");
if let Some(ref address) = cluster.address {
tracing::debug!("Ray cluster address: {}", address);
}
let num_cpus = cluster.num_cpus.unwrap_or(1);
tracing::debug!("Ray cluster CPUs: {}", num_cpus);
let num_gpus = cluster.num_gpus.unwrap_or(0);
tracing::debug!("Ray cluster GPUs: {}", num_gpus);
let memory_gb = cluster.memory_gb.unwrap_or(4.0);
tracing::debug!("Ray cluster memory: {} GB", memory_gb);
let object_store_memory_gb = cluster.object_store_memory_gb.unwrap_or(2.0);
tracing::debug!("Ray object store memory: {} GB", object_store_memory_gb);
if let Some(ref namespace) = cluster.namespace {
tracing::debug!("Ray namespace: {}", namespace);
}
let include_dashboard = cluster.include_dashboard.unwrap_or(true);
if include_dashboard {
let default_host = "127.0.0.1".to_string();
let dashboard_host = cluster.dashboard_host.as_ref().unwrap_or(&default_host);
let dashboard_port = cluster.dashboard_port.unwrap_or(8265);
tracing::debug!("Ray dashboard: {}:{}", dashboard_host, dashboard_port);
}
}
Ok(())
}
fn setup_ray_train(&self) -> TorshResult<()> {
if let Some(ref train) = self.config.train {
tracing::info!("Setting up Ray Train");
tracing::debug!("Ray Train backend: {:?}", train.backend);
tracing::debug!("Ray Train workers: {}", train.num_workers);
let use_gpu = train.use_gpu.unwrap_or(false);
tracing::debug!("Ray Train use GPU: {}", use_gpu);
if let Some(ref resources) = train.resources_per_worker {
tracing::debug!("Ray Train resources per worker: {:?}", resources);
}
let placement_strategy = train
.placement_group_strategy
.unwrap_or(RayPlacementGroupStrategy::Pack);
tracing::debug!(
"Ray Train placement group strategy: {:?}",
placement_strategy
);
if let Some(ref scaling) = train.scaling_config {
tracing::debug!("Ray Train scaling configuration: {:?}", scaling);
}
if let Some(ref run_config) = train.run_config {
if let Some(ref name) = run_config.name {
tracing::debug!("Ray Train experiment name: {}", name);
}
if let Some(ref storage_path) = run_config.storage_path {
tracing::debug!("Ray Train storage path: {}", storage_path);
}
}
if let Some(ref checkpoint) = train.checkpoint_config {
let num_to_keep = checkpoint.num_to_keep.unwrap_or(3);
tracing::debug!("Ray Train checkpoints to keep: {}", num_to_keep);
}
if let Some(ref failure) = train.failure_config {
let max_failures = failure.max_failures.unwrap_or(3);
tracing::debug!("Ray Train max failures: {}", max_failures);
}
}
Ok(())
}
fn setup_ray_tune(&self) -> TorshResult<()> {
if let Some(ref tune) = self.config.tune {
tracing::info!("Setting up Ray Tune");
if let Some(search_alg) = tune.search_alg {
tracing::debug!("Ray Tune search algorithm: {:?}", search_alg);
}
if let Some(scheduler) = tune.scheduler {
tracing::debug!("Ray Tune scheduler: {:?}", scheduler);
}
let num_samples = tune.num_samples.unwrap_or(10);
tracing::debug!("Ray Tune samples: {}", num_samples);
let max_concurrent = tune.max_concurrent_trials.unwrap_or(4);
tracing::debug!("Ray Tune max concurrent trials: {}", max_concurrent);
if let Some(ref resources) = tune.resources_per_trial {
tracing::debug!("Ray Tune resources per trial: {:?}", resources);
}
if let Some(ref metric) = tune.metric {
tracing::debug!("Ray Tune optimization metric: {}", metric);
}
if let Some(ref mode) = tune.mode {
tracing::debug!("Ray Tune optimization mode: {}", mode);
}
if let Some(time_budget) = tune.time_budget_s {
tracing::debug!("Ray Tune time budget: {} seconds", time_budget);
}
}
Ok(())
}
fn setup_ray_serve(&self) -> TorshResult<()> {
if let Some(ref serve) = self.config.serve {
tracing::info!("Setting up Ray Serve");
if let Some(ref http) = serve.http_options {
let default_host = "127.0.0.1".to_string();
let host = http.host.as_ref().unwrap_or(&default_host);
let port = http.port.unwrap_or(8000);
tracing::debug!("Ray Serve HTTP: {}:{}", host, port);
if let Some(ref root_path) = http.root_path {
tracing::debug!("Ray Serve HTTP root path: {}", root_path);
}
}
if let Some(ref grpc) = serve.grpc_options {
let port = grpc.port.unwrap_or(9000);
tracing::debug!("Ray Serve gRPC port: {}", port);
if let Some(ref functions) = grpc.grpc_servicer_functions {
tracing::debug!("Ray Serve gRPC servicer functions: {:?}", functions);
}
}
if let Some(ref deployments) = serve.deployments {
for deployment in deployments {
tracing::debug!("Ray Serve deployment: {}", deployment.name);
let num_replicas = deployment.num_replicas.unwrap_or(1);
tracing::debug!(" Replicas: {}", num_replicas);
if let Some(ref autoscaling) = deployment.autoscaling_config {
let min_replicas = autoscaling.min_replicas.unwrap_or(1);
let max_replicas = autoscaling.max_replicas.unwrap_or(10);
tracing::debug!(" Autoscaling: {} - {}", min_replicas, max_replicas);
}
}
}
}
Ok(())
}
fn setup_ray_data(&self) -> TorshResult<()> {
if let Some(ref data) = self.config.data {
tracing::info!("Setting up Ray Data");
if let Some(format) = data.format {
tracing::debug!("Ray Data format: {:?}", format);
}
let parallelism = data.parallelism.unwrap_or(4);
tracing::debug!("Ray Data parallelism: {}", parallelism);
let batch_size = data.batch_size.unwrap_or(32);
tracing::debug!("Ray Data batch size: {}", batch_size);
let prefetch = data.prefetch.unwrap_or(2);
tracing::debug!("Ray Data prefetch: {}", prefetch);
let shuffle = data.shuffle.unwrap_or(false);
tracing::debug!("Ray Data shuffle: {}", shuffle);
if shuffle {
let shuffle_buffer_size = data.shuffle_buffer_size.unwrap_or(1000);
tracing::debug!("Ray Data shuffle buffer size: {}", shuffle_buffer_size);
}
}
Ok(())
}
fn setup_fault_tolerance(&self) -> TorshResult<()> {
if let Some(ref fault_tolerance) = self.config.fault_tolerance {
tracing::info!("Setting up Ray fault tolerance");
let enabled = fault_tolerance.enabled.unwrap_or(true);
tracing::debug!("Ray fault tolerance enabled: {}", enabled);
if enabled {
let max_restarts = fault_tolerance.max_restarts.unwrap_or(3);
tracing::debug!("Ray max restarts: {}", max_restarts);
let restart_delay = fault_tolerance.restart_delay_s.unwrap_or(5.0);
tracing::debug!("Ray restart delay: {} seconds", restart_delay);
let health_check_interval = fault_tolerance.health_check_interval_s.unwrap_or(10.0);
tracing::debug!(
"Ray health check interval: {} seconds",
health_check_interval
);
}
}
Ok(())
}
pub fn to_elastic_config(&self) -> TorshResult<Option<crate::fault_tolerance::ElasticConfig>> {
if let Some(ref train) = self.config.train {
use crate::fault_tolerance::ElasticConfig;
let min_workers = if let Some(ref scaling) = train.scaling_config {
scaling.num_workers.unwrap_or(train.num_workers)
} else {
train.num_workers
};
let max_workers = min_workers * 2;
let config = ElasticConfig {
min_workers: min_workers as usize,
max_workers: max_workers as usize,
scaling_timeout: std::time::Duration::from_secs(300),
scaling_check_interval: std::time::Duration::from_secs(30),
enable_elastic_scheduling: true,
rendezvous_backend: "etcd".to_string(),
rendezvous_endpoint: "localhost:2379".to_string(),
};
Ok(Some(config))
} else {
Ok(None)
}
}
pub fn config(&self) -> &RayConfig {
&self.config
}
pub fn stats(&self) -> &RayStats {
&self.stats
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn rank(&self) -> u32 {
self.rank
}
pub fn world_size(&self) -> u32 {
self.world_size
}
pub fn local_rank(&self) -> u32 {
self.local_rank
}
pub fn local_size(&self) -> u32 {
self.local_size
}
pub fn is_ray_session_active(&self) -> bool {
self.ray_session_active
}
pub fn run_training(&mut self, train_func_name: &str, num_epochs: u32) -> TorshResult<()> {
if !self.initialized {
return Err(TorshDistributedError::BackendNotInitialized);
}
let start_time = std::time::Instant::now();
tracing::info!(
"Running Ray Train: {} for {} epochs",
train_func_name,
num_epochs
);
for epoch in 1..=num_epochs {
tracing::debug!("Ray Train epoch {}/{}", epoch, num_epochs);
if epoch % 10 == 0 && self.config.fault_tolerance.is_some() {
self.handle_worker_failure()?;
}
}
self.stats.training_runs += 1;
self.stats.training_time_sec += start_time.elapsed().as_secs_f64();
tracing::info!("Ray Train completed: {}", train_func_name);
Ok(())
}
pub fn run_tuning(&mut self, tune_config_name: &str) -> TorshResult<()> {
if !self.initialized {
return Err(TorshDistributedError::BackendNotInitialized);
}
let start_time = std::time::Instant::now();
let num_trials = self
.config
.tune
.as_ref()
.and_then(|t| t.num_samples)
.unwrap_or(10);
tracing::info!(
"Running Ray Tune: {} with {} trials",
tune_config_name,
num_trials
);
for trial in 1..=num_trials {
tracing::debug!("Ray Tune trial {}/{}", trial, num_trials);
self.stats.tuning_trials += 1;
}
self.stats.tuning_time_sec += start_time.elapsed().as_secs_f64();
tracing::info!("Ray Tune completed: {}", tune_config_name);
Ok(())
}
fn handle_worker_failure(&mut self) -> TorshResult<()> {
tracing::warn!("Simulating Ray worker failure");
self.stats.worker_failures += 1;
if let Some(ref fault_tolerance) = self.config.fault_tolerance {
if fault_tolerance.enabled.unwrap_or(true) {
let max_restarts = fault_tolerance.max_restarts.unwrap_or(3);
if self.stats.restarts < max_restarts as u64 {
tracing::info!("Restarting failed Ray worker");
self.stats.restarts += 1;
let restart_delay = fault_tolerance.restart_delay_s.unwrap_or(5.0);
tracing::debug!("Ray restart delay: {} seconds", restart_delay);
} else {
return Err(TorshDistributedError::process_failure(
self.rank,
"ray_worker",
"Maximum restart attempts exceeded",
));
}
}
}
Ok(())
}
pub fn shutdown(&mut self) -> TorshResult<()> {
if self.ray_session_active {
tracing::info!("Shutting down Ray integration");
self.ray_session_active = false;
self.initialized = false;
}
Ok(())
}
pub fn default_config() -> RayConfig {
RayConfig {
cluster: Some(RayClusterConfig {
address: None,
redis_address: None,
num_cpus: Some(4),
num_gpus: Some(0),
memory_gb: Some(8.0),
object_store_memory_gb: Some(2.0),
namespace: None,
dashboard_host: Some("127.0.0.1".to_string()),
dashboard_port: Some(8265),
include_dashboard: Some(true),
}),
train: Some(RayTrainConfig {
backend: RayTrainBackend::Torch,
num_workers: 4,
use_gpu: Some(false),
resources_per_worker: None,
placement_group_strategy: Some(RayPlacementGroupStrategy::Pack),
scaling_config: None,
run_config: None,
checkpoint_config: None,
failure_config: Some(RayFailureConfig {
max_failures: Some(3),
failure_handling: Some(RayFailureHandling::Restart),
}),
}),
tune: None,
serve: None,
data: Some(RayDataConfig {
format: Some(RayDataFormat::Parquet),
parallelism: Some(4),
batch_size: Some(32),
prefetch: Some(2),
shuffle: Some(false),
shuffle_buffer_size: Some(1000),
}),
resources: Some(RayResourceConfig {
num_cpus: Some(4.0),
num_gpus: Some(0.0),
memory: Some(8 * 1024 * 1024 * 1024), object_store_memory: Some(2 * 1024 * 1024 * 1024), custom_resources: None,
}),
fault_tolerance: Some(RayFaultToleranceConfig {
max_restarts: Some(3),
restart_delay_s: Some(5.0),
health_check_interval_s: Some(10.0),
enabled: Some(true),
}),
}
}
pub fn config_with_tune(num_samples: u32, search_alg: RaySearchAlgorithm) -> RayConfig {
let mut config = Self::default_config();
config.tune = Some(RayTuneConfig {
search_alg: Some(search_alg),
scheduler: Some(RayScheduler::ASHA),
num_samples: Some(num_samples),
max_concurrent_trials: Some(4),
resources_per_trial: Some([("cpu".to_string(), 1.0)].into_iter().collect()),
param_space: None,
metric: Some("accuracy".to_string()),
mode: Some("max".to_string()),
time_budget_s: Some(3600.0), });
config
}
pub fn config_with_serve(num_replicas: u32) -> RayConfig {
let mut config = Self::default_config();
config.serve = Some(RayServeConfig {
http_options: Some(RayServeHttpOptions {
host: Some("0.0.0.0".to_string()),
port: Some(8000),
root_path: None,
}),
grpc_options: None,
deployments: Some(vec![RayServeDeploymentConfig {
name: "model".to_string(),
num_replicas: Some(num_replicas),
ray_actor_options: Some(
[(
"num_cpus".to_string(),
serde_json::Value::Number(serde_json::Number::from(1)),
)]
.into_iter()
.collect(),
),
user_config: None,
max_concurrent_queries: Some(100),
autoscaling_config: Some(RayServeAutoscalingConfig {
min_replicas: Some(1),
max_replicas: Some(num_replicas * 2),
target_num_ongoing_requests_per_replica: Some(10.0),
metrics_interval_s: Some(10.0),
look_back_period_s: Some(30.0),
smoothing_factor: Some(1.0),
}),
}]),
});
config
}
}
impl Default for RayConfig {
fn default() -> Self {
RayIntegration::default_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ray_config_validation() {
let config = RayIntegration::default_config();
let mut integration = RayIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.is_initialized());
assert!(integration.is_ray_session_active());
assert_eq!(integration.rank(), 0);
assert_eq!(integration.world_size(), 4);
assert_eq!(integration.local_rank(), 0);
}
#[test]
fn test_ray_training_simulation() {
let config = RayIntegration::default_config();
let mut integration = RayIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.run_training("my_train_func", 5).is_ok());
assert!(integration.run_training("another_train_func", 3).is_ok());
let stats = integration.stats();
assert_eq!(stats.training_runs, 2);
assert!(stats.training_time_sec >= 0.0); }
#[test]
fn test_ray_tuning_simulation() {
let config = RayIntegration::config_with_tune(20, RaySearchAlgorithm::BayesOpt);
let mut integration = RayIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.run_tuning("hyperparameter_search").is_ok());
let stats = integration.stats();
assert_eq!(stats.tuning_trials, 20);
assert!(stats.tuning_time_sec > 0.0);
}
#[test]
fn test_ray_elastic_config_conversion() {
let config = RayIntegration::default_config();
let mut integration = RayIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
let elastic_config = integration.to_elastic_config().unwrap();
assert!(elastic_config.is_some());
if let Some(config) = elastic_config {
assert_eq!(config.min_workers, 4);
assert_eq!(config.max_workers, 8);
assert!(config.enable_elastic_scheduling);
assert_eq!(config.rendezvous_backend, "etcd");
}
}
#[test]
fn test_ray_worker_failure_handling() {
let config = RayIntegration::default_config();
let mut integration = RayIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.handle_worker_failure().is_ok());
assert!(integration.handle_worker_failure().is_ok());
assert!(integration.handle_worker_failure().is_ok());
let stats = integration.stats();
assert_eq!(stats.worker_failures, 3);
assert_eq!(stats.restarts, 3);
assert!(integration.handle_worker_failure().is_err());
}
#[test]
fn test_ray_shutdown() {
let config = RayIntegration::default_config();
let mut integration = RayIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.is_ray_session_active());
assert!(integration.shutdown().is_ok());
assert!(!integration.is_ray_session_active());
assert!(!integration.is_initialized());
}
#[test]
fn test_ray_serve_config() {
let config = RayIntegration::config_with_serve(4);
let mut integration = RayIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.config().serve.is_some());
if let Some(ref serve) = integration.config().serve {
assert!(serve.http_options.is_some());
assert!(serve.deployments.is_some());
if let Some(ref deployments) = serve.deployments {
assert_eq!(deployments.len(), 1);
assert_eq!(deployments[0].name, "model");
assert_eq!(deployments[0].num_replicas, Some(4));
}
}
}
#[test]
fn test_ray_config_serialization() {
let config = RayIntegration::config_with_tune(10, RaySearchAlgorithm::Random);
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("Random"));
assert!(json.contains("ASHA"));
assert!(json.contains("accuracy"));
let deserialized: RayConfig = serde_json::from_str(&json).unwrap();
assert!(deserialized.tune.is_some());
if let Some(tune) = deserialized.tune {
assert_eq!(tune.search_alg, Some(RaySearchAlgorithm::Random));
assert_eq!(tune.scheduler, Some(RayScheduler::ASHA));
assert_eq!(tune.num_samples, Some(10));
}
}
#[test]
fn test_ray_invalid_config() {
let mut config = RayIntegration::default_config();
if let Some(ref mut train) = config.train {
train.num_workers = 0; }
let mut integration = RayIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_err());
}
}