use serde_json::Value;
use zeph_llm::provider::ToolUseRequest;
#[derive(Debug, Clone)]
pub(super) struct Tier {
pub indices: Vec<usize>,
}
pub(super) struct ToolCallDag {
deps: Vec<Vec<usize>>,
string_values: Vec<Vec<String>>,
len: usize,
}
impl ToolCallDag {
pub fn build(tool_calls: &[ToolUseRequest]) -> Self {
let len = tool_calls.len();
let mut deps: Vec<Vec<usize>> = vec![Vec::new(); len];
let string_values: Vec<Vec<String>> = tool_calls
.iter()
.map(|tc| extract_string_values(&tc.input))
.collect();
for (i, values) in string_values.iter().enumerate() {
for (j, tc_j) in tool_calls.iter().enumerate() {
if i == j {
continue;
}
if values.iter().any(|v| v == &tc_j.id) {
deps[i].push(j);
}
}
}
Self {
deps,
string_values,
len,
}
}
#[must_use]
pub(super) fn string_values_for(&self, idx: usize) -> &[String] {
&self.string_values[idx]
}
#[must_use]
pub(super) fn is_trivial(&self) -> bool {
self.deps.iter().all(Vec::is_empty)
}
pub fn tiers(&self) -> Vec<Tier> {
if self.len == 0 {
return Vec::new();
}
let mut in_degree: Vec<usize> = vec![0; self.len];
let mut rev: Vec<Vec<usize>> = vec![Vec::new(); self.len];
for (i, node_deps) in self.deps.iter().enumerate() {
for &j in node_deps {
in_degree[i] += 1;
rev[j].push(i);
}
}
let mut queue: Vec<usize> = (0..self.len).filter(|&i| in_degree[i] == 0).collect();
let mut tiers: Vec<Tier> = Vec::new();
let mut processed = 0_usize;
while !queue.is_empty() {
let current = std::mem::take(&mut queue);
processed += current.len();
let mut next: Vec<usize> = Vec::new();
for &node in ¤t {
for &dep_of_node in &rev[node] {
in_degree[dep_of_node] -= 1;
if in_degree[dep_of_node] == 0 {
next.push(dep_of_node);
}
}
}
tiers.push(Tier { indices: current });
queue = next;
}
if processed < self.len {
tracing::warn!(
total = self.len,
processed,
"tool_call_dag: cycle detected in tool_use_id references — \
falling back to fully sequential execution of all tool calls"
);
return vec![Tier {
indices: (0..self.len).collect(),
}];
}
tiers
}
}
pub(super) fn extract_string_values(value: &Value) -> Vec<String> {
let mut out = Vec::new();
collect_strings(value, &mut out);
out
}
fn collect_strings(value: &Value, out: &mut Vec<String>) {
match value {
Value::String(s) => out.push(s.clone()),
Value::Array(arr) => {
for item in arr {
collect_strings(item, out);
}
}
Value::Object(map) => {
for v in map.values() {
collect_strings(v, out);
}
}
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_call(id: &str, input: Value) -> ToolUseRequest {
ToolUseRequest {
id: id.to_owned(),
name: "test_tool".to_owned().into(),
input,
}
}
fn tier_indices(tiers: &[Tier]) -> Vec<Vec<usize>> {
tiers.iter().map(|t| t.indices.clone()).collect()
}
#[test]
fn dag_empty_is_trivial() {
let dag = ToolCallDag::build(&[]);
assert!(dag.is_trivial());
assert!(dag.tiers().is_empty());
}
#[test]
fn dag_single_call_is_trivial() {
let calls = [make_call("id0", json!({"x": "hello"}))];
let dag = ToolCallDag::build(&calls);
assert!(dag.is_trivial());
let tiers = dag.tiers();
assert_eq!(tier_indices(&tiers), vec![vec![0usize]]);
}
#[test]
fn dag_no_dependencies_is_trivial() {
let calls = [
make_call("toolu_aaa", json!({"file": "foo.txt"})),
make_call("toolu_bbb", json!({"cmd": "ls"})),
make_call("toolu_ccc", json!({"url": "https://example.com"})),
];
let dag = ToolCallDag::build(&calls);
assert!(dag.is_trivial());
let tiers = dag.tiers();
assert_eq!(tiers.len(), 1);
let mut got = tiers[0].indices.clone();
got.sort_unstable();
assert_eq!(got, vec![0usize, 1, 2]);
}
#[test]
fn dag_single_dependency_two_tiers() {
let calls = [
make_call("id_a", json!({"x": 1})),
make_call("id_b", json!({"prerequisite": "id_a"})), ];
let dag = ToolCallDag::build(&calls);
assert!(!dag.is_trivial());
let tiers = dag.tiers();
assert_eq!(tiers.len(), 2);
assert_eq!(tiers[0].indices, vec![0usize]); assert_eq!(tiers[1].indices, vec![1usize]); }
#[test]
fn dag_linear_chain() {
let calls = [
make_call("id_a", json!({})),
make_call("id_b", json!({"dep": "id_a"})),
make_call("id_c", json!({"dep": "id_b"})),
];
let dag = ToolCallDag::build(&calls);
let tiers = dag.tiers();
assert_eq!(tiers.len(), 3);
assert_eq!(tiers[0].indices, vec![0usize]);
assert_eq!(tiers[1].indices, vec![1usize]);
assert_eq!(tiers[2].indices, vec![2usize]);
}
#[test]
fn dag_diamond_dependency() {
let calls = [
make_call("id_a", json!({})),
make_call("id_b", json!({"source": "id_a"})),
make_call("id_c", json!({"source": "id_a"})),
make_call("id_d", json!({"left": "id_b", "right": "id_c"})),
];
let dag = ToolCallDag::build(&calls);
let tiers = dag.tiers();
assert_eq!(tiers.len(), 3);
assert_eq!(tiers[0].indices, vec![0usize]); let mut tier_bc = tiers[1].indices.clone();
tier_bc.sort_unstable();
assert_eq!(tier_bc, vec![1usize, 2]); assert_eq!(tiers[2].indices, vec![3usize]); }
#[test]
fn dag_cycle_falls_back_to_sequential() {
let calls = [
make_call("id_a", json!({"dep": "id_b"})),
make_call("id_b", json!({"dep": "id_a"})),
];
let dag = ToolCallDag::build(&calls);
let tiers = dag.tiers();
assert_eq!(tiers.len(), 1);
assert_eq!(tiers[0].indices, vec![0usize, 1]);
}
#[test]
fn dag_partial_cycle_with_independent() {
let calls = [
make_call("id_a", json!({"dep": "id_b"})),
make_call("id_b", json!({"dep": "id_a"})),
make_call("id_c", json!({"x": "hello"})),
];
let dag = ToolCallDag::build(&calls);
let tiers = dag.tiers();
assert_eq!(tiers.len(), 1);
let mut got = tiers[0].indices.clone();
got.sort_unstable();
assert_eq!(got, vec![0usize, 1, 2]);
}
#[test]
fn dag_short_id_no_false_positive() {
let calls = [
make_call("1", json!({"cmd": "ls"})),
make_call("2", json!({"retries": "1", "count": "10"})),
];
let dag = ToolCallDag::build(&calls);
assert!(
!dag.is_trivial(),
"exact match on id='1' with value '1' creates a dependency (known limitation)"
);
let calls2 = [
make_call("1", json!({"cmd": "ls"})),
make_call("2", json!({"count": "10", "flag": "true"})),
];
let dag2 = ToolCallDag::build(&calls2);
assert!(dag2.is_trivial(), "id='1' must not match substring in '10'");
let calls3 = [
make_call("call_abc", json!({"cmd": "ls"})),
make_call("call_xyz", json!({"path": "call_abcdef/file.txt"})),
];
let dag3 = ToolCallDag::build(&calls3);
assert!(dag3.is_trivial(), "id must not match as substring of path");
}
#[test]
fn dag_nested_json_strings_detected() {
let calls = [
make_call("id_src", json!({})),
make_call("id_dst", json!({"nested": {"deep": ["id_src"]}})),
];
let dag = ToolCallDag::build(&calls);
assert!(!dag.is_trivial());
let tiers = dag.tiers();
assert_eq!(tiers.len(), 2);
assert_eq!(tiers[0].indices, vec![0usize]);
assert_eq!(tiers[1].indices, vec![1usize]);
}
#[test]
fn dag_self_reference_ignored() {
let calls = [make_call("id_self", json!({"ref": "id_self"}))];
let dag = ToolCallDag::build(&calls);
assert!(dag.is_trivial());
}
}