use serde::{Deserialize, Serialize};
use crate::reasoning::operator::OperatorKind;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskState {
pub goal: Goal,
pub constraints: Vec<Constraint>,
pub context: Context,
pub plan: Option<Plan>,
pub steps: Vec<Step>,
pub observations: Vec<Observation>,
pub actions: Vec<Action>,
pub reflections: Vec<Reflection>,
pub artifacts: Vec<Artifact>,
pub metadata: StateMetadata,
}
impl TaskState {
#[must_use]
pub fn new(goal: impl Into<String>, input: impl Into<String>) -> Self {
Self {
goal: Goal::new(goal),
constraints: Vec::new(),
context: Context::new(input),
plan: None,
steps: Vec::new(),
observations: Vec::new(),
actions: Vec::new(),
reflections: Vec::new(),
artifacts: Vec::new(),
metadata: StateMetadata::default(),
}
}
#[must_use]
pub fn with_constraint(mut self, constraint: Constraint) -> Self {
self.constraints.push(constraint);
self
}
#[must_use]
pub fn with_fact(mut self, fact: impl Into<String>) -> Self {
self.context.facts.push(fact.into());
self
}
#[must_use]
pub fn with_domain(mut self, domain: impl Into<String>) -> Self {
self.context.domain = Some(domain.into());
self
}
#[must_use]
pub fn record_step(mut self, step: Step) -> Self {
self.steps.push(step);
self
}
#[must_use]
pub fn observe(mut self, observation: Observation) -> Self {
self.observations.push(observation);
self
}
#[must_use]
pub fn record_action(mut self, action: Action) -> Self {
self.actions.push(action);
self
}
#[must_use]
pub fn reflect(mut self, reflection: Reflection) -> Self {
self.reflections.push(reflection);
self
}
#[must_use]
pub fn add_artifact(mut self, artifact: Artifact) -> Self {
self.artifacts.push(artifact);
self
}
#[must_use]
pub fn with_plan(mut self, plan: Plan) -> Self {
self.plan = Some(plan);
self
}
#[must_use]
pub fn next_iteration(mut self) -> Self {
self.metadata.iteration += 1;
self
}
#[must_use]
pub fn is_goal_achieved(&self) -> bool {
self.artifacts
.iter()
.any(|a| matches!(a.kind, ArtifactKind::Answer))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Goal {
pub description: String,
pub success_criteria: Vec<String>,
}
impl Goal {
#[must_use]
pub fn new(description: impl Into<String>) -> Self {
Self {
description: description.into(),
success_criteria: Vec::new(),
}
}
#[must_use]
pub fn with_criterion(mut self, criterion: impl Into<String>) -> Self {
self.success_criteria.push(criterion.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Constraint {
pub kind: ConstraintKind,
pub description: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ConstraintKind {
Hard,
Soft,
Preference,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Context {
pub facts: Vec<String>,
pub domain: Option<String>,
pub input: String,
}
impl Context {
#[must_use]
pub fn new(input: impl Into<String>) -> Self {
Self {
facts: Vec::new(),
domain: None,
input: input.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Plan {
pub steps: Vec<PlanStep>,
pub rationale: String,
}
impl Plan {
#[must_use]
pub fn new(rationale: impl Into<String>) -> Self {
Self {
steps: Vec::new(),
rationale: rationale.into(),
}
}
#[must_use]
pub fn with_step(mut self, description: impl Into<String>) -> Self {
let id = self.steps.len();
self.steps.push(PlanStep {
id,
description: description.into(),
status: StepStatus::Pending,
});
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanStep {
pub id: usize,
pub description: String,
pub status: StepStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Step {
pub operator: OperatorKind,
pub input: String,
pub output: Option<String>,
pub status: StepStatus,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum StepStatus {
Pending,
InProgress,
Completed,
Failed,
Skipped,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Observation {
pub source: ObservationSource,
pub content: String,
pub sequence: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ObservationSource {
Tool,
Environment,
SelfInspection,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Action {
pub tool: Option<String>,
pub input: String,
pub output: Option<String>,
}
impl Action {
#[must_use]
pub fn tool_call(tool: impl Into<String>, input: impl Into<String>) -> Self {
Self {
tool: Some(tool.into()),
input: input.into(),
output: None,
}
}
#[must_use]
pub fn reasoning(input: impl Into<String>) -> Self {
Self {
tool: None,
input: input.into(),
output: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Reflection {
pub content: String,
pub revised_plan: Option<Plan>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Artifact {
pub kind: ArtifactKind,
pub content: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ArtifactKind {
Code,
Text,
Data,
File,
Image,
Answer,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct StateMetadata {
pub iteration: usize,
pub observation_sequence: u64,
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn task_state_new_should_set_goal_and_input() {
let state = TaskState::new("solve equation", "x + 2 = 5");
assert_eq!(state.goal.description, "solve equation");
assert_eq!(state.context.input, "x + 2 = 5");
assert!(state.constraints.is_empty());
assert!(state.plan.is_none());
}
#[test]
fn task_state_builder_chain_should_work() {
let state = TaskState::new("goal", "input")
.with_constraint(Constraint {
kind: ConstraintKind::Hard,
description: "must be numeric".into(),
})
.with_fact("x is unknown")
.with_domain("mathematics")
.with_plan(Plan::new("algebraic isolation").with_step("subtract 2"));
assert_eq!(state.constraints.len(), 1);
assert_eq!(state.context.facts.len(), 1);
assert_eq!(state.context.domain.as_deref(), Some("mathematics"));
let plan = state.plan.unwrap();
assert_eq!(plan.steps.len(), 1);
assert_eq!(plan.steps[0].status, StepStatus::Pending);
}
#[test]
fn task_state_record_step_should_accumulate() {
let state = TaskState::new("g", "i").record_step(Step {
operator: OperatorKind::Analyze,
input: "what is this".into(),
output: Some("a problem".into()),
status: StepStatus::Completed,
});
assert_eq!(state.steps.len(), 1);
assert_eq!(state.steps[0].operator, OperatorKind::Analyze);
}
#[test]
fn task_state_is_goal_achieved_without_answer_artifact() {
let state = TaskState::new("g", "i");
assert!(!state.is_goal_achieved());
}
#[test]
fn task_state_is_goal_achieved_with_answer_artifact() {
let state = TaskState::new("g", "i").add_artifact(Artifact {
kind: ArtifactKind::Answer,
content: "42".into(),
});
assert!(state.is_goal_achieved());
}
#[test]
fn task_state_next_iteration_should_increment() {
let state = TaskState::new("g", "i").next_iteration().next_iteration();
assert_eq!(state.metadata.iteration, 2);
}
#[test]
fn observe_should_increment_sequence() {
let mut state = TaskState::new("g", "i");
state.metadata.observation_sequence = 0;
let seq = state.metadata.observation_sequence + 1;
let state = state.observe(Observation {
source: ObservationSource::Tool,
content: "result".into(),
sequence: seq,
});
assert_eq!(state.observations.len(), 1);
assert_eq!(state.observations[0].sequence, 1);
}
#[test]
fn action_tool_call_and_reasoning() {
let tool_action = Action::tool_call("calculator", "2+2");
assert_eq!(tool_action.tool.as_deref(), Some("calculator"));
assert!(tool_action.output.is_none());
let reason_action = Action::reasoning("thinking...");
assert!(reason_action.tool.is_none());
}
#[test]
fn plan_with_step_should_auto_increment_ids() {
let plan = Plan::new("rationale")
.with_step("first")
.with_step("second")
.with_step("third");
assert_eq!(plan.steps[0].id, 0);
assert_eq!(plan.steps[1].id, 1);
assert_eq!(plan.steps[2].id, 2);
}
#[test]
fn task_state_should_be_serializable() {
let state = TaskState::new("solve", "x = 1")
.with_fact("x is positive")
.add_artifact(Artifact {
kind: ArtifactKind::Answer,
content: "x = 1".into(),
});
let json = serde_json::to_string(&state).unwrap();
let deserialized: TaskState = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.goal.description, "solve");
assert_eq!(deserialized.artifacts.len(), 1);
}
}