god-graph 0.6.0-alpha

A graph-based LLM white-box optimization toolbox: topology validation, Lie group orthogonalization, tensor ring compression
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
//! 拓扑排序算法插件实现
//!
//! 适用于有向无环图 (DAG),返回节点的线性排序

use crate::node::NodeIndex;
use crate::plugins::algorithm::{
    AlgorithmData, AlgorithmResult, GraphAlgorithm, PluginContext, PluginInfo,
};
use crate::vgi::{Capability, GraphType, VgiResult, VirtualGraph};
use std::any::Any;
use std::collections::VecDeque;

/// 拓扑排序结果
#[derive(Debug, Clone)]
pub enum TopologicalSortResult {
    /// 排序成功,返回节点顺序
    Sorted(Vec<usize>),
    /// 图中存在环,无法进行拓扑排序
    HasCycle,
}

/// 拓扑排序算法插件
///
/// 适用于有向无环图 (DAG),返回节点的线性排序。
/// 如果图中存在环,则返回 HasCycle 错误。
///
/// # 示例
///
/// ```
/// use god_graph::plugins::algorithms::topological_sort::TopologicalSortPlugin;
/// use god_graph::graph::Graph;
///
/// let plugin = TopologicalSortPlugin::new();
///
/// let mut graph = Graph::<i32, f64>::directed();
/// graph.add_node(1).unwrap();
/// graph.add_node(2).unwrap();
/// graph.add_edge(0, 1, 1.0).unwrap();
///
/// let result = plugin.sort(&graph).unwrap();
/// ```
pub struct TopologicalSortPlugin;

impl TopologicalSortPlugin {
    /// 创建新的拓扑排序算法插件实例
    ///
    /// # 示例
    ///
    /// ```
    /// use god_graph::plugins::algorithms::topological_sort::TopologicalSortPlugin;
    ///
    /// let plugin = TopologicalSortPlugin::new();
    /// ```
    pub fn new() -> Self {
        Self
    }

    /// 拓扑排序核心实现(Kahn 算法)
    ///
    /// 使用入度 BFS 方法进行拓扑排序
    pub fn sort<G>(&self, graph: &G) -> VgiResult<TopologicalSortResult>
    where
        G: VirtualGraph + ?Sized,
    {
        let n = graph.node_count();
        if n == 0 {
            return Ok(TopologicalSortResult::Sorted(Vec::new()));
        }

        // 收集所有节点索引
        let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();

        // 构建节点 ID 到位置的映射:node_id_to_pos[id] = position in node_indices
        // 使用 usize::MAX 表示无效条目
        let mut node_id_to_pos: Vec<usize> = vec![usize::MAX; n];
        for (pos, idx) in node_indices.iter().enumerate() {
            node_id_to_pos[idx.index()] = pos;
        }

        // 使用 Vec 代替 HashMap 计算入度,O(1) 访问
        let mut in_degree: Vec<usize> = vec![0; n];

        // 遍历所有边,计算入度
        for node_ref in graph.nodes() {
            let from_idx = node_ref.index();
            let from_node_idx = NodeIndex::new_public(from_idx.index());

            for neighbor_idx in graph.neighbors(from_node_idx) {
                let to_idx = neighbor_idx.index();
                let pos = node_id_to_pos[to_idx];
                if pos != usize::MAX {
                    in_degree[pos] += 1;
                }
            }
        }

        // 初始化队列:将所有入度为 0 的节点加入队列
        let mut queue: VecDeque<usize> = VecDeque::new();
        for (pos, &deg) in in_degree.iter().enumerate() {
            if deg == 0 {
                queue.push_back(pos);
            }
        }

        // BFS 拓扑排序
        let mut result: Vec<usize> = Vec::with_capacity(n);

        while let Some(pos) = queue.pop_front() {
            let node_id = node_indices[pos].index();
            result.push(node_id);

            // 遍历该节点的所有邻居
            let node_idx = NodeIndex::new_public(node_id);
            for neighbor_idx in graph.neighbors(node_idx) {
                let neighbor_id = neighbor_idx.index();
                let neighbor_pos = node_id_to_pos[neighbor_id];
                if neighbor_pos != usize::MAX {
                    in_degree[neighbor_pos] -= 1;

                    // 如果邻居入度变为 0,加入队列
                    if in_degree[neighbor_pos] == 0 {
                        queue.push_back(neighbor_pos);
                    }
                }
            }
        }

        // 如果结果中的节点数小于总节点数,说明图中有环
        if result.len() < n {
            Ok(TopologicalSortResult::HasCycle)
        } else {
            Ok(TopologicalSortResult::Sorted(result))
        }
    }

