Skip to main content

scirs2_integrate/distributed/
solver.rs

1//! Distributed ODE solver implementation
2//!
3//! This module provides the main distributed ODE solver that coordinates
4//! work distribution across compute nodes for large-scale integration problems.
5
6use crate::common::IntegrateFloat;
7use crate::distributed::checkpointing::{
8    Checkpoint, CheckpointConfig, CheckpointGlobalState, CheckpointManager,
9    FaultToleranceCoordinator, RecoveryAction,
10};
11use crate::distributed::communication::{BoundaryExchanger, Communicator, MessageChannel};
12use crate::distributed::load_balancing::{ChunkDistributor, LoadBalancer, LoadBalancerConfig};
13use crate::distributed::node::{ComputeNode, NodeManager};
14use crate::distributed::types::{
15    BoundaryData, ChunkId, ChunkResult, ChunkResultStatus, DistributedConfig, DistributedError,
16    DistributedMetrics, DistributedResult, FaultToleranceMode, JobId, NodeId, NodeInfo, NodeStatus,
17    WorkChunk,
18};
19use crate::error::{IntegrateError, IntegrateResult};
20use crate::ode::types::{ODEMethod, ODEOptions};
21use scirs2_core::ndarray::{array, Array1, ArrayView1};
22use std::collections::{HashMap, VecDeque};
23use std::path::PathBuf;
24use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
25use std::sync::{Arc, Mutex, RwLock};
26use std::time::{Duration, Instant};
27
28/// Distributed ODE solver
29pub struct DistributedODESolver<F: IntegrateFloat> {
30    /// Node manager
31    node_manager: Arc<NodeManager>,
32    /// Load balancer
33    load_balancer: Arc<LoadBalancer<F>>,
34    /// Checkpoint manager
35    checkpoint_manager: Arc<CheckpointManager<F>>,
36    /// Fault tolerance coordinator
37    fault_coordinator: Arc<FaultToleranceCoordinator<F>>,
38    /// Message channels per node
39    channels: RwLock<HashMap<NodeId, Arc<MessageChannel<F>>>>,
40    /// Boundary exchanger
41    boundary_exchanger: Arc<BoundaryExchanger<F>>,
42    /// Configuration
43    config: DistributedConfig<F>,
44    /// Next job ID
45    next_job_id: AtomicU64,
46    /// Shutdown flag
47    shutdown: AtomicBool,
48    /// Active jobs
49    active_jobs: RwLock<HashMap<JobId, JobState<F>>>,
50    /// Metrics
51    metrics: Mutex<DistributedMetrics>,
52}
53
54/// State of an active job
55struct JobState<F: IntegrateFloat> {
56    /// Job ID
57    job_id: JobId,
58    /// Time span
59    t_span: (F, F),
60    /// Initial state
61    initial_state: Array1<F>,
62    /// Total chunks
63    total_chunks: usize,
64    /// Completed chunks
65    completed_chunks: Vec<ChunkResult<F>>,
66    /// Pending chunks
67    pending_chunks: Vec<ChunkId>,
68    /// In-progress chunks
69    in_progress_chunks: HashMap<ChunkId, NodeId>,
70    /// Chunk ordering for assembly
71    chunk_order: Vec<ChunkId>,
72    /// Start time
73    start_time: Instant,
74    /// Last checkpoint time
75    last_checkpoint: Option<Instant>,
76    /// Chunks since last checkpoint
77    chunks_since_checkpoint: usize,
78}
79
80impl<F: IntegrateFloat> DistributedODESolver<F> {
81    /// Create a new distributed ODE solver
82    pub fn new(config: DistributedConfig<F>) -> DistributedResult<Self> {
83        let node_manager = Arc::new(NodeManager::new(config.heartbeat_interval));
84
85        let load_balancer = Arc::new(LoadBalancer::new(
86            config.load_balancing,
87            LoadBalancerConfig::default(),
88        ));
89
90        let checkpoint_path = {
91            let mut p = std::env::temp_dir();
92            p.push("scirs_checkpoints");
93            p
94        };
95        let checkpoint_config = CheckpointConfig {
96            persist_to_disk: config.checkpointing_enabled,
97            interval_chunks: config.checkpoint_interval,
98            ..Default::default()
99        };
100
101        let checkpoint_manager =
102            Arc::new(CheckpointManager::new(checkpoint_path, checkpoint_config)?);
103
104        let fault_coordinator = Arc::new(FaultToleranceCoordinator::new(
105            Arc::clone(&checkpoint_manager),
106            config.fault_tolerance,
107        ));
108
109        let boundary_exchanger = Arc::new(BoundaryExchanger::new(config.communication_timeout));
110
111        Ok(Self {
112            node_manager,
113            load_balancer,
114            checkpoint_manager,
115            fault_coordinator,
116            channels: RwLock::new(HashMap::new()),
117            boundary_exchanger,
118            config,
119            next_job_id: AtomicU64::new(1),
120            shutdown: AtomicBool::new(false),
121            active_jobs: RwLock::new(HashMap::new()),
122            metrics: Mutex::new(DistributedMetrics::default()),
123        })
124    }
125
126    /// Register a compute node
127    pub fn register_node(&self, node: NodeInfo) -> DistributedResult<()> {
128        let node_id = node.id;
129
130        // Register with node manager
131        self.node_manager
132            .register_node(node.address, node.capabilities.clone())?;
133
134        // Register with load balancer
135        self.load_balancer.register_node(node_id)?;
136
137        // Create message channel
138        let channel = Arc::new(MessageChannel::new(self.config.communication_timeout));
139        if let Ok(mut channels) = self.channels.write() {
140            channels.insert(node_id, channel);
141        }
142
143        Ok(())
144    }
145
146    /// Deregister a compute node
147    pub fn deregister_node(&self, node_id: NodeId) -> DistributedResult<()> {
148        self.node_manager.deregister_node(node_id)?;
149        self.load_balancer.deregister_node(node_id)?;
150
151        if let Ok(mut channels) = self.channels.write() {
152            channels.remove(&node_id);
153        }
154
155        Ok(())
156    }
157
158    /// Solve an ODE problem distributedly
159    pub fn solve<Func>(
160        &self,
161        f: Func,
162        t_span: (F, F),
163        y0: Array1<F>,
164        options: Option<ODEOptions<F>>,
165    ) -> IntegrateResult<DistributedODEResult<F>>
166    where
167        Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
168    {
169        let start_time = Instant::now();
170
171        // Get available nodes
172        let available_nodes = self.node_manager.get_available_nodes();
173        if available_nodes.is_empty() {
174            return Err(IntegrateError::ComputationError(
175                "No compute nodes available".to_string(),
176            ));
177        }
178
179        // Create job
180        let job_id = JobId::new(self.next_job_id.fetch_add(1, Ordering::SeqCst));
181
182        // Calculate number of chunks based on nodes
183        let num_chunks = (available_nodes.len() * self.config.chunks_per_node).max(1);
184
185        // Create chunk distributor and generate chunks
186        let distributor = ChunkDistributor::new(job_id);
187        let chunks = distributor.create_chunks(t_span, y0.clone(), num_chunks);
188
189        // Initialize job state
190        let chunk_order: Vec<ChunkId> = chunks.iter().map(|c| c.id).collect();
191        let pending_chunks = chunk_order.clone();
192
193        let job_state = JobState {
194            job_id,
195            t_span,
196            initial_state: y0.clone(),
197            total_chunks: num_chunks,
198            completed_chunks: Vec::new(),
199            pending_chunks,
200            in_progress_chunks: HashMap::new(),
201            chunk_order,
202            start_time,
203            last_checkpoint: None,
204            chunks_since_checkpoint: 0,
205        };
206
207        // Register job
208        if let Ok(mut jobs) = self.active_jobs.write() {
209            jobs.insert(job_id, job_state);
210        }
211
212        // Distribute initial work
213        self.distribute_chunks(job_id, chunks, &available_nodes, &f)?;
214
215        // Wait for completion
216        let result = self.wait_for_completion(job_id, &f)?;
217
218        // Update metrics
219        if let Ok(mut metrics) = self.metrics.lock() {
220            metrics.total_processing_time += start_time.elapsed();
221        }
222
223        // Cleanup job
224        if let Ok(mut jobs) = self.active_jobs.write() {
225            jobs.remove(&job_id);
226        }
227
228        Ok(result)
229    }
230
231    /// Distribute chunks to nodes
232    fn distribute_chunks<Func>(
233        &self,
234        job_id: JobId,
235        chunks: Vec<WorkChunk<F>>,
236        nodes: &[NodeInfo],
237        f: &Func,
238    ) -> IntegrateResult<()>
239    where
240        Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
241    {
242        for chunk in chunks {
243            let node_id = self
244                .load_balancer
245                .assign_chunk(&chunk, nodes)
246                .map_err(|e| IntegrateError::ComputationError(e.to_string()))?;
247
248            // Record assignment
249            if let Ok(mut jobs) = self.active_jobs.write() {
250                if let Some(job) = jobs.get_mut(&job_id) {
251                    job.pending_chunks.retain(|id| *id != chunk.id);
252                    job.in_progress_chunks.insert(chunk.id, node_id);
253                }
254            }
255
256            // In a real implementation, this would send the chunk over the network
257            // For now, we simulate local processing
258        }
259
260        Ok(())
261    }
262
263    /// Wait for job completion
264    fn wait_for_completion<Func>(
265        &self,
266        job_id: JobId,
267        f: &Func,
268    ) -> IntegrateResult<DistributedODEResult<F>>
269    where
270        Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
271    {
272        let timeout = Duration::from_secs(3600); // 1 hour timeout
273        let deadline = Instant::now() + timeout;
274
275        loop {
276            if Instant::now() > deadline {
277                return Err(IntegrateError::ConvergenceError(
278                    "Distributed solve timeout".to_string(),
279                ));
280            }
281
282            // Check completion
283            let (is_complete, needs_processing) = {
284                let jobs = self.active_jobs.read().map_err(|_| {
285                    IntegrateError::ComputationError("Failed to read job state".to_string())
286                })?;
287
288                if let Some(job) = jobs.get(&job_id) {
289                    let complete =
290                        job.pending_chunks.is_empty() && job.in_progress_chunks.is_empty();
291                    let needs = !job.in_progress_chunks.is_empty();
292                    (complete, needs)
293                } else {
294                    return Err(IntegrateError::ComputationError(
295                        "Job not found".to_string(),
296                    ));
297                }
298            };
299
300            if is_complete {
301                break;
302            }
303
304            if needs_processing {
305                // Simulate processing chunks
306                self.process_pending_chunks(job_id, f)?;
307            }
308
309            std::thread::sleep(Duration::from_millis(10));
310        }
311
312        // Assemble result
313        self.assemble_result(job_id)
314    }
315
316    /// Process pending chunks (simulation for local testing)
317    ///
318    /// Chunks are processed sequentially in chunk_order so that each chunk
319    /// can use the final state from the previous chunk as its initial state.
320    fn process_pending_chunks<Func>(&self, job_id: JobId, f: &Func) -> IntegrateResult<()>
321    where
322        Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
323    {
324        // Get ordered list of in-progress chunk IDs and their assigned nodes
325        let ordered_chunks: Vec<(ChunkId, NodeId, usize)> = {
326            let jobs = self.active_jobs.read().map_err(|_| {
327                IntegrateError::ComputationError("Failed to read job state".to_string())
328            })?;
329
330            if let Some(job) = jobs.get(&job_id) {
331                let mut items: Vec<(ChunkId, NodeId, usize)> = job
332                    .in_progress_chunks
333                    .iter()
334                    .map(|(chunk_id, node_id)| {
335                        let idx = job
336                            .chunk_order
337                            .iter()
338                            .position(|id| id == chunk_id)
339                            .unwrap_or(0);
340                        (*chunk_id, *node_id, idx)
341                    })
342                    .collect();
343                // Sort by chunk order index so we process them sequentially
344                items.sort_by_key(|&(_, _, idx)| idx);
345                items
346            } else {
347                Vec::new()
348            }
349        };
350
351        // Process each chunk one at a time, in order
352        for (chunk_id, node_id, idx) in ordered_chunks {
353            // Build the work chunk with correct initial state from completed chunks
354            let chunk = {
355                let jobs = self.active_jobs.read().map_err(|_| {
356                    IntegrateError::ComputationError("Failed to read job state".to_string())
357                })?;
358                let job = jobs
359                    .get(&job_id)
360                    .ok_or_else(|| IntegrateError::ComputationError("Job not found".to_string()))?;
361
362                let (t_start, t_end) = job.t_span;
363                let dt = (t_end - t_start) / F::from(job.total_chunks).unwrap_or(F::one());
364
365                let chunk_t_start = t_start + dt * F::from(idx).unwrap_or(F::zero());
366                let chunk_t_end = if idx == job.total_chunks - 1 {
367                    t_end
368                } else {
369                    t_start + dt * F::from(idx + 1).unwrap_or(F::one())
370                };
371
372                // Get initial state from previous chunk result or job initial state
373                let initial_state = if idx == 0 {
374                    job.initial_state.clone()
375                } else {
376                    let prev_chunk_id = job.chunk_order.get(idx - 1).ok_or_else(|| {
377                        IntegrateError::ComputationError(
378                            "Previous chunk not found in order".to_string(),
379                        )
380                    })?;
381                    job.completed_chunks
382                        .iter()
383                        .find(|r| r.chunk_id == *prev_chunk_id)
384                        .map(|r| r.final_state.clone())
385                        .unwrap_or_else(|| job.initial_state.clone())
386                };
387
388                WorkChunk::new(
389                    chunk_id,
390                    job_id,
391                    (chunk_t_start, chunk_t_end),
392                    initial_state,
393                )
394            };
395
396            let result = self.process_single_chunk(&chunk, node_id, f)?;
397
398            // Update job state
399            if let Ok(mut jobs) = self.active_jobs.write() {
400                if let Some(job) = jobs.get_mut(&job_id) {
401                    job.in_progress_chunks.remove(&chunk_id);
402                    job.completed_chunks.push(result);
403                    job.chunks_since_checkpoint += 1;
404
405                    // Check if checkpoint is needed
406                    if self.config.checkpointing_enabled
407                        && self
408                            .checkpoint_manager
409                            .should_checkpoint(job.chunks_since_checkpoint)
410                    {
411                        let global_state = CheckpointGlobalState {
412                            iteration: 0,
413                            chunks_completed: job.completed_chunks.len(),
414                            chunks_remaining: job.pending_chunks.len()
415                                + job.in_progress_chunks.len(),
416                            current_time: F::zero(),
417                            error_estimate: F::zero(),
418                        };
419
420                        let _ = self.checkpoint_manager.create_checkpoint(
421                            job_id,
422                            job.completed_chunks.clone(),
423                            job.in_progress_chunks.keys().cloned().collect(),
424                            global_state,
425                        );
426
427                        job.chunks_since_checkpoint = 0;
428                        job.last_checkpoint = Some(Instant::now());
429                    }
430                }
431            }
432
433            // Update load balancer
434            let processing_time = Duration::from_millis(10); // Simulated
435            self.load_balancer.report_completion(
436                node_id,
437                chunk.estimated_cost,
438                processing_time,
439                true,
440            );
441        }
442
443        Ok(())
444    }
445
446    /// Process a single chunk using local ODE solver
447    fn process_single_chunk<Func>(
448        &self,
449        chunk: &WorkChunk<F>,
450        node_id: NodeId,
451        f: &Func,
452    ) -> IntegrateResult<ChunkResult<F>>
453    where
454        Func: Fn(F, ArrayView1<F>) -> Array1<F> + Send + Sync + Clone + 'static,
455    {
456        let start_time = Instant::now();
457
458        // Use RK4 for simplicity
459        let (t_start, t_end) = chunk.time_interval;
460        let mut t = t_start;
461        let mut y = chunk.initial_state.clone();
462
463        let n_steps = 100;
464        let h = (t_end - t_start) / F::from(n_steps).unwrap_or(F::one());
465
466        let mut time_points = vec![t_start];
467        let mut states = vec![y.clone()];
468
469        for _ in 0..n_steps {
470            // RK4 step
471            let k1 = f(t, y.view());
472            let k2 = f(
473                t + h / F::from(2.0).unwrap_or(F::one()),
474                (&y + &(&k1 * h / F::from(2.0).unwrap_or(F::one()))).view(),
475            );
476            let k3 = f(
477                t + h / F::from(2.0).unwrap_or(F::one()),
478                (&y + &(&k2 * h / F::from(2.0).unwrap_or(F::one()))).view(),
479            );
480            let k4 = f(t + h, (&y + &(&k3 * h)).view());
481
482            y = &y
483                + &((&k1
484                    + &(&k2 * F::from(2.0).unwrap_or(F::one()))
485                    + &(&k3 * F::from(2.0).unwrap_or(F::one()))
486                    + &k4)
487                    * h
488                    / F::from(6.0).unwrap_or(F::one()));
489            t += h;
490
491            time_points.push(t);
492            states.push(y.clone());
493        }
494
495        let final_state = y.clone();
496        let final_derivative = Some(f(t, y.view()));
497
498        Ok(ChunkResult {
499            chunk_id: chunk.id,
500            node_id,
501            time_points,
502            states,
503            final_state,
504            final_derivative,
505            error_estimate: F::from(1e-6).unwrap_or(F::epsilon()),
506            processing_time: start_time.elapsed(),
507            memory_used: 0,
508            status: ChunkResultStatus::Success,
509        })
510    }
511
512    /// Assemble final result from completed chunks
513    fn assemble_result(&self, job_id: JobId) -> IntegrateResult<DistributedODEResult<F>> {
514        let jobs = self.active_jobs.read().map_err(|_| {
515            IntegrateError::ComputationError("Failed to read job state".to_string())
516        })?;
517
518        let job = jobs
519            .get(&job_id)
520            .ok_or_else(|| IntegrateError::ComputationError("Job not found".to_string()))?;
521
522        // Sort results by chunk order
523        let mut sorted_results: Vec<_> = job.completed_chunks.clone();
524        sorted_results.sort_by_key(|r| {
525            job.chunk_order
526                .iter()
527                .position(|id| *id == r.chunk_id)
528                .unwrap_or(usize::MAX)
529        });
530
531        // Concatenate time points and states
532        let mut t_all = Vec::new();
533        let mut y_all = Vec::new();
534
535        for (i, result) in sorted_results.iter().enumerate() {
536            let skip_first = if i > 0 { 1 } else { 0 };
537            t_all.extend(result.time_points.iter().skip(skip_first).cloned());
538            y_all.extend(result.states.iter().skip(skip_first).cloned());
539        }
540
541        let total_time = job.start_time.elapsed();
542
543        // Get metrics
544        let metrics = self.metrics.lock().map(|m| m.clone()).unwrap_or_default();
545
546        Ok(DistributedODEResult {
547            t: t_all,
548            y: y_all,
549            job_id,
550            chunks_processed: job.completed_chunks.len(),
551            nodes_used: job
552                .completed_chunks
553                .iter()
554                .map(|r| r.node_id)
555                .collect::<std::collections::HashSet<_>>()
556                .len(),
557            total_time,
558            metrics,
559        })
560    }
561
562    /// Shutdown the solver
563    pub fn shutdown(&self) {
564        self.shutdown.store(true, Ordering::Relaxed);
565        self.node_manager.stop_health_monitoring();
566    }
567
568    /// Get solver metrics
569    pub fn get_metrics(&self) -> DistributedMetrics {
570        self.metrics.lock().map(|m| m.clone()).unwrap_or_default()
571    }
572}
573
574/// Result of a distributed ODE solve
575#[derive(Debug, Clone)]
576pub struct DistributedODEResult<F: IntegrateFloat> {
577    /// Time points
578    pub t: Vec<F>,
579    /// Solution states
580    pub y: Vec<Array1<F>>,
581    /// Job ID
582    pub job_id: JobId,
583    /// Number of chunks processed
584    pub chunks_processed: usize,
585    /// Number of nodes used
586    pub nodes_used: usize,
587    /// Total time taken
588    pub total_time: Duration,
589    /// Distributed metrics
590    pub metrics: DistributedMetrics,
591}
592
593impl<F: IntegrateFloat> DistributedODEResult<F> {
594    /// Get final state
595    pub fn final_state(&self) -> Option<&Array1<F>> {
596        self.y.last()
597    }
598
599    /// Get state at specific index
600    pub fn state_at(&self, index: usize) -> Option<&Array1<F>> {
601        self.y.get(index)
602    }
603
604    /// Get number of time points
605    pub fn len(&self) -> usize {
606        self.t.len()
607    }
608
609    /// Check if result is empty
610    pub fn is_empty(&self) -> bool {
611        self.t.is_empty()
612    }
613
614    /// Interpolate solution at a given time
615    pub fn interpolate(&self, t_target: F) -> Option<Array1<F>> {
616        if self.t.is_empty() {
617            return None;
618        }
619
620        // Find bracketing points
621        let mut left_idx = 0;
622        for (i, &t) in self.t.iter().enumerate() {
623            if t <= t_target {
624                left_idx = i;
625            } else {
626                break;
627            }
628        }
629
630        let right_idx = (left_idx + 1).min(self.t.len() - 1);
631
632        if left_idx == right_idx {
633            return self.y.get(left_idx).cloned();
634        }
635
636        // Linear interpolation
637        let t_left = self.t[left_idx];
638        let t_right = self.t[right_idx];
639        let dt = t_right - t_left;
640
641        if dt.abs() < F::epsilon() {
642            return self.y.get(left_idx).cloned();
643        }
644
645        let alpha = (t_target - t_left) / dt;
646        let y_left = &self.y[left_idx];
647        let y_right = &self.y[right_idx];
648
649        Some(y_left * (F::one() - alpha) + y_right * alpha)
650    }
651}
652
653/// Builder for distributed ODE solver
654pub struct DistributedODESolverBuilder<F: IntegrateFloat> {
655    config: DistributedConfig<F>,
656}
657
658impl<F: IntegrateFloat> DistributedODESolverBuilder<F> {
659    /// Create a new builder with default configuration
660    pub fn new() -> Self {
661        Self {
662            config: DistributedConfig::default(),
663        }
664    }
665
666    /// Set tolerance
667    pub fn tolerance(mut self, tol: F) -> Self {
668        self.config.tolerance = tol;
669        self
670    }
671
672    /// Set chunks per node
673    pub fn chunks_per_node(mut self, n: usize) -> Self {
674        self.config.chunks_per_node = n;
675        self
676    }
677
678    /// Enable checkpointing
679    pub fn with_checkpointing(mut self, interval: usize) -> Self {
680        self.config.checkpointing_enabled = true;
681        self.config.checkpoint_interval = interval;
682        self
683    }
684
685    /// Set fault tolerance mode
686    pub fn fault_tolerance(mut self, mode: FaultToleranceMode) -> Self {
687        self.config.fault_tolerance = mode;
688        self
689    }
690
691    /// Set communication timeout
692    pub fn timeout(mut self, timeout: Duration) -> Self {
693        self.config.communication_timeout = timeout;
694        self
695    }
696
697    /// Build the solver
698    pub fn build(self) -> DistributedResult<DistributedODESolver<F>> {
699        DistributedODESolver::new(self.config)
700    }
701}
702
703impl<F: IntegrateFloat> Default for DistributedODESolverBuilder<F> {
704    fn default() -> Self {
705        Self::new()
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712    use crate::distributed::types::NodeCapabilities;
713    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
714
715    fn create_test_node(id: u64) -> NodeInfo {
716        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080 + id as u16);
717        let mut info = NodeInfo::new(NodeId::new(id), addr);
718        info.capabilities = NodeCapabilities::default();
719        info.status = NodeStatus::Available;
720        info
721    }
722
723    #[test]
724    fn test_distributed_solver_creation() {
725        let config = DistributedConfig::<f64>::default();
726        let solver = DistributedODESolver::new(config);
727        assert!(solver.is_ok());
728    }
729
730    #[test]
731    fn test_distributed_solver_node_registration() {
732        let config = DistributedConfig::<f64>::default();
733        let solver = DistributedODESolver::new(config).expect("Failed to create solver");
734
735        let node = create_test_node(1);
736        let result = solver.register_node(node);
737        assert!(result.is_ok());
738    }
739
740    #[test]
741    fn test_distributed_solve_simple_ode() {
742        let config = DistributedConfig::<f64>::default();
743        let solver = DistributedODESolver::new(config).expect("Failed to create solver");
744
745        // Register test nodes
746        for i in 0..2 {
747            let node = create_test_node(i);
748            solver.register_node(node).expect("Failed to register node");
749        }
750
751        // Solve y' = -y, y(0) = 1
752        let f = |_t: f64, y: ArrayView1<f64>| array![-y[0]];
753        let y0 = array![1.0];
754
755        let result = solver.solve(f, (0.0, 1.0), y0, None);
756        assert!(result.is_ok());
757
758        let result = result.expect("Solve failed");
759        assert!(!result.t.is_empty());
760        assert!(!result.y.is_empty());
761
762        // Final value should be close to e^(-1)
763        let expected = (-1.0_f64).exp();
764        let actual = result.final_state().expect("No final state")[0];
765        assert!((actual - expected).abs() < 0.01);
766    }
767
768    #[test]
769    fn test_distributed_result_interpolation() {
770        let result = DistributedODEResult::<f64> {
771            t: vec![0.0, 0.5, 1.0],
772            y: vec![array![1.0], array![0.6], array![0.4]],
773            job_id: JobId::new(1),
774            chunks_processed: 1,
775            nodes_used: 1,
776            total_time: Duration::from_secs(1),
777            metrics: DistributedMetrics::default(),
778        };
779
780        let interpolated = result.interpolate(0.25).expect("Interpolation failed");
781        assert!((interpolated[0] - 0.8_f64).abs() < 0.01_f64);
782    }
783
784    #[test]
785    fn test_solver_builder() {
786        let solver = DistributedODESolverBuilder::<f64>::new()
787            .tolerance(1e-8)
788            .chunks_per_node(8)
789            .with_checkpointing(5)
790            .fault_tolerance(FaultToleranceMode::Standard)
791            .timeout(Duration::from_secs(60))
792            .build();
793
794        assert!(solver.is_ok());
795    }
796}