use std::sync::Arc;
use anyhow::{Context, Result};
use async_nats::jetstream::consumer::DeliverPolicy;
use async_nats::jetstream::consumer::pull::Config as PullConfig;
use futures::StreamExt;
use kanade_shared::kv::STREAM_EXEC;
use kanade_shared::wire::Command;
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
use crate::commands::{DedupCache, handle_command};
fn consumer_name(pc_id: &str) -> String {
let safe: String = pc_id
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '_' || c == '-' {
c
} else {
'_'
}
})
.collect();
format!("agent_replay_{safe}")
}
pub fn spawn(
client: async_nats::Client,
pc_id: String,
dedup: Arc<Mutex<DedupCache>>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
if let Err(e) = run(client, pc_id, dedup).await {
error!(error = ?e, "command-replay loop exited with error");
}
})
}
async fn run(
client: async_nats::Client,
pc_id: String,
dedup: Arc<Mutex<DedupCache>>,
) -> Result<()> {
let jetstream = async_nats::jetstream::new(client.clone());
let stream = jetstream
.get_stream(STREAM_EXEC)
.await
.with_context(|| format!("get stream {STREAM_EXEC}"))?;
let name = consumer_name(&pc_id);
let consumer = stream
.get_or_create_consumer(
&name,
PullConfig {
durable_name: Some(name.clone()),
ack_policy: async_nats::jetstream::consumer::AckPolicy::Explicit,
deliver_policy: DeliverPolicy::LastPerSubject,
filter_subject: "commands.>".into(),
..Default::default()
},
)
.await
.with_context(|| format!("ensure consumer {name}"))?;
info!(
stream = STREAM_EXEC,
consumer = %name,
pc_id = %pc_id,
"command-replay consumer ready",
);
let script_current = jetstream
.get_key_value(kanade_shared::kv::BUCKET_SCRIPT_CURRENT)
.await
.ok();
let script_status = jetstream
.get_key_value(kanade_shared::kv::BUCKET_SCRIPT_STATUS)
.await
.ok();
let mut messages = consumer.messages().await.context("messages stream")?;
while let Some(msg) = messages.next().await {
let msg = match msg {
Ok(m) => m,
Err(e) => {
warn!(error = %e, "replay consumer error");
continue;
}
};
let _ = msg.ack().await;
let cmd: Command = match serde_json::from_slice(&msg.payload) {
Ok(c) => c,
Err(e) => {
warn!(error = %e, subject = %msg.subject, "deserialize replay command");
continue;
}
};
if !is_for_me(&msg.subject, &pc_id) {
debug!(subject = %msg.subject, "replay msg not for this pc; dropping");
continue;
}
if !dedup.lock().await.insert(cmd.request_id.clone()) {
debug!(
request_id = %cmd.request_id,
"replay dedup: already seen via core sub or earlier replay",
);
continue;
}
let client_for_task = client.clone();
let pc_for_task = pc_id.clone();
let cur = script_current.clone();
let sta = script_status.clone();
info!(
cmd_id = %cmd.id,
request_id = %cmd.request_id,
subject = %msg.subject,
"replay: handling missed command",
);
tokio::spawn(async move {
if let Err(e) = handle_command(client_for_task, pc_for_task, cmd, cur, sta).await {
error!(error = %e, "replay command handler failed");
}
});
}
Ok(())
}
fn is_for_me(subject: &str, my_pc_id: &str) -> bool {
if subject == kanade_shared::subject::COMMANDS_ALL {
return true;
}
if let Some(pc) = subject.strip_prefix("commands.pc.") {
return pc == my_pc_id;
}
if subject.starts_with("commands.group.") {
return true;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn commands_all_matches_anyone() {
assert!(is_for_me("commands.all", "minipc-01"));
assert!(is_for_me("commands.all", "anything"));
}
#[test]
fn commands_pc_matches_only_owner() {
assert!(is_for_me("commands.pc.minipc-01", "minipc-01"));
assert!(!is_for_me("commands.pc.minipc-02", "minipc-01"));
}
#[test]
fn commands_group_always_accepted() {
assert!(is_for_me("commands.group.canary", "minipc-01"));
}
#[test]
fn unknown_subject_dropped() {
assert!(!is_for_me("commands.weird", "minipc-01"));
assert!(!is_for_me("results.x", "minipc-01"));
}
#[test]
fn consumer_name_sanitises_pc_id() {
assert_eq!(consumer_name("MINIPC-01"), "agent_replay_MINIPC-01");
assert_eq!(consumer_name("PC.001"), "agent_replay_PC_001");
assert_eq!(
consumer_name("host with space"),
"agent_replay_host_with_space"
);
}
}