use async_trait::async_trait;
use mecha10_behavior_runtime::{BehaviorNode, BoxedBehavior, NodeStatus};
use mecha10_core::Context;
use serde::{Deserialize, Serialize};
use tracing::{debug, info};
#[derive(Debug)]
pub struct WeightedModel {
pub behavior: BoxedBehavior,
pub weight: f32,
pub name: Option<String>,
}
impl WeightedModel {
pub fn new(behavior: BoxedBehavior, weight: f32) -> Self {
Self {
behavior,
weight,
name: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EnsembleStrategy {
WeightedVote,
Conservative,
Optimistic,
Majority,
}
#[derive(Debug)]
pub struct EnsembleNode {
models: Vec<WeightedModel>,
strategy: EnsembleStrategy,
threshold: f32,
model_statuses: Vec<NodeStatus>,
}
impl EnsembleNode {
pub fn new(strategy: EnsembleStrategy) -> Self {
Self {
models: Vec::new(),
strategy,
threshold: 0.5,
model_statuses: Vec::new(),
}
}
pub fn add_model(mut self, behavior: BoxedBehavior, weight: f32) -> Self {
self.models.push(WeightedModel::new(behavior, weight));
self.model_statuses.push(NodeStatus::Running);
self
}
pub fn add_named_model(mut self, name: impl Into<String>, behavior: BoxedBehavior, weight: f32) -> Self {
self.models.push(WeightedModel::new(behavior, weight).with_name(name));
self.model_statuses.push(NodeStatus::Running);
self
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
pub fn model_count(&self) -> usize {
self.models.len()
}
}
#[async_trait]
impl BehaviorNode for EnsembleNode {
async fn tick(&mut self, ctx: &Context) -> anyhow::Result<NodeStatus> {
if self.models.is_empty() {
return Ok(NodeStatus::Failure);
}
for (i, model) in self.models.iter_mut().enumerate() {
if self.model_statuses[i].is_running() {
let status = model.behavior.tick(ctx).await?;
self.model_statuses[i] = status;
let model_name = model.name.as_deref().unwrap_or("unnamed");
debug!("Ensemble: model '{}' status: {}", model_name, status);
}
}
let overall_status = match self.strategy {
EnsembleStrategy::Conservative => self.conservative_combine(),
EnsembleStrategy::Optimistic => self.optimistic_combine(),
EnsembleStrategy::Majority => self.majority_combine(),
EnsembleStrategy::WeightedVote => self.weighted_vote_combine(),
};
info!("Ensemble ({:?}): overall status = {}", self.strategy, overall_status);
Ok(overall_status)
}
async fn reset(&mut self) -> anyhow::Result<()> {
for (i, model) in self.models.iter_mut().enumerate() {
model.behavior.reset().await?;
self.model_statuses[i] = NodeStatus::Running;
}
Ok(())
}
async fn on_init(&mut self, ctx: &Context) -> anyhow::Result<()> {
for model in &mut self.models {
model.behavior.on_init(ctx).await?;
}
Ok(())
}
async fn on_terminate(&mut self, ctx: &Context) -> anyhow::Result<()> {
for model in &mut self.models {
model.behavior.on_terminate(ctx).await?;
}
Ok(())
}
fn name(&self) -> &str {
"ensemble"
}
}
impl EnsembleNode {
fn conservative_combine(&self) -> NodeStatus {
let failures = self.model_statuses.iter().filter(|s| s.is_failure()).count();
let running = self.model_statuses.iter().filter(|s| s.is_running()).count();
if failures > 0 {
NodeStatus::Failure
} else if running > 0 {
NodeStatus::Running
} else {
NodeStatus::Success
}
}
fn optimistic_combine(&self) -> NodeStatus {
let successes = self.model_statuses.iter().filter(|s| s.is_success()).count();
let running = self.model_statuses.iter().filter(|s| s.is_running()).count();
if successes > 0 {
NodeStatus::Success
} else if running > 0 {
NodeStatus::Running
} else {
NodeStatus::Failure
}
}
fn majority_combine(&self) -> NodeStatus {
let successes = self.model_statuses.iter().filter(|s| s.is_success()).count();
let failures = self.model_statuses.iter().filter(|s| s.is_failure()).count();
let running = self.model_statuses.iter().filter(|s| s.is_running()).count();
let _total_decided = successes + failures;
let majority_threshold = (self.models.len() + 1) / 2;
if successes >= majority_threshold {
NodeStatus::Success
} else if failures >= majority_threshold {
NodeStatus::Failure
} else if running > 0 {
NodeStatus::Running
} else {
NodeStatus::Failure
}
}
fn weighted_vote_combine(&self) -> NodeStatus {
let mut success_weight = 0.0;
let mut _failure_weight = 0.0;
let mut total_weight = 0.0;
let mut any_running = false;
for (model, status) in self.models.iter().zip(&self.model_statuses) {
total_weight += model.weight;
match status {
NodeStatus::Success => success_weight += model.weight,
NodeStatus::Failure => _failure_weight += model.weight,
NodeStatus::Running => any_running = true,
}
}
if any_running {
return NodeStatus::Running;
}
let success_ratio = if total_weight > 0.0 {
success_weight / total_weight
} else {
0.0
};
if success_ratio >= self.threshold {
NodeStatus::Success
} else {
NodeStatus::Failure
}
}
}
#[allow(dead_code)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnsembleModelConfig {
pub node: String,
pub weight: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(default)]
pub config: serde_json::Value,
}