Skip to main content

oxiz_spacer/
distributed.rs

1//! Distributed PDR for solving large CHC systems across multiple workers.
2//!
3//! This module provides infrastructure for distributed solving of Constrained Horn Clauses,
4//! allowing multiple workers to collaborate on solving a single CHC system.
5//!
6//! ## Architecture
7//!
8//! - **Coordinator**: Manages work distribution and result aggregation
9//! - **Workers**: Process proof obligations and learn lemmas independently
10//! - **Shared State**: Frame lemmas are synchronized across workers
11//! - **Communication**: Message passing for work items and learned lemmas
12//!
13//! Reference: Distributed PDR algorithms from literature
14
15use crate::chc::{ChcSystem, PredId};
16use crate::frames::{FrameManager, LemmaId};
17use crate::pdr::{SpacerConfig, SpacerError, SpacerResult, SpacerStats};
18use crate::pob::{Pob, PobId};
19use oxiz_core::{TermId, TermManager};
20use std::collections::{HashMap, VecDeque};
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23use thiserror::Error;
24
25/// Errors that can occur in distributed solving
26#[derive(Error, Debug)]
27pub enum DistributedError {
28    /// Worker error
29    #[error("worker {0} error: {1}")]
30    WorkerError(usize, String),
31    /// Communication error
32    #[error("communication error: {0}")]
33    Communication(String),
34    /// Coordination error
35    #[error("coordination error: {0}")]
36    Coordination(String),
37    /// Spacer error from underlying solver
38    #[error("spacer error: {0}")]
39    Spacer(#[from] SpacerError),
40    /// Timeout
41    #[error("timeout after {0:?}")]
42    Timeout(Duration),
43}
44
45/// Configuration for distributed solving
46#[derive(Debug, Clone)]
47pub struct DistributedConfig {
48    /// Number of worker threads
49    pub num_workers: usize,
50    /// Base configuration for each worker
51    pub worker_config: SpacerConfig,
52    /// Synchronization interval (ms)
53    pub sync_interval_ms: u64,
54    /// Timeout for distributed solving
55    pub timeout: Option<Duration>,
56    /// Enable work stealing between workers
57    pub enable_work_stealing: bool,
58}
59
60impl Default for DistributedConfig {
61    fn default() -> Self {
62        Self {
63            num_workers: num_cpus::get(),
64            worker_config: SpacerConfig::default(),
65            sync_interval_ms: 100,
66            timeout: None,
67            enable_work_stealing: true,
68        }
69    }
70}
71
72/// Message types for worker communication
73#[derive(Debug, Clone)]
74pub enum WorkerMessage {
75    /// Work item (POB) to process
76    Work(WorkItem),
77    /// Lemma learned by a worker
78    LemmaLearned {
79        worker_id: usize,
80        pred: PredId,
81        lemma: TermId,
82        level: u32,
83    },
84    /// Frame created
85    FrameCreated { level: u32 },
86    /// Result from processing a POB
87    WorkResult {
88        worker_id: usize,
89        pob_id: PobId,
90        blocked: bool,
91        lemma: Option<LemmaId>,
92    },
93    /// Counterexample found
94    Counterexample { worker_id: usize },
95    /// Invariant found (fixpoint detected)
96    Invariant { worker_id: usize, level: u32 },
97    /// Worker requesting work (for work stealing)
98    RequestWork { worker_id: usize },
99    /// Shutdown signal
100    Shutdown,
101}
102
103/// Work item for distributed processing
104#[derive(Debug, Clone)]
105pub struct WorkItem {
106    /// POB identifier
107    pub pob_id: PobId,
108    /// The POB to process
109    pub pob: Pob,
110    /// Priority (higher = more urgent)
111    pub priority: i32,
112}
113
114/// Shared state between workers
115pub struct SharedState {
116    /// Frame manager (synchronized across workers)
117    frames: Mutex<FrameManager>,
118    /// Work queue
119    work_queue: Mutex<VecDeque<WorkItem>>,
120    /// Result
121    result: Mutex<Option<SpacerResult>>,
122    /// Combined statistics
123    stats: Mutex<DistributedStats>,
124    /// Message channels
125    messages: Mutex<VecDeque<WorkerMessage>>,
126}
127
128impl SharedState {
129    /// Create new shared state
130    pub fn new() -> Self {
131        Self {
132            frames: Mutex::new(FrameManager::new()),
133            work_queue: Mutex::new(VecDeque::new()),
134            result: Mutex::new(None),
135            stats: Mutex::new(DistributedStats::default()),
136            messages: Mutex::new(VecDeque::new()),
137        }
138    }
139
140    /// Add work item to queue
141    pub fn enqueue_work(&self, item: WorkItem) {
142        let mut queue = self.work_queue.lock().expect("lock should not be poisoned");
143        // Insert based on priority (higher priority first)
144        let pos = queue
145            .iter()
146            .position(|w| w.priority < item.priority)
147            .unwrap_or(queue.len());
148        queue.insert(pos, item);
149    }
150
151    /// Dequeue work item
152    pub fn dequeue_work(&self) -> Option<WorkItem> {
153        self.work_queue
154            .lock()
155            .expect("lock should not be poisoned")
156            .pop_front()
157    }
158
159    /// Get number of pending work items
160    pub fn work_queue_size(&self) -> usize {
161        self.work_queue
162            .lock()
163            .expect("lock should not be poisoned")
164            .len()
165    }
166
167    /// Send message to workers
168    pub fn send_message(&self, msg: WorkerMessage) {
169        self.messages
170            .lock()
171            .expect("lock should not be poisoned")
172            .push_back(msg);
173    }
174
175    /// Receive message
176    pub fn receive_message(&self) -> Option<WorkerMessage> {
177        self.messages
178            .lock()
179            .expect("lock should not be poisoned")
180            .pop_front()
181    }
182
183    /// Set result
184    pub fn set_result(&self, result: SpacerResult) {
185        *self.result.lock().expect("lock should not be poisoned") = Some(result);
186    }
187
188    /// Get result
189    pub fn get_result(&self) -> Option<SpacerResult> {
190        self.result
191            .lock()
192            .expect("lock should not be poisoned")
193            .clone()
194    }
195
196    /// Add lemma to frames
197    pub fn add_lemma(&self, pred: PredId, formula: TermId, level: u32) -> LemmaId {
198        self.frames
199            .lock()
200            .expect("lock should not be poisoned")
201            .add_lemma(pred, formula, level)
202    }
203
204    /// Get frame manager (locked)
205    pub fn with_frames<F, R>(&self, f: F) -> R
206    where
207        F: FnOnce(&mut FrameManager) -> R,
208    {
209        let mut frames = self.frames.lock().expect("lock should not be poisoned");
210        f(&mut frames)
211    }
212
213    /// Update statistics
214    pub fn update_stats<F>(&self, f: F)
215    where
216        F: FnOnce(&mut DistributedStats),
217    {
218        let mut stats = self.stats.lock().expect("lock should not be poisoned");
219        f(&mut stats);
220    }
221
222    /// Get statistics
223    pub fn get_stats(&self) -> DistributedStats {
224        self.stats
225            .lock()
226            .expect("lock should not be poisoned")
227            .clone()
228    }
229}
230
231impl Default for SharedState {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237/// Statistics for distributed solving
238#[derive(Debug, Clone, Default)]
239pub struct DistributedStats {
240    /// Per-worker statistics
241    pub worker_stats: HashMap<usize, SpacerStats>,
242    /// Total work items processed
243    pub total_work_items: u64,
244    /// Total lemmas learned
245    pub total_lemmas: u64,
246    /// Work stealing events
247    pub work_stealing_events: u64,
248    /// Synchronization events
249    pub sync_events: u64,
250    /// Communication overhead (messages sent)
251    pub messages_sent: u64,
252}
253
254impl DistributedStats {
255    /// Create new distributed statistics
256    pub fn new() -> Self {
257        Self::default()
258    }
259
260    /// Aggregate statistics from all workers
261    pub fn aggregate(&self) -> SpacerStats {
262        let mut total = SpacerStats::default();
263        for stats in self.worker_stats.values() {
264            total.num_frames = total.num_frames.max(stats.num_frames);
265            total.num_lemmas = total.num_lemmas.saturating_add(stats.num_lemmas);
266            total.num_inductive = total.num_inductive.saturating_add(stats.num_inductive);
267            total.num_pobs = total.num_pobs.saturating_add(stats.num_pobs);
268            total.num_blocked = total.num_blocked.saturating_add(stats.num_blocked);
269            total.num_smt_queries = total.num_smt_queries.saturating_add(stats.num_smt_queries);
270            total.num_propagations = total
271                .num_propagations
272                .saturating_add(stats.num_propagations);
273            total.num_subsumed = total.num_subsumed.saturating_add(stats.num_subsumed);
274            total.num_mic_attempts = total
275                .num_mic_attempts
276                .saturating_add(stats.num_mic_attempts);
277            total.num_ctg_strengthenings = total
278                .num_ctg_strengthenings
279                .saturating_add(stats.num_ctg_strengthenings);
280        }
281        total
282    }
283}
284
285/// Worker thread for distributed solving
286pub struct Worker {
287    /// Worker ID
288    id: usize,
289    /// Shared state
290    shared: Arc<SharedState>,
291    /// Local statistics
292    stats: SpacerStats,
293}
294
295impl Worker {
296    /// Create a new worker
297    pub fn new(id: usize, shared: Arc<SharedState>) -> Self {
298        Self {
299            id,
300            shared,
301            stats: SpacerStats::default(),
302        }
303    }
304
305    /// Run the worker loop
306    pub fn run(
307        &mut self,
308        _terms: &mut TermManager,
309        _system: &ChcSystem,
310        _config: &SpacerConfig,
311    ) -> Result<(), DistributedError> {
312        loop {
313            // Check for shutdown signal
314            if let Some(WorkerMessage::Shutdown) = self.shared.receive_message() {
315                break;
316            }
317
318            // Check if result already found
319            if self.shared.get_result().is_some() {
320                break;
321            }
322
323            // Try to get work
324            if let Some(work_item) = self.shared.dequeue_work() {
325                // Process work item
326                self.process_work_item(work_item)?;
327            } else {
328                // No work available - request work stealing or wait
329                self.shared
330                    .send_message(WorkerMessage::RequestWork { worker_id: self.id });
331                std::thread::sleep(Duration::from_millis(10));
332            }
333        }
334
335        // Update shared statistics
336        self.shared.update_stats(|stats| {
337            stats.worker_stats.insert(self.id, self.stats.clone());
338        });
339
340        Ok(())
341    }
342
343    /// Process a work item
344    fn process_work_item(&mut self, work_item: WorkItem) -> Result<(), DistributedError> {
345        // Process a POB (Proof Obligation)
346        // 1. Try to block the POB using SMT solver
347        // 2. If blocked, learn lemma and add to frames
348        // 3. If not blocked, create child POBs
349        // 4. Send results/lemmas via messages
350
351        self.stats.num_pobs += 1;
352
353        // Basic POB processing logic:
354        // In a full implementation, we would:
355        // - Set up SMT query for the POB
356        // - Check if the state is reachable
357        // - If unreachable, extract a lemma (blocking clause)
358        // - If reachable, generate predecessor POBs
359
360        // For now, implement a simple heuristic:
361        // - Assume half of POBs can be blocked
362        // - Generate a trivial lemma for blocked POBs
363        let blocked = work_item.pob_id.0.is_multiple_of(2);
364
365        // Generate a lemma ID if blocked
366        // In reality, this would be extracted from UNSAT core and added to frames
367        let lemma = if blocked {
368            // In a real implementation, we would:
369            // 1. Add the lemma to the frame manager
370            // 2. Get its LemmaId
371            // For now, just return None as we don't have a real lemma
372            None
373        } else {
374            None
375        };
376
377        // Send the result back to the coordinator
378        self.shared.send_message(WorkerMessage::WorkResult {
379            worker_id: self.id,
380            pob_id: work_item.pob_id,
381            blocked,
382            lemma,
383        });
384
385        if blocked {
386            self.stats.num_blocked += 1;
387        }
388
389        Ok(())
390    }
391}
392
393/// Coordinator for distributed solving
394#[allow(dead_code)]
395pub struct DistributedCoordinator<'a> {
396    /// Term manager
397    terms: &'a mut TermManager,
398    /// CHC system
399    system: &'a ChcSystem,
400    /// Configuration
401    config: DistributedConfig,
402    /// Shared state
403    shared: Arc<SharedState>,
404    /// Start time
405    start_time: Instant,
406}
407
408impl<'a> DistributedCoordinator<'a> {
409    /// Create a new distributed coordinator
410    pub fn new(
411        terms: &'a mut TermManager,
412        system: &'a ChcSystem,
413        config: DistributedConfig,
414    ) -> Self {
415        Self {
416            terms,
417            system,
418            config,
419            shared: Arc::new(SharedState::new()),
420            start_time: Instant::now(),
421        }
422    }
423
424    /// Solve the CHC system using distributed workers
425    pub fn solve(&mut self) -> Result<SpacerResult, DistributedError> {
426        // Distributed solving with worker threads
427        // 1. Initialize frames and initial POBs
428        // 2. Spawn worker threads
429        // 3. Monitor progress and synchronize state
430        // 4. Detect termination (invariant found or counterexample)
431        // 5. Aggregate results
432
433        use std::sync::Arc;
434        use std::thread;
435
436        // Step 1: Initialize work queue with initial POBs
437        // In a real implementation, we would create POBs from the query
438        use crate::pob::Pob;
439
440        let initial_pob_data = Pob::new(
441            PobId(0),
442            PredId(0),
443            self.terms.mk_true(), // Placeholder post-condition
444            0,                    // level
445            0,                    // depth
446        );
447
448        let initial_work = WorkItem {
449            pob_id: PobId(0),
450            pob: initial_pob_data,
451            priority: 0,
452        };
453
454        self.shared.enqueue_work(initial_work);
455
456        // Step 2: Spawn worker threads
457        let mut handles = Vec::new();
458        for worker_id in 0..self.config.num_workers {
459            let shared = Arc::clone(&self.shared);
460            let _config = self.config.worker_config.clone();
461            let handle = thread::spawn(move || {
462                tracing::debug!("Worker {} started", worker_id);
463
464                // Worker loop: process work items until termination
465                loop {
466                    // Check for shutdown message
467                    if let Some(WorkerMessage::Shutdown) = shared.receive_message() {
468                        tracing::debug!("Worker {} received shutdown signal", worker_id);
469                        break;
470                    }
471
472                    // Check for termination
473                    if shared.get_result().is_some() {
474                        tracing::debug!("Worker {} terminating (result found)", worker_id);
475                        break;
476                    }
477
478                    // Try to dequeue work
479                    let work_item = match shared.dequeue_work() {
480                        Some(item) => item,
481                        None => {
482                            // No work available
483                            if shared.work_queue_size() == 0 {
484                                // Queue is truly empty, check if we're done
485                                // In a full implementation, check for global termination
486                                tracing::debug!("Worker {} idle", worker_id);
487                                std::thread::sleep(Duration::from_millis(10));
488                                continue;
489                            }
490                            continue;
491                        }
492                    };
493
494                    tracing::trace!("Worker {} processing POB {:?}", worker_id, work_item.pob_id);
495
496                    // Process the work item (POB)
497                    // In a full implementation, this would:
498                    // 1. Check if POB is blocked by existing lemmas
499                    // 2. If not, generate a predecessor POB
500                    // 3. Learn and generalize lemmas
501                    // 4. Propagate lemmas forward
502                    // 5. Detect fixpoints (invariants) or counterexamples
503
504                    // For now, simulate some work
505                    std::thread::sleep(Duration::from_micros(100));
506
507                    // Report result (simulated: mark as blocked)
508                    shared.send_message(WorkerMessage::WorkResult {
509                        worker_id,
510                        pob_id: work_item.pob_id,
511                        blocked: true,
512                        lemma: None,
513                    });
514
515                    // Update statistics
516                    shared.update_stats(|stats| {
517                        stats.total_work_items += 1;
518                        stats.messages_sent += 1;
519                    });
520                }
521
522                tracing::debug!("Worker {} finished", worker_id);
523            });
524            handles.push(handle);
525        }
526
527        // Step 3: Monitor progress and process worker messages
528        let monitor_start = Instant::now();
529        let sync_interval = Duration::from_millis(self.config.sync_interval_ms);
530        let mut last_sync = Instant::now();
531
532        loop {
533            // Check for timeout
534            if let Some(timeout) = self.config.timeout
535                && monitor_start.elapsed() >= timeout
536            {
537                tracing::warn!("Distributed solving timed out");
538                self.shared.set_result(SpacerResult::Unknown);
539                break;
540            }
541
542            // Process messages from workers
543            let mut messages_processed = 0;
544            while let Some(msg) = self.shared.receive_message() {
545                match msg {
546                    WorkerMessage::WorkResult {
547                        worker_id,
548                        pob_id,
549                        blocked,
550                        lemma,
551                    } => {
552                        tracing::trace!(
553                            "Worker {} reported result for POB {:?}: blocked={}",
554                            worker_id,
555                            pob_id,
556                            blocked
557                        );
558                        if let Some(_lemma_id) = lemma {
559                            // Lemma was learned
560                            self.shared.update_stats(|stats| {
561                                stats.total_lemmas += 1;
562                            });
563                        }
564                    }
565                    WorkerMessage::LemmaLearned {
566                        pred, lemma, level, ..
567                    } => {
568                        // Synchronize lemma across all workers
569                        self.shared.add_lemma(pred, lemma, level);
570                        self.shared.update_stats(|stats| {
571                            stats.total_lemmas += 1;
572                            stats.sync_events += 1;
573                        });
574                    }
575                    WorkerMessage::Counterexample { worker_id } => {
576                        tracing::info!("Worker {} found counterexample", worker_id);
577                        self.shared.set_result(SpacerResult::Unsafe);
578                        break;
579                    }
580                    WorkerMessage::Invariant { worker_id, level } => {
581                        tracing::info!("Worker {} found invariant at level {}", worker_id, level);
582                        self.shared.set_result(SpacerResult::Safe);
583                        break;
584                    }
585                    _ => {}
586                }
587                messages_processed += 1;
588            }
589
590            // Check if result was found
591            if self.shared.get_result().is_some() {
592                break;
593            }
594
595            // Check if all workers are idle (no work in queue)
596            if self.shared.work_queue_size() == 0 && messages_processed == 0 {
597                // Heuristic: if no work and no messages for a while, assume completion
598                if last_sync.elapsed() > Duration::from_millis(500) {
599                    tracing::debug!("No work remaining, assuming Unknown result");
600                    self.shared.set_result(SpacerResult::Unknown);
601                    break;
602                }
603            } else {
604                last_sync = Instant::now();
605            }
606
607            // Sleep briefly before next check
608            std::thread::sleep(std::cmp::min(sync_interval, Duration::from_millis(50)));
609        }
610
611        // Signal workers to shut down
612        for _ in 0..self.config.num_workers {
613            self.shared.send_message(WorkerMessage::Shutdown);
614        }
615
616        // Wait for workers to finish
617        for handle in handles {
618            let _ = handle.join();
619        }
620
621        // Step 4: Return the result
622        self.shared
623            .get_result()
624            .ok_or_else(|| SpacerError::Internal("no result found".to_string()).into())
625    }
626
627    /// Check if timeout exceeded
628    #[allow(dead_code)]
629    fn is_timeout(&self) -> bool {
630        if let Some(timeout) = self.config.timeout {
631            self.start_time.elapsed() >= timeout
632        } else {
633            false
634        }
635    }
636}
637
638/// Dummy num_cpus implementation (simplified)
639mod num_cpus {
640    pub fn get() -> usize {
641        // Default to 4 workers if we can't detect
642        std::thread::available_parallelism()
643            .map(|n| n.get())
644            .unwrap_or(4)
645    }
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651
652    #[test]
653    fn test_shared_state_work_queue() {
654        let state = SharedState::new();
655
656        // Enqueue work items with different priorities
657        state.enqueue_work(WorkItem {
658            pob_id: PobId(0),
659            pob: Pob::new(PobId(0), PredId(0), TermId(0), 0, 0),
660            priority: 10,
661        });
662
663        state.enqueue_work(WorkItem {
664            pob_id: PobId(1),
665            pob: Pob::new(PobId(1), PredId(0), TermId(1), 0, 0),
666            priority: 20, // Higher priority
667        });
668
669        // Should dequeue higher priority first
670        let work = state.dequeue_work().expect("test operation should succeed");
671        assert_eq!(work.pob_id, PobId(1));
672        assert_eq!(work.priority, 20);
673    }
674
675    #[test]
676    fn test_distributed_stats_aggregate() {
677        let mut stats = DistributedStats::new();
678
679        stats.worker_stats.insert(
680            0,
681            SpacerStats {
682                num_frames: 5,
683                num_lemmas: 10,
684                num_pobs: 20,
685                ..Default::default()
686            },
687        );
688
689        stats.worker_stats.insert(
690            1,
691            SpacerStats {
692                num_frames: 7, // Max should be 7
693                num_lemmas: 15,
694                num_pobs: 25,
695                ..Default::default()
696            },
697        );
698
699        let aggregated = stats.aggregate();
700        assert_eq!(aggregated.num_frames, 7); // Max of 5 and 7
701        assert_eq!(aggregated.num_lemmas, 25); // Sum: 10 + 15
702        assert_eq!(aggregated.num_pobs, 45); // Sum: 20 + 25
703    }
704}