use crate::{
Action, ActionId, ActionResult, ActionType, Goal, HierarchicalGoalSolver, LearningConfig,
LearningEngine, Observation, PredictiveConfig, PredictiveModel, StateId,
};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum OperationMode {
Exploration,
Exploitation,
GoalDriven,
#[default]
Adaptive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HopeConfig {
pub learning: LearningConfig,
pub predictive: PredictiveConfig,
pub mode: OperationMode,
pub max_observations: usize,
pub max_actions: usize,
pub anomaly_sensitivity: f64,
pub goal_strategy: GoalSelectionStrategy,
pub auto_decompose_goals: bool,
}
impl Default for HopeConfig {
fn default() -> Self {
Self {
learning: LearningConfig::default(),
predictive: PredictiveConfig::default(),
mode: OperationMode::Adaptive,
max_observations: 1000,
max_actions: 1000,
anomaly_sensitivity: 0.7,
goal_strategy: GoalSelectionStrategy::Priority,
auto_decompose_goals: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GoalSelectionStrategy {
Priority,
Deadline,
Progress,
RoundRobin,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentStats {
pub total_steps: u64,
pub learning_updates: u64,
pub episodes_completed: u64,
pub goals_achieved: u64,
pub goals_failed: u64,
pub anomalies_detected: u64,
pub current_epsilon: f64,
pub avg_reward: f64,
pub success_rate: f64,
}
impl Default for AgentStats {
fn default() -> Self {
Self {
total_steps: 0,
learning_updates: 0,
episodes_completed: 0,
goals_achieved: 0,
goals_failed: 0,
anomalies_detected: 0,
current_epsilon: 0.1,
avg_reward: 0.0,
success_rate: 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedState {
pub config: HopeConfig,
pub stats: AgentStats,
pub current_state: Option<StateId>,
pub active_goal: Option<String>,
pub observation_history: Vec<Observation>,
pub action_history: Vec<Action>,
pub learning_state: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct Outcome {
pub action: Action,
pub result: ActionResult,
pub reward: f64,
pub new_observation: Observation,
pub done: bool,
}
impl Outcome {
pub fn new(
action: Action,
result: ActionResult,
reward: f64,
new_observation: Observation,
done: bool,
) -> Self {
Self {
action,
result,
reward,
new_observation,
done,
}
}
}
pub struct HopeAgent {
learning: LearningEngine,
goal_solver: HierarchicalGoalSolver,
predictive: PredictiveModel,
current_state: Option<StateId>,
active_goal: Option<String>,
observation_history: VecDeque<Observation>,
action_history: VecDeque<Action>,
config: HopeConfig,
stats: AgentStats,
episode_reward: f64,
episode_steps: u64,
available_actions: Vec<ActionId>,
}
impl HopeAgent {
pub fn new(config: HopeConfig) -> Self {
let learning = LearningEngine::new(config.learning.clone());
let goal_solver = HierarchicalGoalSolver::new();
let predictive = PredictiveModel::new(config.predictive.clone());
Self {
learning,
goal_solver,
predictive,
current_state: None,
active_goal: None,
observation_history: VecDeque::with_capacity(config.max_observations),
action_history: VecDeque::with_capacity(config.max_actions),
config,
stats: AgentStats::default(),
episode_reward: 0.0,
episode_steps: 0,
available_actions: Self::default_actions(),
}
}
pub fn with_default_config() -> Self {
Self::new(HopeConfig::default())
}
pub fn step(&mut self, observation: Observation) -> Action {
self.stats.total_steps += 1;
self.episode_steps += 1;
self.update_state(&observation);
if self.predictive.is_anomaly(&observation) {
self.stats.anomalies_detected += 1;
self.handle_anomaly(&observation);
}
self.update_goals(&observation);
let action = self.select_action(&observation);
self.record_action(&action);
action
}
pub fn learn(&mut self, outcome: Outcome) {
let prev_state = self.current_state.as_ref().unwrap();
let action_id = ActionId::from_action(&outcome.action);
let new_state = StateId::from_observation(&outcome.new_observation);
self.learning.update(
prev_state,
&action_id,
outcome.reward,
&new_state,
None,
&self.available_actions,
);
self.stats.learning_updates += 1;
self.episode_reward += outcome.reward;
self.predictive.record_transition(
self.observation_history.back().unwrap(),
&outcome.action,
outcome.reward,
&outcome.new_observation,
);
self.predictive.learn();
if let Some(goal_id) = &self.active_goal {
if outcome.result.success {
if let Some(goal) = self.goal_solver.get_goal_mut(goal_id) {
let progress_delta = (outcome.reward.max(0.0) * 0.1).min(0.5);
let new_progress = (goal.progress + progress_delta as f32).min(1.0);
goal.set_progress(new_progress);
if new_progress >= 1.0 {
self.goal_solver.mark_achieved(goal_id);
self.stats.goals_achieved += 1;
self.active_goal = None;
}
}
} else if outcome.reward < -5.0 {
self.goal_solver
.mark_failed(goal_id, "Action failed with penalty".to_string());
self.stats.goals_failed += 1;
self.active_goal = None;
}
}
let exp = crate::learning::Experience::new(
prev_state.clone(),
action_id,
outcome.reward,
new_state.clone(),
outcome.done,
);
self.learning.add_experience(exp);
if self.stats.total_steps % 10 == 0 {
self.learning.replay_batch(32, &self.available_actions);
}
self.current_state = Some(new_state);
self.record_observation(outcome.new_observation);
if outcome.done {
self.end_episode();
}
if self.config.mode == OperationMode::Adaptive {
self.adapt_mode();
}
}
pub fn set_goal(&mut self, goal: Goal) -> String {
let id = self.goal_solver.add_goal(goal);
if self.config.auto_decompose_goals {
let _ = self.goal_solver.decompose(&id);
}
if self.active_goal.is_none() {
self.goal_solver.activate_goal(&id);
self.active_goal = Some(id.clone());
}
id
}
pub fn get_statistics(&self) -> &AgentStats {
&self.stats
}
pub fn save_state(&self) -> SerializedState {
let learning_state = serde_json::to_vec(&self.learning).unwrap_or_default();
SerializedState {
config: self.config.clone(),
stats: self.stats.clone(),
current_state: self.current_state.clone(),
active_goal: self.active_goal.clone(),
observation_history: self.observation_history.iter().cloned().collect(),
action_history: self.action_history.iter().cloned().collect(),
learning_state,
}
}
pub fn load_state(&mut self, state: SerializedState) {
self.config = state.config;
self.stats = state.stats;
self.current_state = state.current_state;
self.active_goal = state.active_goal;
self.observation_history = state.observation_history.into();
self.action_history = state.action_history.into();
if let Ok(learning) = serde_json::from_slice(&state.learning_state) {
self.learning = learning;
}
}
pub fn mode(&self) -> OperationMode {
self.config.mode
}
pub fn set_mode(&mut self, mode: OperationMode) {
self.config.mode = mode;
match mode {
OperationMode::Exploration => {
self.learning.config_mut().epsilon = 0.5; }
OperationMode::Exploitation => {
self.learning.config_mut().epsilon = 0.01; }
OperationMode::GoalDriven => {
self.learning.config_mut().epsilon = 0.1; }
OperationMode::Adaptive => {
}
}
}
pub fn current_goal(&self) -> Option<&Goal> {
self.active_goal
.as_ref()
.and_then(|id| self.goal_solver.get_goal(id))
}
pub fn active_goals(&self) -> Vec<&Goal> {
self.goal_solver.get_executable_goals()
}
pub fn reset(&mut self) {
self.current_state = None;
self.episode_reward = 0.0;
self.episode_steps = 0;
}
fn update_state(&mut self, observation: &Observation) {
let state_id = StateId::from_observation(observation);
self.current_state = Some(state_id);
self.record_observation(observation.clone());
self.predictive.record(observation);
}
fn handle_anomaly(&mut self, observation: &Observation) {
log::warn!(
"Anomaly detected in observation: {:?}",
observation.obs_type
);
if self.config.mode == OperationMode::Adaptive {
let old_epsilon = self.learning.config().epsilon;
self.learning.config_mut().epsilon = (old_epsilon * 1.5).min(0.5);
}
}
fn update_goals(&mut self, _observation: &Observation) {
if self.active_goal.is_none() {
let executable = self.goal_solver.get_executable_goals();
if !executable.is_empty() {
let selected = self.select_goal(&executable);
if let Some(goal) = selected {
let goal_id = goal.id.clone();
self.active_goal = Some(goal_id.clone());
self.goal_solver.activate_goal(&goal_id);
}
}
}
let conflicts = self.goal_solver.detect_conflicts();
for conflict in conflicts {
self.goal_solver.resolve_conflict(
&conflict,
crate::hierarchical::ConflictResolution::PrioritizeFirst,
);
}
}
fn select_goal<'a>(&self, goals: &[&'a Goal]) -> Option<&'a Goal> {
if goals.is_empty() {
return None;
}
match self.config.goal_strategy {
GoalSelectionStrategy::Priority => goals.iter().max_by_key(|g| g.priority).copied(),
GoalSelectionStrategy::Deadline => goals
.iter()
.filter(|g| g.deadline.is_some())
.min_by_key(|g| g.deadline.unwrap())
.or_else(|| goals.first())
.copied(),
GoalSelectionStrategy::Progress => goals
.iter()
.max_by(|a, b| a.progress.partial_cmp(&b.progress).unwrap())
.copied(),
GoalSelectionStrategy::RoundRobin => {
let idx = (self.stats.total_steps as usize) % goals.len();
Some(goals[idx])
}
}
}
fn select_action(&mut self, observation: &Observation) -> Action {
let state_id = StateId::from_observation(observation);
let action_id = match self.config.mode {
OperationMode::Exploration => {
self.learning
.get_action_epsilon_greedy(&state_id, &self.available_actions)
}
OperationMode::Exploitation => {
self.learning
.get_best_action(&state_id, &self.available_actions)
}
OperationMode::GoalDriven | OperationMode::Adaptive => {
self.learning
.get_action_epsilon_greedy(&state_id, &self.available_actions)
}
};
if let Some(action_id) = action_id {
self.action_from_id(&action_id)
} else {
Action::noop()
}
}
fn action_from_id(&self, action_id: &ActionId) -> Action {
let action_str = action_id.as_str();
if action_str.contains("SendMessage") {
Action::new(ActionType::SendMessage("default".to_string()))
} else if action_str.contains("StoreData") {
Action::new(ActionType::StoreData("default".to_string()))
} else if action_str.contains("Alert") {
Action::new(ActionType::Alert("default".to_string()))
} else if action_str.contains("Wait") {
Action::new(ActionType::Wait)
} else if action_str.contains("NoOp") {
Action::new(ActionType::NoOp)
} else {
Action::new(ActionType::Custom(action_str.to_string()))
}
}
fn record_observation(&mut self, observation: Observation) {
if self.observation_history.len() >= self.config.max_observations {
self.observation_history.pop_front();
}
self.observation_history.push_back(observation);
}
fn record_action(&mut self, action: &Action) {
if self.action_history.len() >= self.config.max_actions {
self.action_history.pop_front();
}
self.action_history.push_back(action.clone());
}
fn end_episode(&mut self) {
self.stats.episodes_completed += 1;
self.learning.end_episode();
let total_episodes = self.stats.episodes_completed as f64;
self.stats.avg_reward =
(self.stats.avg_reward * (total_episodes - 1.0) + self.episode_reward) / total_episodes;
let total_goals = self.stats.goals_achieved + self.stats.goals_failed;
if total_goals > 0 {
self.stats.success_rate = self.stats.goals_achieved as f64 / total_goals as f64;
}
self.stats.current_epsilon = self.learning.epsilon();
self.episode_reward = 0.0;
self.episode_steps = 0;
}
fn adapt_mode(&mut self) {
let success_rate = self.stats.success_rate;
let epsilon = self.learning.epsilon();
if success_rate < 0.3 && epsilon < 0.2 {
self.learning.config_mut().epsilon = (epsilon * 1.1).min(0.5);
} else if success_rate > 0.8 && epsilon > 0.05 {
self.learning.config_mut().epsilon = (epsilon * 0.9).max(0.01);
}
}
fn default_actions() -> Vec<ActionId> {
vec![
ActionId::from_string("NoOp".to_string()),
ActionId::from_string("Wait".to_string()),
ActionId::from_string("SendMessage".to_string()),
ActionId::from_string("StoreData".to_string()),
ActionId::from_string("Alert".to_string()),
]
}
pub fn learning_engine(&self) -> &LearningEngine {
&self.learning
}
pub fn goal_solver(&self) -> &HierarchicalGoalSolver {
&self.goal_solver
}
pub fn predictive_model(&self) -> &PredictiveModel {
&self.predictive
}
pub fn observation_history(&self) -> &VecDeque<Observation> {
&self.observation_history
}
pub fn action_history(&self) -> &VecDeque<Action> {
&self.action_history
}
}
impl Default for HopeAgent {
fn default() -> Self {
Self::with_default_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Goal, GoalStatus, Observation, Priority};
#[test]
fn test_hope_agent_creation() {
let agent = HopeAgent::with_default_config();
assert_eq!(agent.stats.total_steps, 0);
assert_eq!(agent.mode(), OperationMode::Adaptive);
}
#[test]
fn test_step_and_learn_cycle() {
let mut agent = HopeAgent::with_default_config();
let obs1 = Observation::sensor("temperature", 20.0);
let action = agent.step(obs1.clone());
assert!(!action.is_noop() || action.is_noop());
let obs2 = Observation::sensor("temperature", 21.0);
let result = ActionResult::success(&action.id);
let outcome = Outcome::new(action, result, 1.0, obs2, false);
agent.learn(outcome);
assert_eq!(agent.stats.total_steps, 1);
assert_eq!(agent.stats.learning_updates, 1);
}
#[test]
fn test_goal_integration() {
let mut agent = HopeAgent::with_default_config();
let goal = Goal::maintain("temperature", 20.0..25.0).with_priority(Priority::High);
let goal_id = agent.set_goal(goal);
assert!(agent.goal_solver.get_goal(&goal_id).is_some());
assert_eq!(agent.active_goal, Some(goal_id));
}
#[test]
fn test_mode_switching() {
let mut agent = HopeAgent::with_default_config();
agent.set_mode(OperationMode::Exploration);
assert_eq!(agent.mode(), OperationMode::Exploration);
agent.set_mode(OperationMode::Exploitation);
assert_eq!(agent.mode(), OperationMode::Exploitation);
}
#[test]
fn test_anomaly_detection() {
let mut agent = HopeAgent::with_default_config();
for i in 0..10 {
let obs = Observation::sensor("temp", 20.0 + i as f64);
agent.step(obs);
}
let anomaly_obs = Observation::sensor("temp", 1000.0);
let initial_anomalies = agent.stats.anomalies_detected;
agent.step(anomaly_obs);
assert!(agent.stats.anomalies_detected >= initial_anomalies);
}
#[test]
fn test_statistics_tracking() {
let mut agent = HopeAgent::with_default_config();
let obs = Observation::sensor("temp", 20.0);
let action = agent.step(obs.clone());
let outcome = Outcome::new(
action,
ActionResult::success("test"),
5.0,
obs,
true, );
agent.learn(outcome);
let stats = agent.get_statistics();
assert_eq!(stats.total_steps, 1);
assert_eq!(stats.episodes_completed, 1);
assert!(stats.avg_reward > 0.0);
}
#[test]
fn test_serialization() {
let mut agent = HopeAgent::with_default_config();
let obs = Observation::sensor("temp", 20.0);
agent.step(obs);
let state = agent.save_state();
assert_eq!(state.stats.total_steps, 1);
let mut new_agent = HopeAgent::with_default_config();
new_agent.load_state(state);
assert_eq!(new_agent.stats.total_steps, 1);
}
#[test]
fn test_multiple_episodes() {
let mut agent = HopeAgent::with_default_config();
for episode in 0..3 {
for step in 0..5 {
let obs = Observation::sensor("temp", 20.0 + step as f64);
let action = agent.step(obs.clone());
let result = ActionResult::success(&action.id);
let reward = if step == 4 { 10.0 } else { 1.0 };
let done = step == 4;
let outcome = Outcome::new(action, result, reward, obs, done);
agent.learn(outcome);
}
if episode < 2 {
agent.reset();
}
}
assert_eq!(agent.stats.episodes_completed, 3);
assert!(agent.stats.avg_reward > 0.0);
}
#[test]
fn test_goal_completion() {
let mut agent = HopeAgent::with_default_config();
let goal = Goal::maintain("test", 20.0..25.0);
let goal_id = agent.set_goal(goal);
for _ in 0..10 {
let obs = Observation::sensor("test", 22.0);
let action = agent.step(obs.clone());
let outcome = Outcome::new(
action,
ActionResult::success("test"),
2.0, obs,
false,
);
agent.learn(outcome);
}
let goal_status = agent.goal_solver.get_goal(&goal_id).unwrap().status;
assert!(goal_status == GoalStatus::Achieved || goal_status == GoalStatus::Active);
}
#[test]
fn test_exploration_vs_exploitation() {
let mut agent = HopeAgent::with_default_config();
agent.set_mode(OperationMode::Exploration);
let epsilon_explore = agent.learning_engine().epsilon();
agent.set_mode(OperationMode::Exploitation);
let epsilon_exploit = agent.learning_engine().epsilon();
assert!(epsilon_explore > epsilon_exploit);
}
}