use crate::cognition::knowledge::KnowledgeIndex;
use crate::cognition::learning::config::LearningConfig;
use crate::cognition::learning::distill::KnowledgeDistiller;
use crate::cognition::learning::elastic::ElasticMemoryGuard;
use crate::cognition::learning::federated::FederatedAggregator;
use crate::cognition::learning::informal::InformalLearner;
use crate::cognition::learning::meta::MetaAdapter;
use crate::cognition::learning::q_table::{ActionKey, QTable};
use crate::cognition::memory::ColdStore;
use crate::cognition::signal::CognitionSignal;
use crate::types::{AgentSnapshot, ExperienceRecord, LearningMode};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningEngine {
config: LearningConfig,
q_table: QTable,
meta: MetaAdapter,
distiller: KnowledgeDistiller,
aggregator: FederatedAggregator,
elastic: ElasticMemoryGuard,
informal: InformalLearner,
total_reward: f64,
episode_count: usize,
experience_buffer: Vec<ExperienceRecord>,
}
impl LearningEngine {
pub fn new(config: LearningConfig) -> Self {
let q_table = QTable::new(
config.alpha,
config.gamma,
config.epsilon,
config.epsilon_decay,
config.epsilon_min,
);
let meta = MetaAdapter::new(config.meta_top_k);
let distiller = KnowledgeDistiller::new(config.distill_threshold, config.distill_top_k);
let aggregator = FederatedAggregator::new(config.federated_blend);
let elastic = ElasticMemoryGuard::new(config.elastic_pin_count, config.elastic_lambda);
let informal = InformalLearner::new(config.distill_threshold, config.pmi_min_count, 0.5);
Self {
config,
q_table,
meta,
distiller,
aggregator,
elastic,
informal,
total_reward: 0.0,
episode_count: 0,
experience_buffer: Vec::new(),
}
}
pub fn config(&self) -> &LearningConfig {
&self.config
}
pub fn q_table(&self) -> &QTable {
&self.q_table
}
pub fn reset_epsilon(&mut self, epsilon: f64) {
self.q_table.reset_epsilon(epsilon);
}
pub fn episode_count(&self) -> usize {
self.episode_count
}
pub fn total_reward(&self) -> f64 {
self.total_reward
}
pub fn record_step(
&mut self,
signal: &CognitionSignal,
state: u64,
action: ActionKey,
next_state: u64,
) {
if self.config.is_mode_active(LearningMode::QTable) {
self.q_table
.update(state, action, signal.reward, next_state);
}
if self.config.is_mode_active(LearningMode::Informal) {
self.informal.observe(&signal.observation, signal.reward);
}
if self.config.is_mode_active(LearningMode::Elastic) && !signal.observation.is_empty() {
self.elastic.observe_activation(&signal.observation);
}
self.experience_buffer.push(ExperienceRecord {
state,
action,
reward: signal.reward,
next_state,
});
self.total_reward += signal.reward;
}
pub fn recommend_action(&mut self, state: u64, goal: &str, step: usize) -> ActionKey {
if self.config.is_mode_active(LearningMode::MetaAdapt) {
let offsets = self.meta.adapt(goal);
if let Some((&best_action, _)) = offsets
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal))
{
let q_best = self.q_table.best_action(state);
if q_best.is_none() {
return best_action;
}
}
}
self.q_table.select_action(state, step)
}
pub fn end_of_episode(
&mut self,
cold: &ColdStore,
index: &mut KnowledgeIndex,
goal: &str,
avg_reward: f64,
) {
if self.config.is_mode_active(LearningMode::Distill) {
self.distiller.distill(cold, index);
}
if self.config.is_mode_active(LearningMode::MetaAdapt) {
let offsets = self.compute_episode_offsets();
self.meta.record_episode(goal, offsets, avg_reward);
}
if self.config.is_mode_active(LearningMode::Informal) {
self.informal.synthesise_into(index, 5, 0.3);
}
if self.config.is_mode_active(LearningMode::QTable) {
self.q_table.decay_epsilon();
}
self.experience_buffer.clear();
self.episode_count += 1;
}
pub fn federate(&mut self, snapshot: &AgentSnapshot) {
if self.config.is_mode_active(LearningMode::Federated) {
self.aggregator.merge(&mut self.q_table, snapshot);
}
}
pub fn export_snapshot(&self, agent_id: impl Into<String>) -> AgentSnapshot {
AgentSnapshot {
agent_id: agent_id.into(),
q_table: self.q_table.clone(),
total_reward: self.total_reward,
}
}
pub fn aggregator(&self) -> &FederatedAggregator {
&self.aggregator
}
pub fn elastic(&self) -> &ElasticMemoryGuard {
&self.elastic
}
pub fn informal(&self) -> &InformalLearner {
&self.informal
}
pub fn meta(&self) -> &MetaAdapter {
&self.meta
}
pub fn distiller(&self) -> &KnowledgeDistiller {
&self.distiller
}
fn compute_episode_offsets(&self) -> HashMap<ActionKey, f64> {
let mut totals: HashMap<ActionKey, (f64, usize)> = HashMap::new();
for xp in &self.experience_buffer {
let e = totals.entry(xp.action).or_insert((0.0, 0));
e.0 += xp.reward;
e.1 += 1;
}
totals
.into_iter()
.map(|(action, (total, count))| (action, total / count as f64))
.collect()
}
}