use std::collections::BTreeMap;
use std::fmt;
use std::string::String;
use std::vec::Vec;
use crate::hash::Hash;
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct AgentState {
pub version: u64,
pub data: BTreeMap<String, StateData>,
pub state_hash: Hash,
}
impl AgentState {
#[must_use]
pub fn initial() -> Self {
Self {
version: 0,
data: BTreeMap::new(),
state_hash: Hash::zero(),
}
}
#[must_use]
pub fn with_run_id(run_id: u64) -> Self {
let mut state = Self::initial();
state.data.insert(
"system".to_string(),
StateData::Value(StateValue::U64(run_id)),
);
state.rehash();
state
}
#[must_use]
pub fn get(&self, domain: &str) -> Option<&StateData> {
self.data.get(domain)
}
pub fn set(&mut self, domain: impl Into<String>, data: StateData) {
self.data.insert(domain.into(), data);
self.version += 1;
self.rehash();
}
pub fn remove(&mut self, domain: &str) -> Option<StateData> {
let result = self.data.remove(domain);
if result.is_some() {
self.version += 1;
self.rehash();
}
result
}
fn rehash(&mut self) {
self.state_hash = Hash::from_canonical(&self.data);
}
#[must_use]
pub fn matches(&self, other: &Self) -> bool {
self.state_hash == other.state_hash
}
#[must_use]
pub const fn hash(&self) -> Hash {
self.state_hash
}
#[must_use]
pub const fn version(&self) -> u64 {
self.version
}
}
impl Default for AgentState {
fn default() -> Self {
Self::initial()
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum StateData {
Value(StateValue),
Map(BTreeMap<String, StateValue>),
Vec(Vec<StateValue>),
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum StateValue {
String(String),
Bool(bool),
U64(u64),
I64(i64),
Bytes(Vec<u8>),
Hash(Hash),
None,
}
impl StateValue {
#[must_use]
pub fn type_name(&self) -> &str {
match self {
StateValue::String(_) => "string",
StateValue::Bool(_) => "bool",
StateValue::U64(_) => "u64",
StateValue::I64(_) => "i64",
StateValue::Bytes(_) => "bytes",
StateValue::Hash(_) => "hash",
StateValue::None => "none",
}
}
}
impl From<String> for StateValue {
fn from(s: String) -> Self {
Self::String(s)
}
}
impl From<&str> for StateValue {
fn from(s: &str) -> Self {
Self::String(s.to_string())
}
}
impl From<bool> for StateValue {
fn from(b: bool) -> Self {
Self::Bool(b)
}
}
impl From<u64> for StateValue {
fn from(n: u64) -> Self {
Self::U64(n)
}
}
impl From<i64> for StateValue {
fn from(n: i64) -> Self {
Self::I64(n)
}
}
impl From<Hash> for StateValue {
fn from(h: Hash) -> Self {
Self::Hash(h)
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct StateTransition {
pub from_hash: Hash,
pub to_state: AgentState,
pub event_hash: Hash,
pub transition_hash: Hash,
}
impl StateTransition {
#[must_use]
pub fn new(from_hash: Hash, to_state: AgentState, event_hash: Hash) -> Self {
let transition_hash = crate::hash::transition_hash(from_hash, event_hash, to_state.hash());
Self {
from_hash,
to_state,
event_hash,
transition_hash,
}
}
#[must_use]
pub const fn hash(&self) -> Hash {
self.transition_hash
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct Observation {
pub kind: String,
pub data: BTreeMap<String, StateValue>,
pub logical_time: u64,
}
impl Observation {
#[must_use]
pub fn new(kind: impl Into<String>, logical_time: u64) -> Self {
Self {
kind: kind.into(),
data: BTreeMap::new(),
logical_time,
}
}
pub fn with_field(mut self, key: impl Into<String>, value: StateValue) -> Self {
self.data.insert(key.into(), value);
self
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum Decision {
None,
ToolCall(ToolCall),
PatchProposal(PatchProposal),
Multiple(Vec<Decision>),
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ToolCall {
pub tool_name: String,
pub tool_version: String,
pub input: String,
pub capabilities: Vec<String>,
}
impl ToolCall {
#[must_use]
pub fn new(
tool_name: impl Into<String>,
tool_version: impl Into<String>,
input: impl Into<String>,
) -> Self {
Self {
tool_name: tool_name.into(),
tool_version: tool_version.into(),
input: input.into(),
capabilities: Vec::new(),
}
}
pub fn with_capability(mut self, capability: impl Into<String>) -> Self {
self.capabilities.push(capability.into());
self
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PatchProposal {
pub patch_type: PatchType,
pub target: String,
pub patch: String,
pub reasoning: String,
pub test_requirements: Vec<String>,
}
impl PatchProposal {
#[must_use]
pub fn new(
patch_type: PatchType,
target: impl Into<String>,
patch: impl Into<String>,
reasoning: impl Into<String>,
) -> Self {
Self {
patch_type,
target: target.into(),
patch: patch.into(),
reasoning: reasoning.into(),
test_requirements: Vec::new(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum PatchType {
Prompt,
Policy,
Routing,
Config,
Tools,
Other(String),
}
pub trait StateMachine: Send + Sync {
fn transition(
&self,
state: &AgentState,
observation: &Observation,
tool_responses: &[ToolResponse],
context: &ExecutionContext,
) -> StateResult<Transition>;
fn initial_state(&self) -> AgentState;
}
pub type StateResult<T> = core::result::Result<T, StateError>;
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum StateError {
InvalidTransition { from: String, to: String },
MissingState(String),
Corrupted(String),
InvalidObservation(String),
TransitionFailed(String),
}
impl fmt::Display for StateError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StateError::InvalidTransition { from, to } => {
write!(f, "Invalid transition: {} -> {}", from, to)
}
StateError::MissingState(s) => write!(f, "Missing state: {}", s),
StateError::Corrupted(s) => write!(f, "State corrupted: {}", s),
StateError::InvalidObservation(s) => write!(f, "Invalid observation: {}", s),
StateError::TransitionFailed(s) => write!(f, "Transition failed: {}", s),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct Transition {
pub state: AgentState,
pub decision: Decision,
pub transition_hash: Hash,
}
impl Transition {
#[must_use]
pub fn new(state: AgentState, decision: Decision) -> Self {
let transition_hash = state.hash(); Self {
state,
decision,
transition_hash,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ExecutionContext {
pub logical_time: u64,
pub run_id: u64,
pub seed: Option<u64>,
}
impl ExecutionContext {
#[must_use]
pub fn new(logical_time: u64, run_id: u64) -> Self {
Self {
logical_time,
run_id,
seed: None,
}
}
#[must_use]
pub fn with_seed(logical_time: u64, run_id: u64, seed: u64) -> Self {
Self {
logical_time,
run_id,
seed: Some(seed),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ToolResponse {
pub tool_name: String,
pub data: StateData,
pub success: bool,
pub error: Option<String>,
}
impl ToolResponse {
#[must_use]
pub fn success(tool_name: impl Into<String>, data: StateData) -> Self {
Self {
tool_name: tool_name.into(),
data,
success: true,
error: None,
}
}
#[must_use]
pub fn error(tool_name: impl Into<String>, error: impl Into<String>) -> Self {
Self {
tool_name: tool_name.into(),
data: StateData::Value(StateValue::None),
success: false,
error: Some(error.into()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_state_initial() {
let state = AgentState::initial();
assert_eq!(state.version, 0);
assert!(state.data.is_empty());
}
#[test]
fn test_agent_state_set_get() {
let mut state = AgentState::initial();
state.set("test", StateData::Value(StateValue::U64(42)));
assert!(state.get("test").is_some());
assert_eq!(state.version, 1);
}
#[test]
fn test_agent_state_rehash() {
let mut state1 = AgentState::initial();
state1.set("key", StateData::Value(StateValue::String("value".to_string())));
let mut state2 = AgentState::initial();
state2.set("key", StateData::Value(StateValue::String("value".to_string())));
assert_eq!(state1.hash(), state2.hash());
}
#[test]
fn test_observation() {
let obs = Observation::new("test_observation", 100)
.with_field("value", StateValue::U64(42));
assert_eq!(obs.kind, "test_observation");
assert_eq!(obs.logical_time, 100);
assert_eq!(obs.data.len(), 1);
}
}