mecha10_behavior_patterns/
ensemble.rs1use async_trait::async_trait;
7use mecha10_behavior_runtime::{BehaviorNode, BoxedBehavior, NodeStatus};
8use mecha10_core::Context;
9use serde::{Deserialize, Serialize};
10use tracing::{debug, info};
11
12#[derive(Debug)]
14pub struct WeightedModel {
15 pub behavior: BoxedBehavior,
17 pub weight: f32,
19 pub name: Option<String>,
21}
22
23impl WeightedModel {
24 pub fn new(behavior: BoxedBehavior, weight: f32) -> Self {
26 Self {
27 behavior,
28 weight,
29 name: None,
30 }
31 }
32
33 pub fn with_name(mut self, name: impl Into<String>) -> Self {
35 self.name = Some(name.into());
36 self
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum EnsembleStrategy {
44 WeightedVote,
46
47 Conservative,
49
50 Optimistic,
52
53 Majority,
55}
56
57#[derive(Debug)]
108pub struct EnsembleNode {
109 models: Vec<WeightedModel>,
110 strategy: EnsembleStrategy,
111 threshold: f32,
112 model_statuses: Vec<NodeStatus>,
113}
114
115impl EnsembleNode {
116 pub fn new(strategy: EnsembleStrategy) -> Self {
118 Self {
119 models: Vec::new(),
120 strategy,
121 threshold: 0.5,
122 model_statuses: Vec::new(),
123 }
124 }
125
126 pub fn add_model(mut self, behavior: BoxedBehavior, weight: f32) -> Self {
128 self.models.push(WeightedModel::new(behavior, weight));
129 self.model_statuses.push(NodeStatus::Running);
130 self
131 }
132
133 pub fn add_named_model(mut self, name: impl Into<String>, behavior: BoxedBehavior, weight: f32) -> Self {
135 self.models.push(WeightedModel::new(behavior, weight).with_name(name));
136 self.model_statuses.push(NodeStatus::Running);
137 self
138 }
139
140 pub fn with_threshold(mut self, threshold: f32) -> Self {
142 self.threshold = threshold;
143 self
144 }
145
146 pub fn model_count(&self) -> usize {
148 self.models.len()
149 }
150}
151
152#[async_trait]
153impl BehaviorNode for EnsembleNode {
154 async fn tick(&mut self, ctx: &Context) -> anyhow::Result<NodeStatus> {
155 if self.models.is_empty() {
156 return Ok(NodeStatus::Failure);
157 }
158
159 for (i, model) in self.models.iter_mut().enumerate() {
161 if self.model_statuses[i].is_running() {
162 let status = model.behavior.tick(ctx).await?;
163 self.model_statuses[i] = status;
164
165 let model_name = model.name.as_deref().unwrap_or("unnamed");
166 debug!("Ensemble: model '{}' status: {}", model_name, status);
167 }
168 }
169
170 let overall_status = match self.strategy {
172 EnsembleStrategy::Conservative => self.conservative_combine(),
173 EnsembleStrategy::Optimistic => self.optimistic_combine(),
174 EnsembleStrategy::Majority => self.majority_combine(),
175 EnsembleStrategy::WeightedVote => self.weighted_vote_combine(),
176 };
177
178 info!("Ensemble ({:?}): overall status = {}", self.strategy, overall_status);
179 Ok(overall_status)
180 }
181
182 async fn reset(&mut self) -> anyhow::Result<()> {
183 for (i, model) in self.models.iter_mut().enumerate() {
184 model.behavior.reset().await?;
185 self.model_statuses[i] = NodeStatus::Running;
186 }
187 Ok(())
188 }
189
190 async fn on_init(&mut self, ctx: &Context) -> anyhow::Result<()> {
191 for model in &mut self.models {
192 model.behavior.on_init(ctx).await?;
193 }
194 Ok(())
195 }
196
197 async fn on_terminate(&mut self, ctx: &Context) -> anyhow::Result<()> {
198 for model in &mut self.models {
199 model.behavior.on_terminate(ctx).await?;
200 }
201 Ok(())
202 }
203
204 fn name(&self) -> &str {
205 "ensemble"
206 }
207}
208
209impl EnsembleNode {
210 fn conservative_combine(&self) -> NodeStatus {
211 let failures = self.model_statuses.iter().filter(|s| s.is_failure()).count();
213 let running = self.model_statuses.iter().filter(|s| s.is_running()).count();
214
215 if failures > 0 {
216 NodeStatus::Failure
217 } else if running > 0 {
218 NodeStatus::Running
219 } else {
220 NodeStatus::Success
221 }
222 }
223
224 fn optimistic_combine(&self) -> NodeStatus {
225 let successes = self.model_statuses.iter().filter(|s| s.is_success()).count();
227 let running = self.model_statuses.iter().filter(|s| s.is_running()).count();
228
229 if successes > 0 {
230 NodeStatus::Success
231 } else if running > 0 {
232 NodeStatus::Running
233 } else {
234 NodeStatus::Failure
235 }
236 }
237
238 fn majority_combine(&self) -> NodeStatus {
239 let successes = self.model_statuses.iter().filter(|s| s.is_success()).count();
241 let failures = self.model_statuses.iter().filter(|s| s.is_failure()).count();
242 let running = self.model_statuses.iter().filter(|s| s.is_running()).count();
243
244 let _total_decided = successes + failures;
245 let majority_threshold = (self.models.len() + 1) / 2;
246
247 if successes >= majority_threshold {
248 NodeStatus::Success
249 } else if failures >= majority_threshold {
250 NodeStatus::Failure
251 } else if running > 0 {
252 NodeStatus::Running
253 } else {
254 NodeStatus::Failure
256 }
257 }
258
259 fn weighted_vote_combine(&self) -> NodeStatus {
260 let mut success_weight = 0.0;
262 let mut _failure_weight = 0.0;
263 let mut total_weight = 0.0;
264 let mut any_running = false;
265
266 for (model, status) in self.models.iter().zip(&self.model_statuses) {
267 total_weight += model.weight;
268 match status {
269 NodeStatus::Success => success_weight += model.weight,
270 NodeStatus::Failure => _failure_weight += model.weight,
271 NodeStatus::Running => any_running = true,
272 }
273 }
274
275 if any_running {
276 return NodeStatus::Running;
277 }
278
279 let success_ratio = if total_weight > 0.0 {
280 success_weight / total_weight
281 } else {
282 0.0
283 };
284
285 if success_ratio >= self.threshold {
286 NodeStatus::Success
287 } else {
288 NodeStatus::Failure
289 }
290 }
291}
292
293#[allow(dead_code)] #[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct EnsembleModelConfig {
297 pub node: String,
299 pub weight: f32,
301 #[serde(skip_serializing_if = "Option::is_none")]
303 pub name: Option<String>,
304 #[serde(default)]
306 pub config: serde_json::Value,
307}