1#![allow(clippy::unwrap_used)]
2#![allow(dead_code)]
16pub mod match_path;
17pub mod solution;
18
19use std::{
20 collections::{BTreeMap, HashMap},
21 sync::Arc,
22};
23
24use async_stream::stream;
25use drasi_query_ast::ast::{NodeMatch, RelationMatch};
26use futures::{stream::StreamExt, Stream};
27#[allow(unused_imports)]
28use tokio::{
29 sync::Semaphore,
30 task::{JoinError, JoinHandle},
31};
32
33use self::solution::MatchPathSolution;
34use crate::{
35 evaluation::{context::QueryVariables, EvaluationError},
36 interface::{ElementIndex, ElementResult, ElementStream, IndexError, QueryClock},
37 models::Element,
38};
39
40#[cfg(feature = "parallel_solver")]
41const MAX_CONCURRENT_SOLUTIONS: usize = 10;
42
43#[derive(Debug, Clone)]
44pub struct MatchSolveContext<'a> {
45 pub variables: &'a QueryVariables,
46 pub clock: Arc<dyn QueryClock>,
47}
48
49impl<'a> MatchSolveContext<'a> {
50 pub fn new(variables: &'a QueryVariables, clock: Arc<dyn QueryClock>) -> MatchSolveContext<'a> {
51 MatchSolveContext { variables, clock }
52 }
53}
54
55enum SolveDirection {
56 Outward,
57 Inward,
58}
59
60pub struct MatchPathSolver {
61 element_index: Arc<dyn ElementIndex>,
62}
63
64impl MatchPathSolver {
65 pub fn new(element_index: Arc<dyn ElementIndex>) -> MatchPathSolver {
66 MatchPathSolver { element_index }
67 }
68
69 #[tracing::instrument(skip_all, err)]
70 pub async fn solve(
71 &self,
72 path: Arc<match_path::MatchPath>,
73 anchor_element: Arc<Element>,
74 anchor_slot: usize,
75 ) -> Result<HashMap<u64, solution::MatchPathSolution>, EvaluationError> {
76 let total_slots = path.slots.len();
77 let mut start_solution = MatchPathSolution::new(total_slots);
78 start_solution.enqueue_slot(anchor_slot, anchor_element);
79
80 let sol_stream =
81 create_solution_stream(start_solution, path.clone(), self.element_index.clone()).await;
82
83 let mut result = HashMap::new();
84 tokio::pin!(sol_stream);
85
86 while let Some(o) = sol_stream.next().await {
87 match o {
88 Ok((hash, solution)) => {
89 result.insert(hash, solution);
90 }
91 Err(e) => return Err(e),
92 }
93 }
94
95 Ok(result)
96 }
97}
98
99#[allow(dead_code)]
100enum SolutionStreamCommand {
101 Partial(MatchPathSolution),
102 Complete((u64, MatchPathSolution)),
103 Error(EvaluationError),
104 Panic(JoinError),
105 Unsolvable,
106}
107
108async fn create_solution_stream(
109 initial_sol: MatchPathSolution,
110 path: Arc<match_path::MatchPath>,
111 element_index: Arc<dyn ElementIndex>,
112) -> impl Stream<Item = Result<(u64, MatchPathSolution), EvaluationError>> {
113 #[cfg(feature = "parallel_solver")]
114 let permits = Arc::new(Semaphore::new(MAX_CONCURRENT_SOLUTIONS));
115
116 stream! {
117 let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<SolutionStreamCommand>();
118 cmd_tx.send(SolutionStreamCommand::Partial(initial_sol)).unwrap();
119 let mut inflight = 0;
120
121 #[cfg(feature = "parallel_solver")]
122 let (task_tx, mut task_rx) = tokio::sync::mpsc::unbounded_channel::<JoinHandle<()>>();
123
124 #[cfg(feature = "parallel_solver")]
125 {
126 let cmd_tx2 = cmd_tx.clone();
127 tokio::spawn(async move {
128 while let Some(task) = task_rx.recv().await {
129 if let Err(err) = task.await {
130 cmd_tx2.send(SolutionStreamCommand::Panic(err)).unwrap();
131 }
132 }
133 });
134 }
135
136 while let Some(cmd) = cmd_rx.recv().await {
137 match cmd {
138 SolutionStreamCommand::Partial(solution) => {
139 inflight += 1;
140 let path = path.clone();
141 let element_index = element_index.clone();
142 let cmd_tx = cmd_tx.clone();
143
144 #[cfg(not(feature = "parallel_solver"))]
145 try_complete_solution(solution, path, element_index, cmd_tx).await;
146
147 #[cfg(feature = "parallel_solver")]
148 {
149 let permits = permits.clone();
150 let task = tokio::spawn(async move {
151 let _permit = permits.acquire().await.unwrap();
152 try_complete_solution(solution, path, element_index, cmd_tx).await;
153 });
154 task_tx.send(task).unwrap();
155 }
156 },
157 SolutionStreamCommand::Complete((hash, solution)) => {
158 inflight -= 1;
159 yield Ok((hash, solution));
160 },
161 SolutionStreamCommand::Error(e) => {
162 yield Err(e);
163 break;
164 },
165 SolutionStreamCommand::Panic(e) => {
166 panic!("Error in solution task: {:?}", e);
167 },
168 SolutionStreamCommand::Unsolvable => {
169 inflight -= 1;
170 },
171 }
172
173 if inflight == 0 {
174 break;
175 }
176 }
177 }
178}
179
180#[tracing::instrument(skip_all)]
181async fn try_complete_solution(
182 mut solution: MatchPathSolution,
183 path: Arc<match_path::MatchPath>,
184 element_index: Arc<dyn ElementIndex>,
185 cmd_tx: tokio::sync::mpsc::UnboundedSender<SolutionStreamCommand>,
186) {
187 while let Some((slot_num, element)) = solution.slot_cursors.pop_front() {
188 solution.mark_slot_solved(slot_num, element.clone());
189
190 if let Some(hash) = solution.get_solution_signature() {
191 cmd_tx
192 .send(SolutionStreamCommand::Complete((hash, solution)))
193 .unwrap();
194 return;
195 }
196
197 let slot = &path.slots[slot_num];
198 let mut alt_by_slot = HashMap::new();
199
200 for out_slot in &slot.out_slots {
201 if solution.is_slot_solved(*out_slot) {
202 continue;
203 }
204 let mut adjacent_stream = match get_adjacent_elements(
205 element_index.clone(),
206 element.clone(),
207 *out_slot,
208 SolveDirection::Outward,
209 )
210 .await
211 {
212 Ok(s) => s,
213 Err(e) => {
214 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
215 return;
216 }
217 };
218 let adjacent_elements = alt_by_slot.entry(*out_slot).or_insert_with(Vec::new);
219
220 while let Some(adjacent_element) = adjacent_stream.next().await {
221 match adjacent_element {
222 Ok(adjacent_element) => adjacent_elements.push(adjacent_element),
223 Err(e) => {
224 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
225 return;
226 }
227 }
228 }
229 }
230
231 for in_slot in &slot.in_slots {
232 if solution.is_slot_solved(*in_slot) {
233 continue;
234 }
235
236 let mut adjacent_stream = match get_adjacent_elements(
237 element_index.clone(),
238 element.clone(),
239 *in_slot,
240 SolveDirection::Inward,
241 )
242 .await
243 {
244 Ok(s) => s,
245 Err(e) => {
246 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
247 return;
248 }
249 };
250 let adjacent_elements = alt_by_slot.entry(*in_slot).or_insert_with(Vec::new);
251
252 while let Some(adjacent_element) = adjacent_stream.next().await {
253 match adjacent_element {
254 Ok(adjacent_element) => adjacent_elements.push(adjacent_element),
255 Err(e) => {
256 cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
257 return;
258 }
259 }
260 }
261 }
262
263 let mut pointers = BTreeMap::new();
264
265 for (slot, adjacent_elements) in &mut alt_by_slot {
266 match adjacent_elements.len() {
267 0 => {}
268 1 => {
269 solution.enqueue_slot(*slot, adjacent_elements.pop().unwrap());
270 }
271 _ => {
272 pointers.insert(*slot, 0);
273 }
274 }
275 }
276
277 if pointers.is_empty() {
278 continue;
279 }
280
281 let mut permutations = vec![pointers];
282
283 while let Some(mut p) = permutations.pop() {
284 let mut alt_solution = solution.clone();
285
286 for (slot, pointer) in &p {
287 if let Some(adjacent_element) = alt_by_slot.get(slot).unwrap().get(*pointer) {
288 alt_solution.enqueue_slot(*slot, adjacent_element.clone());
289 }
290 }
291
292 cmd_tx
293 .send(SolutionStreamCommand::Partial(alt_solution))
294 .unwrap();
295
296 for (slot, pointer) in &mut p {
297 if *pointer < alt_by_slot.get(slot).unwrap().len() - 1 {
298 *pointer += 1;
299 permutations.push(p);
300 break;
301 } else {
302 *pointer = 0;
303 }
304 }
305 }
306 }
307 cmd_tx.send(SolutionStreamCommand::Unsolvable).unwrap();
308}
309
310#[tracing::instrument(skip_all, err)]
311async fn get_adjacent_elements(
312 element_index: Arc<dyn ElementIndex>,
313 element: Arc<Element>,
314 target_slot: usize,
315 direction: SolveDirection,
316) -> Result<ElementStream, IndexError> {
317 match element.as_ref() {
318 Element::Node {
319 metadata,
320 properties: _,
321 } => match direction {
322 SolveDirection::Outward => Ok(element_index
323 .get_slot_elements_by_inbound(target_slot, &metadata.reference)
324 .await?),
325 SolveDirection::Inward => Ok(element_index
326 .get_slot_elements_by_outbound(target_slot, &metadata.reference)
327 .await?),
328 },
329 Element::Relation {
330 metadata: _,
331 in_node,
332 out_node,
333 properties: _,
334 } => {
335 let adjecent_ref = match direction {
336 SolveDirection::Outward => out_node,
337 SolveDirection::Inward => in_node,
338 };
339
340 match element_index
341 .get_slot_element_by_ref(target_slot, adjecent_ref)
342 .await?
343 {
344 Some(adjacent_element) => Ok(Box::pin(tokio_stream::once::<ElementResult>(Ok(
345 adjacent_element,
346 )))),
347 None => Ok(Box::pin(tokio_stream::empty::<ElementResult>())),
348 }
349 }
350 }
351}
352
353fn merge_node_match<'b>(
354 mtch: &NodeMatch,
355 slots: &'b mut Vec<match_path::MatchPathSlot>,
356 alias_map: &'b mut HashMap<Arc<str>, usize>,
357) -> Result<usize, EvaluationError> {
358 match &mtch.annotation.name {
359 Some(alias) => {
360 if let Some(slot_num) = alias_map.get(alias) {
361 Ok(*slot_num)
362 } else {
363 slots.push(match_path::MatchPathSlot {
364 spec: match_path::SlotElementSpec::from_node_match(mtch),
365 in_slots: Vec::new(),
366 out_slots: Vec::new(),
367 });
368 alias_map.insert(alias.clone(), slots.len() - 1);
369 Ok(slots.len() - 1)
370 }
371 }
372 None => {
373 slots.push(match_path::MatchPathSlot {
374 spec: match_path::SlotElementSpec::from_node_match(mtch),
375 in_slots: Vec::new(),
376 out_slots: Vec::new(),
377 });
378 Ok(slots.len() - 1)
379 }
380 }
381}
382
383fn merge_relation_match<'b>(
384 mtch: &RelationMatch,
385 slots: &'b mut Vec<match_path::MatchPathSlot>,
386 alias_map: &'b mut HashMap<Arc<str>, usize>,
387) -> Result<usize, EvaluationError> {
388 match &mtch.annotation.name {
389 Some(alias) => {
390 if let Some(slot_num) = alias_map.get(alias) {
391 Ok(*slot_num)
392 } else {
393 slots.push(match_path::MatchPathSlot {
394 spec: match_path::SlotElementSpec::from_relation_match(mtch),
395 in_slots: Vec::new(),
396 out_slots: Vec::new(),
397 });
398 alias_map.insert(alias.clone(), slots.len() - 1);
399 Ok(slots.len() - 1)
400 }
401 }
402 None => {
403 slots.push(match_path::MatchPathSlot {
404 spec: match_path::SlotElementSpec::from_relation_match(mtch),
405 in_slots: Vec::new(),
406 out_slots: Vec::new(),
407 });
408 Ok(slots.len() - 1)
409 }
410 }
411}