use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ValueClass {
#[default]
Critical,
Supporting,
Optional,
AuditOnly,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum SideEffectClass {
Pure,
ReadOnly,
MutatesLocal,
MutatesExternal,
#[default]
Indeterminate,
}
impl SideEffectClass {
pub fn is_speculatable(&self) -> bool {
matches!(self, Self::Pure | Self::ReadOnly)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldGroup {
#[serde(default)]
pub fields: Vec<String>,
#[serde(default = "default_estimated_value")]
pub estimated_value: f32,
#[serde(default = "default_include_true")]
pub default_include: bool,
}
fn default_estimated_value() -> f32 {
0.5
}
fn default_include_true() -> bool {
true
}
impl Default for FieldGroup {
fn default() -> Self {
Self {
fields: Vec::new(),
estimated_value: default_estimated_value(),
default_include: default_include_true(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CostModel {
#[serde(default = "default_typical_kb")]
pub typical_kb: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_kb: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub latency_ms_p50: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dollars: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub freshness_ttl_s: Option<u32>,
}
fn default_typical_kb() -> f32 {
1.0
}
impl Default for CostModel {
fn default() -> Self {
Self {
typical_kb: default_typical_kb(),
max_kb: None,
latency_ms_p50: None,
dollars: None,
freshness_ttl_s: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FollowUpLink {
pub tool: String,
#[serde(default = "default_followup_probability")]
pub probability: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub projection: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub projection_arg: Option<String>,
}
fn default_followup_probability() -> f32 {
0.5
}
impl Default for FollowUpLink {
fn default() -> Self {
Self {
tool: String::new(),
probability: default_followup_probability(),
projection: None,
projection_arg: None,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ToolValueModel {
#[serde(default)]
pub value_class: ValueClass,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub field_groups: BTreeMap<String, FieldGroup>,
#[serde(default)]
pub cost_model: CostModel,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub follow_up: Vec<FollowUpLink>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub invalidates: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub fail_fast_after_n: Option<u32>,
#[serde(default, skip_serializing_if = "is_default_side_effect")]
pub side_effect_class: SideEffectClass,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rate_limit_host: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub speculate: Option<bool>,
}
fn is_default_side_effect(s: &SideEffectClass) -> bool {
matches!(s, SideEffectClass::Indeterminate)
}
impl ToolValueModel {
pub fn critical_with_size(typical_kb: f32) -> Self {
Self {
value_class: ValueClass::Critical,
cost_model: CostModel {
typical_kb,
..CostModel::default()
},
..Self::default()
}
}
pub fn audit_only() -> Self {
Self {
value_class: ValueClass::AuditOnly,
..Self::default()
}
}
pub fn excluded_from_budget(&self) -> bool {
matches!(self.value_class, ValueClass::AuditOnly)
}
pub fn is_speculatable(&self) -> bool {
if matches!(self.speculate, Some(false)) {
return false;
}
if matches!(self.speculate, Some(true)) {
return true;
}
self.side_effect_class.is_speculatable()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_critical_with_one_kb() {
let m = ToolValueModel::default();
assert_eq!(m.value_class, ValueClass::Critical);
assert_eq!(m.cost_model.typical_kb, 1.0);
assert!(m.field_groups.is_empty());
assert!(m.follow_up.is_empty());
assert!(m.invalidates.is_empty());
assert!(!m.excluded_from_budget());
}
#[test]
fn audit_only_is_excluded_from_budget() {
assert!(ToolValueModel::audit_only().excluded_from_budget());
assert!(!ToolValueModel::critical_with_size(2.5).excluded_from_budget());
}
#[test]
fn critical_with_size_sets_typical_kb() {
let m = ToolValueModel::critical_with_size(2.5);
assert_eq!(m.value_class, ValueClass::Critical);
assert_eq!(m.cost_model.typical_kb, 2.5);
}
#[test]
fn round_trip_via_toml_default() {
let m: ToolValueModel = toml::from_str("").unwrap();
assert_eq!(m.value_class, ValueClass::default());
assert_eq!(m.cost_model, CostModel::default());
}
#[test]
fn round_trip_via_toml_full() {
let m = ToolValueModel {
value_class: ValueClass::Supporting,
field_groups: {
let mut g = BTreeMap::new();
g.insert(
"must_have".to_string(),
FieldGroup {
fields: vec!["title".into(), "url".into()],
estimated_value: 1.0,
default_include: true,
},
);
g.insert(
"nice_to_have".to_string(),
FieldGroup {
fields: vec!["snippet".into()],
estimated_value: 0.3,
default_include: false,
},
);
g
},
cost_model: CostModel {
typical_kb: 3.1,
max_kb: Some(8.0),
latency_ms_p50: Some(900),
dollars: None,
freshness_ttl_s: Some(3600),
},
follow_up: vec![FollowUpLink {
tool: "WebFetch".into(),
probability: 0.65,
projection: Some("url".into()),
projection_arg: Some("url".into()),
}],
invalidates: vec![],
fail_fast_after_n: Some(2),
side_effect_class: SideEffectClass::ReadOnly,
rate_limit_host: Some("example.com".into()),
speculate: None,
};
let s = toml::to_string_pretty(&m).unwrap();
let back: ToolValueModel = toml::from_str(&s).unwrap();
assert_eq!(back.value_class, ValueClass::Supporting);
assert_eq!(back.field_groups.len(), 2);
assert_eq!(
back.field_groups.get("must_have").unwrap().fields,
vec!["title".to_string(), "url".to_string()]
);
assert_eq!(back.cost_model.typical_kb, 3.1);
assert_eq!(back.cost_model.max_kb, Some(8.0));
assert_eq!(back.follow_up[0].tool, "WebFetch");
assert_eq!(back.follow_up[0].projection.as_deref(), Some("url"));
assert_eq!(back.follow_up[0].projection_arg.as_deref(), Some("url"));
assert_eq!(back.fail_fast_after_n, Some(2));
assert_eq!(back.side_effect_class, SideEffectClass::ReadOnly);
assert_eq!(back.rate_limit_host.as_deref(), Some("example.com"));
assert!(back.is_speculatable());
}
#[test]
fn default_side_effect_class_is_indeterminate_and_blocks_speculation() {
let m = ToolValueModel::default();
assert_eq!(m.side_effect_class, SideEffectClass::Indeterminate);
assert!(
!m.is_speculatable(),
"Indeterminate must never be speculated"
);
}
#[test]
fn pure_and_read_only_are_speculatable() {
let pure = ToolValueModel {
side_effect_class: SideEffectClass::Pure,
..Default::default()
};
let ro = ToolValueModel {
side_effect_class: SideEffectClass::ReadOnly,
..Default::default()
};
assert!(pure.is_speculatable());
assert!(ro.is_speculatable());
}
#[test]
fn mutating_classes_block_speculation() {
for class in [
SideEffectClass::MutatesLocal,
SideEffectClass::MutatesExternal,
] {
let m = ToolValueModel {
side_effect_class: class,
..Default::default()
};
assert!(
!m.is_speculatable(),
"{class:?} must never be speculated — would duplicate writes"
);
}
}
#[test]
fn speculate_override_wins_over_side_effect_class() {
let pure_but_disabled = ToolValueModel {
side_effect_class: SideEffectClass::Pure,
speculate: Some(false),
..Default::default()
};
assert!(!pure_but_disabled.is_speculatable());
let forced_on = ToolValueModel {
side_effect_class: SideEffectClass::Indeterminate,
speculate: Some(true),
..Default::default()
};
assert!(forced_on.is_speculatable());
}
#[test]
fn side_effect_class_serialises_snake_case() {
for (class, expected) in [
(SideEffectClass::Pure, "pure"),
(SideEffectClass::ReadOnly, "read_only"),
(SideEffectClass::MutatesLocal, "mutates_local"),
(SideEffectClass::MutatesExternal, "mutates_external"),
] {
let m = ToolValueModel {
side_effect_class: class,
..Default::default()
};
let s = toml::to_string_pretty(&m).unwrap();
assert!(
s.contains(&format!("side_effect_class = \"{expected}\"")),
"expected `{expected}`, got: {s}"
);
let back: ToolValueModel = toml::from_str(&s).unwrap();
assert_eq!(back.side_effect_class, class);
}
}
#[test]
fn default_indeterminate_skipped_on_serialise() {
let m = ToolValueModel::default();
let s = toml::to_string_pretty(&m).unwrap();
assert!(
!s.contains("side_effect_class"),
"Indeterminate is the default and must be skip_serializing_if'd, got: {s}"
);
assert!(!s.contains("rate_limit_host"));
assert!(!s.contains("speculate"));
}
#[test]
fn followup_link_projection_arg_round_trips() {
let l = FollowUpLink {
tool: "Read".into(),
probability: 0.8,
projection: Some("path".into()),
projection_arg: Some("file_path".into()),
};
let s = toml::to_string_pretty(&l).unwrap();
let back: FollowUpLink = toml::from_str(&s).unwrap();
assert_eq!(back.projection_arg.as_deref(), Some("file_path"));
}
#[test]
fn empty_optional_fields_are_skipped_on_serialise() {
let m = ToolValueModel::default();
let s = toml::to_string_pretty(&m).unwrap();
assert!(!s.contains("field_groups"));
assert!(!s.contains("follow_up"));
assert!(!s.contains("invalidates"));
assert!(!s.contains("fail_fast_after_n"));
assert!(!s.contains("max_kb"));
}
#[test]
fn value_class_serialises_snake_case() {
let m = ToolValueModel {
value_class: ValueClass::AuditOnly,
..Default::default()
};
let s = toml::to_string_pretty(&m).unwrap();
assert!(s.contains("audit_only"), "expected snake_case, got: {s}");
}
#[test]
fn field_group_default_estimated_value_is_half() {
let g = FieldGroup::default();
assert!((g.estimated_value - 0.5).abs() < 1e-6);
assert!(g.default_include);
}
#[test]
fn followup_link_round_trips_without_projection() {
let l = FollowUpLink {
tool: "Bash".into(),
probability: 0.8,
..FollowUpLink::default()
};
let s = toml::to_string_pretty(&l).unwrap();
assert!(
!s.contains("projection"),
"None should be skipped, got: {s}"
);
let back: FollowUpLink = toml::from_str(&s).unwrap();
assert_eq!(back.tool, "Bash");
assert!((back.probability - 0.8).abs() < 1e-6);
assert!(back.projection.is_none());
assert!(back.projection_arg.is_none());
}
}