use std::collections::BTreeMap;
use serde::Serialize;
use crate::residual::{CorrectionDirection, ResidualEvent};
use crate::stability::StabilityParameters;
pub type Evidence = BTreeMap<String, String>;
pub type CorrectionDirectionSet = Vec<CorrectionDirection>;
#[derive(Debug, Clone, PartialEq)]
pub struct AgentBarrierResult {
pub ok: bool,
pub score: f64,
pub residuals: Vec<ResidualEvent>,
pub feedback: CorrectionDirectionSet,
pub evidence: Evidence,
}
impl AgentBarrierResult {
pub fn new(ok: bool, score: f64) -> Self {
Self {
ok,
score,
residuals: Vec::new(),
feedback: Vec::new(),
evidence: Evidence::new(),
}
}
pub fn with_residuals(mut self, residuals: Vec<ResidualEvent>) -> Self {
self.residuals = residuals;
self
}
pub fn with_feedback(mut self, feedback: CorrectionDirectionSet) -> Self {
self.feedback = feedback;
self
}
pub fn into_srbn(
self,
name: impl Into<String>,
) -> srbn::SrbnResult<srbn::BarrierResult<CorrectionDirectionSet, Evidence>> {
let feedback = self
.feedback
.first()
.map(|d| d.instruction.clone())
.unwrap_or_default();
let correction = if self.feedback.is_empty() {
None
} else {
Some(self.feedback)
};
srbn::BarrierResult::new(
name,
self.ok,
self.score,
feedback,
correction,
self.evidence,
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum AgentStabilizationStatus {
Stable,
Descended,
Stopped,
Exhausted,
Degraded,
ReplanRequired,
}
impl From<srbn::Status> for AgentStabilizationStatus {
fn from(status: srbn::Status) -> Self {
match status {
srbn::Status::Stable => AgentStabilizationStatus::Stable,
srbn::Status::Stopped => AgentStabilizationStatus::Stopped,
srbn::Status::Exhausted => AgentStabilizationStatus::Exhausted,
}
}
}
pub fn policy(params: &StabilityParameters, max_attempts: usize) -> srbn::Policy {
srbn::Policy {
max_attempts: max_attempts.max(1),
score_tolerance: params.energy_tolerance,
require_descent: true,
on_no_descent: srbn::OnNoDescent::Stop,
min_descent: params.rho_gate,
}
}
pub fn stabilize<State, B, U>(
initial: State,
mut barrier: B,
updater: U,
params: &StabilityParameters,
max_attempts: usize,
) -> crate::error::Result<srbn::StabilizationResult<State, CorrectionDirectionSet, Evidence>>
where
State: Clone,
B: FnMut(&State) -> AgentBarrierResult,
U: FnMut(
State,
&srbn::BarrierResult<CorrectionDirectionSet, Evidence>,
) -> srbn::SrbnResult<State>,
{
let srbn_barrier = move |state: &State| barrier(state).into_srbn("agent-barrier");
let result = srbn::stabilize(initial, srbn_barrier, updater, policy(params, max_attempts))?;
Ok(result)
}
pub fn trace_to_json<State>(
result: &srbn::StabilizationResult<State, CorrectionDirectionSet, Evidence>,
) -> crate::error::Result<String>
where
State: Serialize,
{
srbn_serde::stabilization_result_json(result)
.map_err(|e| crate::error::SdkError::Kernel(format!("trace serialization failed: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn maps_kernel_status() {
assert_eq!(
AgentStabilizationStatus::from(srbn::Status::Stable),
AgentStabilizationStatus::Stable
);
assert_eq!(
AgentStabilizationStatus::from(srbn::Status::Stopped),
AgentStabilizationStatus::Stopped
);
assert_eq!(
AgentStabilizationStatus::from(srbn::Status::Exhausted),
AgentStabilizationStatus::Exhausted
);
}
#[test]
fn stabilizes_descending_energy_through_kernel() {
let params = StabilityParameters::measured(0.5, 0.0);
let barrier = |state: &i64| {
let v = (*state as f64) * (*state as f64);
AgentBarrierResult::new(*state == 0, v)
};
let updater = |state: i64, _b: &srbn::BarrierResult<CorrectionDirectionSet, Evidence>| {
Ok(state - state.signum())
};
let result = stabilize(3, barrier, updater, ¶ms, 10).unwrap();
assert_eq!(result.status, srbn::Status::Stable);
assert_eq!(result.state, 0);
let json = trace_to_json(&result).unwrap();
assert!(json.contains("\"status\":\"stable\""));
assert!(json.contains("\"attempts\""));
}
#[test]
fn rejects_non_finite_score_at_barrier() {
let params = StabilityParameters::measured(0.5, 0.0);
let barrier = |_state: &i64| AgentBarrierResult::new(false, f64::NAN);
let updater =
|state: i64, _b: &srbn::BarrierResult<CorrectionDirectionSet, Evidence>| Ok(state);
assert!(stabilize(1, barrier, updater, ¶ms, 3).is_err());
}
}