aip-sci 0.1.0

Affective Interaction Programming - 情感交互编程
Documentation
use crate::director::core::{
    DirectorError, InteractionParams, InteractionState, InteractionStrategy,
    InteractionStrategyTrainer, Trajectory, TrainingResult,
};
use candle_core::{Device, Tensor, DType};
use std::path::Path;
use std::collections::HashMap;

const STATE_DIM: usize = 23;
const ACTION_DIM: usize = 7;
const HIDDEN_DIM: usize = 128;

fn relu(x: &Tensor) -> candle_core::Result<Tensor> {
    x.maximum(&x.zeros_like()?)
}

fn sigmoid(x: &Tensor) -> candle_core::Result<Tensor> {
    (x.neg()?.exp()? + 1.0)?.recip()
}

pub struct RogueliteDirector {
    pub(crate) device: Device,
    pub(crate) actor_fc1_weight: Tensor,
    pub(crate) actor_fc1_bias: Tensor,
    pub(crate) actor_fc2_weight: Tensor,
    pub(crate) actor_fc2_bias: Tensor,
    pub(crate) actor_fc3_weight: Tensor,
    pub(crate) actor_fc3_bias: Tensor,
    pub(crate) critic_fc1_weight: Tensor,
    pub(crate) critic_fc1_bias: Tensor,
    pub(crate) critic_fc2_weight: Tensor,
    pub(crate) critic_fc2_bias: Tensor,
    pub(crate) critic_fc3_weight: Tensor,
    pub(crate) critic_fc3_bias: Tensor,
}

