aip-sci 0.1.0

Affective Interaction Programming - 情感交互编程
Documentation
use crate::edm::core::{EdmError, EmotionDataModel, EmotionState, EmotionDataModelTrainer, TrainingDataset, TrainingResult};
use crate::edm::features::{FeatureExtractor, FEATURE_COUNT};
use candle_core::{Device, Tensor, DType};
use std::collections::HashMap;
use std::path::Path;

const HISTORY_STEPS: usize = 5;
const CONV_FILTERS: usize = 32;
const LSTM_HIDDEN: usize = 64;
const MLP_HIDDEN: usize = 64;
const SE_REDUCTION: usize = 8;

fn relu(x: &Tensor) -> candle_core::Result<Tensor> {
    x.maximum(&x.zeros_like()?)
}

fn sigmoid(x: &Tensor) -> candle_core::Result<Tensor> {
    (x.neg()?.exp()? + 1.0)?.recip()
}

fn softmax(x: &Tensor, dim: usize) -> candle_core::Result<Tensor> {
    let max = x.max_keepdim(dim)?;
    let exp = (x - &max)?.exp()?;
    let sum = exp.sum_keepdim(dim)?;
    exp.div(&sum)
}

pub struct RogueliteEdm {
    pub(crate) device: Device,
    pub(crate) conv1_weight: Tensor,
    pub(crate) conv1_bias: Tensor,
    pub(crate) se_fc1_weight: Tensor,
    pub(crate) se_fc1_bias: Tensor,
    pub(crate) se_fc2_weight: Tensor,
    pub(crate) se_fc2_bias: Tensor,
    pub(crate) lstm_weight_ih: Tensor,
    pub(crate) lstm_weight_hh: Tensor,
    pub(crate) lstm_bias: Tensor,
    pub(crate) attn_query_weight: Tensor,
    pub(crate) attn_key_weight: Tensor,
    pub(crate) attn_value_weight: Tensor,
    pub(crate) fc1_weight: Tensor,
    pub(crate) fc1_bias: Tensor,
    pub(crate) fc2_weight: Tensor,
    pub(crate) fc2_bias: Tensor,
}

