1#![allow(clippy::unwrap_used)]
2#![allow(dead_code)]
16pub mod match_path;
17pub mod solution;
18
19use std::{
20 collections::{BTreeMap, HashMap, HashSet},
21 sync::Arc,
22};
23
24use async_stream::stream;
25use drasi_query_ast::ast::{NodeMatch, RelationMatch};
26use futures::{stream::StreamExt, Stream};
27#[allow(unused_imports)]
28use tokio::{
29 sync::Semaphore,
30 task::{JoinError, JoinHandle},
31};
32
33use self::solution::MatchPathSolution;
34use crate::{
35 evaluation::{context::QueryVariables, EvaluationError},
36 interface::{ElementIndex, ElementResult, ElementStream, IndexError, QueryClock},
37 models::Element,
38};
39
40#[cfg(feature = "parallel_solver")]
41const MAX_CONCURRENT_SOLUTIONS: usize = 10;
42
43#[derive(Debug, Clone)]
44pub struct MatchSolveContext<'a> {
45 pub variables: &'a QueryVariables,
46 pub clock: Arc<dyn QueryClock>,
47}
48
49impl<'a> MatchSolveContext<'a> {
50 pub fn new(variables: &'a QueryVariables, clock: Arc<dyn QueryClock>) -> MatchSolveContext<'a> {
51 MatchSolveContext { variables, clock }
52 }
53}
54
55enum SolveDirection {
56 Outward,
57 Inward,
58}
59
60pub struct MatchPathSolver {
61 element_index: Arc<dyn ElementIndex>,
62}
63
64impl MatchPathSolver {
65 pub fn new(element_index: Arc<dyn ElementIndex>) -> MatchPathSolver {
66 MatchPathSolver { element_index }
67 }
68
69 #[tracing::instrument(skip_all, err, level = "debug")]
70 pub async fn solve(
71 &self,
72 path: Arc<match_path::MatchPath>,
73 anchor_element: Arc<Element>,
74 anchor_slot: usize,
75 ) -> Result<HashMap<u64, solution::MatchPathSolution>, EvaluationError> {
76 let total_slots = path.slots.len();
77 let mut start_solution = MatchPathSolution::new(total_slots, anchor_slot);
78 start_solution.enqueue_slot(anchor_slot, Some(anchor_element));
79
80 let sol_stream =
81 create_solution_stream(start_solution, path.clone(), self.element_index.clone()).await;
82
83 let mut result = HashMap::new();
84 tokio::pin!(sol_stream);
85
86 while let Some(o) = sol_stream.next().await {
87 match o {
88 Ok((hash, solution)) => {
89 result.insert(hash, solution);
90 }
91 Err(e) => return Err(e),
92 }
93 }
94
95 Ok(result)
96 }
97}
98
99#[allow(dead_code)]
100enum SolutionStreamCommand {
101 Partial(MatchPathSolution),
102 Complete((u64, MatchPathSolution)),
103 Error(EvaluationError),
104 Panic(JoinError),
105 Unsolvable,
106}
107
108async fn create_solution_stream(
109 initial_sol: MatchPathSolution,
110 path: Arc<match_path::MatchPath>,
111 element_index: Arc<dyn ElementIndex>,
112) -> impl Stream<Item = Result<(u64, MatchPathSolution), EvaluationError>> {
113 #[cfg(feature = "parallel_solver")]
114 let permits = Arc::new(Semaphore::new(MAX_CONCURRENT_SOLUTIONS));
115
116 stream! {
117 let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<SolutionStreamCommand>();
118 cmd_tx.send(SolutionStreamCommand::Partial(initial_sol)).unwrap();
119 let mut inflight = 0;
120
121 #[cfg(feature = "parallel_solver")]
122 let (task_tx, mut task_rx) = tokio::sync::mpsc::unbounded_channel::<JoinHandle<()>>();
123
124 #[cfg(feature = "parallel_solver")]
125 {
126 let cmd_tx2 = cmd_tx.clone();
127 tokio::spawn(async move {
128 while let Some(task) = task_rx.recv().await {
129 if let Err(err) = task.await {
130 cmd_tx2.send(SolutionStreamCommand::Panic(err)).unwrap();
131 }
132 }
133 });
134 }
135
136 while let Some(cmd) = cmd_rx.recv().await {
137 match cmd {
138 SolutionStreamCommand::Partial(solution) => {
139 inflight += 1;
140 let path = path.clone();
141 let element_index = element_index.clone();
142 let cmd_tx = cmd_tx.clone();
143
144 #[cfg(not(feature = "parallel_solver"))]
145 try_complete_solution(solution, path, element_index, cmd_tx).await;
146
147 #[cfg(feature = "parallel_solver")]
148 {
149 let permits = permits.clone();
150 let task = tokio::spawn(async move {
151 let _permit = permits.acquire().await.unwrap();
152 try_complete_solution(solution, path, element_index, cmd_tx).await;
153 });
154 task_tx.send(task).unwrap();
155 }
156 },
157 SolutionStreamCommand::Complete((hash, solution)) => {
158 inflight -= 1;
159 yield Ok((hash, solution));
160 },
161 SolutionStreamCommand::Error(e) => {
162 yield Err(e);
163 break;
164 },
165 SolutionStreamCommand::Panic(e) => {
166 panic!("Error in solution task: {e:?}");
167 },
168 SolutionStreamCommand::Unsolvable => {
169 inflight -= 1;
170 },
171 }
172
173 if inflight == 0 {
174 break;
175 }
176 }
177 }
178}
179
180#[tracing::instrument(skip_all, level = "debug")]
181async fn try_complete_solution(
182 mut solution: MatchPathSolution,
183 path: Arc<match_path::MatchPath>,
184 element_index: Arc<dyn ElementIndex>,
185 cmd_tx: tokio::sync::mpsc::UnboundedSender<SolutionStreamCommand>,
186) {
187 while let Some((slot_num, element)) = solution.slot_cursors.pop_front() {
188 solution.mark_slot_solved(slot_num, element.clone());
189
190 if let Some(hash) = solution.get_solution_signature() {
191 cmd_tx
192 .send(SolutionStreamCommand::Complete((hash, solution)))
193 .unwrap();
194 return;
195 }
196
197 let slot = &path.slots[slot_num];
198 let mut alt_by_slot = HashMap::new();
199
200 for out_slot in &slot.out_slots {
201 if solution.is_slot_solved(*out_slot) {
202 continue;
203 }
204
205 let adjacent_elements = alt_by_slot.entry(*out_slot).or_insert_with(Vec::new);
206 let mut found_adjacent = false;
207
208 if let Some(element) = &element {
209 let mut adjacent_stream = match get_adjacent_elements(
210 element_index.clone(),
211 element.clone(),
212 *out_slot,
213 SolveDirection::Outward,
214 )
215 .await
216 {
217 Ok(s) => s,
218 Err(e) => {
219 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
220 return;
221 }
222 };
223
224 while let Some(adjacent_element) = adjacent_stream.next().await {
225 found_adjacent = true;
226 match adjacent_element {
227 Ok(adjacent_element) => adjacent_elements.push(Some(adjacent_element)),
228 Err(e) => {
229 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
230 return;
231 }
232 }
233 }
234 }
235
236 if path.slots[*out_slot].optional && !found_adjacent {
237 adjacent_elements.push(None);
238 }
239 }
240
241 for in_slot in &slot.in_slots {
242 if solution.is_slot_solved(*in_slot) {
243 continue;
244 }
245
246 let adjacent_elements = alt_by_slot.entry(*in_slot).or_insert_with(Vec::new);
247 let mut found_adjacent = false;
248
249 if let Some(element) = &element {
250 let mut adjacent_stream = match get_adjacent_elements(
251 element_index.clone(),
252 element.clone(),
253 *in_slot,
254 SolveDirection::Inward,
255 )
256 .await
257 {
258 Ok(s) => s,
259 Err(e) => {
260 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
261 return;
262 }
263 };
264
265 while let Some(adjacent_element) = adjacent_stream.next().await {
266 found_adjacent = true;
267 match adjacent_element {
268 Ok(adjacent_element) => adjacent_elements.push(Some(adjacent_element)),
269 Err(e) => {
270 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
271 return;
272 }
273 }
274 }
275 }
276
277 if path.slots[*in_slot].optional && !found_adjacent {
278 adjacent_elements.push(None);
279 }
280 }
281
282 let mut pointers = BTreeMap::new();
283
284 for (slot, adjacent_elements) in &mut alt_by_slot {
285 match adjacent_elements.len() {
286 0 => {}
287 1 => {
288 solution.enqueue_slot(*slot, adjacent_elements.pop().unwrap());
289 }
290 _ => {
291 pointers.insert(*slot, 0);
292 }
293 }
294 }
295
296 if pointers.is_empty() {
297 continue;
298 }
299
300 let mut permutations = vec![pointers];
301
302 while let Some(mut p) = permutations.pop() {
303 let mut alt_solution = solution.clone();
304
305 for (slot, pointer) in &p {
306 if let Some(adjacent_element) = alt_by_slot.get(slot).unwrap().get(*pointer) {
307 alt_solution.enqueue_slot(*slot, adjacent_element.clone());
308 }
309 }
310
311 cmd_tx
312 .send(SolutionStreamCommand::Partial(alt_solution))
313 .unwrap();
314
315 for (slot, pointer) in &mut p {
316 if *pointer < alt_by_slot.get(slot).unwrap().len() - 1 {
317 *pointer += 1;
318 permutations.push(p);
319 break;
320 } else {
321 *pointer = 0;
322 }
323 }
324 }
325 }
326 cmd_tx.send(SolutionStreamCommand::Unsolvable).unwrap();
327}
328
329#[tracing::instrument(skip_all, err, level = "debug")]
330async fn get_adjacent_elements(
331 element_index: Arc<dyn ElementIndex>,
332 element: Arc<Element>,
333 target_slot: usize,
334 direction: SolveDirection,
335) -> Result<ElementStream, IndexError> {
336 match element.as_ref() {
337 Element::Node {
338 metadata,
339 properties: _,
340 } => match direction {
341 SolveDirection::Outward => Ok(element_index
342 .get_slot_elements_by_inbound(target_slot, &metadata.reference)
343 .await?),
344 SolveDirection::Inward => Ok(element_index
345 .get_slot_elements_by_outbound(target_slot, &metadata.reference)
346 .await?),
347 },
348 Element::Relation {
349 metadata: _,
350 in_node,
351 out_node,
352 properties: _,
353 } => {
354 let adjecent_ref = match direction {
355 SolveDirection::Outward => out_node,
356 SolveDirection::Inward => in_node,
357 };
358
359 match element_index
360 .get_slot_element_by_ref(target_slot, adjecent_ref)
361 .await?
362 {
363 Some(adjacent_element) => Ok(Box::pin(tokio_stream::once::<ElementResult>(Ok(
364 adjacent_element,
365 )))),
366 None => Ok(Box::pin(tokio_stream::empty::<ElementResult>())),
367 }
368 }
369 }
370}
371
372fn merge_node_match<'b>(
373 mtch: &NodeMatch,
374 slots: &'b mut Vec<match_path::MatchPathSlot>,
375 alias_map: &'b mut HashMap<Arc<str>, usize>,
376 path_index: usize,
377 optional: bool,
378) -> Result<usize, EvaluationError> {
379 match &mtch.annotation.name {
380 Some(alias) => {
381 if let Some(slot_num) = alias_map.get(alias) {
382 slots[*slot_num].optional = optional && slots[*slot_num].optional;
383 slots[*slot_num].paths.insert(path_index);
384 Ok(*slot_num)
385 } else {
386 slots.push(match_path::MatchPathSlot {
387 spec: match_path::SlotElementSpec::from_node_match(mtch),
388 in_slots: Vec::new(),
389 out_slots: Vec::new(),
390 optional,
391 paths: HashSet::from([path_index]),
392 });
393 alias_map.insert(alias.clone(), slots.len() - 1);
394 Ok(slots.len() - 1)
395 }
396 }
397 None => {
398 slots.push(match_path::MatchPathSlot {
399 spec: match_path::SlotElementSpec::from_node_match(mtch),
400 in_slots: Vec::new(),
401 out_slots: Vec::new(),
402 optional,
403 paths: HashSet::from([path_index]),
404 });
405 Ok(slots.len() - 1)
406 }
407 }
408}
409
410fn merge_relation_match<'b>(
411 mtch: &RelationMatch,
412 slots: &'b mut Vec<match_path::MatchPathSlot>,
413 alias_map: &'b mut HashMap<Arc<str>, usize>,
414 path_index: usize,
415 optional: bool,
416) -> Result<usize, EvaluationError> {
417 match &mtch.annotation.name {
418 Some(alias) => {
419 if let Some(slot_num) = alias_map.get(alias) {
420 slots[*slot_num].optional = optional && slots[*slot_num].optional;
421 slots[*slot_num].paths.insert(path_index);
422 Ok(*slot_num)
423 } else {
424 slots.push(match_path::MatchPathSlot {
425 spec: match_path::SlotElementSpec::from_relation_match(mtch),
426 in_slots: Vec::new(),
427 out_slots: Vec::new(),
428 optional,
429 paths: HashSet::from([path_index]),
430 });
431 alias_map.insert(alias.clone(), slots.len() - 1);
432 Ok(slots.len() - 1)
433 }
434 }
435 None => {
436 slots.push(match_path::MatchPathSlot {
437 spec: match_path::SlotElementSpec::from_relation_match(mtch),
438 in_slots: Vec::new(),
439 out_slots: Vec::new(),
440 optional,
441 paths: HashSet::from([path_index]),
442 });
443 Ok(slots.len() - 1)
444 }
445 }
446}