use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use zeromq::{Socket, SocketSend};
use super::monitor::{AgentState, AgentStatus, SpawnSession, SpawnStats};
pub mod topics {
pub const SESSION: &str = "session";
pub const AGENT: &str = "agent";
pub const OUTPUT: &str = "output";
pub const WAVE: &str = "wave";
pub const STATS: &str = "stats";
pub const HEARTBEAT: &str = "heartbeat";
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionSnapshot {
pub session_name: String,
pub tag: String,
pub terminal: String,
pub created_at: String,
pub working_dir: String,
pub agents: Vec<AgentSnapshot>,
pub stats: StatsSnapshot,
pub timestamp: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentSnapshot {
pub task_id: String,
pub task_title: String,
pub window_name: String,
pub status: String,
pub started_at: String,
pub tag: String,
}
impl From<&AgentState> for AgentSnapshot {
fn from(agent: &AgentState) -> Self {
Self {
task_id: agent.task_id.clone(),
task_title: agent.task_title.clone(),
window_name: agent.window_name.clone(),
status: format!("{:?}", agent.status).to_lowercase(),
started_at: agent.started_at.clone(),
tag: agent.tag.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentUpdate {
pub task_id: String,
pub status: String,
pub previous_status: Option<String>,
pub timestamp: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputMessage {
pub task_id: String,
pub lines: Vec<String>,
pub line_count: usize,
pub timestamp: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WaveUpdate {
pub waves: Vec<WaveSnapshot>,
pub ready_count: usize,
pub running_count: usize,
pub done_count: usize,
pub blocked_count: usize,
pub timestamp: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WaveSnapshot {
pub number: usize,
pub tasks: Vec<TaskSnapshot>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskSnapshot {
pub id: String,
pub title: String,
pub state: String,
pub complexity: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatsSnapshot {
pub session_name: String,
pub tag: String,
pub total_agents: usize,
pub starting: usize,
pub running: usize,
pub completed: usize,
pub failed: usize,
pub timestamp: String,
}
impl From<&SpawnStats> for StatsSnapshot {
fn from(stats: &SpawnStats) -> Self {
Self {
session_name: stats.session_name.clone(),
tag: stats.tag.clone(),
total_agents: stats.total_agents,
starting: stats.starting,
running: stats.running,
completed: stats.completed,
failed: stats.failed,
timestamp: chrono::Utc::now().to_rfc3339(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Heartbeat {
pub session_name: String,
pub uptime_secs: u64,
pub message_count: u64,
pub timestamp: String,
}
#[derive(Debug)]
pub enum FeedMessage {
Session(SessionSnapshot),
AgentUpdate(AgentUpdate),
Output(OutputMessage),
WaveUpdate(WaveUpdate),
Stats(StatsSnapshot),
Heartbeat(Heartbeat),
Shutdown,
}
#[derive(Debug, Clone)]
pub struct FeedConfig {
pub endpoint: String,
pub heartbeat: bool,
pub heartbeat_interval_secs: u64,
}
impl Default for FeedConfig {
fn default() -> Self {
Self {
endpoint: "tcp://*:5555".to_string(),
heartbeat: true,
heartbeat_interval_secs: 5,
}
}
}
impl FeedConfig {
pub fn tcp(port: u16) -> Self {
Self {
endpoint: format!("tcp://*:{}", port),
..Default::default()
}
}
pub fn ipc(path: &str) -> Self {
Self {
endpoint: format!("ipc://{}", path),
..Default::default()
}
}
pub fn from_endpoint(endpoint: &str) -> Self {
Self {
endpoint: endpoint.to_string(),
..Default::default()
}
}
}
#[derive(Clone)]
pub struct FeedHandle {
tx: mpsc::Sender<FeedMessage>,
}
impl FeedHandle {
pub async fn publish_session(&self, snapshot: SessionSnapshot) {
let _ = self.tx.send(FeedMessage::Session(snapshot)).await;
}
pub async fn publish_agent_update(&self, update: AgentUpdate) {
let _ = self.tx.send(FeedMessage::AgentUpdate(update)).await;
}
pub async fn publish_output(&self, output: OutputMessage) {
let _ = self.tx.send(FeedMessage::Output(output)).await;
}
pub async fn publish_wave_update(&self, update: WaveUpdate) {
let _ = self.tx.send(FeedMessage::WaveUpdate(update)).await;
}
pub async fn publish_stats(&self, stats: StatsSnapshot) {
let _ = self.tx.send(FeedMessage::Stats(stats)).await;
}
pub async fn shutdown(&self) {
let _ = self.tx.send(FeedMessage::Shutdown).await;
}
}
pub struct FeedHandleSync {
tx: mpsc::Sender<FeedMessage>,
runtime: tokio::runtime::Handle,
}
impl FeedHandleSync {
pub fn new(tx: mpsc::Sender<FeedMessage>, runtime: tokio::runtime::Handle) -> Self {
Self { tx, runtime }
}
pub fn publish_session(&self, snapshot: SessionSnapshot) {
let tx = self.tx.clone();
let _ = self
.runtime
.block_on(async move { tx.send(FeedMessage::Session(snapshot)).await });
}
pub fn publish_agent_update(&self, update: AgentUpdate) {
let tx = self.tx.clone();
let _ = self
.runtime
.block_on(async move { tx.send(FeedMessage::AgentUpdate(update)).await });
}
pub fn publish_output(&self, output: OutputMessage) {
let tx = self.tx.clone();
let _ = self
.runtime
.block_on(async move { tx.send(FeedMessage::Output(output)).await });
}
pub fn publish_wave_update(&self, update: WaveUpdate) {
let tx = self.tx.clone();
let _ = self
.runtime
.block_on(async move { tx.send(FeedMessage::WaveUpdate(update)).await });
}
pub fn publish_stats(&self, stats: StatsSnapshot) {
let tx = self.tx.clone();
let _ = self
.runtime
.block_on(async move { tx.send(FeedMessage::Stats(stats)).await });
}
pub fn try_publish_output(&self, output: OutputMessage) {
let _ = self.tx.try_send(FeedMessage::Output(output));
}
pub fn sender(&self) -> mpsc::Sender<FeedMessage> {
self.tx.clone()
}
}
impl Clone for FeedHandleSync {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
runtime: self.runtime.clone(),
}
}
}
pub async fn start_feed(config: FeedConfig) -> Result<(FeedHandle, String)> {
let (tx, rx) = mpsc::channel::<FeedMessage>(1000);
let mut socket = zeromq::PubSocket::new();
socket
.bind(&config.endpoint)
.await
.context(format!("Failed to bind ZMQ socket to {}", config.endpoint))?;
let endpoint = config.endpoint.clone();
tokio::spawn(async move {
run_publisher(socket, rx, config).await;
});
Ok((FeedHandle { tx }, endpoint))
}
pub fn start_feed_sync(config: FeedConfig) -> Result<(FeedHandleSync, String)> {
let rt = tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
.enable_all()
.build()
.context("Failed to create tokio runtime for feed")?;
let handle = rt.handle().clone();
let (tx, rx) = mpsc::channel::<FeedMessage>(1000);
let endpoint = config.endpoint.clone();
let endpoint_clone = endpoint.clone();
let socket = handle.block_on(async {
let mut socket = zeromq::PubSocket::new();
socket
.bind(&endpoint_clone)
.await
.context(format!("Failed to bind ZMQ socket to {}", endpoint_clone))?;
Ok::<_, anyhow::Error>(socket)
})?;
let config_clone = config.clone();
std::thread::spawn(move || {
rt.block_on(async move {
run_publisher(socket, rx, config_clone).await;
});
});
Ok((FeedHandleSync::new(tx, handle), endpoint))
}
async fn run_publisher(
mut socket: zeromq::PubSocket,
mut rx: mpsc::Receiver<FeedMessage>,
config: FeedConfig,
) {
let start_time = std::time::Instant::now();
let mut message_count: u64 = 0;
let mut session_name = String::new();
let heartbeat_interval = tokio::time::Duration::from_secs(config.heartbeat_interval_secs);
let mut heartbeat_timer = tokio::time::interval(heartbeat_interval);
loop {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(FeedMessage::Shutdown) | None => {
break;
}
Some(msg) => {
if let FeedMessage::Session(ref s) = msg {
session_name = s.session_name.clone();
}
let (topic, payload) = serialize_message(&msg);
if let Err(e) = send_message(&mut socket, &topic, &payload).await {
eprintln!("Feed send error: {}", e);
} else {
message_count += 1;
}
}
}
}
_ = heartbeat_timer.tick(), if config.heartbeat => {
let heartbeat = Heartbeat {
session_name: session_name.clone(),
uptime_secs: start_time.elapsed().as_secs(),
message_count,
timestamp: chrono::Utc::now().to_rfc3339(),
};
let payload = serde_json::to_string(&heartbeat).unwrap_or_default();
if let Err(e) = send_message(&mut socket, topics::HEARTBEAT, &payload).await {
eprintln!("Heartbeat send error: {}", e);
}
}
}
}
}
fn serialize_message(msg: &FeedMessage) -> (String, String) {
match msg {
FeedMessage::Session(s) => (
topics::SESSION.to_string(),
serde_json::to_string(s).unwrap_or_default(),
),
FeedMessage::AgentUpdate(u) => (
topics::AGENT.to_string(),
serde_json::to_string(u).unwrap_or_default(),
),
FeedMessage::Output(o) => (
topics::OUTPUT.to_string(),
serde_json::to_string(o).unwrap_or_default(),
),
FeedMessage::WaveUpdate(w) => (
topics::WAVE.to_string(),
serde_json::to_string(w).unwrap_or_default(),
),
FeedMessage::Stats(s) => (
topics::STATS.to_string(),
serde_json::to_string(s).unwrap_or_default(),
),
FeedMessage::Heartbeat(h) => (
topics::HEARTBEAT.to_string(),
serde_json::to_string(h).unwrap_or_default(),
),
FeedMessage::Shutdown => ("shutdown".to_string(), "{}".to_string()),
}
}
async fn send_message(
socket: &mut zeromq::PubSocket,
topic: &str,
payload: &str,
) -> Result<()> {
let message = format!("{} {}", topic, payload);
socket
.send(message.into())
.await
.context("Failed to send ZMQ message")?;
Ok(())
}
pub fn session_to_snapshot(session: &SpawnSession) -> SessionSnapshot {
let stats = SpawnStats::from(session);
SessionSnapshot {
session_name: session.session_name.clone(),
tag: session.tag.clone(),
terminal: session.terminal.clone(),
created_at: session.created_at.clone(),
working_dir: session.working_dir.clone(),
agents: session.agents.iter().map(AgentSnapshot::from).collect(),
stats: StatsSnapshot::from(&stats),
timestamp: chrono::Utc::now().to_rfc3339(),
}
}
pub fn create_output_message(task_id: &str, lines: Vec<String>) -> OutputMessage {
let line_count = lines.len();
OutputMessage {
task_id: task_id.to_string(),
lines,
line_count,
timestamp: chrono::Utc::now().to_rfc3339(),
}
}
pub fn create_agent_update(
task_id: &str,
status: &AgentStatus,
previous: Option<&AgentStatus>,
) -> AgentUpdate {
AgentUpdate {
task_id: task_id.to_string(),
status: format!("{:?}", status).to_lowercase(),
previous_status: previous.map(|s| format!("{:?}", s).to_lowercase()),
timestamp: chrono::Utc::now().to_rfc3339(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_feed_config_tcp() {
let config = FeedConfig::tcp(5555);
assert_eq!(config.endpoint, "tcp://*:5555");
}
#[test]
fn test_feed_config_ipc() {
let config = FeedConfig::ipc("/tmp/scud.sock");
assert_eq!(config.endpoint, "ipc:///tmp/scud.sock");
}
#[test]
fn test_serialize_stats() {
let stats = StatsSnapshot {
session_name: "test".to_string(),
tag: "auth".to_string(),
total_agents: 5,
starting: 1,
running: 2,
completed: 2,
failed: 0,
timestamp: "2024-01-01T00:00:00Z".to_string(),
};
let json = serde_json::to_string(&stats).unwrap();
assert!(json.contains("test"));
assert!(json.contains("auth"));
}
#[test]
fn test_serialize_output() {
let output = OutputMessage {
task_id: "auth:1".to_string(),
lines: vec!["line1".to_string(), "line2".to_string()],
line_count: 2,
timestamp: "2024-01-01T00:00:00Z".to_string(),
};
let json = serde_json::to_string(&output).unwrap();
assert!(json.contains("auth:1"));
assert!(json.contains("line1"));
}
}