use crate::embedding::cosine_similarity;
use crate::types::{Outcome, ProceduralMemory};
pub fn update_reliability(
current_trial_count: u32,
current_success_count: u32,
new_outcome: &Outcome,
) -> (f32, u32, u32) {
let new_trial = current_trial_count + 1;
let new_success = match new_outcome {
Outcome::Success => current_success_count + 1,
Outcome::Failure | Outcome::Partial => current_success_count,
};
let alpha = (new_success + 1) as f32;
let beta = (new_trial - new_success + 1) as f32;
let reliability = alpha / (alpha + beta);
(reliability, new_trial, new_success)
}
pub fn should_prune(reliability: f32, trial_count: u32, min_trials: u32, threshold: f32) -> bool {
trial_count >= min_trials && reliability < threshold
}
pub fn select_best_procedure(
procedures: &[ProceduralMemory],
reliability_threshold: f32,
) -> Option<usize> {
procedures
.iter()
.enumerate()
.filter(|(_, p)| p.reliability >= reliability_threshold)
.max_by(|(_, a), (_, b)| a.reliability.partial_cmp(&b.reliability).unwrap())
.map(|(i, _)| i)
}
pub fn transfer_procedures(
source_procedures: &[ProceduralMemory],
existing_procedures: &[ProceduralMemory],
min_reliability: f32,
min_trials: u32,
) -> Vec<ProceduralMemory> {
let transfer_discount = 0.7;
source_procedures
.iter()
.filter(|proc| proc.reliability >= min_reliability && proc.trial_count >= min_trials)
.filter(|proc| {
!existing_procedures
.iter()
.any(|existing| existing.trigger == proc.trigger && existing.action == proc.action)
})
.map(|proc| {
let mut transferred_proc = ProceduralMemory::new(
proc.namespace_id,
proc.trigger.clone(),
proc.action.clone(),
proc.outcome.clone(),
proc.context.clone(),
);
transferred_proc.reliability = proc.reliability * transfer_discount;
transferred_proc.trial_count = 0;
transferred_proc.success_count = 0;
transferred_proc
})
.collect()
}
pub fn context_weighted_reliability(
trials: &[(bool, Vec<f32>)],
current_context: &[f32],
gamma: f32,
) -> f32 {
let mut alpha = 1.0f32;
let mut beta = 1.0f32;
for (success, trial_context) in trials {
let weight = cosine_similarity(current_context, trial_context)
.max(0.0)
.powf(gamma);
if *success {
alpha += weight;
} else {
beta += weight;
}
}
alpha / (alpha + beta)
}
pub fn adaptive_transfer_discount(namespace_similarity: f32) -> f32 {
0.5 + 0.3 * namespace_similarity.clamp(0.0, 1.0)
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use uuid::Uuid;
use super::*;
use crate::types::Outcome;
#[test]
fn test_initial_reliability() {
let (rel, trials, successes) = update_reliability(0, 0, &Outcome::Success);
assert_eq!(trials, 1);
assert_eq!(successes, 1);
assert!((rel - 0.667).abs() < 0.01);
}
#[test]
fn test_reliability_increases_with_success() {
let (r1, _, _) = update_reliability(1, 1, &Outcome::Success);
let (r2, _, _) = update_reliability(2, 2, &Outcome::Success);
assert!(r2 > r1); }
#[test]
fn test_reliability_decreases_with_failure() {
let (r_success, _, _) = update_reliability(1, 1, &Outcome::Success);
let (r_failure, _, _) = update_reliability(1, 1, &Outcome::Failure);
assert!(r_success > r_failure);
}
#[test]
fn test_many_successes_high_reliability() {
let mut trials = 0u32;
let mut successes = 0u32;
let mut rel = 0.5f32;
for _ in 0..20 {
let result = update_reliability(trials, successes, &Outcome::Success);
rel = result.0;
trials = result.1;
successes = result.2;
}
assert!(rel > 0.9);
}
#[test]
fn test_many_failures_low_reliability() {
let mut trials = 0u32;
let mut successes = 0u32;
let mut rel = 0.5f32;
for _ in 0..20 {
let result = update_reliability(trials, successes, &Outcome::Failure);
rel = result.0;
trials = result.1;
successes = result.2;
}
assert!(rel < 0.15);
}
#[test]
fn test_should_prune() {
assert!(should_prune(0.05, 15, 10, 0.1));
assert!(!should_prune(0.05, 5, 10, 0.1)); assert!(!should_prune(0.5, 15, 10, 0.1)); }
#[test]
fn test_select_best_procedure() {
let mut procs = vec![
ProceduralMemory::new(
Uuid::new_v4(),
"trigger",
"action1",
Outcome::Success,
HashMap::new(),
),
ProceduralMemory::new(
Uuid::new_v4(),
"trigger",
"action2",
Outcome::Success,
HashMap::new(),
),
];
procs[0].reliability = 0.3;
procs[1].reliability = 0.8;
let best = select_best_procedure(&procs, 0.1);
assert_eq!(best, Some(1)); }
#[test]
fn test_select_best_procedure_none_above_threshold() {
let mut procs = vec![ProceduralMemory::new(
Uuid::new_v4(),
"trigger",
"action1",
Outcome::Success,
HashMap::new(),
)];
procs[0].reliability = 0.05;
let best = select_best_procedure(&procs, 0.5);
assert_eq!(best, None); }
#[test]
fn test_partial_outcome_not_counted_as_success() {
let (rel_partial, trials_p, successes_p) = update_reliability(0, 0, &Outcome::Partial);
let (rel_failure, trials_f, successes_f) = update_reliability(0, 0, &Outcome::Failure);
assert_eq!(successes_p, successes_f);
assert_eq!(trials_p, trials_f);
assert!((rel_partial - rel_failure).abs() < f32::EPSILON);
}
#[test]
fn test_uninformative_prior_at_zero_trials() {
let (rel, trials, successes) = update_reliability(0, 0, &Outcome::Failure);
assert_eq!(trials, 1);
assert_eq!(successes, 0);
assert!((rel - 0.333).abs() < 0.01);
}
fn make_proc(
trigger: &str,
action: &str,
reliability: f32,
trial_count: u32,
) -> ProceduralMemory {
let mut p = ProceduralMemory::new(
Uuid::new_v4(),
trigger,
action,
Outcome::Success,
HashMap::new(),
);
p.reliability = reliability;
p.trial_count = trial_count;
p.success_count = trial_count; p
}
#[test]
fn test_transfer_high_reliability_only() {
let source = vec![
make_proc("on_error", "retry", 0.9, 20),
make_proc("on_timeout", "backoff", 0.3, 20),
];
let existing: Vec<ProceduralMemory> = vec![];
let result = transfer_procedures(&source, &existing, 0.7, 5);
assert_eq!(result.len(), 1);
assert_eq!(result[0].trigger, "on_error");
}
#[test]
fn test_transfer_applies_discount() {
let source = vec![make_proc("on_error", "retry", 0.9, 20)];
let existing: Vec<ProceduralMemory> = vec![];
let result = transfer_procedures(&source, &existing, 0.5, 5);
assert_eq!(result.len(), 1);
assert!((result[0].reliability - 0.63).abs() < 0.01);
}
#[test]
fn test_transfer_skips_duplicates() {
let source = vec![make_proc("on_error", "retry", 0.9, 20)];
let existing = vec![make_proc("on_error", "retry", 0.5, 5)];
let result = transfer_procedures(&source, &existing, 0.5, 5);
assert!(result.is_empty());
}
#[test]
fn test_transfer_resets_trial_count() {
let source = vec![make_proc("on_error", "retry", 0.9, 20)];
let existing: Vec<ProceduralMemory> = vec![];
let result = transfer_procedures(&source, &existing, 0.5, 5);
assert_eq!(result.len(), 1);
assert_eq!(result[0].trial_count, 0);
assert_eq!(result[0].success_count, 0);
}
#[test]
fn test_transfer_respects_min_trials() {
let source = vec![
make_proc("on_error", "retry", 0.9, 3), make_proc("on_timeout", "backoff", 0.85, 10), ];
let existing: Vec<ProceduralMemory> = vec![];
let result = transfer_procedures(&source, &existing, 0.7, 5);
assert_eq!(result.len(), 1);
assert_eq!(result[0].trigger, "on_timeout");
}
#[test]
fn test_context_weighted_reliability() {
let trials = vec![
(true, vec![1.0, 0.0, 0.0]),
(true, vec![1.0, 0.0, 0.0]),
(false, vec![0.0, 1.0, 0.0]),
];
let ctx_a = vec![1.0, 0.0, 0.0];
let ctx_b = vec![0.0, 1.0, 0.0];
let rel_a = context_weighted_reliability(&trials, &ctx_a, 2.0);
let rel_b = context_weighted_reliability(&trials, &ctx_b, 2.0);
assert!(
rel_a > rel_b,
"Should be more reliable in context A: {rel_a} vs {rel_b}"
);
}
#[test]
fn test_adaptive_transfer_discount() {
let similar = adaptive_transfer_discount(0.9);
let different = adaptive_transfer_discount(0.2);
assert!(similar > different);
assert!(similar <= 0.8);
assert!(different >= 0.5);
}
}