    /// 使用 DFS 方法进行拓扑排序
    pub fn sort_dfs<G>(&self, graph: &G) -> VgiResult<TopologicalSortResult>
    where
        G: VirtualGraph + ?Sized,
    {
        let n = graph.node_count();
        if n == 0 {
            return Ok(TopologicalSortResult::Sorted(Vec::new()));
        }

        // 收集所有节点索引
        let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();

        // 构建节点 ID 到位置的映射
        let mut node_id_to_pos: Vec<usize> = vec![usize::MAX; n];
        for (pos, idx) in node_indices.iter().enumerate() {
            node_id_to_pos[idx.index()] = pos;
        }

        // 状态:0=未访问,1=访问中,2=已访问(使用 Vec 代替 HashMap)
        let mut state: Vec<u8> = vec![0; n];

        let mut result: Vec<usize> = Vec::with_capacity(n);

        // DFS 访问
        for pos in 0..n {
            if state[pos] == 0 && !self.dfs_visit(graph, pos, &mut state, &mut result, &node_id_to_pos)? {
                return Ok(TopologicalSortResult::HasCycle);
            }
        }

        result.reverse(); // DFS 后序遍历的逆序即为拓扑排序
        Ok(TopologicalSortResult::Sorted(result))
    }

    /// DFS 辅助函数
    fn dfs_visit<G>(
        &self,
        graph: &G,
        pos: usize,
        state: &mut Vec<u8>,
        result: &mut Vec<usize>,
        node_id_to_pos: &[usize],
    ) -> VgiResult<bool>
    where
        G: VirtualGraph + ?Sized,
    {
        state[pos] = 1; // 标记为访问中

        let node_id = node_id_to_pos[pos];
        let node_idx = NodeIndex::new_public(node_id);
        for neighbor_idx in graph.neighbors(node_idx) {
            let neighbor_id = neighbor_idx.index();
            let neighbor_pos = node_id_to_pos[neighbor_id];

            if neighbor_pos != usize::MAX {
                let neighbor_state = state[neighbor_pos];

                if neighbor_state == 1 {
                    // 发现后向边,存在环
                    return Ok(false);
                } else if neighbor_state == 0
                    && !self.dfs_visit(graph, neighbor_pos, state, result, node_id_to_pos)?
                {
                    return Ok(false);
                }
            }
        }

        state[pos] = 2; // 标记为已访问
        let node_id = node_id_to_pos[pos];
        result.push(node_id);
        Ok(true)
    }

    /// 检查图是否为 DAG(有向无环图)
    pub fn is_dag<G>(&self, graph: &G) -> VgiResult<bool>
    where
        G: VirtualGraph + ?Sized,
    {
        match self.sort(graph)? {
            TopologicalSortResult::Sorted(_) => Ok(true),
            TopologicalSortResult::HasCycle => Ok(false),
        }
    }
}

impl Default for TopologicalSortPlugin {
    fn default() -> Self {
        Self::new()
    }
}

impl GraphAlgorithm for TopologicalSortPlugin {
    fn info(&self) -> PluginInfo {
        PluginInfo::new(
            "topological-sort",
            "1.0.0",
            "有向无环图 (DAG) 的拓扑排序算法",
        )
        .with_author("God-Graph Team")
        .with_required_capabilities(&[Capability::IncrementalUpdate])
        .with_supported_graph_types(&[GraphType::Directed])
        .with_tags(&["topological-sort", "dag", "ordering", "scheduling"])
    }

    fn execute<G>(&self, ctx: &mut PluginContext<G>) -> VgiResult<AlgorithmResult>
    where
        G: VirtualGraph + ?Sized,
    {
        let use_dfs = ctx.get_config_or("use_dfs", "false") == "true";

        ctx.report_progress(0.1);

        let result = if use_dfs {
            self.sort_dfs(ctx.graph)?
        } else {
            self.sort(ctx.graph)?
        };

        ctx.report_progress(0.8);

        let algorithm_result = match result {
            TopologicalSortResult::Sorted(order) => {
                AlgorithmResult::new("topological_sort", AlgorithmData::NodeList(order.clone()))
                    .with_metadata("is_dag", "true")
                    .with_metadata("node_count", order.len().to_string())
            }
            TopologicalSortResult::HasCycle => AlgorithmResult::new(
                "topological_sort_error",
                AlgorithmData::String(
                    "Graph contains a cycle, cannot perform topological sort".to_string(),
                ),
            )
            .with_metadata("is_dag", "false")
            .with_metadata("error", "has_cycle"),
        }
        .with_metadata("algorithm", "topological-sort")
        .with_metadata("method", if use_dfs { "dfs" } else { "kahn" });

        ctx.report_progress(1.0);
        Ok(algorithm_result)
    }

