use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use crate::base::{Broker, ServerState};
use crate::proto::{ServerInfo, TaskMessage, WorkerInfo};
#[derive(Debug, Clone)]
pub struct HeartbeatMeta {
pub host: String,
pub pid: i32,
pub server_uuid: String,
pub concurrency: usize,
pub queues: HashMap<String, i32>,
pub strict_priority: bool,
pub started: SystemTime,
pub acl_tenant: Option<String>,
}
impl From<(&HeartbeatMeta, i32)> for ServerInfo {
fn from(value: (&HeartbeatMeta, i32)) -> Self {
let (meta, active_worker_count) = value;
ServerInfo {
host: meta.host.clone(),
pid: meta.pid,
server_id: meta.server_uuid.clone(),
concurrency: meta.concurrency as i32,
queues: meta.queues.clone(),
strict_priority: meta.strict_priority,
status: ServerState::Active.as_str().to_string(),
start_time: Some(prost_types::Timestamp::from(meta.started)),
active_worker_count,
}
}
}
#[derive(Clone)]
pub struct WorkerInfoEntry {
pub msg: TaskMessage,
pub started: SystemTime,
pub deadline: SystemTime,
}
#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
pub enum WorkerEvent {
Started(WorkerInfoEntry),
Finished(String),
}
pub struct Heartbeat {
broker: Arc<dyn Broker>,
interval: Duration,
meta: HeartbeatMeta,
workers: RwLock<HashMap<String, WorkerInfoEntry>>,
shutting_down: Arc<AtomicBool>,
event_rx: Mutex<Option<mpsc::Receiver<WorkerEvent>>>,
}
#[derive(Clone)]
pub struct WorkerEventSender {
event_tx: mpsc::Sender<WorkerEvent>,
}
impl WorkerEventSender {
pub async fn send_started(
&self,
entry: WorkerInfoEntry,
) -> Result<(), mpsc::error::SendError<WorkerEvent>> {
self.event_tx.send(WorkerEvent::Started(entry)).await
}
pub async fn send_finished(
&self,
task_id: String,
) -> Result<(), mpsc::error::SendError<WorkerEvent>> {
self.event_tx.send(WorkerEvent::Finished(task_id)).await
}
}
impl Heartbeat {
pub fn new(
broker: Arc<dyn Broker>,
interval: Duration,
meta: HeartbeatMeta,
) -> (Self, WorkerEventSender) {
let (event_tx, event_rx) = mpsc::channel::<WorkerEvent>(100);
let heartbeat = Self {
broker,
interval,
meta,
workers: Default::default(),
shutting_down: Arc::new(AtomicBool::new(false)),
event_rx: Mutex::new(Some(event_rx)),
};
let sender = WorkerEventSender { event_tx };
(heartbeat, sender)
}
pub fn start(self: Arc<Self>) -> JoinHandle<()> {
tracing::info!("starting heartbeat");
let this = self;
tokio::spawn(async move {
let mut event_rx = {
let mut guard = this.event_rx.lock().await;
match guard.take() {
Some(rx) => rx,
None => {
tracing::warn!("Heartbeat already started, skipping duplicate start");
return;
}
}
};
let mut ticker = tokio::time::interval(this.interval);
loop {
tokio::select! {
_ = ticker.tick() => {
if this.shutting_down.load(Ordering::Relaxed) {
break;
}
this.beat().await;
}
Some(event) = event_rx.recv() => {
match event {
WorkerEvent::Started(worker_info) => {
let task_id = worker_info.msg.id.clone();
this.workers.write().await.insert(task_id.clone(), worker_info);
tracing::debug!("Worker started: task_id={}", task_id);
}
WorkerEvent::Finished(task_id) => {
this.workers.write().await.remove(&task_id);
tracing::debug!("Worker finished: task_id={}", task_id);
}
}
}
}
}
let _ = this
.broker
.clear_server_state(
&this.meta.host,
this.meta.pid,
&this.meta.server_uuid,
this.meta.acl_tenant.as_deref(),
)
.await;
})
}
async fn beat(&self) {
let workers = self.workers.read().await;
let worker_infos: Vec<WorkerInfo> = workers
.values()
.map(|w| WorkerInfo {
host: self.meta.host.clone(),
pid: self.meta.pid,
server_id: self.meta.server_uuid.clone(),
task_id: w.msg.id.clone(),
task_type: w.msg.r#type.clone(),
task_payload: w.msg.payload.clone(),
queue: w.msg.queue.clone(),
start_time: Some(prost_types::Timestamp::from(w.started)),
deadline: Some(prost_types::Timestamp::from(w.deadline)),
})
.collect();
let active_worker_count = workers.len() as i32;
drop(workers);
let info: ServerInfo = (&self.meta, active_worker_count).into();
if let Err(e) = self
.broker
.write_server_state(
&info,
worker_infos,
self.interval * 2,
self.meta.acl_tenant.as_deref(),
)
.await
{
tracing::warn!("Heartbeat write failed: {}", e);
}
}
pub fn shutdown(&self) {
self.shutting_down.store(true, Ordering::Relaxed);
}
pub fn is_done(&self) -> bool {
self.shutting_down.load(Ordering::Relaxed)
}
pub async fn active_worker_count(&self) -> usize {
self.workers.read().await.len()
}
}
impl crate::components::ComponentLifecycle for Heartbeat {
fn start(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
Heartbeat::start(self)
}
fn shutdown(&self) {
Heartbeat::shutdown(self)
}
fn is_done(&self) -> bool {
Heartbeat::is_done(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_task_message(id: &str) -> TaskMessage {
TaskMessage {
id: id.to_string(),
r#type: "test:task".to_string(),
payload: b"test payload".to_vec(),
queue: "default".to_string(),
retry: 3,
retried: 0,
timeout: 3600,
deadline: 0,
..Default::default()
}
}
#[test]
fn test_worker_info_entry_creation() {
let msg = create_test_task_message("task-1");
let started = SystemTime::now();
let deadline = started + Duration::from_secs(3600);
let entry = WorkerInfoEntry {
msg: msg.clone(),
started,
deadline,
};
assert_eq!(entry.msg.id, "task-1");
assert_eq!(entry.msg.r#type, "test:task");
assert_eq!(entry.msg.queue, "default");
}
#[test]
fn test_heartbeat_meta_creation() {
let mut queues = HashMap::new();
queues.insert("default".to_string(), 1);
queues.insert("critical".to_string(), 6);
let meta = HeartbeatMeta {
host: "localhost".to_string(),
pid: 1234,
server_uuid: "test-uuid".to_string(),
concurrency: 10,
queues,
strict_priority: false,
started: SystemTime::now(),
acl_tenant: None,
};
assert_eq!(meta.host, "localhost");
assert_eq!(meta.pid, 1234);
assert_eq!(meta.concurrency, 10);
assert!(!meta.strict_priority);
assert_eq!(meta.queues.len(), 2);
}
#[test]
fn test_server_info_from_heartbeat_meta() {
let mut queues = HashMap::new();
queues.insert("default".to_string(), 1);
let meta = HeartbeatMeta {
host: "test-host".to_string(),
pid: 5678,
server_uuid: "server-uuid".to_string(),
concurrency: 5,
queues,
strict_priority: true,
started: SystemTime::now(),
acl_tenant: None,
};
let active_worker_count = 3;
let server_info: ServerInfo = (&meta, active_worker_count).into();
assert_eq!(server_info.host, "test-host");
assert_eq!(server_info.pid, 5678);
assert_eq!(server_info.server_id, "server-uuid");
assert_eq!(server_info.concurrency, 5);
assert!(server_info.strict_priority);
assert_eq!(server_info.active_worker_count, 3);
assert_eq!(server_info.status, "active");
}
#[tokio::test]
async fn test_worker_event_sender() {
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<WorkerEvent>(10);
let sender = WorkerEventSender { event_tx };
let msg = create_test_task_message("task-123");
let worker_info = WorkerInfoEntry {
msg: msg.clone(),
started: SystemTime::now(),
deadline: SystemTime::now() + Duration::from_secs(3600),
};
sender.send_started(worker_info).await.unwrap();
let received = event_rx.recv().await.unwrap();
match received {
WorkerEvent::Started(entry) => {
assert_eq!(entry.msg.id, "task-123");
}
_ => panic!("Expected Started event"),
}
sender.send_finished("task-456".to_string()).await.unwrap();
let finished = event_rx.recv().await.unwrap();
match finished {
WorkerEvent::Finished(task_id) => {
assert_eq!(task_id, "task-456");
}
_ => panic!("Expected Finished event"),
}
}
#[tokio::test]
async fn test_worker_event_sender_clone() {
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<WorkerEvent>(10);
let sender1 = WorkerEventSender { event_tx };
let sender2 = sender1.clone();
sender1.send_finished("task-1".to_string()).await.unwrap();
sender2.send_finished("task-2".to_string()).await.unwrap();
let event1 = event_rx.recv().await.unwrap();
let event2 = event_rx.recv().await.unwrap();
match (event1, event2) {
(WorkerEvent::Finished(id1), WorkerEvent::Finished(id2)) => {
assert_eq!(id1, "task-1");
assert_eq!(id2, "task-2");
}
_ => panic!("Expected two Finished events"),
}
}
}