use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{Receiver, Sender, channel};
use std::thread::{self, JoinHandle};
use super::steps::{CheckResult, StepId};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub(crate) struct ProbeId {
pub step: StepId,
pub generation: u64,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub(crate) enum ProbeKind {
Download,
ConnectTest,
EphemeralTest,
OracleTool,
VersionCheck,
}
#[derive(Clone, Debug)]
pub(crate) enum ProbeStatus {
Idle,
Running(ProbeKind),
Done { kind: ProbeKind, result: CheckResult },
}
#[derive(Clone, Debug)]
pub(crate) struct ProbeMsg {
pub probe_id: ProbeId,
pub kind: ProbeKind,
pub result: CheckResult,
}
pub(crate) struct ProbeRegistry {
generations: [u64; StepId::COUNT],
status: [ProbeStatus; StepId::COUNT],
rx: Receiver<ProbeMsg>,
tx: Sender<ProbeMsg>,
active_ephemeral: Option<EphemeralWorker>,
}
struct EphemeralWorker {
step: StepId,
generation: u64,
cancel: Arc<AtomicBool>,
handle: JoinHandle<()>,
}
impl Clone for ProbeRegistry {
fn clone(&self) -> Self {
let (tx, rx) = channel();
Self {
generations: self.generations,
status: self.status.clone(),
rx,
tx,
active_ephemeral: None,
}
}
}
impl ProbeRegistry {
pub(crate) fn new() -> Self {
let (tx, rx) = channel();
Self {
generations: [0; StepId::COUNT],
status: std::array::from_fn(|_| ProbeStatus::Idle),
rx,
tx,
active_ephemeral: None,
}
}
pub(crate) fn current(&self, step: StepId) -> ProbeId {
ProbeId { step, generation: self.generations[step.index()] }
}
pub(crate) fn bump(&mut self, step: StepId) {
self.cancel_active_ephemeral_for_step(step);
self.generations[step.index()] += 1;
self.status[step.index()] = ProbeStatus::Idle;
}
pub(crate) fn apply(&mut self, msg: ProbeMsg) -> bool {
let step = msg.probe_id.step;
let generation = msg.probe_id.generation;
let kind = msg.kind;
if msg.probe_id == self.current(step) {
self.status[step.index()] = ProbeStatus::Done { kind: msg.kind, result: msg.result };
if kind == ProbeKind::EphemeralTest {
self.join_completed_ephemeral(step, generation);
}
true
} else {
false
}
}
pub(crate) fn status(&self, step: StepId) -> &ProbeStatus {
&self.status[step.index()]
}
pub(crate) fn spawn(
&mut self,
step: StepId,
kind: ProbeKind,
work: impl FnOnce() -> CheckResult + Send + 'static,
) {
if kind == ProbeKind::EphemeralTest {
self.spawn_ephemeral(step, move |_| work());
return;
}
self.bump(step);
self.status[step.index()] = ProbeStatus::Running(kind);
let probe_id = self.current(step);
let tx = self.tx.clone();
thread::spawn(move || {
let result = work();
let _ = tx.send(ProbeMsg { probe_id, kind, result });
});
}
pub(crate) fn spawn_ephemeral(
&mut self,
step: StepId,
work: impl FnOnce(Arc<AtomicBool>) -> CheckResult + Send + 'static,
) {
self.bump(step);
self.status[step.index()] = ProbeStatus::Running(ProbeKind::EphemeralTest);
let probe_id = self.current(step);
let tx = self.tx.clone();
let cancel = Arc::new(AtomicBool::new(false));
let worker_cancel = Arc::clone(&cancel);
let handle = thread::spawn(move || {
if worker_cancel.load(Ordering::Acquire) {
return;
}
let result = work(Arc::clone(&worker_cancel));
if !worker_cancel.load(Ordering::Acquire) {
let _ = tx.send(ProbeMsg { probe_id, kind: ProbeKind::EphemeralTest, result });
}
});
self.active_ephemeral =
Some(EphemeralWorker { step, generation: probe_id.generation, cancel, handle });
}
pub(crate) fn poll(&mut self) -> Vec<(StepId, ProbeKind, CheckResult)> {
let mut applied = Vec::new();
while let Ok(msg) = self.rx.try_recv() {
let step = msg.probe_id.step;
let kind = msg.kind;
let result = msg.result.clone();
if self.apply(msg) {
applied.push((step, kind, result));
}
}
applied
}
pub(crate) fn on_quit(&mut self) {
self.cancel_active_ephemeral();
for status in &mut self.status {
if let ProbeStatus::Running(kind) = status {
match kind {
ProbeKind::Download
| ProbeKind::ConnectTest
| ProbeKind::OracleTool
| ProbeKind::VersionCheck => {},
ProbeKind::EphemeralTest => {
rag_rat_core::index::ai::abort_active_provisioning();
},
}
*status = ProbeStatus::Idle;
}
}
}
fn cancel_active_ephemeral_for_step(&mut self, step: StepId) {
if self.active_ephemeral.as_ref().is_some_and(|worker| worker.step == step) {
self.cancel_active_ephemeral();
}
}
fn cancel_active_ephemeral(&mut self) {
let Some(worker) = self.active_ephemeral.take() else { return };
worker.cancel.store(true, Ordering::Release);
rag_rat_core::index::ai::abort_active_provisioning();
let _ = worker.handle.join();
if self.generations[worker.step.index()] == worker.generation {
self.status[worker.step.index()] = ProbeStatus::Idle;
}
}
fn join_completed_ephemeral(&mut self, step: StepId, generation: u64) {
let Some(worker) = self.active_ephemeral.take() else { return };
if worker.step == step && worker.generation == generation {
let _ = worker.handle.join();
} else {
self.active_ephemeral = Some(worker);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stale_probe_result_is_dropped() {
let mut reg = ProbeRegistry::new();
let id0 = reg.current(StepId::Embedding);
reg.bump(StepId::Embedding); let applied = reg.apply(ProbeMsg {
probe_id: id0,
kind: ProbeKind::ConnectTest,
result: CheckResult::ok(),
});
assert!(!applied); assert!(matches!(reg.status(StepId::Embedding), ProbeStatus::Idle));
let id1 = reg.current(StepId::Embedding);
assert!(reg.apply(ProbeMsg {
probe_id: id1,
kind: ProbeKind::ConnectTest,
result: CheckResult::ok(),
}));
}
#[test]
fn spawn_then_poll_applies_result() {
let mut reg = ProbeRegistry::new();
reg.spawn(StepId::Integration, ProbeKind::VersionCheck, CheckResult::ok);
for _ in 0..100 {
reg.poll();
if matches!(reg.status(StepId::Integration), ProbeStatus::Done { .. }) {
break;
}
std::thread::sleep(std::time::Duration::from_millis(5));
}
assert!(matches!(reg.status(StepId::Integration), ProbeStatus::Done {
kind: ProbeKind::VersionCheck,
..
}));
}
#[test]
fn bump_cancels_active_ephemeral_before_clearing_status() {
use std::time::Duration;
let mut reg = ProbeRegistry::new();
let (started_tx, started_rx) = channel();
reg.spawn(StepId::Embedding, ProbeKind::EphemeralTest, move || {
let _ = started_tx.send(());
std::thread::sleep(Duration::from_millis(20));
CheckResult::ok()
});
started_rx.recv().unwrap();
assert!(matches!(
reg.status(StepId::Embedding),
ProbeStatus::Running(ProbeKind::EphemeralTest)
));
reg.bump(StepId::Embedding);
assert!(matches!(reg.status(StepId::Embedding), ProbeStatus::Idle));
assert!(reg.poll().is_empty());
}
#[test]
fn bump_signals_active_ephemeral_worker_before_joining() {
use std::time::Duration;
let mut reg = ProbeRegistry::new();
let (started_tx, started_rx) = channel();
let saw_cancel = Arc::new(AtomicBool::new(false));
let worker_saw_cancel = Arc::clone(&saw_cancel);
reg.spawn_ephemeral(StepId::Embedding, move |cancel| {
let _ = started_tx.send(());
while !cancel.load(Ordering::Acquire) {
std::thread::sleep(Duration::from_millis(1));
}
worker_saw_cancel.store(true, Ordering::Release);
CheckResult::ok()
});
started_rx.recv().unwrap();
reg.bump(StepId::Embedding);
assert!(saw_cancel.load(Ordering::Acquire));
assert!(matches!(reg.status(StepId::Embedding), ProbeStatus::Idle));
assert!(reg.poll().is_empty());
}
#[test]
fn on_quit_waits_for_active_ephemeral_worker() {
use std::time::Duration;
let mut reg = ProbeRegistry::new();
let (started_tx, started_rx) = channel();
let finished = Arc::new(AtomicBool::new(false));
let worker_finished = Arc::clone(&finished);
reg.spawn(StepId::Embedding, ProbeKind::EphemeralTest, move || {
let _ = started_tx.send(());
std::thread::sleep(Duration::from_millis(20));
worker_finished.store(true, Ordering::Release);
CheckResult::ok()
});
started_rx.recv().unwrap();
reg.on_quit();
assert!(finished.load(Ordering::Acquire));
assert!(matches!(reg.status(StepId::Embedding), ProbeStatus::Idle));
}
}