use serde::{Deserialize, Serialize};
use super::belief::AffectValence;
use super::gate::GateType;
use super::world::GainMode;
#[derive(
Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum InterventionDepth {
#[default]
TextOnly,
LogitAccess,
ActivationAccess,
ArchitectureIntegration,
MultiModel,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CognitiveState {
pub arousal: f64,
pub valence: AffectValence,
pub certainty: f64,
pub sustained_arousal: f64,
pub gain_mode: GainMode,
pub body_budget: f64,
pub sensory_pe: f64,
pub resource_pressure: f64,
pub pe_volatility: f64,
pub gate_confidence: f64,
pub gate_type: GateType,
}
impl Default for CognitiveState {
fn default() -> Self {
Self {
arousal: 0.0,
valence: AffectValence::Neutral,
certainty: 0.5,
sustained_arousal: 0.0,
gain_mode: GainMode::Neutral,
body_budget: 1.0,
sensory_pe: 0.0,
resource_pressure: 0.0,
pe_volatility: 0.0,
gate_confidence: 0.5,
gate_type: GateType::Novel,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplingOverride {
pub temperature: f64,
pub top_p: f64,
pub frequency_penalty: f64,
pub presence_penalty: f64,
pub logit_biases: Vec<LogitBias>,
}
impl Default for SamplingOverride {
fn default() -> Self {
Self {
temperature: 0.5,
top_p: 0.9,
frequency_penalty: 0.0,
presence_penalty: 0.0,
logit_biases: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogitBias {
pub token_id: u32,
pub bias: f64,
pub source: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerTarget {
pub start_layer: usize,
pub end_layer: usize,
pub total_layers: usize,
}
impl Default for LayerTarget {
fn default() -> Self {
Self {
start_layer: 25,
end_layer: 38,
total_layers: 64,
}
}
}
impl LayerTarget {
pub fn contains(&self, layer: usize) -> bool {
layer >= self.start_layer && layer <= self.end_layer
}
pub fn modulated_count(&self) -> usize {
if self.end_layer >= self.start_layer {
self.end_layer - self.start_layer + 1
} else {
0
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DeltaModulationSource {
GainMode,
BodyBudget,
Volatility,
Arousal,
Combined,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeltaModulation {
pub gain_factor: f64,
pub target: LayerTarget,
pub source: DeltaModulationSource,
}
impl Default for DeltaModulation {
fn default() -> Self {
Self {
gain_factor: 1.0,
target: LayerTarget::default(),
source: DeltaModulationSource::GainMode,
}
}
}
#[derive(Debug, Clone)]
pub struct ForwardResult {
pub logits: Vec<f32>,
pub modulation_applied: bool,
pub modulated_layers: Vec<usize>,
pub applied_gain_factor: f64,
pub gate_delta_gain: Option<f64>,
pub gate_alpha: Option<f64>,
pub hs_stats: Option<HiddenStateStats>,
}
impl ForwardResult {
pub fn from_logits(logits: Vec<f32>) -> Self {
Self {
logits,
modulation_applied: false,
modulated_layers: Vec::new(),
applied_gain_factor: 1.0,
gate_delta_gain: None,
gate_alpha: None,
hs_stats: None,
}
}
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct HiddenStateStats {
pub state_churn: f64,
pub state_magnitude: f64,
pub valid: bool,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CognitiveSignals {
pub conservation: f64,
pub salience: f64,
pub confidence: f64,
pub strategy: Option<crate::types::world::ResponseStrategy>,
pub gain_mode: GainMode,
pub valence: AffectValence,
pub recent_quality: f64,
pub rpe: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cognitive_state_default_is_calm_neutral() {
let state = CognitiveState::default();
assert_eq!(state.arousal, 0.0);
assert_eq!(state.valence, AffectValence::Neutral);
assert_eq!(state.body_budget, 1.0);
assert_eq!(state.gain_mode, GainMode::Neutral);
assert_eq!(state.sensory_pe, 0.0);
assert_eq!(state.resource_pressure, 0.0);
}
#[test]
fn sampling_override_default_is_neutral() {
let s = SamplingOverride::default();
assert_eq!(s.temperature, 0.5);
assert_eq!(s.top_p, 0.9);
assert_eq!(s.frequency_penalty, 0.0);
assert_eq!(s.presence_penalty, 0.0);
assert!(s.logit_biases.is_empty());
}
#[test]
fn intervention_depth_ordering() {
assert!(InterventionDepth::TextOnly < InterventionDepth::LogitAccess);
assert!(InterventionDepth::LogitAccess < InterventionDepth::ActivationAccess);
assert!(InterventionDepth::ActivationAccess < InterventionDepth::ArchitectureIntegration);
assert!(InterventionDepth::ArchitectureIntegration < InterventionDepth::MultiModel);
}
#[test]
fn intervention_depth_default_is_text_only() {
assert_eq!(InterventionDepth::default(), InterventionDepth::TextOnly);
}
#[test]
fn logit_bias_construction() {
let bias = LogitBias {
token_id: 42,
bias: -2.5,
source: "emotional".into(),
};
assert_eq!(bias.token_id, 42);
assert_eq!(bias.bias, -2.5);
assert_eq!(bias.source, "emotional");
}
#[test]
fn delta_modulation_default_is_no_modulation() {
let dm = DeltaModulation::default();
assert_eq!(dm.gain_factor, 1.0, "Default should be pass-through");
assert_eq!(dm.source, DeltaModulationSource::GainMode);
}
#[test]
fn layer_target_default_targets_midrange() {
let target = LayerTarget::default();
assert_eq!(target.total_layers, 64, "Default is Falcon Mamba 64 layers");
assert!(target.start_layer >= 25, "Should target ~40% depth");
assert!(target.end_layer <= 40, "Should target ~60% depth");
}
#[test]
fn layer_target_contains() {
let target = LayerTarget {
start_layer: 10,
end_layer: 20,
total_layers: 64,
};
assert!(!target.contains(9));
assert!(target.contains(10));
assert!(target.contains(15));
assert!(target.contains(20));
assert!(!target.contains(21));
}
#[test]
fn layer_target_modulated_count() {
let target = LayerTarget {
start_layer: 10,
end_layer: 20,
total_layers: 64,
};
assert_eq!(target.modulated_count(), 11);
}
#[test]
fn forward_result_from_logits_is_unmodulated() {
let result = ForwardResult::from_logits(vec![1.0, 2.0, 3.0]);
assert!(!result.modulation_applied);
assert!(result.modulated_layers.is_empty());
assert_eq!(result.applied_gain_factor, 1.0);
assert_eq!(result.logits.len(), 3);
assert!(result.gate_delta_gain.is_none());
assert!(result.gate_alpha.is_none());
assert!(result.hs_stats.is_none());
}
#[test]
fn hidden_state_stats_default_is_invalid() {
let stats = HiddenStateStats::default();
assert_eq!(stats.state_churn, 0.0);
assert_eq!(stats.state_magnitude, 0.0);
assert!(!stats.valid, "Default stats should be invalid (no previous hs)");
}
#[test]
fn hidden_state_stats_serde_round_trip() {
let stats = HiddenStateStats {
state_churn: 0.73,
state_magnitude: 2.45,
valid: true,
};
let json = serde_json::to_string(&stats).unwrap();
let restored: HiddenStateStats = serde_json::from_str(&json).unwrap();
assert!((restored.state_churn - stats.state_churn).abs() < 1e-10);
assert!((restored.state_magnitude - stats.state_magnitude).abs() < 1e-10);
assert_eq!(restored.valid, stats.valid);
}
}