use enact_core::context::{ResourceLimits, TenantContext};
use enact_core::kernel::artifact::{
ArtifactStore, ArtifactType, InMemoryArtifactStore, PutArtifactRequest,
};
use enact_core::kernel::ids::{ExecutionId, StepId};
use enact_core::kernel::{
replay, EnforcementMiddleware, EnforcementPolicy, EventLog, ExecutionAction, ExecutionError,
ExecutionKernel, ExecutionState, StepType, TenantId, ViolationType,
};
fn test_tenant() -> TenantContext {
TenantContext::new(TenantId::from("tenant_test"))
}
#[tokio::test]
async fn test_execute_replay_same_state() {
let mut kernel = ExecutionKernel::new(test_tenant());
let exec_id = kernel.execution_id().clone();
let mut event_log = EventLog::new();
kernel.start().unwrap();
event_log.append(ExecutionAction::Start);
let step_id = kernel
.begin_step(StepType::LlmNode, "test_step", None)
.unwrap();
event_log.append(ExecutionAction::StepStarted {
step_id: step_id.clone(),
parent_step_id: None,
step_type: StepType::LlmNode,
name: "test_step".to_string(),
source: None,
});
kernel
.complete_step(step_id.clone(), Some("output".to_string()), 100)
.unwrap();
event_log.append(ExecutionAction::StepCompleted {
step_id: step_id.clone(),
output: Some("output".to_string()),
duration_ms: 100,
});
kernel.complete(Some("final output".to_string())).unwrap();
event_log.append(ExecutionAction::Complete {
output: Some("final output".to_string()),
});
let original_execution = kernel.execution();
let original_state = original_execution.state;
let original_output = original_execution.output.clone();
let original_steps_count = original_execution.steps.len();
let original_step = original_execution.get_step(&step_id).unwrap().clone();
let replayed_execution = replay(exec_id.clone(), event_log.into_actions(), None).unwrap();
assert_eq!(replayed_execution.id.as_str(), exec_id.as_str());
assert_eq!(replayed_execution.state, original_state);
assert_eq!(replayed_execution.output, original_output);
assert_eq!(replayed_execution.steps.len(), original_steps_count);
let replayed_step = replayed_execution.get_step(&step_id).unwrap();
assert_eq!(replayed_step.id, original_step.id);
assert_eq!(replayed_step.name, original_step.name);
assert_eq!(replayed_step.output, original_step.output);
assert_eq!(replayed_step.state, original_step.state);
assert!(replayed_execution.state.is_terminal());
assert_eq!(replayed_execution.state, ExecutionState::Completed);
}
#[tokio::test]
async fn test_policy_violation_blocks_execution() {
let policy = EnforcementPolicy {
warning_threshold: 80,
..Default::default()
};
let middleware = EnforcementMiddleware::with_policy(policy);
let tenant_id = TenantId::new();
let exec_id = ExecutionId::new();
let limits = ResourceLimits {
max_steps: 2,
max_tokens: 1000,
max_wall_time_ms: 300_000,
max_memory_mb: None,
max_concurrent_executions: None,
};
let usage = middleware
.register_execution(exec_id.clone(), tenant_id)
.await;
let check0 = middleware.check_step_allowed(&exec_id, &limits).await;
assert!(
matches!(check0, enact_core::kernel::EnforcementResult::Allowed),
"First check should be allowed"
);
usage.record_step();
let check1 = middleware.check_step_allowed(&exec_id, &limits).await;
match check1 {
enact_core::kernel::EnforcementResult::Allowed
| enact_core::kernel::EnforcementResult::Warning(_) => {
}
_ => panic!(
"Second check should be Allowed or Warning, got {:?}",
check1
),
}
usage.record_step();
let check2 = middleware.check_step_allowed(&exec_id, &limits).await;
match check2 {
enact_core::kernel::EnforcementResult::Blocked(violation) => {
assert_eq!(violation.violation_type, ViolationType::StepLimit);
}
_ => panic!("Expected Blocked result after 2 steps, got {:?}", check2),
}
let mut kernel = ExecutionKernel::new(test_tenant());
kernel.start().unwrap();
let step1 = kernel.begin_step(StepType::LlmNode, "step1", None).unwrap();
kernel
.complete_step(step1, Some("output1".to_string()), 100)
.unwrap();
let step2 = kernel.begin_step(StepType::LlmNode, "step2", None).unwrap();
kernel
.complete_step(step2, Some("output2".to_string()), 100)
.unwrap();
let policy_error = ExecutionError::policy_violation("Step limit exceeded: max 2 steps allowed");
kernel.fail(policy_error.clone()).unwrap();
assert_eq!(kernel.state(), ExecutionState::Failed);
assert!(kernel.execution().error.is_some());
let error = kernel.execution().error.as_ref().unwrap();
assert!(error.is_fatal());
assert!(!error.is_retryable());
}
#[tokio::test]
async fn test_artifact_hash_stable() {
let store = InMemoryArtifactStore::new();
let exec_id = ExecutionId::new();
let step_id = StepId::new();
let content = b"Hello, World! This is a test artifact.".to_vec();
let request1 = PutArtifactRequest::new(
exec_id.clone(),
step_id.clone(),
"test_artifact",
ArtifactType::Text,
content.clone(),
);
let response1 = store.put(request1).await.unwrap();
let retrieved1 = store.get(&response1.artifact_id).await.unwrap();
assert_eq!(retrieved1.content, content);
assert_eq!(retrieved1.metadata.name, "test_artifact");
let request2 = PutArtifactRequest::new(
exec_id.clone(),
step_id.clone(),
"test_artifact_2",
ArtifactType::Text,
content.clone(),
);
let response2 = store.put(request2).await.unwrap();
let retrieved2 = store.get(&response2.artifact_id).await.unwrap();
assert_eq!(retrieved1.content, retrieved2.content);
assert_ne!(
response1.artifact_id, response2.artifact_id,
"Different artifacts should have different IDs"
);
let different_content = b"Different content!".to_vec();
let request3 = PutArtifactRequest::new(
exec_id.clone(),
step_id.clone(),
"different_artifact",
ArtifactType::Text,
different_content.clone(),
);
let response3 = store.put(request3).await.unwrap();
let retrieved3 = store.get(&response3.artifact_id).await.unwrap();
assert_ne!(retrieved1.content, retrieved3.content);
assert_eq!(retrieved3.content, different_content);
let retrieved1_again = store.get(&response1.artifact_id).await.unwrap();
assert_eq!(retrieved1.content, retrieved1_again.content);
assert_eq!(retrieved1.metadata.name, retrieved1_again.metadata.name);
}