use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, SystemTime};
use tokio::sync::{mpsc, Semaphore};
use uuid::Uuid;
use crate::real_time_embedding_pipeline::{
config::PipelineConfig,
types::{UpdateOperation, UpdateBatch, CoordinationState, NodeStatus},
traits::ProcessingPriority,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoordinationConfig {
pub max_concurrent_updates: usize,
pub batch_timeout: Duration,
pub heartbeat_interval: Duration,
pub failure_timeout: Duration,
pub max_queue_size: usize,
pub enable_distributed: bool,
}
impl Default for CoordinationConfig {
fn default() -> Self {
Self {
max_concurrent_updates: 10,
batch_timeout: Duration::from_millis(100),
heartbeat_interval: Duration::from_secs(5),
failure_timeout: Duration::from_secs(30),
max_queue_size: 10000,
enable_distributed: false,
}
}
}
pub struct UpdateCoordinator {
config: CoordinationConfig,
semaphore: Arc<Semaphore>,
update_sender: Arc<RwLock<Option<mpsc::UnboundedSender<UpdateOperation>>>>,
state: Arc<RwLock<CoordinationState>>,
is_running: AtomicBool,
processed_count: AtomicU64,
failed_count: AtomicU64,
active_workers: AtomicU64,
}
impl UpdateCoordinator {
pub fn new(pipeline_config: &PipelineConfig) -> Result<Self> {
let config = CoordinationConfig {
max_concurrent_updates: pipeline_config.max_concurrent_updates,
batch_timeout: Duration::from_millis(pipeline_config.batch_timeout_ms),
..Default::default()
};
let node_id = Uuid::new_v4().to_string();
let state = Arc::new(RwLock::new(CoordinationState {
node_id: node_id.clone(),
leader_id: Some(node_id), status: NodeStatus::Active,
last_heartbeat: SystemTime::now(),
load_factor: 0.0,
active_tasks: 0,
}));
Ok(Self {
semaphore: Arc::new(Semaphore::new(config.max_concurrent_updates)),
config,
update_sender: Arc::new(RwLock::new(None)),
state,
is_running: AtomicBool::new(false),
processed_count: AtomicU64::new(0),
failed_count: AtomicU64::new(0),
active_workers: AtomicU64::new(0),
})
}
pub async fn start(&self) -> Result<()> {
if self.is_running.load(Ordering::Acquire) {
return Err(anyhow::anyhow!("Update coordinator is already running"));
}
self.is_running.store(true, Ordering::Release);
let (sender, mut receiver) = mpsc::unbounded_channel::<UpdateOperation>();
{
let mut sender_guard = self.update_sender.write()
.map_err(|_| anyhow::anyhow!("Failed to acquire sender lock"))?;
*sender_guard = Some(sender);
}
let coordinator = self.clone_for_task();
tokio::spawn(async move {
while let Some(operation) = receiver.recv().await {
if let Err(e) = coordinator.process_update_operation(operation).await {
eprintln!("Error processing update operation: {}", e);
coordinator.failed_count.fetch_add(1, Ordering::Relaxed);
}
}
});
self.start_heartbeat_task().await;
Ok(())
}
pub async fn stop(&self) -> Result<()> {
self.is_running.store(false, Ordering::Release);
{
let mut sender_guard = self.update_sender.write()
.map_err(|_| anyhow::anyhow!("Failed to acquire sender lock"))?;
*sender_guard = None;
}
while self.active_workers.load(Ordering::Acquire) > 0 {
tokio::time::sleep(Duration::from_millis(10)).await;
}
{
let mut state = self.state.write()
.map_err(|_| anyhow::anyhow!("Failed to acquire state lock"))?;
state.status = NodeStatus::Leaving;
}
Ok(())
}
pub async fn submit_update(&self, operation: UpdateOperation) -> Result<()> {
if !self.is_running.load(Ordering::Acquire) {
return Err(anyhow::anyhow!("Update coordinator is not running"));
}
let sender_guard = self.update_sender.read()
.map_err(|_| anyhow::anyhow!("Failed to acquire sender lock"))?;
if let Some(sender) = sender_guard.as_ref() {
sender.send(operation)
.map_err(|_| anyhow::anyhow!("Failed to send update operation"))?;
} else {
return Err(anyhow::anyhow!("Update coordinator not properly initialized"));
}
Ok(())
}
pub async fn submit_batch(&self, batch: UpdateBatch) -> Result<()> {
for operation in batch.operations {
self.submit_update(operation).await?;
}
Ok(())
}
pub fn get_statistics(&self) -> CoordinationStatistics {
CoordinationStatistics {
processed_count: self.processed_count.load(Ordering::Acquire),
failed_count: self.failed_count.load(Ordering::Acquire),
active_workers: self.active_workers.load(Ordering::Acquire) as usize,
queue_size: self.get_queue_size(),
is_running: self.is_running.load(Ordering::Acquire),
}
}
pub fn get_state(&self) -> Result<CoordinationState> {
let state = self.state.read()
.map_err(|_| anyhow::anyhow!("Failed to acquire state lock"))?;
Ok(state.clone())
}
pub fn is_running(&self) -> bool {
self.is_running.load(Ordering::Acquire)
}
fn clone_for_task(&self) -> Self {
Self {
config: self.config.clone(),
semaphore: self.semaphore.clone(),
update_sender: self.update_sender.clone(),
state: self.state.clone(),
is_running: AtomicBool::new(self.is_running.load(Ordering::Acquire)),
processed_count: AtomicU64::new(self.processed_count.load(Ordering::Acquire)),
failed_count: AtomicU64::new(self.failed_count.load(Ordering::Acquire)),
active_workers: AtomicU64::new(self.active_workers.load(Ordering::Acquire)),
}
}
async fn process_update_operation(&self, operation: UpdateOperation) -> Result<()> {
let permit = self.semaphore.acquire().await
.map_err(|_| anyhow::anyhow!("Failed to acquire processing permit"))?;
self.active_workers.fetch_add(1, Ordering::Relaxed);
let result = self.execute_update_operation(operation).await;
self.active_workers.fetch_sub(1, Ordering::Relaxed);
if result.is_ok() {
self.processed_count.fetch_add(1, Ordering::Relaxed);
}
drop(permit);
result
}
async fn execute_update_operation(&self, operation: UpdateOperation) -> Result<()> {
match operation {
UpdateOperation::Insert { id, content } => {
tokio::time::sleep(Duration::from_millis(1)).await;
println!("Processed insert for id: {}", id);
}
UpdateOperation::Update { id, content, version } => {
tokio::time::sleep(Duration::from_millis(1)).await;
println!("Processed update for id: {} (version: {:?})", id, version);
}
UpdateOperation::Delete { id } => {
tokio::time::sleep(Duration::from_millis(1)).await;
println!("Processed delete for id: {}", id);
}
UpdateOperation::Batch { operations } => {
for op in operations {
self.execute_update_operation(op).await?;
}
}
}
Ok(())
}
async fn start_heartbeat_task(&self) {
let state = self.state.clone();
let heartbeat_interval = self.config.heartbeat_interval;
let is_running = self.is_running.clone();
tokio::spawn(async move {
while is_running.load(Ordering::Acquire) {
{
if let Ok(mut state_guard) = state.write() {
state_guard.last_heartbeat = SystemTime::now();
state_guard.load_factor = 0.5; }
}
tokio::time::sleep(heartbeat_interval).await;
}
});
}
fn get_queue_size(&self) -> usize {
0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoordinationStatistics {
pub processed_count: u64,
pub failed_count: u64,
pub active_workers: usize,
pub queue_size: usize,
pub is_running: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::real_time_embedding_pipeline::config::PipelineConfig;
#[tokio::test]
async fn test_coordinator_creation() {
let config = PipelineConfig::default();
let coordinator = UpdateCoordinator::new(&config);
assert!(coordinator.is_ok());
}
#[tokio::test]
async fn test_coordinator_start_stop() -> Result<()> {
let config = PipelineConfig::default();
let coordinator = UpdateCoordinator::new(&config)?;
assert!(!coordinator.is_running());
let start_result = coordinator.start().await;
assert!(start_result.is_ok());
assert!(coordinator.is_running());
let stop_result = coordinator.stop().await;
assert!(stop_result.is_ok());
Ok(())
}
#[tokio::test]
async fn test_update_submission() -> Result<()> {
let config = PipelineConfig::default();
let coordinator = UpdateCoordinator::new(&config)?;
coordinator.start().await?;
let operation = UpdateOperation::Insert {
id: "test_id".to_string(),
content: "test_content".to_string(),
};
let result = coordinator.submit_update(operation).await;
assert!(result.is_ok());
coordinator.stop().await?;
Ok(())
}
}