1use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13
14use sqry_core::graph::unified::concurrent::GraphSnapshot;
15use sqry_core::graph::unified::edge::kind::EdgeKind;
16use sqry_core::graph::unified::node::id::NodeId;
17
18use crate::QueryDb;
19use crate::dependency::record_file_dep;
20use crate::query::DerivedQuery;
21
22pub type SccKey = EdgeKind;
29
30pub type SccValue = std::sync::Arc<CachedSccData>;
33
34#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40pub struct CachedSccData {
41 pub node_to_component: HashMap<NodeId, u32>,
43 pub components: Vec<Vec<NodeId>>,
45 pub edge_kind: EdgeKind,
47}
48
49impl CachedSccData {
50 #[must_use]
52 pub fn component_of(&self, node: NodeId) -> Option<u32> {
53 self.node_to_component.get(&node).copied()
54 }
55
56 #[must_use]
58 pub fn is_in_cycle(&self, node: NodeId) -> bool {
59 self.component_of(node)
60 .map(|idx| {
61 self.components
62 .get(idx as usize)
63 .is_some_and(|c| c.len() > 1)
64 })
65 .unwrap_or(false)
66 }
67
68 #[must_use]
70 pub fn component_count(&self) -> usize {
71 self.components.len()
72 }
73}
74
75pub struct SccQuery;
81
82impl DerivedQuery for SccQuery {
83 type Key = EdgeKind;
84 type Value = Arc<CachedSccData>;
85 const QUERY_TYPE_ID: u32 = crate::queries::type_ids::SCC;
86 const TRACKS_EDGE_REVISION: bool = true;
87
88 fn execute(key: &EdgeKind, _db: &QueryDb, snapshot: &GraphSnapshot) -> Arc<CachedSccData> {
89 for (fid, _seg) in snapshot.file_segments().iter() {
91 record_file_dep(fid);
92 }
93
94 let mut index_counter = 0u32;
97 let mut stack: Vec<NodeId> = Vec::new();
98 let mut on_stack: HashSet<NodeId> = HashSet::new();
99 let mut indices: HashMap<NodeId, u32> = HashMap::new();
100 let mut lowlinks: HashMap<NodeId, u32> = HashMap::new();
101 let mut components: Vec<Vec<NodeId>> = Vec::new();
102
103 let all_nodes: Vec<NodeId> = snapshot
111 .nodes()
112 .iter()
113 .filter(|(_nid, entry)| !entry.is_unified_loser())
114 .map(|(nid, _)| nid)
115 .collect();
116
117 for &start in &all_nodes {
119 if indices.contains_key(&start) {
120 continue;
121 }
122
123 let mut work: Vec<(NodeId, usize)> = vec![(start, 0)];
125 indices.insert(start, index_counter);
126 lowlinks.insert(start, index_counter);
127 index_counter += 1;
128 stack.push(start);
129 on_stack.insert(start);
130
131 while let Some((node, pos)) = work.last_mut() {
132 let neighbors: Vec<NodeId> = snapshot
133 .edges()
134 .edges_from(*node)
135 .iter()
136 .filter(|e| std::mem::discriminant(&e.kind) == std::mem::discriminant(key))
137 .map(|e| e.target)
138 .collect();
139
140 if *pos < neighbors.len() {
141 let neighbor = neighbors[*pos];
142 *pos += 1;
143
144 if let std::collections::hash_map::Entry::Vacant(e) = indices.entry(neighbor) {
145 e.insert(index_counter);
146 lowlinks.insert(neighbor, index_counter);
147 index_counter += 1;
148 stack.push(neighbor);
149 on_stack.insert(neighbor);
150 work.push((neighbor, 0));
151 } else if on_stack.contains(&neighbor) {
152 let node_copy = *node;
153 let neighbor_idx = indices[&neighbor];
154 let current_low = lowlinks[&node_copy];
155 if neighbor_idx < current_low {
156 lowlinks.insert(node_copy, neighbor_idx);
157 }
158 }
159 } else {
160 let node_copy = *node;
162 let node_idx = indices[&node_copy];
163 let node_low = lowlinks[&node_copy];
164
165 if node_low == node_idx {
166 let mut component = Vec::new();
168 while let Some(w) = stack.pop() {
169 on_stack.remove(&w);
170 component.push(w);
171 if w == node_copy {
172 break;
173 }
174 }
175 components.push(component);
176 }
177
178 work.pop();
180 if let Some((parent, _)) = work.last() {
181 let parent_copy = *parent;
182 let parent_low = lowlinks[&parent_copy];
183 if node_low < parent_low {
184 lowlinks.insert(parent_copy, node_low);
185 }
186 }
187 }
188 }
189 }
190
191 let mut node_to_component = HashMap::with_capacity(all_nodes.len());
193 for (idx, component) in components.iter().enumerate() {
194 for &nid in component {
195 node_to_component.insert(nid, idx as u32);
196 }
197 }
198
199 Arc::new(CachedSccData {
200 node_to_component,
201 components,
202 edge_kind: key.clone(),
203 })
204 }
205}
206
207#[cfg(test)]
212mod serde_roundtrip {
213 use super::*;
214 use postcard::{from_bytes, to_allocvec};
215
216 #[test]
217 fn cached_scc_data_roundtrip() {
218 let mut node_to_component = HashMap::new();
219 node_to_component.insert(NodeId::new(1, 1), 0u32);
220 node_to_component.insert(NodeId::new(2, 1), 0u32);
221 node_to_component.insert(NodeId::new(3, 1), 1u32);
222 let original = CachedSccData {
223 node_to_component,
224 components: vec![
225 vec![NodeId::new(1, 1), NodeId::new(2, 1)],
226 vec![NodeId::new(3, 1)],
227 ],
228 edge_kind: EdgeKind::Calls {
229 argument_count: 0,
230 is_async: false,
231 },
232 };
233 let bytes = to_allocvec(&original).expect("serialize failed");
234 let decoded: CachedSccData = from_bytes(&bytes).expect("deserialize failed");
235 assert_eq!(decoded.components, original.components);
237 assert_eq!(decoded.edge_kind, original.edge_kind);
239 for (node, comp) in &original.node_to_component {
241 assert_eq!(decoded.node_to_component.get(node), Some(comp));
242 }
243 }
244
245 #[test]
246 fn scc_key_roundtrip() {
247 let original: SccKey = EdgeKind::Imports {
249 alias: None,
250 is_wildcard: false,
251 };
252 let bytes = to_allocvec(&original).expect("serialize failed");
253 let decoded: SccKey = from_bytes(&bytes).expect("deserialize failed");
254 assert_eq!(decoded, original);
255 }
256
257 #[test]
258 fn scc_value_roundtrip() {
259 let data = CachedSccData {
261 node_to_component: HashMap::new(),
262 components: vec![],
263 edge_kind: EdgeKind::References,
264 };
265 let original: SccValue = Arc::new(data);
266 let bytes = to_allocvec(&original).expect("serialize failed");
267 let decoded: SccValue = from_bytes(&bytes).expect("deserialize failed");
268 assert_eq!(decoded.components, original.components);
269 assert_eq!(decoded.edge_kind, original.edge_kind);
270 }
271}