use parking_lot::RwLock;
use pgrx::prelude::*;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::OnceLock;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
pub fn stoer_wagner_mincut(num_nodes: usize, edges: &[(usize, usize, f64)]) -> f64 {
if num_nodes <= 1 || edges.is_empty() {
return f64::INFINITY;
}
let mut adj = vec![vec![0.0; num_nodes]; num_nodes];
for &(u, v, w) in edges {
if u < num_nodes && v < num_nodes {
adj[u][v] += w;
adj[v][u] += w;
}
}
let mut active: Vec<bool> = vec![true; num_nodes];
let mut vertex_map: Vec<usize> = (0..num_nodes).collect();
let mut min_cut = f64::INFINITY;
let mut remaining = num_nodes;
while remaining > 1 {
let (cut_weight, s, t) = minimum_cut_phase(&adj, &active, remaining);
if cut_weight < min_cut {
min_cut = cut_weight;
}
if s < num_nodes && t < num_nodes {
for i in 0..num_nodes {
if active[i] && i != s && i != t {
adj[s][i] += adj[t][i];
adj[i][s] += adj[i][t];
}
}
active[t] = false;
vertex_map[t] = s;
remaining -= 1;
} else {
break;
}
}
let total_weight: f64 = edges.iter().map(|(_, _, w)| w).sum();
if total_weight > 0.0 {
(min_cut / total_weight).min(1.0)
} else {
1.0
}
}
fn minimum_cut_phase(adj: &[Vec<f64>], active: &[bool], _remaining: usize) -> (f64, usize, usize) {
let n = adj.len();
let start = active.iter().position(|&a| a).unwrap_or(0);
let mut in_a = vec![false; n];
let mut cut_weight = vec![0.0; n];
let mut last = start;
let mut second_last = start;
let mut last_cut = 0.0;
for _ in 0..active.iter().filter(|&&a| a).count() {
let mut max_weight = f64::NEG_INFINITY;
let mut max_vertex = start;
for i in 0..n {
if active[i] && !in_a[i] && cut_weight[i] > max_weight {
max_weight = cut_weight[i];
max_vertex = i;
}
}
in_a[max_vertex] = true;
second_last = last;
last = max_vertex;
last_cut = cut_weight[max_vertex];
for i in 0..n {
if active[i] && !in_a[i] {
cut_weight[i] += adj[max_vertex][i];
}
}
}
(last_cut, second_last, last)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum IntegrityStateType {
Normal = 0,
Stress = 1,
Critical = 2,
Emergency = 3,
}
impl std::fmt::Display for IntegrityStateType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IntegrityStateType::Normal => write!(f, "normal"),
IntegrityStateType::Stress => write!(f, "stress"),
IntegrityStateType::Critical => write!(f, "critical"),
IntegrityStateType::Emergency => write!(f, "emergency"),
}
}
}
impl IntegrityStateType {
pub fn from_lambda(lambda_cut: f64, threshold_high: f64, threshold_low: f64) -> Self {
if lambda_cut >= threshold_high {
IntegrityStateType::Normal
} else if lambda_cut >= threshold_low {
IntegrityStateType::Stress
} else if lambda_cut >= threshold_low / 2.0 {
IntegrityStateType::Critical
} else {
IntegrityStateType::Emergency
}
}
}
pub use IntegrityStateType as IntegrityState;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntegrityConfig {
pub sample_interval_secs: u64,
pub graph_rebuild_interval_secs: u64,
pub threshold_high: f64,
pub threshold_low: f64,
pub max_events: usize,
pub verbose: bool,
}
impl Default for IntegrityConfig {
fn default() -> Self {
Self {
sample_interval_secs: 60,
graph_rebuild_interval_secs: 3600,
threshold_high: 0.7,
threshold_low: 0.3,
max_events: 10000,
verbose: false,
}
}
}
static INTEGRITY_CONFIG: OnceLock<RwLock<IntegrityConfig>> = OnceLock::new();
pub fn get_integrity_config() -> IntegrityConfig {
INTEGRITY_CONFIG
.get_or_init(|| RwLock::new(IntegrityConfig::default()))
.read()
.clone()
}
pub fn set_integrity_config(config: IntegrityConfig) {
let cfg = INTEGRITY_CONFIG.get_or_init(|| RwLock::new(IntegrityConfig::default()));
*cfg.write() = config;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntegrityStateRecord {
pub collection_id: i32,
pub state: IntegrityStateType,
pub lambda_cut: f64,
pub last_sample_ts: u64,
pub last_state_change_ts: u64,
pub sample_count: u64,
pub state_change_count: u64,
}
impl IntegrityStateRecord {
pub fn new(collection_id: i32) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Self {
collection_id,
state: IntegrityStateType::Normal,
lambda_cut: 1.0,
last_sample_ts: now,
last_state_change_ts: now,
sample_count: 0,
state_change_count: 0,
}
}
pub fn update_from_mincut(&mut self, lambda_cut: f64, config: &IntegrityConfig) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let new_state = IntegrityStateType::from_lambda(
lambda_cut,
config.threshold_high,
config.threshold_low,
);
if new_state != self.state {
self.state = new_state;
self.last_state_change_ts = now;
self.state_change_count += 1;
}
self.lambda_cut = lambda_cut;
self.last_sample_ts = now;
self.sample_count += 1;
}
}
#[derive(Debug, Clone)]
pub struct GraphEdge {
pub source: usize,
pub target: usize,
pub weight: f64,
}
fn build_sample_edges(num_nodes: usize) -> Vec<(usize, usize, f64)> {
let mut edges = Vec::new();
for i in 0..num_nodes.saturating_sub(1) {
edges.push((i, i + 1, 1.0));
}
if num_nodes > 3 {
for i in (0..num_nodes).step_by(3) {
let j = (i + 2) % num_nodes;
if i != j {
edges.push((i, j, 0.5));
}
}
}
edges
}
pub struct IntegrityWorker {
worker_id: u64,
config: IntegrityConfig,
running: AtomicBool,
collections: RwLock<Vec<i32>>,
last_rebuild: RwLock<std::collections::HashMap<i32, Instant>>,
states: RwLock<std::collections::HashMap<i32, IntegrityStateRecord>>,
}
impl IntegrityWorker {
pub fn new(worker_id: u64) -> Self {
Self {
worker_id,
config: get_integrity_config(),
running: AtomicBool::new(false),
collections: RwLock::new(Vec::new()),
last_rebuild: RwLock::new(std::collections::HashMap::new()),
states: RwLock::new(std::collections::HashMap::new()),
}
}
pub fn register_collection(&self, collection_id: i32) {
let mut collections = self.collections.write();
if !collections.contains(&collection_id) {
collections.push(collection_id);
let mut states = self.states.write();
states.insert(collection_id, IntegrityStateRecord::new(collection_id));
}
}
pub fn unregister_collection(&self, collection_id: i32) {
let mut collections = self.collections.write();
collections.retain(|&id| id != collection_id);
let mut states = self.states.write();
states.remove(&collection_id);
}
fn sample_collection(&self, collection_id: i32) -> Result<(), String> {
let num_nodes = 10; let edges = build_sample_edges(num_nodes);
let lambda_cut = stoer_wagner_mincut(num_nodes, &edges);
if self.config.verbose {
pgrx::log!("Collection {}: lambda_cut={:.4}", collection_id, lambda_cut);
}
let mut states = self.states.write();
let state = states
.entry(collection_id)
.or_insert_with(|| IntegrityStateRecord::new(collection_id));
let previous_state = state.state;
state.update_from_mincut(lambda_cut, &self.config);
if state.state != previous_state {
pgrx::log!(
"Integrity state change for collection {}: {} -> {} (lambda={:.4})",
collection_id,
previous_state,
state.state,
lambda_cut
);
}
Ok(())
}
pub fn run(&self) {
self.running.store(true, Ordering::SeqCst);
pgrx::log!("Integrity worker {} started", self.worker_id);
let sample_interval = Duration::from_secs(self.config.sample_interval_secs);
while self.running.load(Ordering::SeqCst) {
let collections: Vec<i32> = self.collections.read().clone();
for collection_id in collections {
if !self.running.load(Ordering::SeqCst) {
break;
}
if let Err(e) = self.sample_collection(collection_id) {
pgrx::warning!("Failed to sample collection {}: {}", collection_id, e);
}
}
let sleep_end = Instant::now() + sample_interval;
while Instant::now() < sleep_end && self.running.load(Ordering::SeqCst) {
std::thread::sleep(Duration::from_millis(100));
}
}
pgrx::log!("Integrity worker {} stopped", self.worker_id);
}
pub fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub fn get_state(&self, collection_id: i32) -> Option<IntegrityStateRecord> {
self.states.read().get(&collection_id).cloned()
}
pub fn get_all_states(&self) -> std::collections::HashMap<i32, IntegrityStateRecord> {
self.states.read().clone()
}
pub fn stats(&self) -> serde_json::Value {
let states = self.states.read();
let collections: Vec<_> = states
.iter()
.map(|(id, state)| {
serde_json::json!({
"collection_id": id,
"state": state.state.to_string(),
"lambda_cut": state.lambda_cut,
"sample_count": state.sample_count,
"state_change_count": state.state_change_count,
})
})
.collect();
serde_json::json!({
"worker_id": self.worker_id,
"running": self.is_running(),
"collection_count": states.len(),
"collections": collections,
"config": {
"sample_interval_secs": self.config.sample_interval_secs,
"threshold_high": self.config.threshold_high,
"threshold_low": self.config.threshold_low,
}
})
}
}
static INTEGRITY_WORKER: OnceLock<IntegrityWorker> = OnceLock::new();
pub fn get_integrity_worker() -> &'static IntegrityWorker {
INTEGRITY_WORKER.get_or_init(|| IntegrityWorker::new(1))
}
#[pg_extern]
pub fn ruvector_integrity_worker_status() -> pgrx::JsonB {
let worker = get_integrity_worker();
pgrx::JsonB(worker.stats())
}
#[pg_extern]
pub fn ruvector_integrity_register(collection_id: i32) -> pgrx::JsonB {
let worker = get_integrity_worker();
worker.register_collection(collection_id);
pgrx::JsonB(serde_json::json!({
"success": true,
"collection_id": collection_id,
"registered": true,
}))
}
#[pg_extern]
pub fn ruvector_integrity_unregister(collection_id: i32) -> pgrx::JsonB {
let worker = get_integrity_worker();
worker.unregister_collection(collection_id);
pgrx::JsonB(serde_json::json!({
"success": true,
"collection_id": collection_id,
"registered": false,
}))
}
#[pg_extern]
pub fn ruvector_integrity_sample(collection_id: i32) -> pgrx::JsonB {
let worker = get_integrity_worker();
worker.register_collection(collection_id);
match worker.sample_collection(collection_id) {
Ok(()) => {
let state = worker.get_state(collection_id);
pgrx::JsonB(serde_json::json!({
"success": true,
"collection_id": collection_id,
"state": state.map(|s| serde_json::json!({
"state": s.state.to_string(),
"lambda_cut": s.lambda_cut,
"sample_count": s.sample_count,
})),
}))
}
Err(e) => pgrx::JsonB(serde_json::json!({
"success": false,
"error": e,
})),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_integrity_config_default() {
let config = IntegrityConfig::default();
assert_eq!(config.sample_interval_secs, 60);
assert!((config.threshold_high - 0.7).abs() < 0.01);
}
#[test]
fn test_integrity_state_type() {
assert_eq!(
IntegrityStateType::from_lambda(0.8, 0.7, 0.3),
IntegrityStateType::Normal
);
assert_eq!(
IntegrityStateType::from_lambda(0.5, 0.7, 0.3),
IntegrityStateType::Stress
);
assert_eq!(
IntegrityStateType::from_lambda(0.2, 0.7, 0.3),
IntegrityStateType::Critical
);
}
#[test]
fn test_integrity_state_record() {
let state = IntegrityStateRecord::new(1);
assert_eq!(state.collection_id, 1);
assert_eq!(state.state, IntegrityStateType::Normal);
assert_eq!(state.sample_count, 0);
}
#[test]
fn test_build_sample_edges() {
let edges = build_sample_edges(5);
assert!(!edges.is_empty());
}
#[test]
fn test_integrity_worker_registration() {
let worker = IntegrityWorker::new(1);
worker.register_collection(42);
assert!(worker.get_state(42).is_some());
worker.unregister_collection(42);
assert!(worker.get_state(42).is_none());
}
}