mecha10-behavior-patterns 0.1.23

Common behavior patterns for Mecha10 - subsumption, ensemble, and more
Documentation
//! Ensemble pattern for multi-model fusion
//!
//! This module implements ensemble learning patterns where multiple AI models
//! can be combined to make more robust decisions.

use async_trait::async_trait;
use mecha10_behavior_runtime::{BehaviorNode, BoxedBehavior, NodeStatus};
use mecha10_core::Context;
use serde::{Deserialize, Serialize};
use tracing::{debug, info};

/// A weighted model in an ensemble.
#[derive(Debug)]
pub struct WeightedModel {
    /// The behavior/model
    pub behavior: BoxedBehavior,
    /// Weight for voting (only used in WeightedVote strategy)
    pub weight: f32,
    /// Optional name for debugging
    pub name: Option<String>,
}

impl WeightedModel {
    /// Create a new weighted model.
    pub fn new(behavior: BoxedBehavior, weight: f32) -> Self {
        Self {
            behavior,
            weight,
            name: None,
        }
    }

    /// Set the name of this model.
    pub fn with_name(mut self, name: impl Into<String>) -> Self {
        self.name = Some(name.into());
        self
    }
}

/// Strategy for combining results from multiple models.
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EnsembleStrategy {
    /// Weighted voting: success if weighted sum of successes > threshold (default 0.5)
    WeightedVote,

    /// Conservative: all models must succeed
    Conservative,

    /// Optimistic: any model can succeed
    Optimistic,

    /// Majority: more than half must succeed
    Majority,
}

/// Ensemble node for multi-model fusion.
///
/// This node runs multiple behaviors/models in parallel and combines their results
/// using a configurable strategy. This is useful for:
/// - Combining multiple AI models for robust predictions
/// - Redundancy and fault tolerance
/// - Multi-modal sensor fusion
///
/// # Strategies
///
/// - **WeightedVote**: Models vote with weights, success if total > threshold
/// - **Conservative**: All models must succeed (high precision, low recall)
/// - **Optimistic**: Any model can succeed (high recall, low precision)
/// - **Majority**: More than half must succeed (balanced)
///
/// # Example
///
/// ```rust
/// use mecha10_behavior_patterns::prelude::*;
///
/// # #[derive(Debug)]
/// # struct YoloV8;
/// # #[async_trait]
/// # impl BehaviorNode for YoloV8 {
/// #     async fn tick(&mut self, _ctx: &Context) -> anyhow::Result<NodeStatus> {
/// #         Ok(NodeStatus::Success)
/// #     }
/// # }
/// # #[derive(Debug)]
/// # struct YoloV10;
/// # #[async_trait]
/// # impl BehaviorNode for YoloV10 {
/// #     async fn tick(&mut self, _ctx: &Context) -> anyhow::Result<NodeStatus> {
/// #         Ok(NodeStatus::Success)
/// #     }
/// # }
/// # #[derive(Debug)]
/// # struct CustomDetector;
/// # #[async_trait]
/// # impl BehaviorNode for CustomDetector {
/// #     async fn tick(&mut self, _ctx: &Context) -> anyhow::Result<NodeStatus> {
/// #         Ok(NodeStatus::Success)
/// #     }
/// # }
///
/// let ensemble = EnsembleNode::new(EnsembleStrategy::WeightedVote)
///     .add_model(Box::new(YoloV8), 0.4)
///     .add_model(Box::new(YoloV10), 0.3)
///     .add_model(Box::new(CustomDetector), 0.3);
/// ```
#[derive(Debug)]
pub struct EnsembleNode {
    models: Vec<WeightedModel>,
    strategy: EnsembleStrategy,
    threshold: f32,
    model_statuses: Vec<NodeStatus>,
}

impl EnsembleNode {
    /// Create a new ensemble node with the given strategy.
    pub fn new(strategy: EnsembleStrategy) -> Self {
        Self {
            models: Vec::new(),
            strategy,
            threshold: 0.5,
            model_statuses: Vec::new(),
        }
    }

    /// Add a model with a weight.
    pub fn add_model(mut self, behavior: BoxedBehavior, weight: f32) -> Self {
        self.models.push(WeightedModel::new(behavior, weight));
        self.model_statuses.push(NodeStatus::Running);
        self
    }

    /// Add a named model.
    pub fn add_named_model(mut self, name: impl Into<String>, behavior: BoxedBehavior, weight: f32) -> Self {
        self.models.push(WeightedModel::new(behavior, weight).with_name(name));
        self.model_statuses.push(NodeStatus::Running);
        self
    }

    /// Set the threshold for weighted voting (default 0.5).
    pub fn with_threshold(mut self, threshold: f32) -> Self {
        self.threshold = threshold;
        self
    }

