drasi_core/path_solver/
mod.rs

1#![allow(clippy::unwrap_used)]
2// Copyright 2024 The Drasi Authors.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15#![allow(dead_code)]
16pub mod match_path;
17pub mod solution;
18
19use std::{
20    collections::{BTreeMap, HashMap},
21    sync::Arc,
22};
23
24use async_stream::stream;
25use drasi_query_ast::ast::{NodeMatch, RelationMatch};
26use futures::{stream::StreamExt, Stream};
27#[allow(unused_imports)]
28use tokio::{
29    sync::Semaphore,
30    task::{JoinError, JoinHandle},
31};
32
33use self::solution::MatchPathSolution;
34use crate::{
35    evaluation::{context::QueryVariables, EvaluationError},
36    interface::{ElementIndex, ElementResult, ElementStream, IndexError, QueryClock},
37    models::Element,
38};
39
40#[cfg(feature = "parallel_solver")]
41const MAX_CONCURRENT_SOLUTIONS: usize = 10;
42
43#[derive(Debug, Clone)]
44pub struct MatchSolveContext<'a> {
45    pub variables: &'a QueryVariables,
46    pub clock: Arc<dyn QueryClock>,
47}
48
49impl<'a> MatchSolveContext<'a> {
50    pub fn new(variables: &'a QueryVariables, clock: Arc<dyn QueryClock>) -> MatchSolveContext<'a> {
51        MatchSolveContext { variables, clock }
52    }
53}
54
55enum SolveDirection {
56    Outward,
57    Inward,
58}
59
60pub struct MatchPathSolver {
61    element_index: Arc<dyn ElementIndex>,
62}
63
64impl MatchPathSolver {
65    pub fn new(element_index: Arc<dyn ElementIndex>) -> MatchPathSolver {
66        MatchPathSolver { element_index }
67    }
68
69    #[tracing::instrument(skip_all, err)]
70    pub async fn solve(
71        &self,
72        path: Arc<match_path::MatchPath>,
73        anchor_element: Arc<Element>,
74        anchor_slot: usize,
75    ) -> Result<HashMap<u64, solution::MatchPathSolution>, EvaluationError> {
76        let total_slots = path.slots.len();
77        let mut start_solution = MatchPathSolution::new(total_slots);
78        start_solution.enqueue_slot(anchor_slot, anchor_element);
79
80        let sol_stream =
81            create_solution_stream(start_solution, path.clone(), self.element_index.clone()).await;
82
83        let mut result = HashMap::new();
84        tokio::pin!(sol_stream);
85
86        while let Some(o) = sol_stream.next().await {
87            match o {
88                Ok((hash, solution)) => {
89                    result.insert(hash, solution);
90                }
91                Err(e) => return Err(e),
92            }
93        }
94
95        Ok(result)
96    }
97}
98
99#[allow(dead_code)]
100enum SolutionStreamCommand {
101    Partial(MatchPathSolution),
102    Complete((u64, MatchPathSolution)),
103    Error(EvaluationError),
104    Panic(JoinError),
105    Unsolvable,
106}
107
108async fn create_solution_stream(
109    initial_sol: MatchPathSolution,
110    path: Arc<match_path::MatchPath>,
111    element_index: Arc<dyn ElementIndex>,
112) -> impl Stream<Item = Result<(u64, MatchPathSolution), EvaluationError>> {
113    #[cfg(feature = "parallel_solver")]
114    let permits = Arc::new(Semaphore::new(MAX_CONCURRENT_SOLUTIONS));
115
116    stream! {
117        let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<SolutionStreamCommand>();
118        cmd_tx.send(SolutionStreamCommand::Partial(initial_sol)).unwrap();
119        let mut inflight = 0;
120
121        #[cfg(feature = "parallel_solver")]
122        let (task_tx, mut task_rx) = tokio::sync::mpsc::unbounded_channel::<JoinHandle<()>>();
123
124        #[cfg(feature = "parallel_solver")]
125        {
126            let cmd_tx2 = cmd_tx.clone();
127            tokio::spawn(async move {
128                while let Some(task) = task_rx.recv().await {
129                    if let Err(err) = task.await {
130                        cmd_tx2.send(SolutionStreamCommand::Panic(err)).unwrap();
131                    }
132                }
133            });
134        }
135
136        while let Some(cmd) = cmd_rx.recv().await {
137            match cmd {
138                SolutionStreamCommand::Partial(solution) => {
139                    inflight += 1;
140                    let path = path.clone();
141                    let element_index = element_index.clone();
142                    let cmd_tx = cmd_tx.clone();
143
144                    #[cfg(not(feature = "parallel_solver"))]
145                    try_complete_solution(solution, path, element_index, cmd_tx).await;
146
147                    #[cfg(feature = "parallel_solver")]
148                    {
149                        let permits = permits.clone();
150                        let task = tokio::spawn(async move {
151                            let _permit = permits.acquire().await.unwrap();
152                            try_complete_solution(solution, path, element_index, cmd_tx).await;
153                        });
154                        task_tx.send(task).unwrap();
155                    }
156                },
157                SolutionStreamCommand::Complete((hash, solution)) => {
158                    inflight -= 1;
159                    yield Ok((hash, solution));
160                },
161                SolutionStreamCommand::Error(e) => {
162                    yield Err(e);
163                    break;
164                },
165                SolutionStreamCommand::Panic(e) => {
166                    panic!("Error in solution task: {:?}", e);
167                },
168                SolutionStreamCommand::Unsolvable => {
169                    inflight -= 1;
170                },
171            }
172
173            if inflight == 0 {
174                break;
175            }
176        }
177    }
178}
179
180#[tracing::instrument(skip_all)]
181async fn try_complete_solution(
182    mut solution: MatchPathSolution,
183    path: Arc<match_path::MatchPath>,
184    element_index: Arc<dyn ElementIndex>,
185    cmd_tx: tokio::sync::mpsc::UnboundedSender<SolutionStreamCommand>,
186) {
187    while let Some((slot_num, element)) = solution.slot_cursors.pop_front() {
188        solution.mark_slot_solved(slot_num, element.clone());
189
190        if let Some(hash) = solution.get_solution_signature() {
191            cmd_tx
192                .send(SolutionStreamCommand::Complete((hash, solution)))
193                .unwrap();
194            return;
195        }
196
197        let slot = &path.slots[slot_num];
198        let mut alt_by_slot = HashMap::new();
199
200        for out_slot in &slot.out_slots {
201            if solution.is_slot_solved(*out_slot) {
202                continue;
203            }
204            let mut adjacent_stream = match get_adjacent_elements(
205                element_index.clone(),
206                element.clone(),
207                *out_slot,
208                SolveDirection::Outward,
209            )
210            .await
211            {
212                Ok(s) => s,
213                Err(e) => {
214                    cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
215                    return;
216                }
217            };
218            let adjacent_elements = alt_by_slot.entry(*out_slot).or_insert_with(Vec::new);
219
220            while let Some(adjacent_element) = adjacent_stream.next().await {
221                match adjacent_element {
222                    Ok(adjacent_element) => adjacent_elements.push(adjacent_element),
223                    Err(e) => {
224                        cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
225                        return;
226                    }
227                }
228            }
229        }
230
231        for in_slot in &slot.in_slots {
232            if solution.is_slot_solved(*in_slot) {
233                continue;
234            }
235
236            let mut adjacent_stream = match get_adjacent_elements(
237                element_index.clone(),
238                element.clone(),
239                *in_slot,
240                SolveDirection::Inward,
241            )
242            .await
243            {
244                Ok(s) => s,
245                Err(e) => {
246                    cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
247                    return;
248                }
249            };
250            let adjacent_elements = alt_by_slot.entry(*in_slot).or_insert_with(Vec::new);
251
252            while let Some(adjacent_element) = adjacent_stream.next().await {
253                match adjacent_element {
254                    Ok(adjacent_element) => adjacent_elements.push(adjacent_element),
255                    Err(e) => {
256                        cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
257                        return;
258                    }
259                }
260            }
261        }
262
263        let mut pointers = BTreeMap::new();
264
265        for (slot, adjacent_elements) in &mut alt_by_slot {
266            match adjacent_elements.len() {
267                0 => {}
268                1 => {
269                    solution.enqueue_slot(*slot, adjacent_elements.pop().unwrap());
270                }
271                _ => {
272                    pointers.insert(*slot, 0);
273                }
274            }
275        }
276
277        if pointers.is_empty() {
278            continue;
279        }
280
281        let mut permutations = vec![pointers];
282
283        while let Some(mut p) = permutations.pop() {
284            let mut alt_solution = solution.clone();
285
286            for (slot, pointer) in &p {
287                if let Some(adjacent_element) = alt_by_slot.get(slot).unwrap().get(*pointer) {
288                    alt_solution.enqueue_slot(*slot, adjacent_element.clone());
289                }
290            }
291
292            cmd_tx
293                .send(SolutionStreamCommand::Partial(alt_solution))
294                .unwrap();
295
296            for (slot, pointer) in &mut p {
297                if *pointer < alt_by_slot.get(slot).unwrap().len() - 1 {
298                    *pointer += 1;
299                    permutations.push(p);
300                    break;
301                } else {
302                    *pointer = 0;
303                }
304            }
305        }
306    }
307    cmd_tx.send(SolutionStreamCommand::Unsolvable).unwrap();
308}
309
310#[tracing::instrument(skip_all, err)]
311async fn get_adjacent_elements(
312    element_index: Arc<dyn ElementIndex>,
313    element: Arc<Element>,
314    target_slot: usize,
315    direction: SolveDirection,
316) -> Result<ElementStream, IndexError> {
317    match element.as_ref() {
318        Element::Node {
319            metadata,
320            properties: _,
321        } => match direction {
322            SolveDirection::Outward => Ok(element_index
323                .get_slot_elements_by_inbound(target_slot, &metadata.reference)
324                .await?),
325            SolveDirection::Inward => Ok(element_index
326                .get_slot_elements_by_outbound(target_slot, &metadata.reference)
327                .await?),
328        },
329        Element::Relation {
330            metadata: _,
331            in_node,
332            out_node,
333            properties: _,
334        } => {
335            let adjecent_ref = match direction {
336                SolveDirection::Outward => out_node,
337                SolveDirection::Inward => in_node,
338            };
339
340            match element_index
341                .get_slot_element_by_ref(target_slot, adjecent_ref)
342                .await?
343            {
344                Some(adjacent_element) => Ok(Box::pin(tokio_stream::once::<ElementResult>(Ok(
345                    adjacent_element,
346                )))),
347                None => Ok(Box::pin(tokio_stream::empty::<ElementResult>())),
348            }
349        }
350    }
351}
352
353fn merge_node_match<'b>(
354    mtch: &NodeMatch,
355    slots: &'b mut Vec<match_path::MatchPathSlot>,
356    alias_map: &'b mut HashMap<Arc<str>, usize>,
357) -> Result<usize, EvaluationError> {
358    match &mtch.annotation.name {
359        Some(alias) => {
360            if let Some(slot_num) = alias_map.get(alias) {
361                Ok(*slot_num)
362            } else {
363                slots.push(match_path::MatchPathSlot {
364                    spec: match_path::SlotElementSpec::from_node_match(mtch),
365                    in_slots: Vec::new(),
366                    out_slots: Vec::new(),
367                });
368                alias_map.insert(alias.clone(), slots.len() - 1);
369                Ok(slots.len() - 1)
370            }
371        }
372        None => {
373            slots.push(match_path::MatchPathSlot {
374                spec: match_path::SlotElementSpec::from_node_match(mtch),
375                in_slots: Vec::new(),
376                out_slots: Vec::new(),
377            });
378            Ok(slots.len() - 1)
379        }
380    }
381}
382
383fn merge_relation_match<'b>(
384    mtch: &RelationMatch,
385    slots: &'b mut Vec<match_path::MatchPathSlot>,
386    alias_map: &'b mut HashMap<Arc<str>, usize>,
387) -> Result<usize, EvaluationError> {
388    match &mtch.annotation.name {
389        Some(alias) => {
390            if let Some(slot_num) = alias_map.get(alias) {
391                Ok(*slot_num)
392            } else {
393                slots.push(match_path::MatchPathSlot {
394                    spec: match_path::SlotElementSpec::from_relation_match(mtch),
395                    in_slots: Vec::new(),
396                    out_slots: Vec::new(),
397                });
398                alias_map.insert(alias.clone(), slots.len() - 1);
399                Ok(slots.len() - 1)
400            }
401        }
402        None => {
403            slots.push(match_path::MatchPathSlot {
404                spec: match_path::SlotElementSpec::from_relation_match(mtch),
405                in_slots: Vec::new(),
406                out_slots: Vec::new(),
407            });
408            Ok(slots.len() - 1)
409        }
410    }
411}