1use crate::chc::{ChcSystem, PredId};
16use crate::frames::{FrameManager, LemmaId};
17use crate::pdr::{SpacerConfig, SpacerError, SpacerResult, SpacerStats};
18use crate::pob::{Pob, PobId};
19use oxiz_core::{TermId, TermManager};
20use std::collections::{HashMap, VecDeque};
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23use thiserror::Error;
24
25#[derive(Error, Debug)]
27pub enum DistributedError {
28 #[error("worker {0} error: {1}")]
30 WorkerError(usize, String),
31 #[error("communication error: {0}")]
33 Communication(String),
34 #[error("coordination error: {0}")]
36 Coordination(String),
37 #[error("spacer error: {0}")]
39 Spacer(#[from] SpacerError),
40 #[error("timeout after {0:?}")]
42 Timeout(Duration),
43}
44
45#[derive(Debug, Clone)]
47pub struct DistributedConfig {
48 pub num_workers: usize,
50 pub worker_config: SpacerConfig,
52 pub sync_interval_ms: u64,
54 pub timeout: Option<Duration>,
56 pub enable_work_stealing: bool,
58}
59
60impl Default for DistributedConfig {
61 fn default() -> Self {
62 Self {
63 num_workers: num_cpus::get(),
64 worker_config: SpacerConfig::default(),
65 sync_interval_ms: 100,
66 timeout: None,
67 enable_work_stealing: true,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub enum WorkerMessage {
75 Work(WorkItem),
77 LemmaLearned {
79 worker_id: usize,
80 pred: PredId,
81 lemma: TermId,
82 level: u32,
83 },
84 FrameCreated { level: u32 },
86 WorkResult {
88 worker_id: usize,
89 pob_id: PobId,
90 blocked: bool,
91 lemma: Option<LemmaId>,
92 },
93 Counterexample { worker_id: usize },
95 Invariant { worker_id: usize, level: u32 },
97 RequestWork { worker_id: usize },
99 Shutdown,
101}
102
103#[derive(Debug, Clone)]
105pub struct WorkItem {
106 pub pob_id: PobId,
108 pub pob: Pob,
110 pub priority: i32,
112}
113
114pub struct SharedState {
116 frames: Mutex<FrameManager>,
118 work_queue: Mutex<VecDeque<WorkItem>>,
120 result: Mutex<Option<SpacerResult>>,
122 stats: Mutex<DistributedStats>,
124 messages: Mutex<VecDeque<WorkerMessage>>,
126}
127
128impl SharedState {
129 pub fn new() -> Self {
131 Self {
132 frames: Mutex::new(FrameManager::new()),
133 work_queue: Mutex::new(VecDeque::new()),
134 result: Mutex::new(None),
135 stats: Mutex::new(DistributedStats::default()),
136 messages: Mutex::new(VecDeque::new()),
137 }
138 }
139
140 pub fn enqueue_work(&self, item: WorkItem) {
142 let mut queue = self.work_queue.lock().expect("lock should not be poisoned");
143 let pos = queue
145 .iter()
146 .position(|w| w.priority < item.priority)
147 .unwrap_or(queue.len());
148 queue.insert(pos, item);
149 }
150
151 pub fn dequeue_work(&self) -> Option<WorkItem> {
153 self.work_queue
154 .lock()
155 .expect("lock should not be poisoned")
156 .pop_front()
157 }
158
159 pub fn work_queue_size(&self) -> usize {
161 self.work_queue
162 .lock()
163 .expect("lock should not be poisoned")
164 .len()
165 }
166
167 pub fn send_message(&self, msg: WorkerMessage) {
169 self.messages
170 .lock()
171 .expect("lock should not be poisoned")
172 .push_back(msg);
173 }
174
175 pub fn receive_message(&self) -> Option<WorkerMessage> {
177 self.messages
178 .lock()
179 .expect("lock should not be poisoned")
180 .pop_front()
181 }
182
183 pub fn set_result(&self, result: SpacerResult) {
185 *self.result.lock().expect("lock should not be poisoned") = Some(result);
186 }
187
188 pub fn get_result(&self) -> Option<SpacerResult> {
190 self.result
191 .lock()
192 .expect("lock should not be poisoned")
193 .clone()
194 }
195
196 pub fn add_lemma(&self, pred: PredId, formula: TermId, level: u32) -> LemmaId {
198 self.frames
199 .lock()
200 .expect("lock should not be poisoned")
201 .add_lemma(pred, formula, level)
202 }
203
204 pub fn with_frames<F, R>(&self, f: F) -> R
206 where
207 F: FnOnce(&mut FrameManager) -> R,
208 {
209 let mut frames = self.frames.lock().expect("lock should not be poisoned");
210 f(&mut frames)
211 }
212
213 pub fn update_stats<F>(&self, f: F)
215 where
216 F: FnOnce(&mut DistributedStats),
217 {
218 let mut stats = self.stats.lock().expect("lock should not be poisoned");
219 f(&mut stats);
220 }
221
222 pub fn get_stats(&self) -> DistributedStats {
224 self.stats
225 .lock()
226 .expect("lock should not be poisoned")
227 .clone()
228 }
229}
230
231impl Default for SharedState {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237#[derive(Debug, Clone, Default)]
239pub struct DistributedStats {
240 pub worker_stats: HashMap<usize, SpacerStats>,
242 pub total_work_items: u64,
244 pub total_lemmas: u64,
246 pub work_stealing_events: u64,
248 pub sync_events: u64,
250 pub messages_sent: u64,
252}
253
254impl DistributedStats {
255 pub fn new() -> Self {
257 Self::default()
258 }
259
260 pub fn aggregate(&self) -> SpacerStats {
262 let mut total = SpacerStats::default();
263 for stats in self.worker_stats.values() {
264 total.num_frames = total.num_frames.max(stats.num_frames);
265 total.num_lemmas = total.num_lemmas.saturating_add(stats.num_lemmas);
266 total.num_inductive = total.num_inductive.saturating_add(stats.num_inductive);
267 total.num_pobs = total.num_pobs.saturating_add(stats.num_pobs);
268 total.num_blocked = total.num_blocked.saturating_add(stats.num_blocked);
269 total.num_smt_queries = total.num_smt_queries.saturating_add(stats.num_smt_queries);
270 total.num_propagations = total
271 .num_propagations
272 .saturating_add(stats.num_propagations);
273 total.num_subsumed = total.num_subsumed.saturating_add(stats.num_subsumed);
274 total.num_mic_attempts = total
275 .num_mic_attempts
276 .saturating_add(stats.num_mic_attempts);
277 total.num_ctg_strengthenings = total
278 .num_ctg_strengthenings
279 .saturating_add(stats.num_ctg_strengthenings);
280 }
281 total
282 }
283}
284
285pub struct Worker {
287 id: usize,
289 shared: Arc<SharedState>,
291 stats: SpacerStats,
293}
294
295impl Worker {
296 pub fn new(id: usize, shared: Arc<SharedState>) -> Self {
298 Self {
299 id,
300 shared,
301 stats: SpacerStats::default(),
302 }
303 }
304
305 pub fn run(
307 &mut self,
308 _terms: &mut TermManager,
309 _system: &ChcSystem,
310 _config: &SpacerConfig,
311 ) -> Result<(), DistributedError> {
312 loop {
313 if let Some(WorkerMessage::Shutdown) = self.shared.receive_message() {
315 break;
316 }
317
318 if self.shared.get_result().is_some() {
320 break;
321 }
322
323 if let Some(work_item) = self.shared.dequeue_work() {
325 self.process_work_item(work_item)?;
327 } else {
328 self.shared
330 .send_message(WorkerMessage::RequestWork { worker_id: self.id });
331 std::thread::sleep(Duration::from_millis(10));
332 }
333 }
334
335 self.shared.update_stats(|stats| {
337 stats.worker_stats.insert(self.id, self.stats.clone());
338 });
339
340 Ok(())
341 }
342
343 fn process_work_item(&mut self, work_item: WorkItem) -> Result<(), DistributedError> {
345 self.stats.num_pobs += 1;
352
353 let blocked = work_item.pob_id.0.is_multiple_of(2);
364
365 let lemma = if blocked {
368 None
373 } else {
374 None
375 };
376
377 self.shared.send_message(WorkerMessage::WorkResult {
379 worker_id: self.id,
380 pob_id: work_item.pob_id,
381 blocked,
382 lemma,
383 });
384
385 if blocked {
386 self.stats.num_blocked += 1;
387 }
388
389 Ok(())
390 }
391}
392
393#[allow(dead_code)]
395pub struct DistributedCoordinator<'a> {
396 terms: &'a mut TermManager,
398 system: &'a ChcSystem,
400 config: DistributedConfig,
402 shared: Arc<SharedState>,
404 start_time: Instant,
406}
407
408impl<'a> DistributedCoordinator<'a> {
409 pub fn new(
411 terms: &'a mut TermManager,
412 system: &'a ChcSystem,
413 config: DistributedConfig,
414 ) -> Self {
415 Self {
416 terms,
417 system,
418 config,
419 shared: Arc::new(SharedState::new()),
420 start_time: Instant::now(),
421 }
422 }
423
424 pub fn solve(&mut self) -> Result<SpacerResult, DistributedError> {
426 use std::sync::Arc;
434 use std::thread;
435
436 use crate::pob::Pob;
439
440 let initial_pob_data = Pob::new(
441 PobId(0),
442 PredId(0),
443 self.terms.mk_true(), 0, 0, );
447
448 let initial_work = WorkItem {
449 pob_id: PobId(0),
450 pob: initial_pob_data,
451 priority: 0,
452 };
453
454 self.shared.enqueue_work(initial_work);
455
456 let mut handles = Vec::new();
458 for worker_id in 0..self.config.num_workers {
459 let shared = Arc::clone(&self.shared);
460 let _config = self.config.worker_config.clone();
461 let handle = thread::spawn(move || {
462 tracing::debug!("Worker {} started", worker_id);
463
464 loop {
466 if let Some(WorkerMessage::Shutdown) = shared.receive_message() {
468 tracing::debug!("Worker {} received shutdown signal", worker_id);
469 break;
470 }
471
472 if shared.get_result().is_some() {
474 tracing::debug!("Worker {} terminating (result found)", worker_id);
475 break;
476 }
477
478 let work_item = match shared.dequeue_work() {
480 Some(item) => item,
481 None => {
482 if shared.work_queue_size() == 0 {
484 tracing::debug!("Worker {} idle", worker_id);
487 std::thread::sleep(Duration::from_millis(10));
488 continue;
489 }
490 continue;
491 }
492 };
493
494 tracing::trace!("Worker {} processing POB {:?}", worker_id, work_item.pob_id);
495
496 std::thread::sleep(Duration::from_micros(100));
506
507 shared.send_message(WorkerMessage::WorkResult {
509 worker_id,
510 pob_id: work_item.pob_id,
511 blocked: true,
512 lemma: None,
513 });
514
515 shared.update_stats(|stats| {
517 stats.total_work_items += 1;
518 stats.messages_sent += 1;
519 });
520 }
521
522 tracing::debug!("Worker {} finished", worker_id);
523 });
524 handles.push(handle);
525 }
526
527 let monitor_start = Instant::now();
529 let sync_interval = Duration::from_millis(self.config.sync_interval_ms);
530 let mut last_sync = Instant::now();
531
532 loop {
533 if let Some(timeout) = self.config.timeout
535 && monitor_start.elapsed() >= timeout
536 {
537 tracing::warn!("Distributed solving timed out");
538 self.shared.set_result(SpacerResult::Unknown);
539 break;
540 }
541
542 let mut messages_processed = 0;
544 while let Some(msg) = self.shared.receive_message() {
545 match msg {
546 WorkerMessage::WorkResult {
547 worker_id,
548 pob_id,
549 blocked,
550 lemma,
551 } => {
552 tracing::trace!(
553 "Worker {} reported result for POB {:?}: blocked={}",
554 worker_id,
555 pob_id,
556 blocked
557 );
558 if let Some(_lemma_id) = lemma {
559 self.shared.update_stats(|stats| {
561 stats.total_lemmas += 1;
562 });
563 }
564 }
565 WorkerMessage::LemmaLearned {
566 pred, lemma, level, ..
567 } => {
568 self.shared.add_lemma(pred, lemma, level);
570 self.shared.update_stats(|stats| {
571 stats.total_lemmas += 1;
572 stats.sync_events += 1;
573 });
574 }
575 WorkerMessage::Counterexample { worker_id } => {
576 tracing::info!("Worker {} found counterexample", worker_id);
577 self.shared.set_result(SpacerResult::Unsafe);
578 break;
579 }
580 WorkerMessage::Invariant { worker_id, level } => {
581 tracing::info!("Worker {} found invariant at level {}", worker_id, level);
582 self.shared.set_result(SpacerResult::Safe);
583 break;
584 }
585 _ => {}
586 }
587 messages_processed += 1;
588 }
589
590 if self.shared.get_result().is_some() {
592 break;
593 }
594
595 if self.shared.work_queue_size() == 0 && messages_processed == 0 {
597 if last_sync.elapsed() > Duration::from_millis(500) {
599 tracing::debug!("No work remaining, assuming Unknown result");
600 self.shared.set_result(SpacerResult::Unknown);
601 break;
602 }
603 } else {
604 last_sync = Instant::now();
605 }
606
607 std::thread::sleep(std::cmp::min(sync_interval, Duration::from_millis(50)));
609 }
610
611 for _ in 0..self.config.num_workers {
613 self.shared.send_message(WorkerMessage::Shutdown);
614 }
615
616 for handle in handles {
618 let _ = handle.join();
619 }
620
621 self.shared
623 .get_result()
624 .ok_or_else(|| SpacerError::Internal("no result found".to_string()).into())
625 }
626
627 #[allow(dead_code)]
629 fn is_timeout(&self) -> bool {
630 if let Some(timeout) = self.config.timeout {
631 self.start_time.elapsed() >= timeout
632 } else {
633 false
634 }
635 }
636}
637
638mod num_cpus {
640 pub fn get() -> usize {
641 std::thread::available_parallelism()
643 .map(|n| n.get())
644 .unwrap_or(4)
645 }
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651
652 #[test]
653 fn test_shared_state_work_queue() {
654 let state = SharedState::new();
655
656 state.enqueue_work(WorkItem {
658 pob_id: PobId(0),
659 pob: Pob::new(PobId(0), PredId(0), TermId(0), 0, 0),
660 priority: 10,
661 });
662
663 state.enqueue_work(WorkItem {
664 pob_id: PobId(1),
665 pob: Pob::new(PobId(1), PredId(0), TermId(1), 0, 0),
666 priority: 20, });
668
669 let work = state.dequeue_work().expect("test operation should succeed");
671 assert_eq!(work.pob_id, PobId(1));
672 assert_eq!(work.priority, 20);
673 }
674
675 #[test]
676 fn test_distributed_stats_aggregate() {
677 let mut stats = DistributedStats::new();
678
679 stats.worker_stats.insert(
680 0,
681 SpacerStats {
682 num_frames: 5,
683 num_lemmas: 10,
684 num_pobs: 20,
685 ..Default::default()
686 },
687 );
688
689 stats.worker_stats.insert(
690 1,
691 SpacerStats {
692 num_frames: 7, num_lemmas: 15,
694 num_pobs: 25,
695 ..Default::default()
696 },
697 );
698
699 let aggregated = stats.aggregate();
700 assert_eq!(aggregated.num_frames, 7); assert_eq!(aggregated.num_lemmas, 25); assert_eq!(aggregated.num_pobs, 45); }
704}