use tracing::{debug, info};
use crate::types::WorkerId;
use super::config::{ExhaustedBehavior, MultiWorkerStrategy, TerminationConfig};
use super::reason::{FailureReason, SuccessReason, TerminationVerdict};
use super::state::CompletionState;
#[derive(Debug, Clone)]
pub struct TerminationRequest {
pub reason: String,
pub tick: u64,
}
#[derive(Debug)]
pub struct TerminationJudge {
config: TerminationConfig,
state: CompletionState,
external_request: Option<TerminationRequest>,
current_tick: u64,
worker_count: usize,
consecutive_errors: u64,
}
impl TerminationJudge {
pub fn new(config: TerminationConfig, worker_count: usize) -> Self {
Self {
config,
state: CompletionState::new(),
external_request: None,
current_tick: 0,
worker_count,
consecutive_errors: 0,
}
}
pub fn set_tick(&mut self, tick: u64) {
self.current_tick = tick;
}
pub fn current_tick(&self) -> u64 {
self.current_tick
}
pub fn notify_worker_done(
&mut self,
worker_id: WorkerId,
success: bool,
message: Option<String>,
) {
info!(
worker_id = worker_id.0,
success = success,
message = ?message,
tick = self.current_tick,
"TerminationJudge: worker done notification"
);
self.state
.record_worker_done(worker_id, success, message.clone(), self.current_tick);
if success {
self.consecutive_errors = 0;
}
self.reevaluate();
}
pub fn notify_exploration_complete(&mut self, exhausted: bool) {
info!(
exhausted = exhausted,
tick = self.current_tick,
"TerminationJudge: exploration complete notification"
);
self.state.mark_exploration_done(exhausted);
self.reevaluate();
}
pub fn notify_error(&mut self) {
self.consecutive_errors += 1;
if let Some(max) = self.config.max_consecutive_errors {
if self.consecutive_errors >= max {
info!(
errors = self.consecutive_errors,
max = max,
"TerminationJudge: max consecutive errors exceeded"
);
self.state.set_verdict(TerminationVerdict::Failure {
reason: FailureReason::MaxErrorsExceeded {
count: self.consecutive_errors,
limit: max,
},
});
}
}
}
pub fn request_terminate(&mut self, reason: impl Into<String>) {
let reason = reason.into();
info!(
reason = %reason,
tick = self.current_tick,
"TerminationJudge: external termination requested"
);
self.external_request = Some(TerminationRequest {
reason: reason.clone(),
tick: self.current_tick,
});
self.state
.set_verdict(TerminationVerdict::ExternalStop { reason });
}
pub fn should_terminate(&self) -> bool {
if self.external_request.is_some() {
return true;
}
if self.state.has_verdict() {
return true;
}
if self.config.max_ticks > 0 && self.current_tick >= self.config.max_ticks {
return true;
}
false
}
pub fn should_skip_guidance(&self) -> bool {
if self.external_request.is_some() {
return true;
}
if self.state.has_verdict() {
return true;
}
if self.state.is_environment_done() {
return true;
}
if self.state.is_exploration_done() {
return true;
}
false
}
pub fn verdict(&self) -> TerminationVerdict {
if let Some(verdict) = self.state.verdict() {
return verdict.clone();
}
if let Some(ref req) = self.external_request {
return TerminationVerdict::ExternalStop {
reason: req.reason.clone(),
};
}
if self.config.max_ticks > 0 && self.current_tick >= self.config.max_ticks {
let partial_success = self.state.any_worker_succeeded();
return TerminationVerdict::Timeout { partial_success };
}
TerminationVerdict::Failure {
reason: FailureReason::InternalError {
message: "verdict() called without termination condition".to_string(),
},
}
}
pub fn is_environment_done(&self) -> bool {
self.state.is_environment_done()
}
pub fn completion_state(&self) -> &CompletionState {
&self.state
}
fn reevaluate(&mut self) {
if self.state.has_verdict() {
return;
}
if let Some(verdict) = self.evaluate_worker_completion() {
debug!(verdict = ?verdict, "TerminationJudge: verdict from worker completion");
self.state.set_verdict(verdict);
return;
}
if self.state.is_exploration_exhausted() {
let verdict = self.evaluate_exhaustion();
debug!(verdict = ?verdict, "TerminationJudge: verdict from exploration exhaustion");
self.state.set_verdict(verdict);
}
}
fn evaluate_worker_completion(&self) -> Option<TerminationVerdict> {
let completed = self.state.completed_workers();
match self.config.multi_worker_strategy {
MultiWorkerStrategy::FirstSuccess => {
if let Some((worker_id, result)) = self.state.first_success() {
return Some(TerminationVerdict::Success {
reason: SuccessReason::WorkerDone {
worker_id: worker_id.0,
message: result.message.clone(),
},
});
}
None
}
MultiWorkerStrategy::AllComplete => {
if completed.len() >= self.worker_count {
if self.state.any_worker_succeeded() {
return Some(TerminationVerdict::Success {
reason: SuccessReason::ConditionsMet,
});
} else {
return Some(TerminationVerdict::Failure {
reason: FailureReason::WorkerFailed {
worker_id: 0,
message: Some("All workers failed".to_string()),
},
});
}
}
None
}
MultiWorkerStrategy::AllSuccess => {
if completed.len() >= self.worker_count {
if self.state.all_completed_workers_succeeded() {
return Some(TerminationVerdict::Success {
reason: SuccessReason::ConditionsMet,
});
} else {
return Some(TerminationVerdict::Failure {
reason: FailureReason::WorkerFailed {
worker_id: 0,
message: Some("Not all workers succeeded".to_string()),
},
});
}
}
None
}
MultiWorkerStrategy::Conditions => {
None
}
}
}
fn evaluate_exhaustion(&self) -> TerminationVerdict {
match self.config.on_exhausted {
ExhaustedBehavior::Fail => TerminationVerdict::Failure {
reason: FailureReason::ExplorationExhausted,
},
ExhaustedBehavior::Success => TerminationVerdict::Success {
reason: SuccessReason::ExplorationComplete,
},
ExhaustedBehavior::CheckConditions => {
if self.state.any_worker_succeeded() {
TerminationVerdict::Success {
reason: SuccessReason::ExplorationComplete,
}
} else {
TerminationVerdict::Failure {
reason: FailureReason::ExplorationExhausted,
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_judge() -> TerminationJudge {
TerminationJudge::new(TerminationConfig::default(), 1)
}
#[test]
fn test_initial_state() {
let judge = default_judge();
assert!(!judge.should_terminate());
assert!(!judge.should_skip_guidance());
}
#[test]
fn test_worker_done_success() {
let mut judge = default_judge();
judge.set_tick(5);
judge.notify_worker_done(WorkerId(0), true, Some("Done!".to_string()));
assert!(judge.should_terminate());
assert!(judge.should_skip_guidance());
assert!(judge.verdict().is_success());
}
#[test]
fn test_external_termination() {
let mut judge = default_judge();
judge.request_terminate("User requested stop");
assert!(judge.should_terminate());
assert!(judge.should_skip_guidance());
assert!(matches!(
judge.verdict(),
TerminationVerdict::ExternalStop { .. }
));
}
#[test]
fn test_max_ticks_timeout() {
let config = TerminationConfig::with_max_ticks(100);
let mut judge = TerminationJudge::new(config, 1);
judge.set_tick(100);
assert!(judge.should_terminate());
assert!(matches!(
judge.verdict(),
TerminationVerdict::Timeout { .. }
));
}
#[test]
fn test_exploration_exhausted() {
let mut judge = default_judge();
judge.notify_exploration_complete(true);
assert!(judge.should_terminate());
assert!(matches!(
judge.verdict(),
TerminationVerdict::Failure {
reason: FailureReason::ExplorationExhausted
}
));
}
#[test]
fn test_all_success_strategy() {
let config =
TerminationConfig::default().multi_worker_strategy(MultiWorkerStrategy::AllSuccess);
let mut judge = TerminationJudge::new(config, 2);
judge.notify_worker_done(WorkerId(0), true, None);
assert!(!judge.should_terminate());
judge.notify_worker_done(WorkerId(1), true, None);
assert!(judge.should_terminate());
assert!(judge.verdict().is_success());
}
#[test]
fn test_first_success_strategy() {
let mut judge = default_judge();
judge.notify_worker_done(WorkerId(0), true, None);
assert!(judge.should_terminate());
assert!(judge.verdict().is_success());
}
}