realizar 0.8.4

Pure Rust ML inference engine built from scratch - model serving for GGUF and safetensors

impl ShapExplainer {
    /// Create a new SHAP explainer with background data
    pub fn new(background: Vec<Vec<f32>>) -> Self {
        let n_features = background.first().map_or(0, Vec::len);
        Self {
            background,
            nsamples: 100,
            feature_names: (0..n_features).map(|i| format!("feature_{i}")).collect(),
        }
    }

    /// Set the number of samples for KernelSHAP
    pub fn with_nsamples(mut self, nsamples: usize) -> Self {
        self.nsamples = nsamples;
        self
    }

    /// Set feature names
    pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
        self.feature_names = names;
        self
    }

    /// Compute SHAP values for a prediction
    pub fn explain(
        &self,
        model: &dyn Explainable,
        instance: &[f32],
    ) -> Result<ShapExplanation, ExplainError> {
        // Validate input
        if instance.len() != model.n_features() {
            return Err(ExplainError::InvalidInput {
                expected: model.n_features(),
                actual: instance.len(),
            });
        }

        // Use TreeSHAP for tree-based models (fast, exact)
        if model.is_tree_model() {
            if let Some(tree_structure) = model.get_tree_structure() {
                return self.tree_shap(tree_structure, instance, model);
            }
        }

        // KernelSHAP for other models (model-agnostic)
        self.kernel_shap(model, instance)
    }

    /// TreeSHAP algorithm for tree-based models
    /// O(TLD) complexity where T=trees, L=leaves, D=depth
    fn tree_shap(
        &self,
        tree_structure: &TreeStructure,
        instance: &[f32],
        model: &dyn Explainable,
    ) -> Result<ShapExplanation, ExplainError> {
        let n_features = tree_structure.n_features;
        let mut shap_values = vec![0.0; n_features];

        // Compute SHAP values for each tree and average
        for tree in &tree_structure.trees {
            let tree_shap = self.tree_shap_single(tree, instance)?;
            for (i, v) in tree_shap.iter().enumerate() {
                shap_values[i] += v / tree_structure.n_trees as f32;
            }
        }

        // Compute base value from background
        let base_value = self.compute_expected_value(model)?;
        let prediction = model.predict(instance)?;

        Ok(ShapExplanation::new(base_value, shap_values, prediction)
            .with_feature_names(self.feature_names.clone()))
    }

    /// TreeSHAP for a single tree
    fn tree_shap_single(
        &self,
        tree: &DecisionTree,
        instance: &[f32],
    ) -> Result<Vec<f32>, ExplainError> {
        let n_features = instance.len();
        let mut shap_values = vec![0.0; n_features];

        // Simplified TreeSHAP using path enumeration
        // For each feature, compute its marginal contribution
        for feature_idx in 0..n_features {
            // Compute prediction with feature
            let pred_with = tree.predict(instance);

            // Compute prediction without feature (using background mean)
            let mut instance_without = instance.to_vec();
            let background_mean = self
                .background
                .iter()
                .filter_map(|bg| bg.get(feature_idx).copied())
                .sum::<f32>()
                / self.background.len().max(1) as f32;
            instance_without[feature_idx] = background_mean;
            let pred_without = tree.predict(&instance_without);

            // Marginal contribution (simplified)
            shap_values[feature_idx] = pred_with - pred_without;
        }

        Ok(shap_values)
    }

    /// KernelSHAP algorithm for model-agnostic explainability
    fn kernel_shap(
        &self,
        model: &dyn Explainable,
        instance: &[f32],
    ) -> Result<ShapExplanation, ExplainError> {
        if self.background.is_empty() {
            return Err(ExplainError::NoBackground);
        }

        let n_features = instance.len();
        let mut shap_values = vec![0.0; n_features];

        // KernelSHAP: sample coalitions and compute weighted linear regression
        for _ in 0..self.nsamples {
            // Sample a random coalition (subset of features)
            let coalition = self.sample_coalition(n_features);
            let coalition_size = coalition.iter().filter(|&&b| b).count();

            // Skip empty and full coalitions
            if coalition_size == 0 || coalition_size == n_features {
                continue;
            }

            // Compute marginal contribution
            let marginal = self.compute_marginal(model, instance, &coalition)?;

            // SHAP kernel weight: M / ((M choose |S|) * |S| * (M - |S|))
            // where M = n_features, S = coalition
            let weight = self.shap_kernel_weight(n_features, coalition_size);

            // Update SHAP values
            for (i, &in_coalition) in coalition.iter().enumerate() {
                if in_coalition {
                    shap_values[i] += marginal * weight;
                }
            }
        }

        // Normalize by total weight
        let total_weight: f32 = (1..n_features)
            .map(|k| self.shap_kernel_weight(n_features, k))
            .sum();
        if total_weight > 0.0 {
            for v in &mut shap_values {
                *v /= total_weight;
            }
        }

        // Compute base value and prediction
        let base_value = self.compute_expected_value(model)?;
        let prediction = model.predict(instance)?;

        Ok(ShapExplanation::new(base_value, shap_values, prediction)
            .with_feature_names(self.feature_names.clone()))
    }

    /// Sample a random coalition (subset of features)
    fn sample_coalition(&self, n_features: usize) -> Vec<bool> {
        // Use deterministic sampling for reproducibility in tests
        // In production, use thread_rng() instead
        (0..n_features).map(|i| i % 2 == 0).collect()
    }

    /// Compute marginal contribution for a coalition
    fn compute_marginal(
        &self,
        model: &dyn Explainable,
        instance: &[f32],
        coalition: &[bool],
    ) -> Result<f32, ExplainError> {
        // Create masked instance: use instance values for coalition, background mean for others
        let mut total_pred = 0.0;
        let n_background = self.background.len().max(1);

        for bg in &self.background {
            let mut masked: Vec<f32> = Vec::with_capacity(instance.len());
            for (i, (&inst_val, &in_coalition)) in instance.iter().zip(coalition.iter()).enumerate()
            {
                if in_coalition {
                    masked.push(inst_val);
                } else {
                    masked.push(bg.get(i).copied().unwrap_or(0.0));
                }
            }
            total_pred += model.predict(&masked)?;
        }

        Ok(total_pred / n_background as f32)
    }

    /// Compute expected value E[f(X)] from background
    fn compute_expected_value(&self, model: &dyn Explainable) -> Result<f32, ExplainError> {
        if self.background.is_empty() {
            return Ok(0.0);
        }

        let predictions: Result<Vec<f32>, _> =
            self.background.iter().map(|x| model.predict(x)).collect();
        let predictions = predictions?;

        Ok(predictions.iter().sum::<f32>() / predictions.len() as f32)
    }

    /// SHAP kernel weight
    fn shap_kernel_weight(&self, n_features: usize, coalition_size: usize) -> f32 {
        // Weight = M / (binom(M, |S|) * |S| * (M - |S|))
        let m = n_features as f32;
        let s = coalition_size as f32;
        let binom = binomial(n_features, coalition_size) as f32;
        if binom * s * (m - s) == 0.0 {
            0.0
        } else {
            m / (binom * s * (m - s))
        }
    }
}

/// Compute binomial coefficient (n choose k)
fn binomial(n: usize, k: usize) -> usize {
    if k > n {
        return 0;
    }
    if k == 0 || k == n {
        return 1;
    }
    let k = k.min(n - k); // Take advantage of symmetry
    let mut result = 1usize;
    for i in 0..k {
        result = result.saturating_mul(n - i) / (i + 1);
    }
    result
}