astro_run_scheduler/
default.rs

1use crate::{RunnerMetadata, Scheduler};
2use astro_run::{Context, JobId, JobRunResult, StepId, StepRunResult};
3use parking_lot::Mutex;
4use std::{collections::HashMap, sync::Arc};
5
6#[derive(Clone)]
7pub struct SchedulerState {
8  /// Runner ID -> Job runs count
9  pub runs_count: HashMap<String, i32>,
10  /// Job ID -> Runner ID
11  pub job_runners: HashMap<JobId, String>,
12  /// Step ID -> Runner ID
13  pub step_runners: HashMap<StepId, String>,
14}
15
16#[derive(Clone)]
17pub struct DefaultScheduler {
18  pub state: Arc<Mutex<SchedulerState>>,
19}
20
21impl DefaultScheduler {
22  pub fn new() -> Self {
23    Self::default()
24  }
25}
26
27impl Default for DefaultScheduler {
28  fn default() -> Self {
29    Self {
30      state: Arc::new(Mutex::new(SchedulerState {
31        runs_count: HashMap::new(),
32        job_runners: HashMap::new(),
33        step_runners: HashMap::new(),
34      })),
35    }
36  }
37}
38
39#[astro_run::async_trait]
40impl Scheduler for DefaultScheduler {
41  async fn schedule<'a, 'b: 'a>(
42    &'b self,
43    runners: &'a [RunnerMetadata],
44    ctx: &Context,
45  ) -> Option<&'a RunnerMetadata> {
46    log::trace!("Scheduling runners: {:?}", runners);
47    let mut runner: Option<&'a RunnerMetadata> = None;
48
49    let job_id = ctx.command.id.job_id();
50    let container_name = ctx.command.container.clone().map(|c| c.name);
51    let is_runs_on_host = container_name
52      .clone()
53      .map(|c| c.starts_with("host/"))
54      .unwrap_or(false);
55
56    log::trace!("Is runs on host: {}", is_runs_on_host);
57
58    let last_used_id = self.state.lock().job_runners.get(&job_id).cloned();
59
60    if let Some(last_used_id) = last_used_id {
61      runner = runners.iter().find(|r| {
62        if r.id == last_used_id {
63          if is_runs_on_host {
64            let container_name = container_name.clone().unwrap();
65            return r.support_host && container_name == format!("host/{}", r.os)
66              || container_name == format!("host/{}-{}", r.os, r.arch);
67          }
68
69          return true;
70        }
71
72        false
73      });
74
75      log::trace!("Last used runner: {:?}", runner);
76    }
77
78    if runner.is_none() {
79      runner = self.pick_runner(runners, container_name);
80      log::trace!("Picked runner: {:?}", runner);
81    }
82
83    if let Some(runner) = &runner {
84      let mut state = self.state.lock();
85      state
86        .step_runners
87        .insert(ctx.command.id.clone(), runner.id.clone());
88      // Update runs count
89      let runs_count = state.runs_count.entry(runner.id.clone()).or_insert(0);
90      *runs_count += 1;
91
92      if !is_runs_on_host {
93        // Update job runner
94        state.job_runners.insert(job_id, runner.id.clone());
95      }
96
97      log::trace!("Runs count: {:?}", state.runs_count);
98    }
99
100    runner
101  }
102
103  fn on_step_completed(&self, result: StepRunResult) {
104    let mut state = self.state.lock();
105    let step_id = result.id;
106    let runner_id = state.step_runners.get(&step_id).cloned();
107
108    if let Some(runner_id) = runner_id {
109      let runs_count = state
110        .runs_count
111        .entry(runner_id.clone())
112        .and_modify(|c| *c -= 1)
113        .or_insert(0);
114
115      if *runs_count <= 0 {
116        state.runs_count.remove(&runner_id);
117      }
118    }
119
120    state.step_runners.remove(&step_id);
121  }
122
123  fn on_job_completed(&self, result: JobRunResult) {
124    let mut state = self.state.lock();
125    let job_id = result.id;
126
127    state.job_runners.remove(&job_id);
128  }
129}
130
131impl DefaultScheduler {
132  fn pick_runner<'a, 'b: 'a>(
133    &'b self,
134    runners: &'a [RunnerMetadata],
135    container: Option<String>,
136  ) -> Option<&'a RunnerMetadata> {
137    let is_runs_on_host = container
138      .clone()
139      .map(|c| c.starts_with("host/"))
140      .unwrap_or(false);
141
142    if is_runs_on_host {
143      self.pick_host_runner(runners, container.unwrap())
144    } else {
145      self.pick_docker_runner(runners)
146    }
147  }
148
149  fn pick_docker_runner<'a, 'b: 'a>(
150    &'b self,
151    runners: &'a [RunnerMetadata],
152  ) -> Option<&'a RunnerMetadata> {
153    let runs_count = self.state.lock().runs_count.clone();
154    let min_runs = runners
155      .iter()
156      .filter(|r| r.support_docker)
157      .min_by_key(|r| runs_count.get(&r.id).unwrap_or(&0));
158
159    min_runs
160  }
161
162  fn pick_host_runner<'a, 'b: 'a>(
163    &'b self,
164    runners: &'a [RunnerMetadata],
165    container: String,
166  ) -> Option<&'a RunnerMetadata> {
167    runners.iter().filter(|r| r.support_host).find(|r| {
168      container == format!("host/{}", r.os) || container == format!("host/{}-{}", r.os, r.arch)
169    })
170  }
171}