use anyhow::{Result, anyhow};
use crate::{
dsl::{StateRow, Value},
graph::{Graph, NodeId, OwnedGraphId},
traversal::{EdgeCtx, Kernel},
};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BudgetState {
pub spent: u64,
}
#[derive(Clone, Debug)]
pub struct WeightedBudget {
pub weight_col: String,
pub budget: u64,
pub target: OwnedGraphId,
}
impl Kernel for WeightedBudget {
type State = BudgetState;
fn initial_state(&self, _graph: &Graph, _start: NodeId) -> Self::State {
BudgetState { spent: 0 }
}
fn visit(&self, cx: &EdgeCtx<'_, Self::State>) -> Result<bool> {
let Some(weight) = cx.edge_u64(&self.weight_col)? else {
return Ok(false);
};
Ok(cx.state().spent.saturating_add(weight) <= self.budget)
}
fn next_state(&self, cx: &EdgeCtx<'_, Self::State>) -> Result<Self::State> {
let weight = cx.edge_u64(&self.weight_col)?.unwrap_or(0);
Ok(BudgetState {
spent: cx.state().spent.saturating_add(weight),
})
}
fn stop(&self, cx: &EdgeCtx<'_, Self::State>) -> Result<bool> {
Ok(cx.dest_id() == Some(self.target.as_ref()))
}
fn state_row(&self, state: &Self::State) -> StateRow {
vec![("spent".to_string(), Value::U64(state.spent))]
}
}
impl WeightedBudget {
fn from_params(params: &serde_json::Value) -> Result<Self> {
let weight_col = params
.get("weight_col")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow!("weighted_budget: missing string param 'weight_col'"))?
.to_string();
let budget = params
.get("budget")
.and_then(serde_json::Value::as_u64)
.ok_or_else(|| anyhow!("weighted_budget: missing u64 param 'budget'"))?;
let target = match params.get("target") {
Some(serde_json::Value::Number(n)) => n
.as_u64()
.map(OwnedGraphId::U64)
.ok_or_else(|| anyhow!("weighted_budget: 'target' number must be a u64"))?,
Some(serde_json::Value::String(s)) => OwnedGraphId::Str(s.clone()),
_ => {
return Err(anyhow!(
"weighted_budget: 'target' must be a u64 or string node id"
));
}
};
Ok(Self {
weight_col,
budget,
target,
})
}
}
crate::inventory::submit! {
crate::KernelEntry {
name: "weighted_budget",
make: |params| Ok(crate::boxed_run(WeightedBudget::from_params(params)?)),
}
}