    fn as_any(&self) -> &dyn Any {
        self
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::graph::Graph;
    use crate::graph::traits::GraphOps;

    fn create_dag() -> Graph<String, ()> {
        let mut graph = Graph::<String, ()>::directed();

        // 创建 DAG: 任务依赖关系
        // A -> B -> D
        // A -> C -> D
        let a = graph.add_node("A".to_string()).unwrap();
        let b = graph.add_node("B".to_string()).unwrap();
        let c = graph.add_node("C".to_string()).unwrap();
        let d = graph.add_node("D".to_string()).unwrap();

        graph.add_edge(a, b, ()).unwrap();
        graph.add_edge(a, c, ()).unwrap();
        graph.add_edge(b, d, ()).unwrap();
        graph.add_edge(c, d, ()).unwrap();

        graph
    }

    #[test]
    fn test_topological_sort_dag() {
        let graph = create_dag();
        let plugin = TopologicalSortPlugin::new();

        let result = plugin.sort(&graph).unwrap();

        match result {
            TopologicalSortResult::Sorted(order) => {
                assert_eq!(order.len(), 4);
                // A 应该在 B 和 C 之前
                let a_pos = order.iter().position(|&x| x == 0).unwrap();
                let b_pos = order.iter().position(|&x| x == 1).unwrap();
                let c_pos = order.iter().position(|&x| x == 2).unwrap();
                let d_pos = order.iter().position(|&x| x == 3).unwrap();

                assert!(a_pos < b_pos);
                assert!(a_pos < c_pos);
                assert!(b_pos < d_pos);
                assert!(c_pos < d_pos);
            }
            TopologicalSortResult::HasCycle => panic!("DAG should not have cycle"),
        }
    }

    #[test]
    fn test_topological_sort_dfs() {
        let graph = create_dag();
        let plugin = TopologicalSortPlugin::new();

        let result = plugin.sort_dfs(&graph).unwrap();

        match result {
            TopologicalSortResult::Sorted(order) => {
                assert_eq!(order.len(), 4);
            }
            TopologicalSortResult::HasCycle => panic!("DAG should not have cycle"),
        }
    }

    #[test]
    fn test_topological_sort_with_cycle() {
        let mut graph = Graph::<String, ()>::directed();

        // 创建有环图:A -> B -> C -> A
        let a = graph.add_node("A".to_string()).unwrap();
        let b = graph.add_node("B".to_string()).unwrap();
        let c = graph.add_node("C".to_string()).unwrap();

        graph.add_edge(a, b, ()).unwrap();
        graph.add_edge(b, c, ()).unwrap();
        graph.add_edge(c, a, ()).unwrap();

        let plugin = TopologicalSortPlugin::new();
        let result = plugin.sort(&graph).unwrap();

        match result {
            TopologicalSortResult::HasCycle => {} // 预期结果
            TopologicalSortResult::Sorted(_) => panic!("Graph with cycle should not be sortable"),
        }
    }

    #[test]
    fn test_topological_sort_empty_graph() {
        let graph = Graph::<String, ()>::directed();
        let plugin = TopologicalSortPlugin::new();

        let result = plugin.sort(&graph).unwrap();

        match result {
            TopologicalSortResult::Sorted(order) => {
                assert!(order.is_empty());
            }
            TopologicalSortResult::HasCycle => panic!("Empty graph should not have cycle"),
        }
    }

    #[test]
    fn test_topological_sort_is_dag() {
        let dag = create_dag();
        let plugin = TopologicalSortPlugin::new();

        assert!(plugin.is_dag(&dag).unwrap());

        let mut cyclic = Graph::<String, ()>::directed();
        let a = cyclic.add_node("A".to_string()).unwrap();
        let b = cyclic.add_node("B".to_string()).unwrap();
        cyclic.add_edge(a, b, ()).unwrap();
        cyclic.add_edge(b, a, ()).unwrap();

        assert!(!plugin.is_dag(&cyclic).unwrap());
    }

    #[test]
    fn test_topological_sort_plugin_info() {
        let plugin = TopologicalSortPlugin::new();
        let info = plugin.info();

        assert_eq!(info.name, "topological-sort");
        assert_eq!(info.version, "1.0.0");
        assert!(info.tags.contains(&"topological-sort".to_string()));
    }
}