1#![allow(clippy::unwrap_used)]
16#![allow(dead_code)]
17pub mod match_path;
18pub mod solution;
19
20use std::{
21 collections::{BTreeMap, HashMap, HashSet},
22 sync::Arc,
23};
24
25use async_stream::stream;
26use drasi_query_ast::ast::{NodeMatch, RelationMatch};
27use futures::{stream::StreamExt, Stream};
28#[allow(unused_imports)]
29use tokio::{
30 sync::Semaphore,
31 task::{JoinError, JoinHandle},
32};
33
34use self::solution::MatchPathSolution;
35use crate::{
36 evaluation::{context::QueryVariables, EvaluationError},
37 interface::{ElementIndex, ElementResult, ElementStream, IndexError, QueryClock},
38 models::Element,
39};
40
41#[cfg(feature = "parallel_solver")]
42const MAX_CONCURRENT_SOLUTIONS: usize = 10;
43
44#[derive(Debug, Clone)]
45pub struct MatchSolveContext<'a> {
46 pub variables: &'a QueryVariables,
47 pub clock: Arc<dyn QueryClock>,
48}
49
50impl<'a> MatchSolveContext<'a> {
51 pub fn new(variables: &'a QueryVariables, clock: Arc<dyn QueryClock>) -> MatchSolveContext<'a> {
52 MatchSolveContext { variables, clock }
53 }
54}
55
56enum SolveDirection {
57 Outward,
58 Inward,
59}
60
61pub struct MatchPathSolver {
62 element_index: Arc<dyn ElementIndex>,
63}
64
65impl MatchPathSolver {
66 pub fn new(element_index: Arc<dyn ElementIndex>) -> MatchPathSolver {
67 MatchPathSolver { element_index }
68 }
69
70 #[tracing::instrument(skip_all, err, level = "debug")]
71 pub async fn solve(
72 &self,
73 path: Arc<match_path::MatchPath>,
74 anchor_element: Arc<Element>,
75 anchor_slot: usize,
76 ) -> Result<HashMap<u64, solution::MatchPathSolution>, EvaluationError> {
77 let total_slots = path.slots.len();
78 let mut start_solution = MatchPathSolution::new(total_slots, anchor_slot);
79 start_solution.enqueue_slot(anchor_slot, Some(anchor_element));
80
81 let sol_stream =
82 create_solution_stream(start_solution, path.clone(), self.element_index.clone()).await;
83
84 let mut result = HashMap::new();
85 tokio::pin!(sol_stream);
86
87 while let Some(o) = sol_stream.next().await {
88 match o {
89 Ok((hash, solution)) => {
90 result.insert(hash, solution);
91 }
92 Err(e) => return Err(e),
93 }
94 }
95
96 Ok(result)
97 }
98}
99
100#[allow(dead_code)]
101enum SolutionStreamCommand {
102 Partial(MatchPathSolution),
103 Complete((u64, MatchPathSolution)),
104 Error(EvaluationError),
105 Panic(JoinError),
106 Unsolvable,
107}
108
109async fn create_solution_stream(
110 initial_sol: MatchPathSolution,
111 path: Arc<match_path::MatchPath>,
112 element_index: Arc<dyn ElementIndex>,
113) -> impl Stream<Item = Result<(u64, MatchPathSolution), EvaluationError>> {
114 #[cfg(feature = "parallel_solver")]
115 let permits = Arc::new(Semaphore::new(MAX_CONCURRENT_SOLUTIONS));
116
117 stream! {
118 let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<SolutionStreamCommand>();
119 cmd_tx.send(SolutionStreamCommand::Partial(initial_sol)).unwrap();
120 let mut inflight = 0;
121
122 #[cfg(feature = "parallel_solver")]
123 let (task_tx, mut task_rx) = tokio::sync::mpsc::unbounded_channel::<JoinHandle<()>>();
124
125 #[cfg(feature = "parallel_solver")]
126 {
127 let cmd_tx2 = cmd_tx.clone();
128 tokio::spawn(async move {
129 while let Some(task) = task_rx.recv().await {
130 if let Err(err) = task.await {
131 cmd_tx2.send(SolutionStreamCommand::Panic(err)).unwrap();
132 }
133 }
134 });
135 }
136
137 while let Some(cmd) = cmd_rx.recv().await {
138 match cmd {
139 SolutionStreamCommand::Partial(solution) => {
140 inflight += 1;
141 let path = path.clone();
142 let element_index = element_index.clone();
143 let cmd_tx = cmd_tx.clone();
144
145 #[cfg(not(feature = "parallel_solver"))]
146 try_complete_solution(solution, path, element_index, cmd_tx).await;
147
148 #[cfg(feature = "parallel_solver")]
149 {
150 let permits = permits.clone();
151 let task = tokio::spawn(async move {
152 let _permit = permits.acquire().await.unwrap();
153 try_complete_solution(solution, path, element_index, cmd_tx).await;
154 });
155 task_tx.send(task).unwrap();
156 }
157 },
158 SolutionStreamCommand::Complete((hash, solution)) => {
159 inflight -= 1;
160 yield Ok((hash, solution));
161 },
162 SolutionStreamCommand::Error(e) => {
163 yield Err(e);
164 break;
165 },
166 SolutionStreamCommand::Panic(e) => {
167 panic!("Error in solution task: {e:?}");
168 },
169 SolutionStreamCommand::Unsolvable => {
170 inflight -= 1;
171 },
172 }
173
174 if inflight == 0 {
175 break;
176 }
177 }
178 }
179}
180
181#[tracing::instrument(skip_all, level = "debug")]
182async fn try_complete_solution(
183 mut solution: MatchPathSolution,
184 path: Arc<match_path::MatchPath>,
185 element_index: Arc<dyn ElementIndex>,
186 cmd_tx: tokio::sync::mpsc::UnboundedSender<SolutionStreamCommand>,
187) {
188 while let Some((slot_num, element)) = solution.slot_cursors.pop_front() {
189 solution.mark_slot_solved(slot_num, element.clone());
190
191 if let Some(hash) = solution.get_solution_signature() {
192 cmd_tx
193 .send(SolutionStreamCommand::Complete((hash, solution)))
194 .unwrap();
195 return;
196 }
197
198 let slot = &path.slots[slot_num];
199 let mut alt_by_slot = HashMap::new();
200
201 for out_slot in &slot.out_slots {
202 if solution.is_slot_solved(*out_slot) {
203 continue;
204 }
205
206 let adjacent_elements = alt_by_slot.entry(*out_slot).or_insert_with(Vec::new);
207 let mut found_adjacent = false;
208
209 if let Some(element) = &element {
210 let mut adjacent_stream = match get_adjacent_elements(
211 element_index.clone(),
212 element.clone(),
213 *out_slot,
214 SolveDirection::Outward,
215 )
216 .await
217 {
218 Ok(s) => s,
219 Err(e) => {
220 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
221 return;
222 }
223 };
224
225 while let Some(adjacent_element) = adjacent_stream.next().await {
226 found_adjacent = true;
227 match adjacent_element {
228 Ok(adjacent_element) => adjacent_elements.push(Some(adjacent_element)),
229 Err(e) => {
230 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
231 return;
232 }
233 }
234 }
235 }
236
237 if path.slots[*out_slot].optional && !found_adjacent {
238 adjacent_elements.push(None);
239 }
240 }
241
242 for in_slot in &slot.in_slots {
243 if solution.is_slot_solved(*in_slot) {
244 continue;
245 }
246
247 let adjacent_elements = alt_by_slot.entry(*in_slot).or_insert_with(Vec::new);
248 let mut found_adjacent = false;
249
250 if let Some(element) = &element {
251 let mut adjacent_stream = match get_adjacent_elements(
252 element_index.clone(),
253 element.clone(),
254 *in_slot,
255 SolveDirection::Inward,
256 )
257 .await
258 {
259 Ok(s) => s,
260 Err(e) => {
261 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
262 return;
263 }
264 };
265
266 while let Some(adjacent_element) = adjacent_stream.next().await {
267 found_adjacent = true;
268 match adjacent_element {
269 Ok(adjacent_element) => adjacent_elements.push(Some(adjacent_element)),
270 Err(e) => {
271 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
272 return;
273 }
274 }
275 }
276 }
277
278 if path.slots[*in_slot].optional && !found_adjacent {
279 adjacent_elements.push(None);
280 }
281 }
282
283 let mut pointers = BTreeMap::new();
284
285 for (slot, adjacent_elements) in &mut alt_by_slot {
286 match adjacent_elements.len() {
287 0 => {}
288 1 => {
289 solution.enqueue_slot(*slot, adjacent_elements.pop().unwrap());
290 }
291 _ => {
292 pointers.insert(*slot, 0);
293 }
294 }
295 }
296
297 if pointers.is_empty() {
298 continue;
299 }
300
301 let mut permutations = vec![pointers];
302
303 while let Some(mut p) = permutations.pop() {
304 let mut alt_solution = solution.clone();
305
306 for (slot, pointer) in &p {
307 if let Some(adjacent_element) = alt_by_slot.get(slot).unwrap().get(*pointer) {
308 alt_solution.enqueue_slot(*slot, adjacent_element.clone());
309 }
310 }
311
312 cmd_tx
313 .send(SolutionStreamCommand::Partial(alt_solution))
314 .unwrap();
315
316 for (slot, pointer) in &mut p {
317 if *pointer < alt_by_slot.get(slot).unwrap().len() - 1 {
318 *pointer += 1;
319 permutations.push(p);
320 break;
321 } else {
322 *pointer = 0;
323 }
324 }
325 }
326 }
327 cmd_tx.send(SolutionStreamCommand::Unsolvable).unwrap();
328}
329
330#[tracing::instrument(skip_all, err, level = "debug")]
331async fn get_adjacent_elements(
332 element_index: Arc<dyn ElementIndex>,
333 element: Arc<Element>,
334 target_slot: usize,
335 direction: SolveDirection,
336) -> Result<ElementStream, IndexError> {
337 match element.as_ref() {
338 Element::Node {
339 metadata,
340 properties: _,
341 } => match direction {
342 SolveDirection::Outward => Ok(element_index
343 .get_slot_elements_by_inbound(target_slot, &metadata.reference)
344 .await?),
345 SolveDirection::Inward => Ok(element_index
346 .get_slot_elements_by_outbound(target_slot, &metadata.reference)
347 .await?),
348 },
349 Element::Relation {
350 metadata: _,
351 in_node,
352 out_node,
353 properties: _,
354 } => {
355 let adjacent_ref = match direction {
356 SolveDirection::Outward => out_node,
357 SolveDirection::Inward => in_node,
358 };
359
360 match element_index
361 .get_slot_element_by_ref(target_slot, adjacent_ref)
362 .await?
363 {
364 Some(adjacent_element) => Ok(Box::pin(tokio_stream::once::<ElementResult>(Ok(
365 adjacent_element,
366 )))),
367 None => Ok(Box::pin(tokio_stream::empty::<ElementResult>())),
368 }
369 }
370 }
371}
372
373fn merge_node_match<'b>(
374 mtch: &NodeMatch,
375 slots: &'b mut Vec<match_path::MatchPathSlot>,
376 alias_map: &'b mut HashMap<Arc<str>, usize>,
377 path_index: usize,
378 optional: bool,
379) -> Result<usize, EvaluationError> {
380 match &mtch.annotation.name {
381 Some(alias) => {
382 if let Some(slot_num) = alias_map.get(alias) {
383 slots[*slot_num].optional = optional && slots[*slot_num].optional;
384 slots[*slot_num].paths.insert(path_index);
385 Ok(*slot_num)
386 } else {
387 slots.push(match_path::MatchPathSlot {
388 spec: match_path::SlotElementSpec::from_node_match(mtch),
389 in_slots: Vec::new(),
390 out_slots: Vec::new(),
391 optional,
392 paths: HashSet::from([path_index]),
393 });
394 alias_map.insert(alias.clone(), slots.len() - 1);
395 Ok(slots.len() - 1)
396 }
397 }
398 None => {
399 slots.push(match_path::MatchPathSlot {
400 spec: match_path::SlotElementSpec::from_node_match(mtch),
401 in_slots: Vec::new(),
402 out_slots: Vec::new(),
403 optional,
404 paths: HashSet::from([path_index]),
405 });
406 Ok(slots.len() - 1)
407 }
408 }
409}
410
411fn merge_relation_match<'b>(
412 mtch: &RelationMatch,
413 slots: &'b mut Vec<match_path::MatchPathSlot>,
414 alias_map: &'b mut HashMap<Arc<str>, usize>,
415 path_index: usize,
416 optional: bool,
417) -> Result<usize, EvaluationError> {
418 match &mtch.annotation.name {
419 Some(alias) => {
420 if let Some(slot_num) = alias_map.get(alias) {
421 slots[*slot_num].optional = optional && slots[*slot_num].optional;
422 slots[*slot_num].paths.insert(path_index);
423 Ok(*slot_num)
424 } else {
425 slots.push(match_path::MatchPathSlot {
426 spec: match_path::SlotElementSpec::from_relation_match(mtch),
427 in_slots: Vec::new(),
428 out_slots: Vec::new(),
429 optional,
430 paths: HashSet::from([path_index]),
431 });
432 alias_map.insert(alias.clone(), slots.len() - 1);
433 Ok(slots.len() - 1)
434 }
435 }
436 None => {
437 slots.push(match_path::MatchPathSlot {
438 spec: match_path::SlotElementSpec::from_relation_match(mtch),
439 in_slots: Vec::new(),
440 out_slots: Vec::new(),
441 optional,
442 paths: HashSet::from([path_index]),
443 });
444 Ok(slots.len() - 1)
445 }
446 }
447}