1use std::collections::HashMap;
25
26use crate::graph::graph_classifier::GraphValue;
27use crate::pattern::Pattern;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum TraversalDirection {
38 Forward,
40 Backward,
42}
43
44#[cfg(not(feature = "thread-safe"))]
60pub type TraversalWeight<V> = std::rc::Rc<dyn Fn(&Pattern<V>, TraversalDirection) -> f64>;
61
62#[cfg(feature = "thread-safe")]
63pub type TraversalWeight<V> =
64 std::sync::Arc<dyn Fn(&Pattern<V>, TraversalDirection) -> f64 + Send + Sync>;
65
66#[cfg(not(feature = "thread-safe"))]
72pub fn undirected<V>() -> TraversalWeight<V> {
73 std::rc::Rc::new(|_rel: &Pattern<V>, _dir: TraversalDirection| 1.0)
74}
75
76#[cfg(feature = "thread-safe")]
77pub fn undirected<V: Send + Sync + 'static>() -> TraversalWeight<V> {
78 std::sync::Arc::new(|_rel: &Pattern<V>, _dir: TraversalDirection| 1.0)
79}
80
81#[cfg(not(feature = "thread-safe"))]
83pub fn directed<V>() -> TraversalWeight<V> {
84 std::rc::Rc::new(|_rel: &Pattern<V>, dir: TraversalDirection| match dir {
85 TraversalDirection::Forward => 1.0,
86 TraversalDirection::Backward => f64::INFINITY,
87 })
88}
89
90#[cfg(feature = "thread-safe")]
91pub fn directed<V: Send + Sync + 'static>() -> TraversalWeight<V> {
92 std::sync::Arc::new(|_rel: &Pattern<V>, dir: TraversalDirection| match dir {
93 TraversalDirection::Forward => 1.0,
94 TraversalDirection::Backward => f64::INFINITY,
95 })
96}
97
98#[cfg(not(feature = "thread-safe"))]
100pub fn directed_reverse<V>() -> TraversalWeight<V> {
101 std::rc::Rc::new(|_rel: &Pattern<V>, dir: TraversalDirection| match dir {
102 TraversalDirection::Forward => f64::INFINITY,
103 TraversalDirection::Backward => 1.0,
104 })
105}
106
107#[cfg(feature = "thread-safe")]
108pub fn directed_reverse<V: Send + Sync + 'static>() -> TraversalWeight<V> {
109 std::sync::Arc::new(|_rel: &Pattern<V>, dir: TraversalDirection| match dir {
110 TraversalDirection::Forward => f64::INFINITY,
111 TraversalDirection::Backward => 1.0,
112 })
113}
114
115#[cfg(not(feature = "thread-safe"))]
134#[allow(clippy::type_complexity)]
135pub struct GraphQuery<V: GraphValue> {
136 pub query_nodes: std::rc::Rc<dyn Fn() -> Vec<Pattern<V>>>,
138 pub query_relationships: std::rc::Rc<dyn Fn() -> Vec<Pattern<V>>>,
140 pub query_incident_rels: std::rc::Rc<dyn Fn(&Pattern<V>) -> Vec<Pattern<V>>>,
142 pub query_source: std::rc::Rc<dyn Fn(&Pattern<V>) -> Option<Pattern<V>>>,
144 pub query_target: std::rc::Rc<dyn Fn(&Pattern<V>) -> Option<Pattern<V>>>,
146 pub query_degree: std::rc::Rc<dyn Fn(&Pattern<V>) -> usize>,
148 pub query_node_by_id: std::rc::Rc<dyn Fn(&V::Id) -> Option<Pattern<V>>>,
150 pub query_relationship_by_id: std::rc::Rc<dyn Fn(&V::Id) -> Option<Pattern<V>>>,
152 pub query_containers: std::rc::Rc<dyn Fn(&Pattern<V>) -> Vec<Pattern<V>>>,
154}
155
156#[cfg(feature = "thread-safe")]
157#[allow(clippy::type_complexity)]
158pub struct GraphQuery<V: GraphValue> {
159 pub query_nodes: std::sync::Arc<dyn Fn() -> Vec<Pattern<V>> + Send + Sync>,
161 pub query_relationships: std::sync::Arc<dyn Fn() -> Vec<Pattern<V>> + Send + Sync>,
163 pub query_incident_rels: std::sync::Arc<dyn Fn(&Pattern<V>) -> Vec<Pattern<V>> + Send + Sync>,
165 pub query_source: std::sync::Arc<dyn Fn(&Pattern<V>) -> Option<Pattern<V>> + Send + Sync>,
167 pub query_target: std::sync::Arc<dyn Fn(&Pattern<V>) -> Option<Pattern<V>> + Send + Sync>,
169 pub query_degree: std::sync::Arc<dyn Fn(&Pattern<V>) -> usize + Send + Sync>,
171 pub query_node_by_id: std::sync::Arc<dyn Fn(&V::Id) -> Option<Pattern<V>> + Send + Sync>,
173 pub query_relationship_by_id:
175 std::sync::Arc<dyn Fn(&V::Id) -> Option<Pattern<V>> + Send + Sync>,
176 pub query_containers: std::sync::Arc<dyn Fn(&Pattern<V>) -> Vec<Pattern<V>> + Send + Sync>,
178}
179
180#[cfg(not(feature = "thread-safe"))]
185impl<V: GraphValue> Clone for GraphQuery<V> {
186 fn clone(&self) -> Self {
187 GraphQuery {
188 query_nodes: std::rc::Rc::clone(&self.query_nodes),
189 query_relationships: std::rc::Rc::clone(&self.query_relationships),
190 query_incident_rels: std::rc::Rc::clone(&self.query_incident_rels),
191 query_source: std::rc::Rc::clone(&self.query_source),
192 query_target: std::rc::Rc::clone(&self.query_target),
193 query_degree: std::rc::Rc::clone(&self.query_degree),
194 query_node_by_id: std::rc::Rc::clone(&self.query_node_by_id),
195 query_relationship_by_id: std::rc::Rc::clone(&self.query_relationship_by_id),
196 query_containers: std::rc::Rc::clone(&self.query_containers),
197 }
198 }
199}
200
201#[cfg(feature = "thread-safe")]
202impl<V: GraphValue> Clone for GraphQuery<V> {
203 fn clone(&self) -> Self {
204 GraphQuery {
205 query_nodes: std::sync::Arc::clone(&self.query_nodes),
206 query_relationships: std::sync::Arc::clone(&self.query_relationships),
207 query_incident_rels: std::sync::Arc::clone(&self.query_incident_rels),
208 query_source: std::sync::Arc::clone(&self.query_source),
209 query_target: std::sync::Arc::clone(&self.query_target),
210 query_degree: std::sync::Arc::clone(&self.query_degree),
211 query_node_by_id: std::sync::Arc::clone(&self.query_node_by_id),
212 query_relationship_by_id: std::sync::Arc::clone(&self.query_relationship_by_id),
213 query_containers: std::sync::Arc::clone(&self.query_containers),
214 }
215 }
216}
217
218#[cfg(not(feature = "thread-safe"))]
238#[allow(clippy::type_complexity)]
239pub fn frame_query<V>(
240 include: std::rc::Rc<dyn Fn(&Pattern<V>) -> bool>,
241 base: GraphQuery<V>,
242) -> GraphQuery<V>
243where
244 V: GraphValue + Clone + 'static,
245{
246 use std::rc::Rc;
247
248 let inc1 = Rc::clone(&include);
249 let query_nodes = Rc::new(move || {
250 (base.query_nodes)()
251 .into_iter()
252 .filter(|n| inc1(n))
253 .collect()
254 });
255
256 let inc2 = Rc::clone(&include);
257 let base_rels = Rc::clone(&base.query_relationships);
258 let query_relationships =
259 Rc::new(move || base_rels().into_iter().filter(|r| inc2(r)).collect());
260
261 let inc3 = Rc::clone(&include);
262 let base_inc = Rc::clone(&base.query_incident_rels);
263 let base_src = Rc::clone(&base.query_source);
264 let base_tgt = Rc::clone(&base.query_target);
265 let query_incident_rels = Rc::new(move |node: &Pattern<V>| {
266 base_inc(node)
267 .into_iter()
268 .filter(|rel| {
269 let src_ok = base_src(rel).as_ref().map(|s| inc3(s)).unwrap_or(false);
270 let tgt_ok = base_tgt(rel).as_ref().map(|t| inc3(t)).unwrap_or(false);
271 src_ok && tgt_ok
272 })
273 .collect()
274 });
275
276 let query_source = Rc::clone(&base.query_source);
277 let query_target = Rc::clone(&base.query_target);
278
279 let inc4 = Rc::clone(&include);
280 let base_inc2 = Rc::clone(&base.query_incident_rels);
281 let base_src2 = Rc::clone(&base.query_source);
282 let base_tgt2 = Rc::clone(&base.query_target);
283 let query_degree = Rc::new(move |node: &Pattern<V>| {
284 base_inc2(node)
285 .into_iter()
286 .filter(|rel| {
287 let src_ok = base_src2(rel).as_ref().map(|s| inc4(s)).unwrap_or(false);
288 let tgt_ok = base_tgt2(rel).as_ref().map(|t| inc4(t)).unwrap_or(false);
289 src_ok && tgt_ok
290 })
291 .count()
292 });
293
294 let inc5 = Rc::clone(&include);
295 let base_nbi = Rc::clone(&base.query_node_by_id);
296 let query_node_by_id = Rc::new(move |id: &V::Id| base_nbi(id).filter(|n| inc5(n)));
297
298 let inc6 = Rc::clone(&include);
299 let base_rbi = Rc::clone(&base.query_relationship_by_id);
300 let query_relationship_by_id = Rc::new(move |id: &V::Id| base_rbi(id).filter(|r| inc6(r)));
301
302 let inc7 = Rc::clone(&include);
303 let base_cont = Rc::clone(&base.query_containers);
304 let query_containers = Rc::new(move |element: &Pattern<V>| {
305 base_cont(element).into_iter().filter(|c| inc7(c)).collect()
306 });
307
308 GraphQuery {
309 query_nodes,
310 query_relationships,
311 query_incident_rels,
312 query_source,
313 query_target,
314 query_degree,
315 query_node_by_id,
316 query_relationship_by_id,
317 query_containers,
318 }
319}
320
321#[cfg(feature = "thread-safe")]
322#[allow(clippy::type_complexity)]
323pub fn frame_query<V>(
324 include: std::sync::Arc<dyn Fn(&Pattern<V>) -> bool + Send + Sync>,
325 base: GraphQuery<V>,
326) -> GraphQuery<V>
327where
328 V: GraphValue + Clone + Send + Sync + 'static,
329 V::Id: Clone + Send + Sync + 'static,
330{
331 use std::sync::Arc;
332
333 let inc1 = Arc::clone(&include);
334 let query_nodes = Arc::new(move || {
335 (base.query_nodes)()
336 .into_iter()
337 .filter(|n| inc1(n))
338 .collect()
339 });
340
341 let inc2 = Arc::clone(&include);
342 let base_rels = Arc::clone(&base.query_relationships);
343 let query_relationships =
344 Arc::new(move || base_rels().into_iter().filter(|r| inc2(r)).collect());
345
346 let inc3 = Arc::clone(&include);
347 let base_inc = Arc::clone(&base.query_incident_rels);
348 let base_src = Arc::clone(&base.query_source);
349 let base_tgt = Arc::clone(&base.query_target);
350 let query_incident_rels = Arc::new(move |node: &Pattern<V>| {
351 base_inc(node)
352 .into_iter()
353 .filter(|rel| {
354 let src_ok = base_src(rel).as_ref().map(|s| inc3(s)).unwrap_or(false);
355 let tgt_ok = base_tgt(rel).as_ref().map(|t| inc3(t)).unwrap_or(false);
356 src_ok && tgt_ok
357 })
358 .collect()
359 });
360
361 let query_source = Arc::clone(&base.query_source);
362 let query_target = Arc::clone(&base.query_target);
363
364 let inc4 = Arc::clone(&include);
365 let base_inc2 = Arc::clone(&base.query_incident_rels);
366 let base_src2 = Arc::clone(&base.query_source);
367 let base_tgt2 = Arc::clone(&base.query_target);
368 let query_degree = Arc::new(move |node: &Pattern<V>| {
369 base_inc2(node)
370 .into_iter()
371 .filter(|rel| {
372 let src_ok = base_src2(rel).as_ref().map(|s| inc4(s)).unwrap_or(false);
373 let tgt_ok = base_tgt2(rel).as_ref().map(|t| inc4(t)).unwrap_or(false);
374 src_ok && tgt_ok
375 })
376 .count()
377 });
378
379 let inc5 = Arc::clone(&include);
380 let base_nbi = Arc::clone(&base.query_node_by_id);
381 let query_node_by_id = Arc::new(move |id: &V::Id| base_nbi(id).filter(|n| inc5(n)));
382
383 let inc6 = Arc::clone(&include);
384 let base_rbi = Arc::clone(&base.query_relationship_by_id);
385 let query_relationship_by_id = Arc::new(move |id: &V::Id| base_rbi(id).filter(|r| inc6(r)));
386
387 let inc7 = Arc::clone(&include);
388 let base_cont = Arc::clone(&base.query_containers);
389 let query_containers = Arc::new(move |element: &Pattern<V>| {
390 base_cont(element).into_iter().filter(|c| inc7(c)).collect()
391 });
392
393 GraphQuery {
394 query_nodes,
395 query_relationships,
396 query_incident_rels,
397 query_source,
398 query_target,
399 query_degree,
400 query_node_by_id,
401 query_relationship_by_id,
402 query_containers,
403 }
404}
405
406#[cfg(not(feature = "thread-safe"))]
425pub fn memoize_incident_rels<V>(base: GraphQuery<V>) -> GraphQuery<V>
426where
427 V: GraphValue + Clone + 'static,
428 V::Id: Clone + Eq + std::hash::Hash + 'static,
429{
430 use std::rc::Rc;
431
432 let nodes = (base.query_nodes)();
434 let mut cache: HashMap<V::Id, Vec<Pattern<V>>> = HashMap::new();
435 for node in &nodes {
436 let id = node.value.identify().clone();
437 let rels = (base.query_incident_rels)(node);
438 cache.insert(id, rels);
439 }
440 let cache = Rc::new(cache);
441
442 let cache1 = Rc::clone(&cache);
443 let query_incident_rels = Rc::new(move |node: &Pattern<V>| {
444 cache1
445 .get(node.value.identify())
446 .cloned()
447 .unwrap_or_default()
448 });
449
450 let cache2 = Rc::clone(&cache);
451 let query_degree = Rc::new(move |node: &Pattern<V>| {
452 cache2
453 .get(node.value.identify())
454 .map(|v| v.len())
455 .unwrap_or(0)
456 });
457
458 GraphQuery {
459 query_nodes: base.query_nodes,
460 query_relationships: base.query_relationships,
461 query_incident_rels,
462 query_source: base.query_source,
463 query_target: base.query_target,
464 query_degree,
465 query_node_by_id: base.query_node_by_id,
466 query_relationship_by_id: base.query_relationship_by_id,
467 query_containers: base.query_containers,
468 }
469}
470
471#[cfg(feature = "thread-safe")]
472pub fn memoize_incident_rels<V>(base: GraphQuery<V>) -> GraphQuery<V>
473where
474 V: GraphValue + Clone + Send + Sync + 'static,
475 V::Id: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
476{
477 use std::sync::Arc;
478
479 let nodes = (base.query_nodes)();
480 let mut cache: HashMap<V::Id, Vec<Pattern<V>>> = HashMap::new();
481 for node in &nodes {
482 let id = node.value.identify().clone();
483 let rels = (base.query_incident_rels)(node);
484 cache.insert(id, rels);
485 }
486 let cache = Arc::new(cache);
487
488 let cache1 = Arc::clone(&cache);
489 let query_incident_rels = Arc::new(move |node: &Pattern<V>| {
490 cache1
491 .get(node.value.identify())
492 .cloned()
493 .unwrap_or_default()
494 });
495
496 let cache2 = Arc::clone(&cache);
497 let query_degree = Arc::new(move |node: &Pattern<V>| {
498 cache2
499 .get(node.value.identify())
500 .map(|v| v.len())
501 .unwrap_or(0)
502 });
503
504 GraphQuery {
505 query_nodes: base.query_nodes,
506 query_relationships: base.query_relationships,
507 query_incident_rels,
508 query_source: base.query_source,
509 query_target: base.query_target,
510 query_degree,
511 query_node_by_id: base.query_node_by_id,
512 query_relationship_by_id: base.query_relationship_by_id,
513 query_containers: base.query_containers,
514 }
515}