1use std::cell::RefCell;
13use std::hash::Hash;
14use std::sync::{Arc, OnceLock};
15
16use arc_swap::ArcSwap;
17use parking_lot::RwLock;
18use rustc_hash::{FxBuildHasher, FxHashMap};
19
20use crate::cch::contract::Topology;
21use crate::cch::customize::{PartialUpdateContext, Shortcuts, WeightMap, Weights};
22use crate::cch::query::Query;
23
24pub mod contract;
25pub mod customize;
26pub mod order;
27pub mod path;
28pub mod query;
29
30pub trait ProfileId: Eq + Hash + Copy + Send + Sync + 'static {}
35impl<T: Eq + Hash + Copy + Send + Sync + 'static> ProfileId for T {}
36
37thread_local! {
38 static QUERY_BUFFER: RefCell<Query> = RefCell::new(Query::new(0));
39}
40
41#[derive(Debug, Clone)]
43pub struct PathResult {
44 pub distance: f32,
46 pub path: Vec<u32>,
48}
49
50pub trait CchGraph {
55 fn num_nodes(&self) -> usize;
57 fn num_edges(&self) -> usize;
59
60 fn first_out(&self) -> &[u32];
63 fn head(&self) -> &[u32];
65
66 fn x(&self, node_id: u32) -> f32;
68 fn y(&self, node_id: u32) -> f32;
70
71 #[inline(always)]
73 fn edge_indices(&self, u: usize) -> std::ops::Range<usize> {
74 let start = self.first_out()[u] as usize;
75 let end = self.first_out()[u + 1] as usize;
76 start..end
77 }
78
79 #[inline(always)]
81 fn neighbors(&self, u: u32) -> impl Iterator<Item = u32> + '_ {
82 let range = self.edge_indices(u as usize);
83 self.head()[range].iter().cloned()
84 }
85}
86
87pub struct EngineData<P: ProfileId> {
91 pub cch: Arc<Cch>,
93 pub profiles: FxHashMap<P, Arc<RwLock<ProfileData>>>,
95 pub update_ctx: OnceLock<Arc<PartialUpdateContext>>,
97}
98
99#[derive(Clone)]
105pub struct CchEngine<P: ProfileId> {
106 pub data: Arc<ArcSwap<EngineData<P>>>,
108}
109
110impl<P: ProfileId> CchEngine<P> {
111 pub fn new<'a, G: CchGraph + Sync>(graph: &G, profile_weights: impl IntoIterator<Item = (&'a P, &'a Vec<f32>)>) -> Self {
116 let cch = Arc::new(Cch::new(graph));
117 let mut profiles = FxHashMap::default();
118 for (pid, weights) in profile_weights {
119 let (w, s) = cch.customize(&cch.weight_map, &cch.scheduler, weights);
120 profiles.insert(
121 *pid,
122 Arc::new(RwLock::new(ProfileData {
123 input_weights: weights.clone(),
124 weights: w,
125 shortcuts: s,
126 })),
127 );
128 }
129 Self {
130 data: Arc::new(ArcSwap::from_pointee(EngineData {
131 cch,
132 profiles,
133 update_ctx: OnceLock::new(),
134 })),
135 }
136 }
137
138 pub fn update_weights(&self, profile_id: P, new_weights: &[f32]) {
143 let current = self.data.load();
144 let Some(profile) = current.profiles.get(&profile_id) else {
145 return;
146 };
147 let mut profile = profile.write();
148 current.cch.recustomize_profile(&mut profile, new_weights);
149 }
150
151 pub fn update_weights_partial(&self, profile_id: P, updates: &[(usize, f32)]) {
156 if updates.is_empty() {
157 return;
158 }
159
160 let current = self.data.load();
161 let Some(profile) = current.profiles.get(&profile_id) else {
162 return;
163 };
164 let update_ctx = current.update_ctx.get_or_init(|| Arc::new(current.cch.build_partial_update_context())).clone();
165
166 let mut profile = profile.write();
167 let ProfileData { input_weights, weights, shortcuts } = &mut *profile;
168 current.cch.customize_partial(¤t.cch.weight_map, update_ctx.as_ref(), input_weights, weights, shortcuts, updates);
169 }
170
171 pub fn update_topology<G: CchGraph + Sync>(&self, graph: &G, modified_old_nodes: &[u32], new_nodes: &[u32], profile_weights: &FxHashMap<P, Vec<f32>>) {
177 let current = self.data.load();
178 let mut new_cch_inner = (*current.cch).clone();
179
180 let start_rank = new_cch_inner.update_order(graph, modified_old_nodes, new_nodes);
181 let new_topology = new_cch_inner.recontract(graph, start_rank, ¤t.cch.topology);
182
183 new_cch_inner.weight_map = WeightMap::build(graph, &new_cch_inner.ranks, &new_topology);
184 new_cch_inner.scheduler = new_topology.build_scheduler();
185 new_cch_inner.topology = new_topology;
186
187 let cch = Arc::new(new_cch_inner);
188 let mut profiles = FxHashMap::with_capacity_and_hasher(profile_weights.len(), FxBuildHasher);
189
190 for (pid, weights) in profile_weights {
191 let (w, s) = cch.customize(&cch.weight_map, &cch.scheduler, weights);
192 profiles.insert(
193 *pid,
194 Arc::new(RwLock::new(ProfileData {
195 input_weights: weights.clone(),
196 weights: w,
197 shortcuts: s,
198 })),
199 );
200 }
201
202 let update_ctx = OnceLock::new();
203 if current.update_ctx.get().is_some() {
204 let _ = update_ctx.set(Arc::new(cch.build_partial_update_context()));
205 }
206
207 self.data.store(Arc::new(EngineData { cch, profiles, update_ctx }));
208 }
209
210 pub fn rebuild_topology<G: CchGraph + Sync>(&self, graph: &G, profile_weights: &FxHashMap<P, Vec<f32>>) {
215 let current = self.data.load();
216 let new_engine = Self::new(graph, profile_weights);
217 let next = new_engine.data.load_full();
218 if current.update_ctx.get().is_some() {
219 let _ = next.update_ctx.set(Arc::new(next.cch.build_partial_update_context()));
220 }
221 self.data.store(next);
222 }
223
224 pub fn query_path(&self, profile: P, from: u32, to: u32) -> Option<PathResult> {
231 let current = self.data.load();
232 let profile_data = current.profiles.get(&profile)?;
233 QUERY_BUFFER.with(|q| {
234 let mut server = q.borrow_mut();
235 let num_nodes = current.cch.ranks.len();
236 if server.fw.len() < num_nodes {
237 *server = Query::new(num_nodes);
238 }
239 let profile_data = profile_data.read();
240 let result = current.cch.query(&mut server, &profile_data.weights, from, to)?;
241 let path = current.cch.query_path(&server, &profile_data.shortcuts, &result);
242 Some(PathResult { distance: result.distance, path })
243 })
244 }
245}
246
247#[derive(Clone)]
249pub struct ProfileData {
250 pub input_weights: Vec<f32>,
252 pub weights: Weights,
254 pub shortcuts: Shortcuts,
256}
257
258#[derive(Clone)]
263pub struct Cch {
264 pub topology: Topology,
266 pub weight_map: WeightMap,
268 pub scheduler: Vec<Vec<u32>>,
270
271 pub ranks: Vec<u32>,
273 pub order: Vec<u32>,
275}
276
277impl Cch {
278 pub fn new<G: CchGraph + Sync>(graph: &G) -> Self {
283 let ranks = Self::get_metis_order(graph);
284
285 let mut order = vec![0; graph.num_nodes()];
286 for (node_id, &rank) in ranks.iter().enumerate() {
287 order[rank as usize] = node_id as u32;
288 }
289
290 let topology = Self::contract(&ranks, graph);
291 let weight_map = WeightMap::build(graph, &ranks, &topology);
292 let scheduler = topology.build_scheduler();
293
294 Self {
295 topology,
296 weight_map,
297 scheduler,
298 ranks,
299 order,
300 }
301 }
302
303 pub fn build_profile(&self, original_weights: &[f32]) -> ProfileData {
308 let (weights, shortcuts) = self.customize(&self.weight_map, &self.scheduler, original_weights);
309 ProfileData {
310 input_weights: original_weights.to_vec(),
311 weights,
312 shortcuts,
313 }
314 }
315
316 pub fn recustomize_profile(&self, profile: &mut ProfileData, new_weights: &[f32]) {
318 profile.input_weights.clear();
319 profile.input_weights.extend_from_slice(new_weights);
320 let (weights, shortcuts) = self.customize(&self.weight_map, &self.scheduler, &profile.input_weights);
321 profile.weights = weights;
322 profile.shortcuts = shortcuts;
323 }
324
325 pub fn customize_profile_partial(&self, profile: &mut ProfileData, updates: &[(usize, f32)]) {
327 let update_ctx = self.build_partial_update_context();
328 self.customize_profile_partial_with_context(profile, &update_ctx, updates);
329 }
330
331 pub fn customize_profile_partial_with_context(&self, profile: &mut ProfileData, update_ctx: &PartialUpdateContext, updates: &[(usize, f32)]) {
333 self.customize_partial(&self.weight_map, update_ctx, &mut profile.input_weights, &mut profile.weights, &mut profile.shortcuts, updates);
334 }
335
336 pub fn update_order<G: CchGraph + Sync>(&mut self, graph: &G, modified_old_nodes: &[u32], new_nodes: &[u32]) -> u32 {
341 let old_num = self.ranks.len();
342 let new_num = graph.num_nodes();
343
344 self.ranks.resize(new_num, 0);
345 self.order.resize(new_num, 0);
346
347 for (i, &u) in new_nodes.iter().enumerate() {
348 let rank = (old_num + i) as u32;
349 self.ranks[u as usize] = rank;
350 self.order[rank as usize] = u;
351 }
352
353 let mut min_dirty_rank = old_num as u32;
354 for &u in modified_old_nodes {
355 let r = self.ranks[u as usize];
356 if r < min_dirty_rank {
357 min_dirty_rank = r;
358 }
359 }
360
361 min_dirty_rank
362 }
363
364 pub fn get_order(&self) -> &[u32] { &self.order }
366
367 pub fn get_ranks(&self) -> &[u32] { &self.ranks }
369}