astro_run_scheduler/
default.rs1use crate::{RunnerMetadata, Scheduler};
2use astro_run::{Context, JobId, JobRunResult, StepId, StepRunResult};
3use parking_lot::Mutex;
4use std::{collections::HashMap, sync::Arc};
5
6#[derive(Clone)]
7pub struct SchedulerState {
8 pub runs_count: HashMap<String, i32>,
10 pub job_runners: HashMap<JobId, String>,
12 pub step_runners: HashMap<StepId, String>,
14}
15
16#[derive(Clone)]
17pub struct DefaultScheduler {
18 pub state: Arc<Mutex<SchedulerState>>,
19}
20
21impl DefaultScheduler {
22 pub fn new() -> Self {
23 Self::default()
24 }
25}
26
27impl Default for DefaultScheduler {
28 fn default() -> Self {
29 Self {
30 state: Arc::new(Mutex::new(SchedulerState {
31 runs_count: HashMap::new(),
32 job_runners: HashMap::new(),
33 step_runners: HashMap::new(),
34 })),
35 }
36 }
37}
38
39#[astro_run::async_trait]
40impl Scheduler for DefaultScheduler {
41 async fn schedule<'a, 'b: 'a>(
42 &'b self,
43 runners: &'a [RunnerMetadata],
44 ctx: &Context,
45 ) -> Option<&'a RunnerMetadata> {
46 log::trace!("Scheduling runners: {:?}", runners);
47 let mut runner: Option<&'a RunnerMetadata> = None;
48
49 let job_id = ctx.command.id.job_id();
50 let container_name = ctx.command.container.clone().map(|c| c.name);
51 let is_runs_on_host = container_name
52 .clone()
53 .map(|c| c.starts_with("host/"))
54 .unwrap_or(false);
55
56 log::trace!("Is runs on host: {}", is_runs_on_host);
57
58 let last_used_id = self.state.lock().job_runners.get(&job_id).cloned();
59
60 if let Some(last_used_id) = last_used_id {
61 runner = runners.iter().find(|r| {
62 if r.id == last_used_id {
63 if is_runs_on_host {
64 let container_name = container_name.clone().unwrap();
65 return r.support_host && container_name == format!("host/{}", r.os)
66 || container_name == format!("host/{}-{}", r.os, r.arch);
67 }
68
69 return true;
70 }
71
72 false
73 });
74
75 log::trace!("Last used runner: {:?}", runner);
76 }
77
78 if runner.is_none() {
79 runner = self.pick_runner(runners, container_name);
80 log::trace!("Picked runner: {:?}", runner);
81 }
82
83 if let Some(runner) = &runner {
84 let mut state = self.state.lock();
85 state
86 .step_runners
87 .insert(ctx.command.id.clone(), runner.id.clone());
88 let runs_count = state.runs_count.entry(runner.id.clone()).or_insert(0);
90 *runs_count += 1;
91
92 if !is_runs_on_host {
93 state.job_runners.insert(job_id, runner.id.clone());
95 }
96
97 log::trace!("Runs count: {:?}", state.runs_count);
98 }
99
100 runner
101 }
102
103 fn on_step_completed(&self, result: StepRunResult) {
104 let mut state = self.state.lock();
105 let step_id = result.id;
106 let runner_id = state.step_runners.get(&step_id).cloned();
107
108 if let Some(runner_id) = runner_id {
109 let runs_count = state
110 .runs_count
111 .entry(runner_id.clone())
112 .and_modify(|c| *c -= 1)
113 .or_insert(0);
114
115 if *runs_count <= 0 {
116 state.runs_count.remove(&runner_id);
117 }
118 }
119
120 state.step_runners.remove(&step_id);
121 }
122
123 fn on_job_completed(&self, result: JobRunResult) {
124 let mut state = self.state.lock();
125 let job_id = result.id;
126
127 state.job_runners.remove(&job_id);
128 }
129}
130
131impl DefaultScheduler {
132 fn pick_runner<'a, 'b: 'a>(
133 &'b self,
134 runners: &'a [RunnerMetadata],
135 container: Option<String>,
136 ) -> Option<&'a RunnerMetadata> {
137 let is_runs_on_host = container
138 .clone()
139 .map(|c| c.starts_with("host/"))
140 .unwrap_or(false);
141
142 if is_runs_on_host {
143 self.pick_host_runner(runners, container.unwrap())
144 } else {
145 self.pick_docker_runner(runners)
146 }
147 }
148
149 fn pick_docker_runner<'a, 'b: 'a>(
150 &'b self,
151 runners: &'a [RunnerMetadata],
152 ) -> Option<&'a RunnerMetadata> {
153 let runs_count = self.state.lock().runs_count.clone();
154 let min_runs = runners
155 .iter()
156 .filter(|r| r.support_docker)
157 .min_by_key(|r| runs_count.get(&r.id).unwrap_or(&0));
158
159 min_runs
160 }
161
162 fn pick_host_runner<'a, 'b: 'a>(
163 &'b self,
164 runners: &'a [RunnerMetadata],
165 container: String,
166 ) -> Option<&'a RunnerMetadata> {
167 runners.iter().filter(|r| r.support_host).find(|r| {
168 container == format!("host/{}", r.os) || container == format!("host/{}-{}", r.os, r.arch)
169 })
170 }
171}