impl RogueliteDirector {
    pub fn new(device: Device) -> Result<Self, DirectorError> {
        let actor_fc1_weight = Tensor::zeros((HIDDEN_DIM, STATE_DIM), DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        let actor_fc1_bias = Tensor::zeros(HIDDEN_DIM, DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        
        let actor_fc2_weight = Tensor::zeros((HIDDEN_DIM, HIDDEN_DIM), DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        let actor_fc2_bias = Tensor::zeros(HIDDEN_DIM, DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        
        let actor_fc3_weight = Tensor::zeros((ACTION_DIM, HIDDEN_DIM), DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        let actor_fc3_bias = Tensor::zeros(ACTION_DIM, DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        
        let critic_fc1_weight = Tensor::zeros((HIDDEN_DIM, STATE_DIM), DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        let critic_fc1_bias = Tensor::zeros(HIDDEN_DIM, DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        
        let critic_fc2_weight = Tensor::zeros((HIDDEN_DIM, HIDDEN_DIM), DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        let critic_fc2_bias = Tensor::zeros(HIDDEN_DIM, DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        
        let critic_fc3_weight = Tensor::zeros((1, HIDDEN_DIM), DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        let critic_fc3_bias = Tensor::zeros(1, DType::F32, &device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        
        Ok(Self {
            device,
            actor_fc1_weight,
            actor_fc1_bias,
            actor_fc2_weight,
            actor_fc2_bias,
            actor_fc3_weight,
            actor_fc3_bias,
            critic_fc1_weight,
            critic_fc1_bias,
            critic_fc2_weight,
            critic_fc2_bias,
            critic_fc3_weight,
            critic_fc3_bias,
        })
    }
    
    fn encode_state(&self, state: &InteractionState) -> candle_core::Result<Tensor> {
        let mut state_vec = Vec::with_capacity(STATE_DIM);
        
        for i in 0..8 {
            state_vec.push(*state.user_traits.get(&(i as u32)).unwrap_or(&0.5));
        }
        
        for i in 0..6 {
            state_vec.push(*state.env_state.get(&(i as u32)).unwrap_or(&0.5));
        }
        
        #[cfg(feature = "edm")]
        {
            state_vec.push(state.emotion.valence);
            state_vec.push(state.emotion.arousal);
            state_vec.push(state.emotion.dominance);
        }
        
        #[cfg(not(feature = "edm"))]
        {
            state_vec.push(0.5);
            state_vec.push(0.5);
            state_vec.push(0.5);
        }
        
        state_vec.push(state.emotion_stats.valence_mean);
        state_vec.push(state.emotion_stats.valence_std);
        state_vec.push(state.emotion_stats.arousal_mean);
        state_vec.push(state.emotion_stats.arousal_std);
        state_vec.push(state.emotion_stats.dominance_mean);
        state_vec.push(state.emotion_stats.dominance_std);
        
        Tensor::from_vec(state_vec, (1, STATE_DIM), &self.device)
    }
    
    fn actor_forward(&self, state_tensor: &Tensor) -> candle_core::Result<Tensor> {
        let x = state_tensor.matmul(&self.actor_fc1_weight.t()?)?;
        let x = x.broadcast_add(&self.actor_fc1_bias)?;
        let x = relu(&x)?;
        
        let x = x.matmul(&self.actor_fc2_weight.t()?)?;
        let x = x.broadcast_add(&self.actor_fc2_bias)?;
        let x = relu(&x)?;
        
        let x = x.matmul(&self.actor_fc3_weight.t()?)?;
        let x = x.broadcast_add(&self.actor_fc3_bias)?;
        sigmoid(&x)
    }
    
    fn critic_forward(&self, state_tensor: &Tensor) -> candle_core::Result<Tensor> {
        let x = state_tensor.matmul(&self.critic_fc1_weight.t()?)?;
        let x = x.broadcast_add(&self.critic_fc1_bias)?;
        let x = relu(&x)?;
        
        let x = x.matmul(&self.critic_fc2_weight.t()?)?;
        let x = x.broadcast_add(&self.critic_fc2_bias)?;
        let x = relu(&x)?;
        
        let x = x.matmul(&self.critic_fc3_weight.t()?)?;
        x.broadcast_add(&self.critic_fc3_bias)
    }
    
    fn tensor_to_params(&self, tensor: &Tensor) -> candle_core::Result<InteractionParams> {
        let tensor = tensor.squeeze(0)?;
        let values: Vec<f32> = tensor.to_vec1()?;
        Ok(InteractionParams {
            intensity_factor: 0.5 + values[0] * 1.5,
            feedback_intensity: 0.3 + values[1] * 1.2,
            pace_speed: 0.6 + values[2] * 1.2,
            reward_scarcity: values[3],
            env_arousal: 0.3 + values[4] * 0.7,
            rhythm_modulation: 0.8 + values[5] * 0.7,
            challenge_curve: values[6] * 2.0 - 1.0,
        })
    }
}

impl InteractionStrategy for RogueliteDirector {
    fn decide(&self, state: &InteractionState) -> Result<InteractionParams, DirectorError> {
        let state_tensor = self.encode_state(state)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        
        let action_tensor = self.actor_forward(&state_tensor)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        
        let mut params = self.tensor_to_params(&action_tensor)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        params.clamp();
        Ok(params)
    }
    
    fn load(&mut self, path: &Path) -> Result<(), DirectorError> {
        let weights = candle_core::safetensors::load(path, &self.device)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        
        self.actor_fc1_weight = weights.get("actor_fc1_weight")
            .ok_or_else(|| DirectorError::ModelError("actor_fc1_weight not found".into()))?
            .clone();
        self.actor_fc1_bias = weights.get("actor_fc1_bias")
            .ok_or_else(|| DirectorError::ModelError("actor_fc1_bias not found".into()))?
            .clone();
        self.actor_fc2_weight = weights.get("actor_fc2_weight")
            .ok_or_else(|| DirectorError::ModelError("actor_fc2_weight not found".into()))?
            .clone();
        self.actor_fc2_bias = weights.get("actor_fc2_bias")
            .ok_or_else(|| DirectorError::ModelError("actor_fc2_bias not found".into()))?
            .clone();
        self.actor_fc3_weight = weights.get("actor_fc3_weight")
            .ok_or_else(|| DirectorError::ModelError("actor_fc3_weight not found".into()))?
            .clone();
        self.actor_fc3_bias = weights.get("actor_fc3_bias")
            .ok_or_else(|| DirectorError::ModelError("actor_fc3_bias not found".into()))?
            .clone();
        self.critic_fc1_weight = weights.get("critic_fc1_weight")
            .ok_or_else(|| DirectorError::ModelError("critic_fc1_weight not found".into()))?
            .clone();
        self.critic_fc1_bias = weights.get("critic_fc1_bias")
            .ok_or_else(|| DirectorError::ModelError("critic_fc1_bias not found".into()))?
            .clone();
        self.critic_fc2_weight = weights.get("critic_fc2_weight")
            .ok_or_else(|| DirectorError::ModelError("critic_fc2_weight not found".into()))?
            .clone();
        self.critic_fc2_bias = weights.get("critic_fc2_bias")
            .ok_or_else(|| DirectorError::ModelError("critic_fc2_bias not found".into()))?
            .clone();
        self.critic_fc3_weight = weights.get("critic_fc3_weight")
            .ok_or_else(|| DirectorError::ModelError("critic_fc3_weight not found".into()))?
            .clone();
        self.critic_fc3_bias = weights.get("critic_fc3_bias")
            .ok_or_else(|| DirectorError::ModelError("critic_fc3_bias not found".into()))?
            .clone();
        
        Ok(())
    }
}

impl InteractionStrategyTrainer for RogueliteDirector {
    fn train(&mut self, _trajectories: &[Trajectory]) -> Result<TrainingResult, DirectorError> {
        todo!("PPO training implementation")
    }
    
    fn save(&self, path: &Path) -> Result<(), DirectorError> {
        let weights = HashMap::from([
            ("actor_fc1_weight", self.actor_fc1_weight.clone()),
            ("actor_fc1_bias", self.actor_fc1_bias.clone()),
            ("actor_fc2_weight", self.actor_fc2_weight.clone()),
            ("actor_fc2_bias", self.actor_fc2_bias.clone()),
            ("actor_fc3_weight", self.actor_fc3_weight.clone()),
            ("actor_fc3_bias", self.actor_fc3_bias.clone()),
            ("critic_fc1_weight", self.critic_fc1_weight.clone()),
            ("critic_fc1_bias", self.critic_fc1_bias.clone()),
            ("critic_fc2_weight", self.critic_fc2_weight.clone()),
            ("critic_fc2_bias", self.critic_fc2_bias.clone()),
            ("critic_fc3_weight", self.critic_fc3_weight.clone()),
            ("critic_fc3_bias", self.critic_fc3_bias.clone()),
        ]);
        
        candle_core::safetensors::save(&weights, path)
            .map_err(|e| DirectorError::ModelError(e.to_string()))?;
        
        Ok(())
    }
    
    fn load(&mut self, path: &Path) -> Result<(), DirectorError> {
        InteractionStrategy::load(self, path)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap as StdHashMap;
    
    #[test]
    fn test_director_new() {
        let device = Device::Cpu;
        let model = RogueliteDirector::new(device);
        assert!(model.is_ok());
    }
    
    #[test]
    fn test_director_decide() {
        let device = Device::Cpu;
        let model = RogueliteDirector::new(device).unwrap();
        
        let state = InteractionState {
            user_traits: StdHashMap::from([
                (0, 0.5), (1, 0.5), (2, 0.5), (3, 0.5),
                (4, 0.5), (5, 0.5), (6, 0.5), (7, 0.5),
            ]),
            env_state: StdHashMap::from([
                (0, 0.5), (1, 0.5), (2, 0.5),
                (3, 0.5), (4, 0.5), (5, 0.5),
            ]),
            #[cfg(feature = "edm")]
            emotion: crate::edm::core::EmotionState::new(0.5, 0.5, 0.5),
            emotion_stats: crate::director::core::EmotionStats::default(),
        };
        
        let result = model.decide(&state);
        if let Err(ref e) = result {
            eprintln!("Error: {:?}", e);
        }
        assert!(result.is_ok());
        
        let params = result.unwrap();
        assert!(params.intensity_factor >= 0.5 && params.intensity_factor <= 2.0);
        assert!(params.feedback_intensity >= 0.3 && params.feedback_intensity <= 1.5);
        assert!(params.pace_speed >= 0.6 && params.pace_speed <= 1.8);
        assert!(params.reward_scarcity >= 0.0 && params.reward_scarcity <= 1.0);
        assert!(params.env_arousal >= 0.3 && params.env_arousal <= 1.0);
        assert!(params.rhythm_modulation >= 0.8 && params.rhythm_modulation <= 1.5);
        assert!(params.challenge_curve >= -1.0 && params.challenge_curve <= 1.0);
    }
    
    #[test]
    fn test_director_save_load() {
        let device = Device::Cpu;
        let model = RogueliteDirector::new(device.clone()).unwrap();
        
        let mut model_dir = std::env::current_dir().unwrap();
        model_dir.push("models");
        if !model_dir.exists() {
            std::fs::create_dir_all(&model_dir).unwrap();
        }
        let model_path = model_dir.join("test_director_model.safetensors");
        
        let save_result = model.save(&model_path);
        assert!(save_result.is_ok());
        
        let mut model2 = RogueliteDirector::new(device).unwrap();
        let load_result = InteractionStrategy::load(&mut model2, &model_path);
        assert!(load_result.is_ok());
    }
}