Skip to main content

drasi_core/path_solver/
mod.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![allow(clippy::unwrap_used)]
16#![allow(dead_code)]
17pub mod match_path;
18pub mod solution;
19
20use std::{
21    collections::{BTreeMap, HashMap, HashSet},
22    sync::Arc,
23};
24
25use async_stream::stream;
26use drasi_query_ast::ast::{NodeMatch, RelationMatch};
27use futures::{stream::StreamExt, Stream};
28#[allow(unused_imports)]
29use tokio::{
30    sync::Semaphore,
31    task::{JoinError, JoinHandle},
32};
33
34use self::solution::MatchPathSolution;
35use crate::{
36    evaluation::{context::QueryVariables, EvaluationError},
37    interface::{ElementIndex, ElementResult, ElementStream, IndexError, QueryClock},
38    models::Element,
39};
40
41#[cfg(feature = "parallel_solver")]
42const MAX_CONCURRENT_SOLUTIONS: usize = 10;
43
44#[derive(Debug, Clone)]
45pub struct MatchSolveContext<'a> {
46    pub variables: &'a QueryVariables,
47    pub clock: Arc<dyn QueryClock>,
48}
49
50impl<'a> MatchSolveContext<'a> {
51    pub fn new(variables: &'a QueryVariables, clock: Arc<dyn QueryClock>) -> MatchSolveContext<'a> {
52        MatchSolveContext { variables, clock }
53    }
54}
55
56enum SolveDirection {
57    Outward,
58    Inward,
59}
60
61pub struct MatchPathSolver {
62    element_index: Arc<dyn ElementIndex>,
63}
64
65impl MatchPathSolver {
66    pub fn new(element_index: Arc<dyn ElementIndex>) -> MatchPathSolver {
67        MatchPathSolver { element_index }
68    }
69
70    #[tracing::instrument(skip_all, err, level = "debug")]
71    pub async fn solve(
72        &self,
73        path: Arc<match_path::MatchPath>,
74        anchor_element: Arc<Element>,
75        anchor_slot: usize,
76    ) -> Result<HashMap<u64, solution::MatchPathSolution>, EvaluationError> {
77        let total_slots = path.slots.len();
78        let mut start_solution = MatchPathSolution::new(total_slots, anchor_slot);
79        start_solution.enqueue_slot(anchor_slot, Some(anchor_element));
80
81        let sol_stream =
82            create_solution_stream(start_solution, path.clone(), self.element_index.clone()).await;
83
84        let mut result = HashMap::new();
85        tokio::pin!(sol_stream);
86
87        while let Some(o) = sol_stream.next().await {
88            match o {
89                Ok((hash, solution)) => {
90                    result.insert(hash, solution);
91                }
92                Err(e) => return Err(e),
93            }
94        }
95
96        Ok(result)
97    }
98}
99
100#[allow(dead_code)]
101enum SolutionStreamCommand {
102    Partial(MatchPathSolution),
103    Complete((u64, MatchPathSolution)),
104    Error(EvaluationError),
105    Panic(JoinError),
106    Unsolvable,
107}
108
109async fn create_solution_stream(
110    initial_sol: MatchPathSolution,
111    path: Arc<match_path::MatchPath>,
112    element_index: Arc<dyn ElementIndex>,
113) -> impl Stream<Item = Result<(u64, MatchPathSolution), EvaluationError>> {
114    #[cfg(feature = "parallel_solver")]
115    let permits = Arc::new(Semaphore::new(MAX_CONCURRENT_SOLUTIONS));
116
117    stream! {
118        let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<SolutionStreamCommand>();
119        cmd_tx.send(SolutionStreamCommand::Partial(initial_sol)).unwrap();
120        let mut inflight = 0;
121
122        #[cfg(feature = "parallel_solver")]
123        let (task_tx, mut task_rx) = tokio::sync::mpsc::unbounded_channel::<JoinHandle<()>>();
124
125        #[cfg(feature = "parallel_solver")]
126        {
127            let cmd_tx2 = cmd_tx.clone();
128            tokio::spawn(async move {
129                while let Some(task) = task_rx.recv().await {
130                    if let Err(err) = task.await {
131                        cmd_tx2.send(SolutionStreamCommand::Panic(err)).unwrap();
132                    }
133                }
134            });
135        }
136
137        while let Some(cmd) = cmd_rx.recv().await {
138            match cmd {
139                SolutionStreamCommand::Partial(solution) => {
140                    inflight += 1;
141                    let path = path.clone();
142                    let element_index = element_index.clone();
143                    let cmd_tx = cmd_tx.clone();
144
145                    #[cfg(not(feature = "parallel_solver"))]
146                    try_complete_solution(solution, path, element_index, cmd_tx).await;
147
148                    #[cfg(feature = "parallel_solver")]
149                    {
150                        let permits = permits.clone();
151                        let task = tokio::spawn(async move {
152                            let _permit = permits.acquire().await.unwrap();
153                            try_complete_solution(solution, path, element_index, cmd_tx).await;
154                        });
155                        task_tx.send(task).unwrap();
156                    }
157                },
158                SolutionStreamCommand::Complete((hash, solution)) => {
159                    inflight -= 1;
160                    yield Ok((hash, solution));
161                },
162                SolutionStreamCommand::Error(e) => {
163                    yield Err(e);
164                    break;
165                },
166                SolutionStreamCommand::Panic(e) => {
167                    panic!("Error in solution task: {e:?}");
168                },
169                SolutionStreamCommand::Unsolvable => {
170                    inflight -= 1;
171                },
172            }
173
174            if inflight == 0 {
175                break;
176            }
177        }
178    }
179}
180
181#[tracing::instrument(skip_all, level = "debug")]
182async fn try_complete_solution(
183    mut solution: MatchPathSolution,
184    path: Arc<match_path::MatchPath>,
185    element_index: Arc<dyn ElementIndex>,
186    cmd_tx: tokio::sync::mpsc::UnboundedSender<SolutionStreamCommand>,
187) {
188    while let Some((slot_num, element)) = solution.slot_cursors.pop_front() {
189        solution.mark_slot_solved(slot_num, element.clone());
190
191        if let Some(hash) = solution.get_solution_signature() {
192            cmd_tx
193                .send(SolutionStreamCommand::Complete((hash, solution)))
194                .unwrap();
195            return;
196        }
197
198        let slot = &path.slots[slot_num];
199        let mut alt_by_slot = HashMap::new();
200
201        for out_slot in &slot.out_slots {
202            if solution.is_slot_solved(*out_slot) {
203                continue;
204            }
205
206            let adjacent_elements = alt_by_slot.entry(*out_slot).or_insert_with(Vec::new);
207            let mut found_adjacent = false;
208
209            if let Some(element) = &element {
210                let mut adjacent_stream = match get_adjacent_elements(
211                    element_index.clone(),
212                    element.clone(),
213                    *out_slot,
214                    SolveDirection::Outward,
215                )
216                .await
217                {
218                    Ok(s) => s,
219                    Err(e) => {
220                        cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
221                        return;
222                    }
223                };
224
225                while let Some(adjacent_element) = adjacent_stream.next().await {
226                    found_adjacent = true;
227                    match adjacent_element {
228                        Ok(adjacent_element) => adjacent_elements.push(Some(adjacent_element)),
229                        Err(e) => {
230                            cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
231                            return;
232                        }
233                    }
234                }
235            }
236
237            if path.slots[*out_slot].optional && !found_adjacent {
238                adjacent_elements.push(None);
239            }
240        }
241
242        for in_slot in &slot.in_slots {
243            if solution.is_slot_solved(*in_slot) {
244                continue;
245            }
246
247            let adjacent_elements = alt_by_slot.entry(*in_slot).or_insert_with(Vec::new);
248            let mut found_adjacent = false;
249
250            if let Some(element) = &element {
251                let mut adjacent_stream = match get_adjacent_elements(
252                    element_index.clone(),
253                    element.clone(),
254                    *in_slot,
255                    SolveDirection::Inward,
256                )
257                .await
258                {
259                    Ok(s) => s,
260                    Err(e) => {
261                        cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
262                        return;
263                    }
264                };
265
266                while let Some(adjacent_element) = adjacent_stream.next().await {
267                    found_adjacent = true;
268                    match adjacent_element {
269                        Ok(adjacent_element) => adjacent_elements.push(Some(adjacent_element)),
270                        Err(e) => {
271                            cmd_tx.send(SolutionStreamCommand::Error(e.into())).unwrap();
272                            return;
273                        }
274                    }
275                }
276            }
277
278            if path.slots[*in_slot].optional && !found_adjacent {
279                adjacent_elements.push(None);
280            }
281        }
282
283        let mut pointers = BTreeMap::new();
284
285        for (slot, adjacent_elements) in &mut alt_by_slot {
286            match adjacent_elements.len() {
287                0 => {}
288                1 => {
289                    solution.enqueue_slot(*slot, adjacent_elements.pop().unwrap());
290                }
291                _ => {
292                    pointers.insert(*slot, 0);
293                }
294            }
295        }
296
297        if pointers.is_empty() {
298            continue;
299        }
300
301        let mut permutations = vec![pointers];
302
303        while let Some(mut p) = permutations.pop() {
304            let mut alt_solution = solution.clone();
305
306            for (slot, pointer) in &p {
307                if let Some(adjacent_element) = alt_by_slot.get(slot).unwrap().get(*pointer) {
308                    alt_solution.enqueue_slot(*slot, adjacent_element.clone());
309                }
310            }
311
312            cmd_tx
313                .send(SolutionStreamCommand::Partial(alt_solution))
314                .unwrap();
315
316            for (slot, pointer) in &mut p {
317                if *pointer < alt_by_slot.get(slot).unwrap().len() - 1 {
318                    *pointer += 1;
319                    permutations.push(p);
320                    break;
321                } else {
322                    *pointer = 0;
323                }
324            }
325        }
326    }
327    cmd_tx.send(SolutionStreamCommand::Unsolvable).unwrap();
328}
329
330#[tracing::instrument(skip_all, err, level = "debug")]
331async fn get_adjacent_elements(
332    element_index: Arc<dyn ElementIndex>,
333    element: Arc<Element>,
334    target_slot: usize,
335    direction: SolveDirection,
336) -> Result<ElementStream, IndexError> {
337    match element.as_ref() {
338        Element::Node {
339            metadata,
340            properties: _,
341        } => match direction {
342            SolveDirection::Outward => Ok(element_index
343                .get_slot_elements_by_inbound(target_slot, &metadata.reference)
344                .await?),
345            SolveDirection::Inward => Ok(element_index
346                .get_slot_elements_by_outbound(target_slot, &metadata.reference)
347                .await?),
348        },
349        Element::Relation {
350            metadata: _,
351            in_node,
352            out_node,
353            properties: _,
354        } => {
355            let adjacent_ref = match direction {
356                SolveDirection::Outward => out_node,
357                SolveDirection::Inward => in_node,
358            };
359
360            match element_index
361                .get_slot_element_by_ref(target_slot, adjacent_ref)
362                .await?
363            {
364                Some(adjacent_element) => Ok(Box::pin(tokio_stream::once::<ElementResult>(Ok(
365                    adjacent_element,
366                )))),
367                None => Ok(Box::pin(tokio_stream::empty::<ElementResult>())),
368            }
369        }
370    }
371}
372
373fn merge_node_match<'b>(
374    mtch: &NodeMatch,
375    slots: &'b mut Vec<match_path::MatchPathSlot>,
376    alias_map: &'b mut HashMap<Arc<str>, usize>,
377    path_index: usize,
378    optional: bool,
379) -> Result<usize, EvaluationError> {
380    match &mtch.annotation.name {
381        Some(alias) => {
382            if let Some(slot_num) = alias_map.get(alias) {
383                slots[*slot_num].optional = optional && slots[*slot_num].optional;
384                slots[*slot_num].paths.insert(path_index);
385                Ok(*slot_num)
386            } else {
387                slots.push(match_path::MatchPathSlot {
388                    spec: match_path::SlotElementSpec::from_node_match(mtch),
389                    in_slots: Vec::new(),
390                    out_slots: Vec::new(),
391                    optional,
392                    paths: HashSet::from([path_index]),
393                });
394                alias_map.insert(alias.clone(), slots.len() - 1);
395                Ok(slots.len() - 1)
396            }
397        }
398        None => {
399            slots.push(match_path::MatchPathSlot {
400                spec: match_path::SlotElementSpec::from_node_match(mtch),
401                in_slots: Vec::new(),
402                out_slots: Vec::new(),
403                optional,
404                paths: HashSet::from([path_index]),
405            });
406            Ok(slots.len() - 1)
407        }
408    }
409}
410
411fn merge_relation_match<'b>(
412    mtch: &RelationMatch,
413    slots: &'b mut Vec<match_path::MatchPathSlot>,
414    alias_map: &'b mut HashMap<Arc<str>, usize>,
415    path_index: usize,
416    optional: bool,
417) -> Result<usize, EvaluationError> {
418    match &mtch.annotation.name {
419        Some(alias) => {
420            if let Some(slot_num) = alias_map.get(alias) {
421                slots[*slot_num].optional = optional && slots[*slot_num].optional;
422                slots[*slot_num].paths.insert(path_index);
423                Ok(*slot_num)
424            } else {
425                slots.push(match_path::MatchPathSlot {
426                    spec: match_path::SlotElementSpec::from_relation_match(mtch),
427                    in_slots: Vec::new(),
428                    out_slots: Vec::new(),
429                    optional,
430                    paths: HashSet::from([path_index]),
431                });
432                alias_map.insert(alias.clone(), slots.len() - 1);
433                Ok(slots.len() - 1)
434            }
435        }
436        None => {
437            slots.push(match_path::MatchPathSlot {
438                spec: match_path::SlotElementSpec::from_relation_match(mtch),
439                in_slots: Vec::new(),
440                out_slots: Vec::new(),
441                optional,
442                paths: HashSet::from([path_index]),
443            });
444            Ok(slots.len() - 1)
445        }
446    }
447}