use std::collections::HashMap;
use anyhow::Result;
use async_trait::async_trait;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Command {
pub command_id: String,
pub execution_id: i64,
pub step: String,
pub tool_kind: String,
pub input: serde_json::Value,
#[serde(default)]
pub render_context: HashMap<String, serde_json::Value>,
#[serde(default)]
pub attempts: u32,
}
#[derive(Debug, Clone)]
pub enum ClaimOutcome {
Claimed(Command),
AlreadyClaimed,
RetryLater(String),
Failed(String),
}
#[derive(Debug, Clone)]
pub struct Pulled<H> {
pub outcome: ClaimOutcome,
pub ack: H,
}
#[async_trait]
pub trait CommandSource: Send + Sync {
type AckHandle: Send + Sync;
async fn next(&mut self) -> Result<Option<Pulled<Self::AckHandle>>>;
async fn ack(&self, handle: Self::AckHandle) -> Result<()>;
async fn nack(&self, handle: Self::AckHandle) -> Result<()>;
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct MockSource {
queue: std::collections::VecDeque<ClaimOutcome>,
ack_log: Arc<Mutex<Vec<MockAck>>>,
next_ack_id: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MockAck {
Acked(u64),
Nacked(u64),
}
impl MockSource {
pub fn new(outcomes: Vec<ClaimOutcome>) -> Self {
Self {
queue: outcomes.into(),
ack_log: Arc::new(Mutex::new(Vec::new())),
next_ack_id: 0,
}
}
pub fn ack_log(&self) -> Arc<Mutex<Vec<MockAck>>> {
Arc::clone(&self.ack_log)
}
}
#[async_trait]
impl CommandSource for MockSource {
type AckHandle = u64;
async fn next(&mut self) -> Result<Option<Pulled<u64>>> {
match self.queue.pop_front() {
None => Ok(None),
Some(outcome) => {
let id = self.next_ack_id;
self.next_ack_id += 1;
Ok(Some(Pulled { outcome, ack: id }))
}
}
}
async fn ack(&self, handle: u64) -> Result<()> {
self.ack_log.lock().await.push(MockAck::Acked(handle));
Ok(())
}
async fn nack(&self, handle: u64) -> Result<()> {
self.ack_log.lock().await.push(MockAck::Nacked(handle));
Ok(())
}
}
fn dummy_command(id: &str) -> Command {
Command {
command_id: id.to_string(),
execution_id: 12345,
step: "fetch".to_string(),
tool_kind: "http".to_string(),
input: serde_json::json!({"url": "https://example.com"}),
render_context: HashMap::new(),
attempts: 0,
}
}
#[tokio::test]
async fn empty_source_returns_none() {
let mut source = MockSource::new(vec![]);
assert!(source.next().await.unwrap().is_none());
}
#[tokio::test]
async fn next_yields_in_order_and_increments_handles() {
let mut source = MockSource::new(vec![
ClaimOutcome::Claimed(dummy_command("a")),
ClaimOutcome::Claimed(dummy_command("b")),
]);
let first = source.next().await.unwrap().unwrap();
let second = source.next().await.unwrap().unwrap();
assert_eq!(first.ack, 0);
assert_eq!(second.ack, 1);
if let ClaimOutcome::Claimed(c) = first.outcome {
assert_eq!(c.command_id, "a");
} else {
panic!("expected Claimed");
}
}
#[tokio::test]
async fn ack_and_nack_recorded_in_order() {
let source = MockSource::new(vec![]);
let log = source.ack_log();
source.ack(7).await.unwrap();
source.nack(9).await.unwrap();
source.ack(11).await.unwrap();
let log = log.lock().await;
assert_eq!(
*log,
vec![MockAck::Acked(7), MockAck::Nacked(9), MockAck::Acked(11)]
);
}
#[tokio::test]
async fn already_claimed_outcome_carries_handle() {
let mut source = MockSource::new(vec![ClaimOutcome::AlreadyClaimed]);
let pulled = source.next().await.unwrap().unwrap();
assert!(matches!(pulled.outcome, ClaimOutcome::AlreadyClaimed));
source.ack(pulled.ack).await.unwrap();
let log = source.ack_log.lock().await;
assert_eq!(*log, vec![MockAck::Acked(0)]);
}
#[tokio::test]
async fn retry_later_outcome_carries_error_message() {
let mut source = MockSource::new(vec![
ClaimOutcome::RetryLater("overload".to_string()),
]);
let pulled = source.next().await.unwrap().unwrap();
match pulled.outcome {
ClaimOutcome::RetryLater(msg) => assert_eq!(msg, "overload"),
_ => panic!("expected RetryLater"),
}
}
#[tokio::test]
async fn failed_outcome_carries_error_message() {
let mut source = MockSource::new(vec![
ClaimOutcome::Failed("malformed payload".to_string()),
]);
let pulled = source.next().await.unwrap().unwrap();
match pulled.outcome {
ClaimOutcome::Failed(msg) => assert_eq!(msg, "malformed payload"),
_ => panic!("expected Failed"),
}
}
#[test]
fn command_round_trips_through_serde_with_defaults() {
let json = serde_json::json!({
"command_id": "cmd-1",
"execution_id": 7,
"step": "s",
"tool_kind": "http",
"input": {"url": "https://example.com"},
});
let cmd: Command = serde_json::from_value(json).unwrap();
assert!(cmd.render_context.is_empty());
assert_eq!(cmd.attempts, 0);
}
#[test]
fn command_round_trips_through_serde_with_full_fields() {
let json = serde_json::json!({
"command_id": "cmd-2",
"execution_id": 12345,
"step": "process",
"tool_kind": "rhai",
"input": {"code": "1 + 1"},
"render_context": {"workload.region": "us-east-1"},
"attempts": 3,
});
let cmd: Command = serde_json::from_value(json).unwrap();
assert_eq!(cmd.attempts, 3);
assert_eq!(
cmd.render_context.get("workload.region"),
Some(&serde_json::json!("us-east-1"))
);
}
}