use crate::{TorshDistributedError, TorshResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskConfig {
pub cluster: Option<DaskClusterConfig>,
pub scheduler: Option<DaskSchedulerConfig>,
pub worker: Option<DaskWorkerConfig>,
pub array: Option<DaskArrayConfig>,
pub dataframe: Option<DaskDataFrameConfig>,
pub bag: Option<DaskBagConfig>,
pub ml: Option<DaskMLConfig>,
pub distributed: Option<DaskDistributedConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskClusterConfig {
pub cluster_type: DaskClusterType,
pub n_workers: Option<u32>,
pub threads_per_worker: Option<u32>,
pub memory_limit: Option<String>,
pub processes: Option<bool>,
pub dashboard_address: Option<String>,
pub silence_logs: Option<bool>,
pub security: Option<DaskSecurityConfig>,
pub scaling: Option<DaskScalingConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DaskClusterType {
Local,
LocalProcess,
Kubernetes,
Slurm,
PBS,
SGE,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskSecurityConfig {
pub tls_cert: Option<String>,
pub tls_key: Option<String>,
pub tls_ca_file: Option<String>,
pub require_encryption: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskScalingConfig {
pub minimum: Option<u32>,
pub maximum: Option<u32>,
pub target_cpu: Option<f32>,
pub target_memory: Option<f32>,
pub scale_up_threshold: Option<f32>,
pub scale_down_threshold: Option<f32>,
pub interval: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskSchedulerConfig {
pub address: Option<String>,
pub port: Option<u16>,
pub dashboard_port: Option<u16>,
pub bokeh_port: Option<u16>,
pub worker_timeout: Option<f32>,
pub idle_timeout: Option<f32>,
pub transition_log_length: Option<u32>,
pub task_duration_overhead: Option<f32>,
pub allowed_failures: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskWorkerConfig {
pub nworkers: Option<u32>,
pub nthreads: Option<u32>,
pub memory_limit: Option<String>,
pub worker_port: Option<String>,
pub nanny_port: Option<String>,
pub dashboard_port: Option<u16>,
pub death_timeout: Option<f32>,
pub preload: Option<Vec<String>>,
pub env: Option<HashMap<String, String>>,
pub resources: Option<HashMap<String, f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskArrayConfig {
pub chunk_size: Option<String>,
pub backend: Option<String>,
pub overlap: Option<u32>,
pub boundary: Option<String>,
pub trim: Option<bool>,
pub rechunk_threshold: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskDataFrameConfig {
pub partition_size: Option<String>,
pub shuffle_method: Option<DaskShuffleMethod>,
pub query_planning: Option<bool>,
pub backend: Option<String>,
pub optimize_index: Option<bool>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DaskShuffleMethod {
Disk,
Tasks,
P2P,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskBagConfig {
pub partition_size: Option<u64>,
pub compression: Option<String>,
pub encoding: Option<String>,
pub linedelimiter: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskMLConfig {
pub model_selection: Option<DaskMLModelSelectionConfig>,
pub preprocessing: Option<DaskMLPreprocessingConfig>,
pub linear_model: Option<DaskMLLinearModelConfig>,
pub ensemble: Option<DaskMLEnsembleConfig>,
pub cluster: Option<DaskMLClusterConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskMLModelSelectionConfig {
pub cv_folds: Option<u32>,
pub scoring: Option<String>,
pub n_jobs: Option<i32>,
pub search_method: Option<DaskMLSearchMethod>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DaskMLSearchMethod {
GridSearch,
RandomSearch,
SuccessiveHalving,
Hyperband,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskMLPreprocessingConfig {
pub standardization: Option<String>,
pub categorical_encoding: Option<String>,
pub feature_selection: Option<String>,
pub dimensionality_reduction: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskMLLinearModelConfig {
pub solver: Option<String>,
pub alpha: Option<f32>,
pub max_iter: Option<u32>,
pub tol: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskMLEnsembleConfig {
pub n_estimators: Option<u32>,
pub bootstrap: Option<bool>,
pub random_state: Option<u32>,
pub oob_score: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskMLClusterConfig {
pub n_clusters: Option<u32>,
pub init: Option<String>,
pub max_iter: Option<u32>,
pub tol: Option<f32>,
pub n_init: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskDistributedConfig {
pub comm: Option<DaskCommConfig>,
pub serialization: Option<DaskSerializationConfig>,
pub client: Option<DaskClientConfig>,
pub scheduling: Option<DaskSchedulingConfig>,
pub diagnostics: Option<DaskDiagnosticsConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskCommConfig {
pub compression: Option<String>,
pub serializers: Option<Vec<String>>,
pub timeouts: Option<DaskTimeoutsConfig>,
pub tcp: Option<DaskTcpConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskTimeoutsConfig {
pub connect: Option<f32>,
pub tcp: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskTcpConfig {
pub reuse_port: Option<bool>,
pub no_delay: Option<bool>,
pub keep_alive: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskSerializationConfig {
pub compression: Option<Vec<String>>,
pub default_serializers: Option<Vec<String>>,
pub pickle_protocol: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskClientConfig {
pub heartbeat_interval: Option<f32>,
pub scheduler_info_interval: Option<f32>,
pub task_metadata: Option<Vec<String>>,
pub set_as_default: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskSchedulingConfig {
pub work_stealing: Option<bool>,
pub work_stealing_interval: Option<f32>,
pub unknown_task_duration: Option<f32>,
pub default_task_durations: Option<HashMap<String, f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DaskDiagnosticsConfig {
pub progress_bar: Option<bool>,
pub profile: Option<bool>,
pub memory_profiling: Option<bool>,
pub task_stream: Option<bool>,
pub resource_monitor: Option<bool>,
}
#[derive(Debug, Clone, Default)]
pub struct DaskStats {
pub tasks_executed: u64,
pub task_execution_time_sec: f64,
pub workers_connected: u32,
pub data_transferred_bytes: u64,
pub task_retries: u64,
pub worker_failures: u64,
pub memory_usage_bytes: u64,
pub cpu_utilization: f64,
pub network_bandwidth_bytes_per_sec: f64,
pub average_task_duration_sec: f64,
}
pub struct DaskIntegration {
config: DaskConfig,
stats: DaskStats,
initialized: bool,
rank: u32,
world_size: u32,
local_rank: u32,
local_size: u32,
client_active: bool,
}
impl DaskIntegration {
pub fn new(config: DaskConfig) -> Self {
Self {
config,
stats: DaskStats::default(),
initialized: false,
rank: 0,
world_size: 1,
local_rank: 0,
local_size: 1,
client_active: false,
}
}
pub fn from_file<P: AsRef<Path>>(path: P) -> TorshResult<Self> {
let content = std::fs::read_to_string(path).map_err(|e| {
TorshDistributedError::configuration_error(format!(
"Failed to read Dask config file: {}",
e
))
})?;
let config: DaskConfig = serde_json::from_str(&content).map_err(|e| {
TorshDistributedError::configuration_error(format!(
"Failed to parse Dask config: {}",
e
))
})?;
Ok(Self::new(config))
}
pub fn initialize(
&mut self,
rank: u32,
world_size: u32,
local_rank: u32,
local_size: u32,
) -> TorshResult<()> {
if self.initialized {
return Err(TorshDistributedError::configuration_error(
"Dask integration already initialized",
));
}
self.rank = rank;
self.world_size = world_size;
self.local_rank = local_rank;
self.local_size = local_size;
self.validate_config()?;
self.setup_cluster()?;
self.setup_scheduler()?;
self.setup_workers()?;
self.setup_client()?;
self.setup_ml()?;
self.setup_distributed()?;
self.initialized = true;
self.client_active = true;
tracing::info!(
"Dask integration initialized - rank: {}, world_size: {}, local_rank: {}",
self.rank,
self.world_size,
self.local_rank
);
Ok(())
}
fn validate_config(&self) -> TorshResult<()> {
if let Some(ref cluster) = self.config.cluster {
if let Some(n_workers) = cluster.n_workers {
if n_workers == 0 {
return Err(TorshDistributedError::configuration_error(
"Dask cluster n_workers must be greater than 0",
));
}
}
if let Some(threads_per_worker) = cluster.threads_per_worker {
if threads_per_worker == 0 {
return Err(TorshDistributedError::configuration_error(
"Dask cluster threads_per_worker must be greater than 0",
));
}
}
if let Some(ref scaling) = cluster.scaling {
if let Some(minimum) = scaling.minimum {
if let Some(maximum) = scaling.maximum {
if minimum > maximum {
return Err(TorshDistributedError::configuration_error(
"Dask scaling minimum workers cannot exceed maximum workers",
));
}
}
}
}
}
if let Some(ref scheduler) = self.config.scheduler {
if let Some(port) = scheduler.port {
if port == 0 {
return Err(TorshDistributedError::configuration_error(
"Dask scheduler port must be greater than 0",
));
}
}
if let Some(worker_timeout) = scheduler.worker_timeout {
if worker_timeout <= 0.0 {
return Err(TorshDistributedError::configuration_error(
"Dask scheduler worker_timeout must be greater than 0",
));
}
}
}
if let Some(ref worker) = self.config.worker {
if let Some(nworkers) = worker.nworkers {
if nworkers == 0 {
return Err(TorshDistributedError::configuration_error(
"Dask worker nworkers must be greater than 0",
));
}
}
if let Some(nthreads) = worker.nthreads {
if nthreads == 0 {
return Err(TorshDistributedError::configuration_error(
"Dask worker nthreads must be greater than 0",
));
}
}
}
Ok(())
}
fn setup_cluster(&self) -> TorshResult<()> {
if let Some(ref cluster) = self.config.cluster {
tracing::info!("Setting up Dask cluster: {:?}", cluster.cluster_type);
let n_workers = cluster.n_workers.unwrap_or(4);
tracing::debug!("Dask cluster workers: {}", n_workers);
let threads_per_worker = cluster.threads_per_worker.unwrap_or(2);
tracing::debug!("Dask threads per worker: {}", threads_per_worker);
if let Some(ref memory_limit) = cluster.memory_limit {
tracing::debug!("Dask memory limit per worker: {}", memory_limit);
}
let processes = cluster.processes.unwrap_or(false);
tracing::debug!("Dask use processes: {}", processes);
if let Some(ref dashboard_address) = cluster.dashboard_address {
tracing::debug!("Dask dashboard address: {}", dashboard_address);
}
let silence_logs = cluster.silence_logs.unwrap_or(false);
tracing::debug!("Dask silence logs: {}", silence_logs);
if let Some(ref security) = cluster.security {
tracing::debug!("Dask security enabled");
if let Some(ref tls_cert) = security.tls_cert {
tracing::debug!("Dask TLS cert: {}", tls_cert);
}
if let Some(ref tls_key) = security.tls_key {
tracing::debug!("Dask TLS key: {}", tls_key);
}
let require_encryption = security.require_encryption.unwrap_or(false);
tracing::debug!("Dask require encryption: {}", require_encryption);
}
if let Some(ref scaling) = cluster.scaling {
let minimum = scaling.minimum.unwrap_or(1);
let maximum = scaling.maximum.unwrap_or(n_workers * 2);
tracing::debug!("Dask scaling: {} - {} workers", minimum, maximum);
let target_cpu = scaling.target_cpu.unwrap_or(0.8);
tracing::debug!("Dask target CPU utilization: {}", target_cpu);
let interval = scaling.interval.unwrap_or(30.0);
tracing::debug!("Dask scaling interval: {} seconds", interval);
}
}
Ok(())
}
fn setup_scheduler(&self) -> TorshResult<()> {
if let Some(ref scheduler) = self.config.scheduler {
tracing::info!("Setting up Dask scheduler");
if let Some(ref address) = scheduler.address {
tracing::debug!("Dask scheduler address: {}", address);
}
let port = scheduler.port.unwrap_or(8786);
tracing::debug!("Dask scheduler port: {}", port);
let dashboard_port = scheduler.dashboard_port.unwrap_or(8787);
tracing::debug!("Dask scheduler dashboard port: {}", dashboard_port);
let worker_timeout = scheduler.worker_timeout.unwrap_or(60.0);
tracing::debug!("Dask scheduler worker timeout: {} seconds", worker_timeout);
let idle_timeout = scheduler.idle_timeout.unwrap_or(1800.0);
tracing::debug!("Dask scheduler idle timeout: {} seconds", idle_timeout);
let allowed_failures = scheduler.allowed_failures.unwrap_or(3);
tracing::debug!("Dask scheduler allowed failures: {}", allowed_failures);
}
Ok(())
}
fn setup_workers(&mut self) -> TorshResult<()> {
if let Some(ref worker) = self.config.worker {
tracing::info!("Setting up Dask workers");
let nworkers = worker.nworkers.unwrap_or(4);
tracing::debug!("Dask number of workers: {}", nworkers);
let nthreads = worker.nthreads.unwrap_or(2);
tracing::debug!("Dask threads per worker: {}", nthreads);
if let Some(ref memory_limit) = worker.memory_limit {
tracing::debug!("Dask worker memory limit: {}", memory_limit);
}
if let Some(ref worker_port) = worker.worker_port {
tracing::debug!("Dask worker port range: {}", worker_port);
}
if let Some(ref nanny_port) = worker.nanny_port {
tracing::debug!("Dask nanny port range: {}", nanny_port);
}
let death_timeout = worker.death_timeout.unwrap_or(60.0);
tracing::debug!("Dask worker death timeout: {} seconds", death_timeout);
if let Some(ref preload) = worker.preload {
tracing::debug!("Dask worker preload modules: {:?}", preload);
}
if let Some(ref env) = worker.env {
tracing::debug!("Dask worker environment variables: {:?}", env);
}
if let Some(ref resources) = worker.resources {
tracing::debug!("Dask worker resources: {:?}", resources);
}
}
self.stats.workers_connected = self
.config
.worker
.as_ref()
.and_then(|w| w.nworkers)
.unwrap_or(4);
Ok(())
}
fn setup_client(&self) -> TorshResult<()> {
if let Some(ref distributed) = self.config.distributed {
if let Some(ref client) = distributed.client {
tracing::info!("Setting up Dask client");
let heartbeat_interval = client.heartbeat_interval.unwrap_or(5.0);
tracing::debug!(
"Dask client heartbeat interval: {} seconds",
heartbeat_interval
);
let scheduler_info_interval = client.scheduler_info_interval.unwrap_or(2.0);
tracing::debug!(
"Dask client scheduler info interval: {} seconds",
scheduler_info_interval
);
if let Some(ref task_metadata) = client.task_metadata {
tracing::debug!("Dask client task metadata: {:?}", task_metadata);
}
let set_as_default = client.set_as_default.unwrap_or(true);
tracing::debug!("Dask client set as default: {}", set_as_default);
}
}
Ok(())
}
fn setup_ml(&self) -> TorshResult<()> {
if let Some(ref ml) = self.config.ml {
tracing::info!("Setting up Dask ML");
if let Some(ref model_selection) = ml.model_selection {
let cv_folds = model_selection.cv_folds.unwrap_or(5);
tracing::debug!("Dask ML cross-validation folds: {}", cv_folds);
if let Some(ref scoring) = model_selection.scoring {
tracing::debug!("Dask ML scoring metric: {}", scoring);
}
if let Some(search_method) = model_selection.search_method {
tracing::debug!("Dask ML search method: {:?}", search_method);
}
}
if let Some(ref preprocessing) = ml.preprocessing {
if let Some(ref standardization) = preprocessing.standardization {
tracing::debug!("Dask ML standardization: {}", standardization);
}
if let Some(ref encoding) = preprocessing.categorical_encoding {
tracing::debug!("Dask ML categorical encoding: {}", encoding);
}
}
if let Some(ref linear_model) = ml.linear_model {
if let Some(ref solver) = linear_model.solver {
tracing::debug!("Dask ML linear model solver: {}", solver);
}
let max_iter = linear_model.max_iter.unwrap_or(1000);
tracing::debug!("Dask ML linear model max iterations: {}", max_iter);
}
if let Some(ref ensemble) = ml.ensemble {
let n_estimators = ensemble.n_estimators.unwrap_or(100);
tracing::debug!("Dask ML ensemble estimators: {}", n_estimators);
let bootstrap = ensemble.bootstrap.unwrap_or(true);
tracing::debug!("Dask ML ensemble bootstrap: {}", bootstrap);
}
if let Some(ref cluster) = ml.cluster {
let n_clusters = cluster.n_clusters.unwrap_or(8);
tracing::debug!("Dask ML clustering clusters: {}", n_clusters);
let max_iter = cluster.max_iter.unwrap_or(300);
tracing::debug!("Dask ML clustering max iterations: {}", max_iter);
}
}
Ok(())
}
fn setup_distributed(&self) -> TorshResult<()> {
if let Some(ref distributed) = self.config.distributed {
tracing::info!("Setting up Dask distributed");
if let Some(ref comm) = distributed.comm {
if let Some(ref compression) = comm.compression {
tracing::debug!("Dask communication compression: {}", compression);
}
if let Some(ref serializers) = comm.serializers {
tracing::debug!("Dask communication serializers: {:?}", serializers);
}
if let Some(ref timeouts) = comm.timeouts {
let connect_timeout = timeouts.connect.unwrap_or(10.0);
tracing::debug!(
"Dask communication connect timeout: {} seconds",
connect_timeout
);
let tcp_timeout = timeouts.tcp.unwrap_or(30.0);
tracing::debug!("Dask communication TCP timeout: {} seconds", tcp_timeout);
}
if let Some(ref tcp) = comm.tcp {
let reuse_port = tcp.reuse_port.unwrap_or(false);
tracing::debug!("Dask TCP reuse port: {}", reuse_port);
let no_delay = tcp.no_delay.unwrap_or(true);
tracing::debug!("Dask TCP no delay: {}", no_delay);
let keep_alive = tcp.keep_alive.unwrap_or(false);
tracing::debug!("Dask TCP keep alive: {}", keep_alive);
}
}
if let Some(ref serialization) = distributed.serialization {
if let Some(ref compression) = serialization.compression {
tracing::debug!("Dask serialization compression: {:?}", compression);
}
let pickle_protocol = serialization.pickle_protocol.unwrap_or(4);
tracing::debug!("Dask serialization pickle protocol: {}", pickle_protocol);
}
if let Some(ref scheduling) = distributed.scheduling {
let work_stealing = scheduling.work_stealing.unwrap_or(true);
tracing::debug!("Dask scheduling work stealing: {}", work_stealing);
if work_stealing {
let interval = scheduling.work_stealing_interval.unwrap_or(0.1);
tracing::debug!("Dask work stealing interval: {} seconds", interval);
}
let unknown_task_duration = scheduling.unknown_task_duration.unwrap_or(0.5);
tracing::debug!(
"Dask unknown task duration: {} seconds",
unknown_task_duration
);
}
if let Some(ref diagnostics) = distributed.diagnostics {
let progress_bar = diagnostics.progress_bar.unwrap_or(true);
tracing::debug!("Dask diagnostics progress bar: {}", progress_bar);
let profile = diagnostics.profile.unwrap_or(false);
tracing::debug!("Dask diagnostics profile: {}", profile);
let memory_profiling = diagnostics.memory_profiling.unwrap_or(false);
tracing::debug!("Dask diagnostics memory profiling: {}", memory_profiling);
}
}
Ok(())
}
pub fn config(&self) -> &DaskConfig {
&self.config
}
pub fn stats(&self) -> &DaskStats {
&self.stats
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn rank(&self) -> u32 {
self.rank
}
pub fn world_size(&self) -> u32 {
self.world_size
}
pub fn local_rank(&self) -> u32 {
self.local_rank
}
pub fn local_size(&self) -> u32 {
self.local_size
}
pub fn is_client_active(&self) -> bool {
self.client_active
}
pub fn submit_task(&mut self, task_name: &str, task_size: usize) -> TorshResult<String> {
if !self.initialized {
return Err(TorshDistributedError::BackendNotInitialized);
}
let start_time = std::time::Instant::now();
tracing::debug!("Submitting Dask task: {} ({} bytes)", task_name, task_size);
let task_id = format!("task_{}_{}", task_name, self.stats.tasks_executed);
self.stats.tasks_executed += 1;
let execution_time = start_time.elapsed().as_secs_f64();
self.stats.task_execution_time_sec += execution_time;
self.stats.average_task_duration_sec =
self.stats.task_execution_time_sec / self.stats.tasks_executed as f64;
self.stats.data_transferred_bytes += task_size as u64;
tracing::debug!("Dask task submitted: {} (ID: {})", task_name, task_id);
Ok(task_id)
}
pub fn compute(&mut self, collection_name: &str) -> TorshResult<()> {
if !self.initialized {
return Err(TorshDistributedError::BackendNotInitialized);
}
let start_time = std::time::Instant::now();
tracing::info!("Computing Dask collection: {}", collection_name);
let num_tasks = 10; for i in 0..num_tasks {
self.submit_task(&format!("{}_task_{}", collection_name, i), 1024)?;
}
let execution_time = start_time.elapsed().as_secs_f64();
tracing::info!(
"Dask collection computed: {} in {:.2}s",
collection_name,
execution_time
);
Ok(())
}
pub fn scale_cluster(&mut self, target_workers: u32) -> TorshResult<()> {
if !self.initialized {
return Err(TorshDistributedError::BackendNotInitialized);
}
tracing::info!("Scaling Dask cluster to {} workers", target_workers);
if let Some(ref cluster) = self.config.cluster {
if let Some(ref scaling) = cluster.scaling {
let minimum = scaling.minimum.unwrap_or(1);
let maximum = scaling.maximum.unwrap_or(100);
if target_workers < minimum {
return Err(TorshDistributedError::invalid_argument(
"target_workers",
format!("Cannot scale below minimum: {}", minimum),
format!("At least {} workers", minimum),
));
}
if target_workers > maximum {
return Err(TorshDistributedError::invalid_argument(
"target_workers",
format!("Cannot scale above maximum: {}", maximum),
format!("At most {} workers", maximum),
));
}
}
}
self.stats.workers_connected = target_workers;
tracing::info!("Dask cluster scaled to {} workers", target_workers);
Ok(())
}
pub fn handle_worker_failure(&mut self, worker_id: u32) -> TorshResult<()> {
tracing::warn!("Dask worker {} failed", worker_id);
self.stats.worker_failures += 1;
if self.stats.workers_connected > 0 {
self.stats.workers_connected -= 1;
}
if let Some(ref cluster) = self.config.cluster {
if let Some(ref scaling) = cluster.scaling {
let minimum = scaling.minimum.unwrap_or(1);
if self.stats.workers_connected < minimum {
tracing::info!("Auto-scaling Dask cluster due to worker failure");
self.scale_cluster(minimum)?;
}
}
}
Ok(())
}
pub fn shutdown(&mut self) -> TorshResult<()> {
if self.client_active {
tracing::info!("Shutting down Dask integration");
self.client_active = false;
self.initialized = false;
self.stats.workers_connected = 0;
}
Ok(())
}
pub fn default_config() -> DaskConfig {
DaskConfig {
cluster: Some(DaskClusterConfig {
cluster_type: DaskClusterType::Local,
n_workers: Some(4),
threads_per_worker: Some(2),
memory_limit: Some("4GB".to_string()),
processes: Some(false),
dashboard_address: Some("127.0.0.1:8787".to_string()),
silence_logs: Some(false),
security: None,
scaling: Some(DaskScalingConfig {
minimum: Some(1),
maximum: Some(10),
target_cpu: Some(0.8),
target_memory: Some(0.8),
scale_up_threshold: Some(0.8),
scale_down_threshold: Some(0.2),
interval: Some(30.0),
}),
}),
scheduler: Some(DaskSchedulerConfig {
address: None,
port: Some(8786),
dashboard_port: Some(8787),
bokeh_port: Some(8788),
worker_timeout: Some(60.0),
idle_timeout: Some(1800.0),
transition_log_length: Some(100000),
task_duration_overhead: Some(0.1),
allowed_failures: Some(3),
}),
worker: Some(DaskWorkerConfig {
nworkers: Some(4),
nthreads: Some(2),
memory_limit: Some("4GB".to_string()),
worker_port: Some("40000:40100".to_string()),
nanny_port: Some("40100:40200".to_string()),
dashboard_port: Some(8789),
death_timeout: Some(60.0),
preload: None,
env: None,
resources: None,
}),
array: Some(DaskArrayConfig {
chunk_size: Some("128MB".to_string()),
backend: Some("numpy".to_string()),
overlap: Some(0),
boundary: Some("reflect".to_string()),
trim: Some(true),
rechunk_threshold: Some(4.0),
}),
dataframe: Some(DaskDataFrameConfig {
partition_size: Some("128MB".to_string()),
shuffle_method: Some(DaskShuffleMethod::Tasks),
query_planning: Some(false),
backend: Some("pandas".to_string()),
optimize_index: Some(true),
}),
bag: Some(DaskBagConfig {
partition_size: Some(134217728), compression: Some("gzip".to_string()),
encoding: Some("utf-8".to_string()),
linedelimiter: Some("\n".to_string()),
}),
ml: None,
distributed: Some(DaskDistributedConfig {
comm: Some(DaskCommConfig {
compression: Some("lz4".to_string()),
serializers: Some(vec!["pickle".to_string(), "msgpack".to_string()]),
timeouts: Some(DaskTimeoutsConfig {
connect: Some(10.0),
tcp: Some(30.0),
}),
tcp: Some(DaskTcpConfig {
reuse_port: Some(false),
no_delay: Some(true),
keep_alive: Some(false),
}),
}),
serialization: Some(DaskSerializationConfig {
compression: Some(vec!["lz4".to_string(), "zlib".to_string()]),
default_serializers: Some(vec!["pickle".to_string()]),
pickle_protocol: Some(4),
}),
client: Some(DaskClientConfig {
heartbeat_interval: Some(5.0),
scheduler_info_interval: Some(2.0),
task_metadata: Some(vec!["task_name".to_string(), "worker".to_string()]),
set_as_default: Some(true),
}),
scheduling: Some(DaskSchedulingConfig {
work_stealing: Some(true),
work_stealing_interval: Some(0.1),
unknown_task_duration: Some(0.5),
default_task_durations: None,
}),
diagnostics: Some(DaskDiagnosticsConfig {
progress_bar: Some(true),
profile: Some(false),
memory_profiling: Some(false),
task_stream: Some(false),
resource_monitor: Some(false),
}),
}),
}
}
pub fn config_with_ml() -> DaskConfig {
let mut config = Self::default_config();
config.ml = Some(DaskMLConfig {
model_selection: Some(DaskMLModelSelectionConfig {
cv_folds: Some(5),
scoring: Some("accuracy".to_string()),
n_jobs: Some(-1),
search_method: Some(DaskMLSearchMethod::RandomSearch),
}),
preprocessing: Some(DaskMLPreprocessingConfig {
standardization: Some("StandardScaler".to_string()),
categorical_encoding: Some("OneHotEncoder".to_string()),
feature_selection: Some("SelectKBest".to_string()),
dimensionality_reduction: Some("PCA".to_string()),
}),
linear_model: Some(DaskMLLinearModelConfig {
solver: Some("lbfgs".to_string()),
alpha: Some(1.0),
max_iter: Some(1000),
tol: Some(1e-4),
}),
ensemble: Some(DaskMLEnsembleConfig {
n_estimators: Some(100),
bootstrap: Some(true),
random_state: Some(42),
oob_score: Some(true),
}),
cluster: Some(DaskMLClusterConfig {
n_clusters: Some(8),
init: Some("k-means++".to_string()),
max_iter: Some(300),
tol: Some(1e-4),
n_init: Some(10),
}),
});
config
}
pub fn config_with_large_scale(n_workers: u32, memory_per_worker: &str) -> DaskConfig {
let mut config = Self::default_config();
if let Some(ref mut cluster) = config.cluster {
cluster.n_workers = Some(n_workers);
cluster.memory_limit = Some(memory_per_worker.to_string());
cluster.processes = Some(true);
if let Some(ref mut scaling) = cluster.scaling {
scaling.minimum = Some(n_workers / 2);
scaling.maximum = Some(n_workers * 2);
}
}
if let Some(ref mut worker) = config.worker {
worker.nworkers = Some(n_workers);
worker.memory_limit = Some(memory_per_worker.to_string());
}
if let Some(ref mut distributed) = config.distributed {
if let Some(ref mut comm) = distributed.comm {
comm.compression = Some("zstd".to_string()); }
if let Some(ref mut scheduling) = distributed.scheduling {
scheduling.work_stealing_interval = Some(0.5); }
}
config
}
}
impl Default for DaskConfig {
fn default() -> Self {
DaskIntegration::default_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dask_config_validation() {
let config = DaskIntegration::default_config();
let mut integration = DaskIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.is_initialized());
assert!(integration.is_client_active());
assert_eq!(integration.rank(), 0);
assert_eq!(integration.world_size(), 4);
assert_eq!(integration.local_rank(), 0);
assert_eq!(integration.stats().workers_connected, 4);
}
#[test]
fn test_dask_task_submission() {
let config = DaskIntegration::default_config();
let mut integration = DaskIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
let task_id1 = integration.submit_task("compute_gradient", 1024).unwrap();
let task_id2 = integration.submit_task("update_parameters", 2048).unwrap();
assert!(task_id1.contains("compute_gradient"));
assert!(task_id2.contains("update_parameters"));
let stats = integration.stats();
assert_eq!(stats.tasks_executed, 2);
assert!(stats.task_execution_time_sec > 0.0);
assert_eq!(stats.data_transferred_bytes, 3072);
assert!(stats.average_task_duration_sec > 0.0);
}
#[test]
fn test_dask_compute_collection() {
let config = DaskIntegration::default_config();
let mut integration = DaskIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.compute("training_dataset").is_ok());
let stats = integration.stats();
assert_eq!(stats.tasks_executed, 10); assert!(stats.task_execution_time_sec > 0.0);
}
#[test]
fn test_dask_cluster_scaling() {
let config = DaskIntegration::default_config();
let mut integration = DaskIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.scale_cluster(8).is_ok());
assert_eq!(integration.stats().workers_connected, 8);
assert!(integration.scale_cluster(2).is_ok());
assert_eq!(integration.stats().workers_connected, 2);
assert!(integration.scale_cluster(0).is_err());
assert!(integration.scale_cluster(20).is_err());
}
#[test]
fn test_dask_worker_failure_handling() {
let config = DaskIntegration::default_config();
let mut integration = DaskIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.handle_worker_failure(1).is_ok());
assert_eq!(integration.stats().worker_failures, 1);
assert_eq!(integration.stats().workers_connected, 3);
assert!(integration.handle_worker_failure(2).is_ok());
assert!(integration.handle_worker_failure(3).is_ok());
assert_eq!(integration.stats().worker_failures, 3);
assert_eq!(integration.stats().workers_connected, 1); }
#[test]
fn test_dask_ml_config() {
let config = DaskIntegration::config_with_ml();
let mut integration = DaskIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.config().ml.is_some());
if let Some(ref ml) = integration.config().ml {
assert!(ml.model_selection.is_some());
assert!(ml.preprocessing.is_some());
assert!(ml.linear_model.is_some());
assert!(ml.ensemble.is_some());
assert!(ml.cluster.is_some());
if let Some(ref model_selection) = ml.model_selection {
assert_eq!(model_selection.cv_folds, Some(5));
assert_eq!(
model_selection.search_method,
Some(DaskMLSearchMethod::RandomSearch)
);
}
}
}
#[test]
fn test_dask_large_scale_config() {
let config = DaskIntegration::config_with_large_scale(16, "8GB");
let mut integration = DaskIntegration::new(config);
assert!(integration.initialize(0, 16, 0, 4).is_ok());
if let Some(ref cluster) = integration.config().cluster {
assert_eq!(cluster.n_workers, Some(16));
assert_eq!(cluster.memory_limit, Some("8GB".to_string()));
assert_eq!(cluster.processes, Some(true));
if let Some(ref scaling) = cluster.scaling {
assert_eq!(scaling.minimum, Some(8));
assert_eq!(scaling.maximum, Some(32));
}
}
}
#[test]
fn test_dask_shutdown() {
let config = DaskIntegration::default_config();
let mut integration = DaskIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_ok());
assert!(integration.is_client_active());
assert!(integration.shutdown().is_ok());
assert!(!integration.is_client_active());
assert!(!integration.is_initialized());
assert_eq!(integration.stats().workers_connected, 0);
}
#[test]
fn test_dask_config_serialization() {
let config = DaskIntegration::config_with_ml();
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("Local"));
assert!(json.contains("accuracy"));
assert!(json.contains("RandomSearch"));
let deserialized: DaskConfig = serde_json::from_str(&json).unwrap();
assert!(deserialized.cluster.is_some());
assert!(deserialized.ml.is_some());
if let Some(cluster) = deserialized.cluster {
assert_eq!(cluster.cluster_type, DaskClusterType::Local);
}
}
#[test]
fn test_dask_invalid_config() {
let mut config = DaskIntegration::default_config();
if let Some(ref mut cluster) = config.cluster {
cluster.n_workers = Some(0); }
let mut integration = DaskIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_err());
}
#[test]
fn test_dask_scaling_validation() {
let mut config = DaskIntegration::default_config();
if let Some(ref mut cluster) = config.cluster {
if let Some(ref mut scaling) = cluster.scaling {
scaling.minimum = Some(10);
scaling.maximum = Some(5); }
}
let mut integration = DaskIntegration::new(config);
assert!(integration.initialize(0, 4, 0, 2).is_err());
}
}