Skip to main content

wesichain_graph/
program.rs

1use std::collections::HashMap;
2
3use petgraph::graph::{Graph, NodeIndex};
4use petgraph::visit::EdgeRef;
5
6use crate::graph::GraphNode;
7use crate::StateSchema;
8
9use std::sync::Arc;
10
11pub struct NodeData<S: StateSchema> {
12    pub name: String,
13    pub runnable: Arc<dyn GraphNode<S>>,
14}
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum EdgeKind {
18    Default,
19}
20
21pub struct GraphProgram<S: StateSchema> {
22    graph: Graph<NodeData<S>, EdgeKind>,
23    name_to_index: HashMap<String, NodeIndex>,
24}
25
26impl<S: StateSchema> GraphProgram<S> {
27    pub(crate) fn new(
28        graph: Graph<NodeData<S>, EdgeKind>,
29        name_to_index: HashMap<String, NodeIndex>,
30    ) -> Self {
31        Self {
32            graph,
33            name_to_index,
34        }
35    }
36
37    pub fn node_names(&self) -> Vec<String> {
38        self.name_to_index.keys().cloned().collect()
39    }
40
41    pub fn edge_names(&self) -> Vec<(String, String)> {
42        self.graph
43            .edge_references()
44            .filter_map(|edge| {
45                let from = self.graph.node_weight(edge.source())?;
46                let to = self.graph.node_weight(edge.target())?;
47                Some((from.name.clone(), to.name.clone()))
48            })
49            .collect()
50    }
51}