use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(transparent)]
pub struct BudgetNodeId(pub String);
impl BudgetNodeId {
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<&str> for BudgetNodeId {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for BudgetNodeId {
fn from(s: String) -> Self {
Self(s)
}
}
impl std::fmt::Display for BudgetNodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum BudgetWindow {
Daily,
Monthly,
Rolling {
seconds: u64,
},
}
impl BudgetWindow {
#[must_use]
pub fn duration_seconds(&self) -> u64 {
match self {
Self::Daily => 86_400,
Self::Monthly => 86_400 * 30,
Self::Rolling { seconds } => *seconds,
}
}
#[must_use]
pub fn bucket_start(&self, ts: u64) -> u64 {
let d = self.duration_seconds();
if d == 0 {
return 0;
}
ts - (ts % d)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct BudgetLimits {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_spend_units: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub currency: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_requests: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_warehouse_bytes: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BudgetNode {
pub id: BudgetNodeId,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parent: Option<BudgetNodeId>,
#[serde(default)]
pub limits: BudgetLimits,
pub window: BudgetWindow,
#[serde(default = "default_true")]
pub enabled: bool,
}
fn default_true() -> bool {
true
}
impl BudgetNode {
#[must_use]
pub fn new(id: impl Into<BudgetNodeId>, window: BudgetWindow) -> Self {
Self {
id: id.into(),
parent: None,
limits: BudgetLimits::default(),
window,
enabled: true,
}
}
#[must_use]
pub fn with_parent(mut self, parent: impl Into<BudgetNodeId>) -> Self {
self.parent = Some(parent.into());
self
}
#[must_use]
pub fn with_limits(mut self, limits: BudgetLimits) -> Self {
self.limits = limits;
self
}
#[must_use]
pub fn disabled(mut self) -> Self {
self.enabled = false;
self
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct AggregateSpend {
#[serde(default)]
pub spend_units: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub currency: Option<String>,
#[serde(default)]
pub tokens: u64,
#[serde(default)]
pub requests: u64,
#[serde(default)]
pub warehouse_bytes: u64,
}
impl AggregateSpend {
#[must_use]
pub fn with_spend(units: u64, currency: impl Into<String>) -> Self {
Self {
spend_units: units,
currency: Some(currency.into()),
..Self::default()
}
}
#[must_use]
pub fn with_tokens(tokens: u64) -> Self {
Self {
tokens,
..Self::default()
}
}
#[must_use]
pub fn with_requests(requests: u64) -> Self {
Self {
requests,
..Self::default()
}
}
#[must_use]
pub fn with_warehouse_bytes(bytes: u64) -> Self {
Self {
warehouse_bytes: bytes,
..Self::default()
}
}
fn saturating_add(&self, other: &Self) -> Self {
Self {
spend_units: self.spend_units.saturating_add(other.spend_units),
currency: self.currency.clone().or_else(|| other.currency.clone()),
tokens: self.tokens.saturating_add(other.tokens),
requests: self.requests.saturating_add(other.requests),
warehouse_bytes: self.warehouse_bytes.saturating_add(other.warehouse_bytes),
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct PerWindowSpend {
#[serde(default)]
pub window_start: u64,
#[serde(default)]
pub current: AggregateSpend,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpendSnapshot {
#[serde(default)]
pub per_node: HashMap<BudgetNodeId, PerWindowSpend>,
}
impl SpendSnapshot {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set(&mut self, id: BudgetNodeId, spend: PerWindowSpend) {
self.per_node.insert(id, spend);
}
#[must_use]
pub fn get(&self, id: &BudgetNodeId) -> Option<&PerWindowSpend> {
self.per_node.get(id)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "reason", rename_all = "snake_case")]
pub enum BudgetDenyReason {
NodeDisabled {
node: BudgetNodeId,
},
DimensionExceeded {
node: BudgetNodeId,
dimension: String,
cap: String,
would_reach: String,
},
WindowExpired {
node: BudgetNodeId,
},
UnknownNode {
node: BudgetNodeId,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "decision", rename_all = "snake_case")]
pub enum BudgetDecision {
Allow,
Deny {
reason: BudgetDenyReason,
},
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum BudgetError {
#[error("cycle detected while inserting node `{node}` (conflicts with ancestor)")]
Cycle {
node: BudgetNodeId,
},
#[error("parent `{parent}` of node `{node}` is not present in the tree")]
MissingParent {
node: BudgetNodeId,
parent: BudgetNodeId,
},
#[error("duplicate node id `{node}`")]
Duplicate {
node: BudgetNodeId,
},
#[error("invalid serialized tree: {0}")]
InvalidSerialization(String),
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct BudgetTree {
nodes: HashMap<BudgetNodeId, BudgetNode>,
children: HashMap<BudgetNodeId, Vec<BudgetNodeId>>,
}
impl BudgetTree {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn len(&self) -> usize {
self.nodes.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
#[must_use]
pub fn get(&self, id: &BudgetNodeId) -> Option<&BudgetNode> {
self.nodes.get(id)
}
pub fn insert(&mut self, node: BudgetNode) -> Result<(), BudgetError> {
if self.nodes.contains_key(&node.id) {
return Err(BudgetError::Duplicate {
node: node.id.clone(),
});
}
if let Some(parent) = &node.parent {
if !self.nodes.contains_key(parent) {
return Err(BudgetError::MissingParent {
node: node.id.clone(),
parent: parent.clone(),
});
}
let mut cursor: Option<BudgetNodeId> = Some(parent.clone());
let mut visited: HashSet<BudgetNodeId> = HashSet::new();
while let Some(current) = cursor {
if current == node.id {
return Err(BudgetError::Cycle {
node: node.id.clone(),
});
}
if !visited.insert(current.clone()) {
return Err(BudgetError::Cycle {
node: node.id.clone(),
});
}
cursor = self.nodes.get(¤t).and_then(|n| n.parent.clone());
}
}
if let Some(parent) = &node.parent {
self.children
.entry(parent.clone())
.or_default()
.push(node.id.clone());
}
self.nodes.insert(node.id.clone(), node);
Ok(())
}
pub fn validate(&self) -> Result<(), BudgetError> {
for node in self.nodes.values() {
if let Some(parent) = &node.parent {
if !self.nodes.contains_key(parent) {
return Err(BudgetError::MissingParent {
node: node.id.clone(),
parent: parent.clone(),
});
}
}
}
for id in self.nodes.keys() {
let mut visited: HashSet<BudgetNodeId> = HashSet::new();
let mut cursor = Some(id.clone());
while let Some(current) = cursor {
if !visited.insert(current.clone()) {
return Err(BudgetError::Cycle { node: id.clone() });
}
cursor = self.nodes.get(¤t).and_then(|n| n.parent.clone());
}
}
Ok(())
}
#[must_use]
pub fn ancestors(&self, id: &BudgetNodeId) -> Vec<BudgetNodeId> {
let mut out = Vec::new();
let mut visited: HashSet<BudgetNodeId> = HashSet::new();
let mut cursor: Option<BudgetNodeId> = if self.nodes.contains_key(id) {
Some(id.clone())
} else {
None
};
while let Some(current) = cursor {
if !visited.insert(current.clone()) {
break;
}
let next = self.nodes.get(¤t).and_then(|n| n.parent.clone());
out.push(current);
cursor = next;
}
out
}
#[must_use]
pub fn descendants(&self, id: &BudgetNodeId) -> Vec<BudgetNodeId> {
let mut out = Vec::new();
if !self.nodes.contains_key(id) {
return out;
}
let mut queue: Vec<BudgetNodeId> = self.children.get(id).cloned().unwrap_or_default();
let mut visited: HashSet<BudgetNodeId> = HashSet::new();
while let Some(current) = queue.first().cloned() {
queue.remove(0);
if !visited.insert(current.clone()) {
continue;
}
if let Some(next) = self.children.get(¤t) {
for c in next {
queue.push(c.clone());
}
}
out.push(current);
}
out
}
#[must_use]
pub fn evaluate(
&self,
id: &BudgetNodeId,
draft: AggregateSpend,
current: &SpendSnapshot,
) -> BudgetDecision {
if !self.nodes.contains_key(id) {
return BudgetDecision::Deny {
reason: BudgetDenyReason::UnknownNode { node: id.clone() },
};
}
let ancestors = self.ancestors(id);
let mut offender: Option<(usize, BudgetDenyReason)> = None;
for (idx, node_id) in ancestors.iter().enumerate() {
let Some(node) = self.nodes.get(node_id) else {
continue;
};
if !node.enabled {
let candidate = BudgetDenyReason::NodeDisabled {
node: node_id.clone(),
};
offender = Some((idx, candidate));
continue;
}
let zero = PerWindowSpend::default();
let current_spend = current.per_node.get(node_id).unwrap_or(&zero);
let projected = current_spend.current.saturating_add(&draft);
let limits = &node.limits;
if let Some(cap) = limits.max_spend_units {
let currency_matches = match (&limits.currency, &draft.currency) {
(Some(a), Some(b)) => a == b,
_ => false,
};
if currency_matches && projected.spend_units > cap {
let cap_str =
format!("{} {}", cap, limits.currency.clone().unwrap_or_default());
let reach_str = format!(
"{} {}",
projected.spend_units,
projected.currency.clone().unwrap_or_default()
);
let candidate = BudgetDenyReason::DimensionExceeded {
node: node_id.clone(),
dimension: "spend".to_string(),
cap: cap_str,
would_reach: reach_str,
};
offender = Some((idx, candidate));
}
}
if let Some(cap) = limits.max_tokens {
if projected.tokens > cap {
let candidate = BudgetDenyReason::DimensionExceeded {
node: node_id.clone(),
dimension: "tokens".to_string(),
cap: cap.to_string(),
would_reach: projected.tokens.to_string(),
};
offender = Some((idx, candidate));
}
}
if let Some(cap) = limits.max_requests {
if projected.requests > cap {
let candidate = BudgetDenyReason::DimensionExceeded {
node: node_id.clone(),
dimension: "requests".to_string(),
cap: cap.to_string(),
would_reach: projected.requests.to_string(),
};
offender = Some((idx, candidate));
}
}
if let Some(cap) = limits.max_warehouse_bytes {
if projected.warehouse_bytes > cap {
let candidate = BudgetDenyReason::DimensionExceeded {
node: node_id.clone(),
dimension: "warehouse_bytes".to_string(),
cap: cap.to_string(),
would_reach: projected.warehouse_bytes.to_string(),
};
offender = Some((idx, candidate));
}
}
}
match offender {
None => BudgetDecision::Allow,
Some((_, reason)) => BudgetDecision::Deny { reason },
}
}
#[must_use]
pub fn serialize(&self) -> serde_json::Value {
let mut ids: Vec<&BudgetNodeId> = self.nodes.keys().collect();
ids.sort();
let nodes: Vec<&BudgetNode> = ids.iter().filter_map(|id| self.nodes.get(id)).collect();
serde_json::json!({
"version": 1,
"nodes": nodes,
})
}
pub fn deserialize(v: serde_json::Value) -> Result<Self, BudgetError> {
#[derive(Deserialize)]
struct Encoded {
#[serde(default)]
nodes: Vec<BudgetNode>,
}
let enc: Encoded = serde_json::from_value(v)
.map_err(|e| BudgetError::InvalidSerialization(format!("{e}")))?;
let mut tree = Self::new();
let mut remaining: Vec<BudgetNode> = enc.nodes;
loop {
let before = remaining.len();
let mut next: Vec<BudgetNode> = Vec::new();
for node in remaining {
let parent_ready = match &node.parent {
None => true,
Some(p) => tree.nodes.contains_key(p),
};
if parent_ready {
tree.insert(node)?;
} else {
next.push(node);
}
}
if next.is_empty() {
break;
}
if next.len() == before {
let first = next.into_iter().next();
if let Some(node) = first {
let parent = node.parent.clone().unwrap_or_else(|| BudgetNodeId::new(""));
return Err(BudgetError::MissingParent {
node: node.id,
parent,
});
}
break;
}
remaining = next;
}
tree.validate()?;
Ok(tree)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn leaf(
id: &str,
parent: Option<&str>,
limits: BudgetLimits,
window: BudgetWindow,
) -> BudgetNode {
let mut n = BudgetNode::new(id, window).with_limits(limits);
if let Some(p) = parent {
n = n.with_parent(p);
}
n
}
#[test]
fn duplicate_insert_is_rejected() {
let mut tree = BudgetTree::new();
tree.insert(leaf(
"org/acme",
None,
BudgetLimits::default(),
BudgetWindow::Daily,
))
.expect("insert");
let err = tree
.insert(leaf(
"org/acme",
None,
BudgetLimits::default(),
BudgetWindow::Daily,
))
.unwrap_err();
assert!(matches!(err, BudgetError::Duplicate { .. }));
}
#[test]
fn missing_parent_is_rejected() {
let mut tree = BudgetTree::new();
let err = tree
.insert(leaf(
"team/x",
Some("dept/missing"),
BudgetLimits::default(),
BudgetWindow::Daily,
))
.unwrap_err();
assert!(matches!(err, BudgetError::MissingParent { .. }));
}
#[test]
fn bucket_start_is_window_aligned() {
assert_eq!(BudgetWindow::Daily.bucket_start(0), 0);
assert_eq!(BudgetWindow::Daily.bucket_start(86_399), 0);
assert_eq!(BudgetWindow::Daily.bucket_start(86_400), 86_400);
assert_eq!(
BudgetWindow::Rolling { seconds: 3600 }.bucket_start(7_200 + 5),
7_200
);
}
}