impl RogueliteEdm {
    pub fn new(device: Device) -> Result<Self, EdmError> {
        let conv1_weight = Tensor::zeros((CONV_FILTERS, FEATURE_COUNT, 3), DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        let conv1_bias = Tensor::zeros(CONV_FILTERS, DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        let se_fc1_weight = Tensor::zeros((CONV_FILTERS / SE_REDUCTION, CONV_FILTERS), DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        let se_fc1_bias = Tensor::zeros(CONV_FILTERS / SE_REDUCTION, DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        let se_fc2_weight = Tensor::zeros((CONV_FILTERS, CONV_FILTERS / SE_REDUCTION), DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        let se_fc2_bias = Tensor::zeros(CONV_FILTERS, DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        let lstm_weight_ih = Tensor::zeros((4 * LSTM_HIDDEN, CONV_FILTERS), DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        let lstm_weight_hh = Tensor::zeros((4 * LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        let lstm_bias = Tensor::zeros(4 * LSTM_HIDDEN, DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        let attn_query_weight = Tensor::zeros((LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        let attn_key_weight = Tensor::zeros((LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        let attn_value_weight = Tensor::zeros((LSTM_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        let fc1_weight = Tensor::zeros((MLP_HIDDEN, LSTM_HIDDEN), DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        let fc1_bias = Tensor::zeros(MLP_HIDDEN, DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        let fc2_weight = Tensor::zeros((3, MLP_HIDDEN), DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        let fc2_bias = Tensor::zeros(3, DType::F32, &device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        Ok(Self {
            device,
            conv1_weight,
            conv1_bias,
            se_fc1_weight,
            se_fc1_bias,
            se_fc2_weight,
            se_fc2_bias,
            lstm_weight_ih,
            lstm_weight_hh,
            lstm_bias,
            attn_query_weight,
            attn_key_weight,
            attn_value_weight,
            fc1_weight,
            fc1_bias,
            fc2_weight,
            fc2_bias,
        })
    }
    
    fn conv1d_forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
        let (batch, in_ch, seq_len) = input.dims3()?;
        let kernel_size = 3;
        
        let input_padded = input.pad_with_zeros(2, 1, 1)?;
        
        let mut unfolded = Vec::new();
        for i in 0..seq_len {
            let window = input_padded.narrow(2, i, kernel_size)?;
            unfolded.push(window);
        }
        let unfolded = Tensor::stack(&unfolded, 2)?;
        
        let unfolded = unfolded.reshape((batch * seq_len, in_ch * kernel_size))?;
        
        let weight = self.conv1_weight.reshape((CONV_FILTERS, in_ch * kernel_size))?;
        let weight_t = weight.t()?;
        
        let output = unfolded.matmul(&weight_t)?;
        
        let output = output.reshape((batch, seq_len, CONV_FILTERS))?;
        
        let output = output.permute((0, 2, 1))?;
        
        let bias = self.conv1_bias.reshape((1, CONV_FILTERS, 1))?;
        output.broadcast_add(&bias)
    }
    
    fn se_attention(&self, input: &Tensor) -> candle_core::Result<Tensor> {
        let (batch, channels, _seq_len) = input.dims3()?;
        
        let squeeze = input.mean(2)?;
        
        let excitation = squeeze.matmul(&self.se_fc1_weight.t()?)?;
        let excitation = excitation.broadcast_add(&self.se_fc1_bias)?;
        let excitation = relu(&excitation)?;
        
        let excitation = excitation.matmul(&self.se_fc2_weight.t()?)?;
        let excitation = excitation.broadcast_add(&self.se_fc2_bias)?;
        let scale = sigmoid(&excitation)?;
        
        let scale = scale.reshape((batch, channels, 1))?;
        input.broadcast_mul(&scale)
    }
    
    fn lstm_forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
        let (seq_len, batch, _) = input.dims3()?;
        let mut h = Tensor::zeros((batch, LSTM_HIDDEN), DType::F32, &self.device)?;
        let mut c = Tensor::zeros((batch, LSTM_HIDDEN), DType::F32, &self.device)?;
        
        let mut outputs = Vec::new();
        
        for t in 0..seq_len {
            let x_t = input.narrow(0, t, 1)?.squeeze(0)?;
            
            let gates = x_t.matmul(&self.lstm_weight_ih.t()?)?
                .broadcast_add(&h.matmul(&self.lstm_weight_hh.t()?)?)?
                .broadcast_add(&self.lstm_bias.reshape((1, 4 * LSTM_HIDDEN))?)?;
            
            let gates = gates.chunk(4, 1)?;
            let i = sigmoid(&gates[0])?;
            let f = sigmoid(&gates[1])?;
            let g = gates[2].tanh()?;
            let o = sigmoid(&gates[3])?;
            
            c = ((&f * &c)? + (&i * &g)?)?;
            h = (&o * &c.tanh()?)?;
            
            outputs.push(h.clone());
        }
        
        Tensor::stack(&outputs, 0)
    }
    
    fn self_attention(&self, input: &Tensor) -> candle_core::Result<Tensor> {
        let (seq_len, batch, hidden) = input.dims3()?;
        
        let input_flat = input.reshape((batch * seq_len, hidden))?;
        
        let q = input_flat.matmul(&self.attn_query_weight)?;
        let k = input_flat.matmul(&self.attn_key_weight)?;
        let v = input_flat.matmul(&self.attn_value_weight)?;
        
        let q = q.reshape((batch, seq_len, hidden))?;
        let k = k.reshape((batch, seq_len, hidden))?;
        let v = v.reshape((batch, seq_len, hidden))?;
        
        let scale = 1.0 / (hidden as f64).sqrt();
        let scores = q.matmul(&k.transpose(1, 2)?)?;
        let scores = (scores * scale)?;
        let attn_weights = softmax(&scores, 2)?;
        
        let output = attn_weights.matmul(&v)?;
        output.reshape((seq_len, batch, hidden))
    }
    
    fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
        let x = self.conv1d_forward(input)?;
        let x = relu(&x)?;
        
        let x = self.se_attention(&x)?;
        
        let x = x.permute((2, 0, 1))?;
        
        let lstm_out = self.lstm_forward(&x)?;
        
        let attn_out = self.self_attention(&lstm_out)?;
        
        let hidden = attn_out.mean(0)?;
        
        let x = hidden.matmul(&self.fc1_weight.t()?)?;
        let x = x.broadcast_add(&self.fc1_bias)?;
        let x = relu(&x)?;
        
        let x = x.matmul(&self.fc2_weight.t()?)?;
        let x = x.broadcast_add(&self.fc2_bias)?;
        sigmoid(&x)
    }
}

impl EmotionDataModel for RogueliteEdm {
    fn infer(&self, features: &HashMap<u32, f32>) -> Result<EmotionState, EdmError> {
        let feature_vec = FeatureExtractor::to_vector(features);
        let input = Tensor::from_vec(feature_vec, (1, FEATURE_COUNT, 1), &self.device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        let output = self.forward(&input)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        let output = output.squeeze(0)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        let values: Vec<f32> = output.to_vec1()
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        Ok(EmotionState::new(values[0], values[1], values[2]))
    }
    
    fn load(&mut self, path: &Path) -> Result<(), EdmError> {
        let weights = candle_core::safetensors::load(path, &self.device)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        self.conv1_weight = weights.get("conv1_weight")
            .ok_or_else(|| EdmError::ModelError("conv1_weight not found".into()))?
            .clone();
        self.conv1_bias = weights.get("conv1_bias")
            .ok_or_else(|| EdmError::ModelError("conv1_bias not found".into()))?
            .clone();
        self.se_fc1_weight = weights.get("se_fc1_weight")
            .ok_or_else(|| EdmError::ModelError("se_fc1_weight not found".into()))?
            .clone();
        self.se_fc1_bias = weights.get("se_fc1_bias")
            .ok_or_else(|| EdmError::ModelError("se_fc1_bias not found".into()))?
            .clone();
        self.se_fc2_weight = weights.get("se_fc2_weight")
            .ok_or_else(|| EdmError::ModelError("se_fc2_weight not found".into()))?
            .clone();
        self.se_fc2_bias = weights.get("se_fc2_bias")
            .ok_or_else(|| EdmError::ModelError("se_fc2_bias not found".into()))?
            .clone();
        self.lstm_weight_ih = weights.get("lstm_weight_ih")
            .ok_or_else(|| EdmError::ModelError("lstm_weight_ih not found".into()))?
            .clone();
        self.lstm_weight_hh = weights.get("lstm_weight_hh")
            .ok_or_else(|| EdmError::ModelError("lstm_weight_hh not found".into()))?
            .clone();
        self.lstm_bias = weights.get("lstm_bias")
            .ok_or_else(|| EdmError::ModelError("lstm_bias not found".into()))?
            .clone();
        self.attn_query_weight = weights.get("attn_query_weight")
            .ok_or_else(|| EdmError::ModelError("attn_query_weight not found".into()))?
            .clone();
        self.attn_key_weight = weights.get("attn_key_weight")
            .ok_or_else(|| EdmError::ModelError("attn_key_weight not found".into()))?
            .clone();
        self.attn_value_weight = weights.get("attn_value_weight")
            .ok_or_else(|| EdmError::ModelError("attn_value_weight not found".into()))?
            .clone();
        self.fc1_weight = weights.get("fc1_weight")
            .ok_or_else(|| EdmError::ModelError("fc1_weight not found".into()))?
            .clone();
        self.fc1_bias = weights.get("fc1_bias")
            .ok_or_else(|| EdmError::ModelError("fc1_bias not found".into()))?
            .clone();
        self.fc2_weight = weights.get("fc2_weight")
            .ok_or_else(|| EdmError::ModelError("fc2_weight not found".into()))?
            .clone();
        self.fc2_bias = weights.get("fc2_bias")
            .ok_or_else(|| EdmError::ModelError("fc2_bias not found".into()))?
            .clone();
        
        Ok(())
    }
}

impl EmotionDataModelTrainer for RogueliteEdm {
    fn train(&mut self, _dataset: &TrainingDataset) -> Result<TrainingResult, EdmError> {
        todo!("Training implementation")
    }
    
    fn save(&self, path: &Path) -> Result<(), EdmError> {
        let weights = HashMap::from([
            ("conv1_weight", self.conv1_weight.clone()),
            ("conv1_bias", self.conv1_bias.clone()),
            ("se_fc1_weight", self.se_fc1_weight.clone()),
            ("se_fc1_bias", self.se_fc1_bias.clone()),
            ("se_fc2_weight", self.se_fc2_weight.clone()),
            ("se_fc2_bias", self.se_fc2_bias.clone()),
            ("lstm_weight_ih", self.lstm_weight_ih.clone()),
            ("lstm_weight_hh", self.lstm_weight_hh.clone()),
            ("lstm_bias", self.lstm_bias.clone()),
            ("attn_query_weight", self.attn_query_weight.clone()),
            ("attn_key_weight", self.attn_key_weight.clone()),
            ("attn_value_weight", self.attn_value_weight.clone()),
            ("fc1_weight", self.fc1_weight.clone()),
            ("fc1_bias", self.fc1_bias.clone()),
            ("fc2_weight", self.fc2_weight.clone()),
            ("fc2_bias", self.fc2_bias.clone()),
        ]);
        
        candle_core::safetensors::save(&weights, path)
            .map_err(|e| EdmError::ModelError(e.to_string()))?;
        
        Ok(())
    }
    
    fn load(&mut self, path: &Path) -> Result<(), EdmError> {
        EmotionDataModel::load(self, path)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    
    #[test]
    fn test_edm_new() {
        let device = Device::Cpu;
        let model = RogueliteEdm::new(device);
        assert!(model.is_ok());
    }
    
    #[test]
    fn test_edm_infer() {
        let device = Device::Cpu;
        let model = RogueliteEdm::new(device).unwrap();
        
        let mut features = HashMap::new();
        for i in 0u32..15 {
            features.insert(i, 0.5);
        }
        
        let result = model.infer(&features);
        if let Err(ref e) = result {
            eprintln!("Error: {:?}", e);
        }
        assert!(result.is_ok());
        
        let state = result.unwrap();
        assert!(state.valence >= 0.0 && state.valence <= 1.0);
        assert!(state.arousal >= 0.0 && state.arousal <= 1.0);
        assert!(state.dominance >= 0.0 && state.dominance <= 1.0);
    }
    
    #[test]
    fn test_edm_save_load() {
        let device = Device::Cpu;
        let model = RogueliteEdm::new(device.clone()).unwrap();
        
        let mut model_dir = std::env::current_dir().unwrap();
        model_dir.push("models");
        if !model_dir.exists() {
            std::fs::create_dir_all(&model_dir).unwrap();
        }
        let model_path = model_dir.join("test_edm_model.safetensors");
        
        let save_result = model.save(&model_path);
        assert!(save_result.is_ok());
        
        let mut model2 = RogueliteEdm::new(device).unwrap();
        let load_result = EmotionDataModel::load(&mut model2, &model_path);
        assert!(load_result.is_ok());
    }
}