assemble_freight/core/
task_resolver.rs

1use crate::core::ConstructionError;
2
3use crate::consts::EXEC_GRAPH_LOG_LEVEL;
4use assemble_core::identifier::{ProjectId, TaskId};
5use assemble_core::project::buildable::Buildable;
6use assemble_core::project::error::ProjectResult;
7use assemble_core::project::requests::TaskRequests;
8use assemble_core::project::{GetProjectId, Project};
9use assemble_core::task::task_container::TaskContainer;
10use assemble_core::task::{FullTask, TaskOrderingKind};
11use colored::Colorize;
12use petgraph::prelude::*;
13
14use assemble_core::dependencies::project_dependency::ProjectDependencyPlugin;
15use assemble_core::error::PayloadError;
16use assemble_core::prelude::ProjectError;
17use assemble_core::project::finder::{
18    ProjectFinder, ProjectPathBuf, TaskFinder, TaskPath, TaskPathBuf,
19};
20use assemble_core::project::shared::SharedProject;
21use assemble_core::startup::execution_graph::{ExecutionGraph, SharedAnyTask};
22use itertools::Itertools;
23use parking_lot::RwLock;
24use std::collections::{HashMap, HashSet, VecDeque};
25use std::fmt::Debug;
26use std::sync::Arc;
27
28/// Resolves tasks
29pub struct TaskResolver {
30    project: SharedProject,
31}
32
33impl TaskResolver {
34    /// Create a new instance of a task resolver for a project
35    pub fn new(project: &SharedProject) -> Self {
36        Self {
37            project: project.clone(),
38        }
39    }
40
41    pub fn find_task(
42        &self,
43        task_id: &TaskId,
44    ) -> Result<Box<dyn FullTask>, PayloadError<ConstructionError>> {
45        let project_id = task_id.project_id();
46        match project_id {
47            None => {
48                panic!("task {} has no parent", task_id);
49            }
50            Some(project) => {
51                let mut ptr = self.project.clone();
52                let mut iter = project.iter();
53                let first = iter.next().unwrap();
54                if ptr.project_id() != first {
55                    return Err(
56                        ConstructionError::ProjectError(ProjectError::NoSharedProjectSet).into(),
57                    );
58                }
59                for id in iter {
60                    ptr = ptr.get_subproject(id).map_err(PayloadError::into)?;
61                }
62
63                let config_info = ptr
64                    .get_task(task_id)
65                    .map_err(PayloadError::into)?
66                    .resolve_shared(&self.project)
67                    .map_err(PayloadError::into)?;
68
69                Ok(config_info)
70            }
71        }
72    }
73
74    /// Create a task resolver using the given set of tasks as a starting point. Not all tasks
75    /// registered to the project will be added to the tasks,
76    /// just the ones that are required for the specified tasks to be ran.
77    ///
78    /// # Error
79    /// Will return Err() if any of the [`ExecutionGraph`](ExecutionGraph) rules are invalidated.
80    ///
81    /// # Example
82    /// ```rust
83    /// # use assemble_core::Project;
84    /// use assemble_core::defaults::tasks::Empty;
85    /// # let mut project = Project::temp(None);
86    /// project.register_task::<Empty>("task1").unwrap();
87    /// project.register_task::<Empty>("task2").unwrap().configure_with(|task, _| {
88    ///     task.depends_on("task1");
89    ///     Ok(())
90    /// }).unwrap();
91    /// ```
92    pub fn to_execution_graph(
93        self,
94        tasks: TaskRequests,
95    ) -> Result<ExecutionGraph, PayloadError<ConstructionError>> {
96        let mut task_id_graph = TaskIdentifierGraph::new();
97
98        let mut task_queue: VecDeque<TaskId> = VecDeque::new();
99        let requested = tasks.requested_tasks().to_vec();
100        task_queue.extend(requested);
101        log!(
102            EXEC_GRAPH_LOG_LEVEL,
103            "task queue at start: {:?}",
104            task_queue
105        );
106
107        let mut visited = HashSet::new();
108
109        while let Some(task_id) = task_queue.pop_front() {
110            if visited.contains(&task_id) {
111                log!(
112                    EXEC_GRAPH_LOG_LEVEL,
113                    "task {task_id} already visited, skipping..."
114                );
115                continue;
116            }
117
118            if !task_id_graph.contains_id(&task_id) {
119                log!(EXEC_GRAPH_LOG_LEVEL, "adding {} to task graph", task_id);
120                task_id_graph.add_id(task_id.clone());
121            }
122            visited.insert(task_id.clone());
123
124            let config_info = self.find_task(&task_id)?;
125
126            log!(
127                EXEC_GRAPH_LOG_LEVEL,
128                "got configured info: {:#?}",
129                config_info
130            );
131            for ordering in config_info.ordering() {
132                let buildable = ordering.buildable();
133                let dependencies = self
134                    .project
135                    .with(|p| buildable.get_dependencies(p))
136                    .map_err(PayloadError::into)?;
137
138                log!(
139                    EXEC_GRAPH_LOG_LEVEL,
140                    "[{:^20}] adding dependencies from {:?} -> {:#?}",
141                    task_id.to_string().italic(),
142                    buildable,
143                    dependencies
144                );
145
146                for next_id in dependencies {
147                    if !task_id_graph.contains_id(&next_id) {
148                        log!(EXEC_GRAPH_LOG_LEVEL, "adding {} to task graph", task_id);
149                        task_id_graph.add_id(next_id.clone());
150                    }
151
152                    log!(
153                        EXEC_GRAPH_LOG_LEVEL,
154                        "creating task dependency from {} to {} with kind {:?}",
155                        task_id,
156                        next_id,
157                        ordering.ordering_kind()
158                    );
159
160                    log!(EXEC_GRAPH_LOG_LEVEL, "adding {} to resolve queue", next_id);
161                    task_queue.push_back(next_id.clone());
162                    task_id_graph.add_task_ordering(
163                        task_id.clone(),
164                        next_id.clone(),
165                        *ordering.ordering_kind(),
166                    );
167                    log!(EXEC_GRAPH_LOG_LEVEL, "task_id_graph: {:#?}", task_id_graph);
168                }
169            }
170        }
171        debug!("Attempting to create execution graph.");
172        let execution_graph = task_id_graph.map_with(self.project.clone())?;
173        Ok(ExecutionGraph::new(execution_graph, tasks))
174    }
175}
176
177// impl Debug for ExecutionGraph {
178//     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
179//         f.debug_struct("ExecutionGraph")
180//             .field("requested_tasks", &self.requested_tasks)
181//             .finish_non_exhaustive()
182//     }
183// }
184#[derive(Debug)]
185struct TaskIdentifierGraph {
186    graph: DiGraph<TaskId, TaskOrderingKind>,
187    index_to_id: HashMap<TaskId, NodeIndex>,
188}
189
190impl TaskIdentifierGraph {
191    fn new() -> Self {
192        Self {
193            graph: DiGraph::new(),
194            index_to_id: HashMap::new(),
195        }
196    }
197
198    fn add_id(&mut self, id: TaskId) {
199        let index = self.graph.add_node(id.clone());
200        self.index_to_id.insert(id, index);
201    }
202
203    fn contains_id(&self, id: &TaskId) -> bool {
204        self.index_to_id.contains_key(id)
205    }
206
207    fn add_task_ordering(
208        &mut self,
209        from_id: TaskId,
210        to_id: TaskId,
211        dependency_type: TaskOrderingKind,
212    ) {
213        let from = self.index_to_id[&from_id];
214        let to = self.index_to_id[&to_id];
215        self.graph.add_edge(from, to, dependency_type);
216    }
217
218    fn map_with(
219        self,
220        project: SharedProject,
221    ) -> Result<DiGraph<SharedAnyTask, TaskOrderingKind>, PayloadError<ConstructionError>> {
222        trace!("creating digraph from TaskIdentifierGraph");
223        let input = self.graph;
224
225        let mut mapping = Vec::new();
226
227        let finder = ProjectFinder::new(&project);
228
229        for node in input.node_indices() {
230            let id = &input[node];
231            let project: ProjectPathBuf = id.project_id().unwrap().into();
232
233            let project = finder
234                .find(&project)
235                .unwrap_or_else(|| panic!("no project found for name {:?}", project));
236
237            let mut task = project.get_task(id).map_err(PayloadError::into)?;
238            let task = task.resolve_shared(&project).map_err(PayloadError::into)?;
239            mapping.push((task, node));
240        }
241
242        let mut output: DiGraph<SharedAnyTask, TaskOrderingKind> =
243            DiGraph::with_capacity(input.node_count(), input.edge_count());
244        let mut output_mapping = HashMap::new();
245
246        for (exec, index) in mapping {
247            let output_index = output.add_node(Arc::new(RwLock::new(exec)));
248            output_mapping.insert(index, output_index);
249        }
250
251        for old_index in input.node_indices() {
252            let new_index_from = output_mapping[&old_index];
253            for outgoing in input.edges(old_index) {
254                let weight = *outgoing.weight();
255                let new_index_to = output_mapping[&outgoing.target()];
256                output.add_edge(new_index_from, new_index_to, weight);
257            }
258        }
259        Ok(output)
260    }
261}