use tokitai_operator::ir::SemanticGraph;
#[cfg(feature = "graph")]
use god_graph::Graph;
#[cfg(feature = "graph")]
use god_graph::graph::GraphOps;
#[cfg(feature = "graph")]
use crate::BridgeError;
#[cfg(feature = "graph")]
pub fn semantic_to_god_graph(sg: &SemanticGraph) -> Graph<String, ()> {
match convert_semantic_to_god_graph(sg, ConversionMode::BestEffort) {
Ok(g) => g,
Err(err) => {
#[cfg(feature = "tracing")]
tracing::debug!("best-effort SemanticGraph conversion failed unexpectedly: {err}");
#[cfg(all(feature = "observability", not(feature = "tracing")))]
log::debug!("best-effort SemanticGraph conversion failed unexpectedly: {err}");
let _ = err;
Graph::directed()
}
}
}
#[cfg(feature = "graph")]
pub fn try_semantic_to_god_graph(sg: &SemanticGraph) -> Result<Graph<String, ()>, BridgeError> {
convert_semantic_to_god_graph(sg, ConversionMode::Strict)
}
#[cfg(feature = "graph")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ConversionMode {
BestEffort,
Strict,
}
#[cfg(feature = "graph")]
impl ConversionMode {
fn is_strict(self) -> bool {
matches!(self, Self::Strict)
}
}
#[cfg(feature = "graph")]
fn convert_semantic_to_god_graph(
sg: &SemanticGraph,
mode: ConversionMode,
) -> Result<Graph<String, ()>, BridgeError> {
#[cfg(feature = "tracing")]
tracing::trace!(
"op_dag_graph::semantic_to_god_graph: entry ({} SemanticNodes)",
sg.nodes().len()
);
#[cfg(all(feature = "observability", not(feature = "tracing")))]
log::trace!(
"op_dag_graph::semantic_to_god_graph: entry ({} SemanticNodes)",
sg.nodes().len()
);
let mut g: Graph<String, ()> = Graph::directed();
let mut node_indices: Vec<Option<god_graph::NodeIndex>> = vec![None; sg.nodes().len()];
for (i, node) in sg.nodes().iter().enumerate() {
let payload = node.op_name.clone();
match g.add_node(payload) {
Ok(ni) => node_indices[i] = Some(ni),
Err(err) if mode.is_strict() => {
return Err(BridgeError::Upstream(format!(
"failed to add god-graph node for SemanticNode {i} (op_name={:?}): {err}",
node.op_name
)));
}
Err(_err) => {}
}
}
let mut value_producer: std::collections::HashMap<usize, usize> =
std::collections::HashMap::new();
for (i, node) in sg.nodes().iter().enumerate() {
for &out_id in &node.output_ids {
if let Some(previous_idx) = value_producer.insert(out_id, i) {
if mode.is_strict() {
return Err(BridgeError::Unsupported(format!(
"value id {out_id} is produced by multiple SemanticNodes: {previous_idx} and {i}"
)));
}
}
}
}
for (i, node) in sg.nodes().iter().enumerate() {
let this_ni = match node_indices[i] {
Some(ni) => ni,
None if mode.is_strict() => {
return Err(BridgeError::Upstream(format!(
"missing god-graph node index for SemanticNode {i} (op_name={:?})",
node.op_name
)));
}
None => continue,
};
for &input_value_id in &node.inputs {
let Some(&producer_idx) = value_producer.get(&input_value_id) else {
if mode.is_strict() {
return Err(BridgeError::Unsupported(missing_producer_message(
sg,
i,
&node.op_name,
input_value_id,
)));
}
continue;
};
if producer_idx == i {
if mode.is_strict() {
return Err(BridgeError::Unsupported(format!(
"SemanticNode {i} (op_name={:?}) input value id {input_value_id} is produced by the same node; strict conversion rejects self-edges",
node.op_name
)));
}
continue;
}
let producer_ni = match node_indices[producer_idx] {
Some(ni) => ni,
None if mode.is_strict() => {
return Err(BridgeError::Upstream(format!(
"missing god-graph node index for producer SemanticNode {producer_idx} while adding edge to SemanticNode {i}"
)));
}
None => continue,
};
match g.add_edge(producer_ni, this_ni, ()) {
Ok(_edge) => {}
Err(err) if mode.is_strict() => {
return Err(BridgeError::Upstream(format!(
"failed to add god-graph edge for input value id {input_value_id} from SemanticNode {producer_idx} to SemanticNode {i}: {err}"
)));
}
Err(_err) => {}
}
}
}
#[cfg(feature = "tracing")]
{
use god_graph::graph::traits::GraphBase;
tracing::debug!(
"op_dag_graph::semantic_to_god_graph: exit ({} nodes, {} edges)",
g.node_count(),
g.edge_count()
);
}
#[cfg(all(feature = "observability", not(feature = "tracing")))]
{
use god_graph::graph::traits::GraphBase;
log::debug!(
"op_dag_graph::semantic_to_god_graph: exit ({} nodes, {} edges)",
g.node_count(),
g.edge_count()
);
}
Ok(g)
}
#[cfg(feature = "graph")]
fn missing_producer_message(
sg: &SemanticGraph,
node_idx: usize,
op_name: &str,
input_value_id: usize,
) -> String {
let value_kind = if sg.value(input_value_id).is_some() {
"external graph input"
} else {
"unknown value"
};
format!(
"SemanticNode {node_idx} (op_name={op_name:?}) input value id {input_value_id} has no SemanticNode producer ({value_kind}); strict conversion cannot represent this dependency in Graph<String, ()>"
)
}
pub fn semantic_node_count(sg: &SemanticGraph) -> usize {
sg.nodes().len()
}
pub fn semantic_input_count(sg: &SemanticGraph) -> usize {
sg.nodes().iter().filter(|n| n.op_name == "input").count()
}
#[cfg(all(feature = "graph", feature = "operator"))]
pub fn op_dag_neighbors(sg: &SemanticGraph, target_op_name: &str) -> Vec<(usize, Vec<String>)> {
let mut value_producer: std::collections::HashMap<usize, usize> =
std::collections::HashMap::new();
for (i, node) in sg.nodes().iter().enumerate() {
for &out_id in &node.output_ids {
value_producer.insert(out_id, i);
}
}
let mut out: Vec<(usize, Vec<String>)> = Vec::new();
for (i, node) in sg.nodes().iter().enumerate() {
if node.op_name != target_op_name {
continue;
}
let mut consumers: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
for &out_id in &node.output_ids {
for (j, other) in sg.nodes().iter().enumerate() {
if j == i {
continue;
}
if other.inputs.contains(&out_id) {
consumers.insert(other.op_name.clone());
}
}
}
out.push((i, consumers.into_iter().collect()));
}
out
}
#[cfg(all(feature = "graph", feature = "operator"))]
pub fn op_dag_pagerank(sg: &SemanticGraph) -> Vec<(String, f64)> {
use god_graph::algorithms::centrality::pagerank;
use god_graph::graph::traits::GraphQuery;
#[cfg(feature = "tracing")]
tracing::trace!(
"op_dag_graph::op_dag_pagerank: entry ({} SemanticNodes)",
sg.nodes().len()
);
#[cfg(all(feature = "observability", not(feature = "tracing")))]
log::trace!(
"op_dag_graph::op_dag_pagerank: entry ({} SemanticNodes)",
sg.nodes().len()
);
let g: Graph<String, ()> = semantic_to_god_graph(sg);
let scores = pagerank(&g, 0.85, 20);
let mut aggregated: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
for node in g.nodes() {
let payload = node.data().clone();
let score = scores.get(&node.index()).copied().unwrap_or(0.0);
*aggregated.entry(payload).or_insert(0.0) += score;
}
let mut out: Vec<(String, f64)> = aggregated.into_iter().collect();
out.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
out
}
#[cfg(all(feature = "graph", feature = "operator"))]
pub fn op_dag_shortest_path(
sg: &SemanticGraph,
from_op_name: &str,
to_op_name: &str,
) -> Option<Vec<String>> {
use god_graph::graph::traits::{GraphBase, GraphQuery};
use god_graph::node::NodeIndex;
use std::collections::VecDeque;
#[cfg(feature = "tracing")]
tracing::trace!(
"op_dag_graph::op_dag_shortest_path: entry (from={from_op_name:?}, to={to_op_name:?})"
);
#[cfg(all(feature = "observability", not(feature = "tracing")))]
log::trace!(
"op_dag_graph::op_dag_shortest_path: entry (from={from_op_name:?}, to={to_op_name:?})"
);
let g: Graph<String, ()> = semantic_to_god_graph(sg);
if g.node_count() == 0 {
return None;
}
let sources: Vec<NodeIndex> = g
.nodes()
.filter(|n| n.data() == from_op_name)
.map(|n| n.index())
.collect();
let goals: Vec<NodeIndex> = g
.nodes()
.filter(|n| n.data() == to_op_name)
.map(|n| n.index())
.collect();
if sources.is_empty() || goals.is_empty() {
return None;
}
let n = g.node_count();
let mut visited = vec![None::<NodeIndex>; n];
let mut queue: VecDeque<(NodeIndex, NodeIndex)> = VecDeque::new();
for &s in &sources {
visited[s.index()] = Some(s);
queue.push_back((s, s));
}
let mut found: Option<(NodeIndex, NodeIndex)> = None;
while let Some((node, root)) = queue.pop_front() {
if goals.contains(&node) {
found = Some((node, root));
break;
}
for nb in g.neighbors(node) {
if visited[nb.index()].is_none() {
visited[nb.index()] = Some(node);
queue.push_back((nb, root));
}
}
}
let (goal, _root) = found?;
let mut path: Vec<NodeIndex> = Vec::new();
let mut cur = goal;
loop {
path.push(cur);
match visited[cur.index()] {
Some(prev) if prev != cur => cur = prev,
_ => break, }
}
path.reverse();
Some(
path.iter()
.map(|ni| g.get_node(*ni).cloned().unwrap_or_default())
.collect(),
)
}