use serde::{Deserialize, Serialize};
use crate::traits::ExtractionSource;
pub const BUDGET_GAUGE_NAME: &str = "mnem_inference_budget_effective_ms";
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct InferenceBudget {
pub extract_latency_budget_ms: u32,
pub max_inference_ms_per_commit: u32,
pub max_phrases_embedded: u32,
pub max_types: u32,
pub author_rate_limit_per_commit: u32,
}
impl InferenceBudget {
pub const MAX_INFERENCE_MS_PER_COMMIT: u32 = 250;
pub const FALLBACK_EXTRACT_LATENCY_MS: u32 = 500;
#[must_use]
pub const fn conservative() -> Self {
Self {
extract_latency_budget_ms: Self::MAX_INFERENCE_MS_PER_COMMIT,
max_inference_ms_per_commit: Self::MAX_INFERENCE_MS_PER_COMMIT,
max_phrases_embedded: 10_000,
max_types: 8,
author_rate_limit_per_commit: 200,
}
}
#[must_use]
pub const fn effective_ms(&self) -> u32 {
let extract = self.extract_latency_budget_ms;
let hard = self.max_inference_ms_per_commit;
if extract < hard { extract } else { hard }
}
#[must_use]
pub fn effective_ms_gauge(&self) -> (&'static str, f64) {
(BUDGET_GAUGE_NAME, f64::from(self.effective_ms()))
}
pub fn validate(&self) -> Result<(), &'static str> {
if self.max_inference_ms_per_commit == 0 {
return Err("max_inference_ms_per_commit must be > 0");
}
if self.extract_latency_budget_ms == 0 {
return Err("extract_latency_budget_ms must be > 0");
}
if self.max_phrases_embedded == 0 {
return Err("max_phrases_embedded must be > 0");
}
Ok(())
}
}
impl Default for InferenceBudget {
fn default() -> Self {
Self::conservative()
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum InferenceMethod {
PatternEmbedding,
CooccurrencePmi,
Custom(String),
}
impl InferenceMethod {
#[must_use]
pub fn provenance_label(&self) -> String {
match self {
Self::PatternEmbedding => "inferred:pattern_embedding".to_string(),
Self::CooccurrencePmi => "inferred:cooccurrence_pmi".to_string(),
Self::Custom(s) => format!("inferred:{s}"),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TypedRelation {
pub src: String,
pub dst: String,
pub predicate: String,
pub confidence: f32,
pub source: ExtractionSource,
pub source_label: String,
}
impl TypedRelation {
#[must_use]
pub fn new(
src: impl Into<String>,
dst: impl Into<String>,
predicate: impl Into<String>,
confidence: f32,
method: &InferenceMethod,
) -> Self {
Self {
src: src.into(),
dst: dst.into(),
predicate: predicate.into(),
confidence,
source: ExtractionSource::Statistical,
source_label: method.provenance_label(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn conservative_budget_passes_validation() {
let b = InferenceBudget::conservative();
assert!(b.validate().is_ok());
}
#[test]
fn effective_ms_is_minimum_of_extract_and_hard_wall() {
let mut b = InferenceBudget::conservative();
b.extract_latency_budget_ms = 100;
assert_eq!(b.effective_ms(), 100);
b.extract_latency_budget_ms = 1_000;
assert_eq!(
b.effective_ms(),
InferenceBudget::MAX_INFERENCE_MS_PER_COMMIT
);
}
#[test]
fn hard_wall_matches_spec_pinned_value() {
assert_eq!(InferenceBudget::MAX_INFERENCE_MS_PER_COMMIT, 250);
}
#[test]
fn gauge_emits_stable_name_and_effective_value() {
let b = InferenceBudget::conservative();
let (name, val) = b.effective_ms_gauge();
assert_eq!(name, "mnem_inference_budget_effective_ms");
assert!((val - f64::from(b.effective_ms())).abs() < f64::EPSILON);
}
#[test]
fn validate_rejects_zero_caps() {
let mut b = InferenceBudget::conservative();
b.max_inference_ms_per_commit = 0;
assert!(b.validate().is_err());
let mut b = InferenceBudget::conservative();
b.extract_latency_budget_ms = 0;
assert!(b.validate().is_err());
let mut b = InferenceBudget::conservative();
b.max_phrases_embedded = 0;
assert!(b.validate().is_err());
}
#[test]
fn inference_method_renders_provenance_label() {
assert_eq!(
InferenceMethod::PatternEmbedding.provenance_label(),
"inferred:pattern_embedding",
);
assert_eq!(
InferenceMethod::CooccurrencePmi.provenance_label(),
"inferred:cooccurrence_pmi",
);
assert_eq!(
InferenceMethod::Custom("my_method".into()).provenance_label(),
"inferred:my_method",
);
}
#[test]
fn typed_relation_auto_tags_provenance_label() {
let r = TypedRelation::new(
"alice",
"bob",
"knows",
0.9,
&InferenceMethod::PatternEmbedding,
);
assert_eq!(r.source_label, "inferred:pattern_embedding");
assert_eq!(r.source, ExtractionSource::Statistical);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn budget_respected(
extract_ms in 1u32..10_000,
hard_ms in 1u32..10_000,
max_phrases in 1u32..100_000,
max_types in 1u32..64,
author_cap in 1u32..10_000,
) {
let b = InferenceBudget {
extract_latency_budget_ms: extract_ms,
max_inference_ms_per_commit: hard_ms,
max_phrases_embedded: max_phrases,
max_types,
author_rate_limit_per_commit: author_cap,
};
prop_assert!(b.validate().is_ok());
let eff = b.effective_ms();
prop_assert!(eff <= extract_ms);
prop_assert!(eff <= hard_ms);
prop_assert!(eff == extract_ms.min(hard_ms));
let (_, val) = b.effective_ms_gauge();
prop_assert!((val - f64::from(eff)).abs() < f64::EPSILON);
}
#[test]
fn conservative_default_matches_hard_wall(_n in 0u32..8) {
let b = InferenceBudget::conservative();
prop_assert_eq!(
b.effective_ms(),
InferenceBudget::MAX_INFERENCE_MS_PER_COMMIT,
);
}
}
}