    /// Get the number of models in the ensemble.
    pub fn model_count(&self) -> usize {
        self.models.len()
    }
}

#[async_trait]
impl BehaviorNode for EnsembleNode {
    async fn tick(&mut self, ctx: &Context) -> anyhow::Result<NodeStatus> {
        if self.models.is_empty() {
            return Ok(NodeStatus::Failure);
        }

        // Execute all models in parallel
        for (i, model) in self.models.iter_mut().enumerate() {
            if self.model_statuses[i].is_running() {
                let status = model.behavior.tick(ctx).await?;
                self.model_statuses[i] = status;

                let model_name = model.name.as_deref().unwrap_or("unnamed");
                debug!("Ensemble: model '{}' status: {}", model_name, status);
            }
        }

        // Combine results based on strategy
        let overall_status = match self.strategy {
            EnsembleStrategy::Conservative => self.conservative_combine(),
            EnsembleStrategy::Optimistic => self.optimistic_combine(),
            EnsembleStrategy::Majority => self.majority_combine(),
            EnsembleStrategy::WeightedVote => self.weighted_vote_combine(),
        };

        info!("Ensemble ({:?}): overall status = {}", self.strategy, overall_status);
        Ok(overall_status)
    }

    async fn reset(&mut self) -> anyhow::Result<()> {
        for (i, model) in self.models.iter_mut().enumerate() {
            model.behavior.reset().await?;
            self.model_statuses[i] = NodeStatus::Running;
        }
        Ok(())
    }

    async fn on_init(&mut self, ctx: &Context) -> anyhow::Result<()> {
        for model in &mut self.models {
            model.behavior.on_init(ctx).await?;
        }
        Ok(())
    }

    async fn on_terminate(&mut self, ctx: &Context) -> anyhow::Result<()> {
        for model in &mut self.models {
            model.behavior.on_terminate(ctx).await?;
        }
        Ok(())
    }

    fn name(&self) -> &str {
        "ensemble"
    }
}

impl EnsembleNode {
    fn conservative_combine(&self) -> NodeStatus {
        // All must succeed
        let failures = self.model_statuses.iter().filter(|s| s.is_failure()).count();
        let running = self.model_statuses.iter().filter(|s| s.is_running()).count();

        if failures > 0 {
            NodeStatus::Failure
        } else if running > 0 {
            NodeStatus::Running
        } else {
            NodeStatus::Success
        }
    }

    fn optimistic_combine(&self) -> NodeStatus {
        // Any can succeed
        let successes = self.model_statuses.iter().filter(|s| s.is_success()).count();
        let running = self.model_statuses.iter().filter(|s| s.is_running()).count();

        if successes > 0 {
            NodeStatus::Success
        } else if running > 0 {
            NodeStatus::Running
        } else {
            NodeStatus::Failure
        }
    }

    fn majority_combine(&self) -> NodeStatus {
        // More than half must succeed
        let successes = self.model_statuses.iter().filter(|s| s.is_success()).count();
        let failures = self.model_statuses.iter().filter(|s| s.is_failure()).count();
        let running = self.model_statuses.iter().filter(|s| s.is_running()).count();

        let _total_decided = successes + failures;
        let majority_threshold = (self.models.len() + 1) / 2;

        if successes >= majority_threshold {
            NodeStatus::Success
        } else if failures >= majority_threshold {
            NodeStatus::Failure
        } else if running > 0 {
            NodeStatus::Running
        } else {
            // Tie - default to failure
            NodeStatus::Failure
        }
    }

    fn weighted_vote_combine(&self) -> NodeStatus {
        // Weighted sum of successes must exceed threshold
        let mut success_weight = 0.0;
        let mut _failure_weight = 0.0;
        let mut total_weight = 0.0;
        let mut any_running = false;

        for (model, status) in self.models.iter().zip(&self.model_statuses) {
            total_weight += model.weight;
            match status {
                NodeStatus::Success => success_weight += model.weight,
                NodeStatus::Failure => _failure_weight += model.weight,
                NodeStatus::Running => any_running = true,
            }
        }

        if any_running {
            return NodeStatus::Running;
        }

        let success_ratio = if total_weight > 0.0 {
            success_weight / total_weight
        } else {
            0.0
        };

        if success_ratio >= self.threshold {
            NodeStatus::Success
        } else {
            NodeStatus::Failure
        }
    }
}

/// Configuration for an ensemble model (for JSON deserialization).
#[allow(dead_code)] // Will be used when implementing JSON loading
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnsembleModelConfig {
    /// Node type to instantiate
    pub node: String,
    /// Weight for this model
    pub weight: f32,
    /// Optional name for this model
    #[serde(skip_serializing_if = "Option::is_none")]
    pub name: Option<String>,
    /// Configuration for the node
    #[serde(default)]
    pub config: serde_json::Value,
}