mecha10_behavior_patterns/
ensemble.rs

1//! Ensemble pattern for multi-model fusion
2//!
3//! This module implements ensemble learning patterns where multiple AI models
4//! can be combined to make more robust decisions.
5
6use async_trait::async_trait;
7use mecha10_behavior_runtime::{BehaviorNode, BoxedBehavior, NodeStatus};
8use mecha10_core::Context;
9use serde::{Deserialize, Serialize};
10use tracing::{debug, info};
11
12/// A weighted model in an ensemble.
13#[derive(Debug)]
14pub struct WeightedModel {
15    /// The behavior/model
16    pub behavior: BoxedBehavior,
17    /// Weight for voting (only used in WeightedVote strategy)
18    pub weight: f32,
19    /// Optional name for debugging
20    pub name: Option<String>,
21}
22
23impl WeightedModel {
24    /// Create a new weighted model.
25    pub fn new(behavior: BoxedBehavior, weight: f32) -> Self {
26        Self {
27            behavior,
28            weight,
29            name: None,
30        }
31    }
32
33    /// Set the name of this model.
34    pub fn with_name(mut self, name: impl Into<String>) -> Self {
35        self.name = Some(name.into());
36        self
37    }
38}
39
40/// Strategy for combining results from multiple models.
41#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum EnsembleStrategy {
44    /// Weighted voting: success if weighted sum of successes > threshold (default 0.5)
45    WeightedVote,
46
47    /// Conservative: all models must succeed
48    Conservative,
49
50    /// Optimistic: any model can succeed
51    Optimistic,
52
53    /// Majority: more than half must succeed
54    Majority,
55}
56
57/// Ensemble node for multi-model fusion.
58///
59/// This node runs multiple behaviors/models in parallel and combines their results
60/// using a configurable strategy. This is useful for:
61/// - Combining multiple AI models for robust predictions
62/// - Redundancy and fault tolerance
63/// - Multi-modal sensor fusion
64///
65/// # Strategies
66///
67/// - **WeightedVote**: Models vote with weights, success if total > threshold
68/// - **Conservative**: All models must succeed (high precision, low recall)
69/// - **Optimistic**: Any model can succeed (high recall, low precision)
70/// - **Majority**: More than half must succeed (balanced)
71///
72/// # Example
73///
74/// ```rust
75/// use mecha10_behavior_patterns::prelude::*;
76///
77/// # #[derive(Debug)]
78/// # struct YoloV8;
79/// # #[async_trait]
80/// # impl BehaviorNode for YoloV8 {
81/// #     async fn tick(&mut self, _ctx: &Context) -> anyhow::Result<NodeStatus> {
82/// #         Ok(NodeStatus::Success)
83/// #     }
84/// # }
85/// # #[derive(Debug)]
86/// # struct YoloV10;
87/// # #[async_trait]
88/// # impl BehaviorNode for YoloV10 {
89/// #     async fn tick(&mut self, _ctx: &Context) -> anyhow::Result<NodeStatus> {
90/// #         Ok(NodeStatus::Success)
91/// #     }
92/// # }
93/// # #[derive(Debug)]
94/// # struct CustomDetector;
95/// # #[async_trait]
96/// # impl BehaviorNode for CustomDetector {
97/// #     async fn tick(&mut self, _ctx: &Context) -> anyhow::Result<NodeStatus> {
98/// #         Ok(NodeStatus::Success)
99/// #     }
100/// # }
101///
102/// let ensemble = EnsembleNode::new(EnsembleStrategy::WeightedVote)
103///     .add_model(Box::new(YoloV8), 0.4)
104///     .add_model(Box::new(YoloV10), 0.3)
105///     .add_model(Box::new(CustomDetector), 0.3);
106/// ```
107#[derive(Debug)]
108pub struct EnsembleNode {
109    models: Vec<WeightedModel>,
110    strategy: EnsembleStrategy,
111    threshold: f32,
112    model_statuses: Vec<NodeStatus>,
113}
114
115impl EnsembleNode {
116    /// Create a new ensemble node with the given strategy.
117    pub fn new(strategy: EnsembleStrategy) -> Self {
118        Self {
119            models: Vec::new(),
120            strategy,
121            threshold: 0.5,
122            model_statuses: Vec::new(),
123        }
124    }
125
126    /// Add a model with a weight.
127    pub fn add_model(mut self, behavior: BoxedBehavior, weight: f32) -> Self {
128        self.models.push(WeightedModel::new(behavior, weight));
129        self.model_statuses.push(NodeStatus::Running);
130        self
131    }
132
133    /// Add a named model.
134    pub fn add_named_model(mut self, name: impl Into<String>, behavior: BoxedBehavior, weight: f32) -> Self {
135        self.models.push(WeightedModel::new(behavior, weight).with_name(name));
136        self.model_statuses.push(NodeStatus::Running);
137        self
138    }
139
140    /// Set the threshold for weighted voting (default 0.5).
141    pub fn with_threshold(mut self, threshold: f32) -> Self {
142        self.threshold = threshold;
143        self
144    }
145
146    /// Get the number of models in the ensemble.
147    pub fn model_count(&self) -> usize {
148        self.models.len()
149    }
150}
151
152#[async_trait]
153impl BehaviorNode for EnsembleNode {
154    async fn tick(&mut self, ctx: &Context) -> anyhow::Result<NodeStatus> {
155        if self.models.is_empty() {
156            return Ok(NodeStatus::Failure);
157        }
158
159        // Execute all models in parallel
160        for (i, model) in self.models.iter_mut().enumerate() {
161            if self.model_statuses[i].is_running() {
162                let status = model.behavior.tick(ctx).await?;
163                self.model_statuses[i] = status;
164
165                let model_name = model.name.as_deref().unwrap_or("unnamed");
166                debug!("Ensemble: model '{}' status: {}", model_name, status);
167            }
168        }
169
170        // Combine results based on strategy
171        let overall_status = match self.strategy {
172            EnsembleStrategy::Conservative => self.conservative_combine(),
173            EnsembleStrategy::Optimistic => self.optimistic_combine(),
174            EnsembleStrategy::Majority => self.majority_combine(),
175            EnsembleStrategy::WeightedVote => self.weighted_vote_combine(),
176        };
177
178        info!("Ensemble ({:?}): overall status = {}", self.strategy, overall_status);
179        Ok(overall_status)
180    }
181
182    async fn reset(&mut self) -> anyhow::Result<()> {
183        for (i, model) in self.models.iter_mut().enumerate() {
184            model.behavior.reset().await?;
185            self.model_statuses[i] = NodeStatus::Running;
186        }
187        Ok(())
188    }
189
190    async fn on_init(&mut self, ctx: &Context) -> anyhow::Result<()> {
191        for model in &mut self.models {
192            model.behavior.on_init(ctx).await?;
193        }
194        Ok(())
195    }
196
197    async fn on_terminate(&mut self, ctx: &Context) -> anyhow::Result<()> {
198        for model in &mut self.models {
199            model.behavior.on_terminate(ctx).await?;
200        }
201        Ok(())
202    }
203
204    fn name(&self) -> &str {
205        "ensemble"
206    }
207}
208
209impl EnsembleNode {
210    fn conservative_combine(&self) -> NodeStatus {
211        // All must succeed
212        let failures = self.model_statuses.iter().filter(|s| s.is_failure()).count();
213        let running = self.model_statuses.iter().filter(|s| s.is_running()).count();
214
215        if failures > 0 {
216            NodeStatus::Failure
217        } else if running > 0 {
218            NodeStatus::Running
219        } else {
220            NodeStatus::Success
221        }
222    }
223
224    fn optimistic_combine(&self) -> NodeStatus {
225        // Any can succeed
226        let successes = self.model_statuses.iter().filter(|s| s.is_success()).count();
227        let running = self.model_statuses.iter().filter(|s| s.is_running()).count();
228
229        if successes > 0 {
230            NodeStatus::Success
231        } else if running > 0 {
232            NodeStatus::Running
233        } else {
234            NodeStatus::Failure
235        }
236    }
237
238    fn majority_combine(&self) -> NodeStatus {
239        // More than half must succeed
240        let successes = self.model_statuses.iter().filter(|s| s.is_success()).count();
241        let failures = self.model_statuses.iter().filter(|s| s.is_failure()).count();
242        let running = self.model_statuses.iter().filter(|s| s.is_running()).count();
243
244        let _total_decided = successes + failures;
245        let majority_threshold = (self.models.len() + 1) / 2;
246
247        if successes >= majority_threshold {
248            NodeStatus::Success
249        } else if failures >= majority_threshold {
250            NodeStatus::Failure
251        } else if running > 0 {
252            NodeStatus::Running
253        } else {
254            // Tie - default to failure
255            NodeStatus::Failure
256        }
257    }
258
259    fn weighted_vote_combine(&self) -> NodeStatus {
260        // Weighted sum of successes must exceed threshold
261        let mut success_weight = 0.0;
262        let mut _failure_weight = 0.0;
263        let mut total_weight = 0.0;
264        let mut any_running = false;
265
266        for (model, status) in self.models.iter().zip(&self.model_statuses) {
267            total_weight += model.weight;
268            match status {
269                NodeStatus::Success => success_weight += model.weight,
270                NodeStatus::Failure => _failure_weight += model.weight,
271                NodeStatus::Running => any_running = true,
272            }
273        }
274
275        if any_running {
276            return NodeStatus::Running;
277        }
278
279        let success_ratio = if total_weight > 0.0 {
280            success_weight / total_weight
281        } else {
282            0.0
283        };
284
285        if success_ratio >= self.threshold {
286            NodeStatus::Success
287        } else {
288            NodeStatus::Failure
289        }
290    }
291}
292
293/// Configuration for an ensemble model (for JSON deserialization).
294#[allow(dead_code)] // Will be used when implementing JSON loading
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct EnsembleModelConfig {
297    /// Node type to instantiate
298    pub node: String,
299    /// Weight for this model
300    pub weight: f32,
301    /// Optional name for this model
302    #[serde(skip_serializing_if = "Option::is_none")]
303    pub name: Option<String>,
304    /// Configuration for the node
305    #[serde(default)]
306    pub config: serde_json::Value,
307}