planetary_orchestrator/
lib.rs

1//! Implements the Planetary task orchestrator.
2
3use std::future::Future;
4use std::sync::Arc;
5use std::time::Duration;
6
7use axum::Router;
8use axum::extract::Request;
9use axum::extract::State as ExtractState;
10use axum::middleware;
11use axum::middleware::Next;
12use axum::response::IntoResponse;
13use axum::response::Response;
14use axum::routing::delete;
15use axum::routing::get;
16use axum::routing::patch;
17use axum::routing::post;
18use axum::routing::put;
19use axum_extra::TypedHeader;
20use axum_extra::headers::Authorization;
21use axum_extra::headers::authorization::Bearer;
22use bon::Builder;
23use planetary_db::Database;
24use planetary_db::TaskIo;
25use planetary_server::DEFAULT_ADDRESS;
26use planetary_server::DEFAULT_PORT;
27use planetary_server::Error;
28use planetary_server::Json;
29use planetary_server::Path;
30use planetary_server::ServerResponse;
31use secrecy::ExposeSecret;
32use secrecy::SecretString;
33use tes::v1::types::responses::OutputFile;
34use tokio_retry2::RetryError;
35use tokio_retry2::strategy::ExponentialFactorBackoff;
36use tokio_retry2::strategy::MaxInterval;
37use tracing::warn;
38use url::Url;
39
40use crate::orchestrator::Monitor;
41use crate::orchestrator::PreemptibleConfig;
42use crate::orchestrator::TaskOrchestrator;
43use crate::orchestrator::TransporterInfo;
44
45mod orchestrator;
46
47/// Gets an iterator over the retry durations for network operations.
48///
49/// Retries use an exponential power of 2 backoff, starting at 1 second with
50/// a maximum duration of 60 seconds.
51fn retry_durations() -> impl Iterator<Item = Duration> {
52    const INITIAL_DELAY_MILLIS: u64 = 1000;
53    const BASE_FACTOR: f64 = 2.0;
54    const MAX_DURATION: Duration = Duration::from_secs(60);
55    const RETRIES: usize = 5;
56
57    ExponentialFactorBackoff::from_millis(INITIAL_DELAY_MILLIS, BASE_FACTOR)
58        .max_duration(MAX_DURATION)
59        .take(RETRIES)
60}
61
62/// Helper for notifying that a network operation failed and will be retried.
63fn notify_retry(e: &kube::Error, duration: Duration) {
64    warn!(
65        "network operation failed: {e} (retrying after {duration} seconds)",
66        duration = duration.as_secs()
67    );
68}
69
70/// Converts a Kubernetes error into a retry error.
71fn into_retry_error(e: kube::Error) -> RetryError<kube::Error> {
72    match e {
73        kube::Error::Api(kube::core::ErrorResponse { code, .. }) if code >= 500 => {
74            RetryError::transient(e)
75        }
76        kube::Error::HyperError(_)
77        | kube::Error::Service(_)
78        | kube::Error::ReadEvents(_)
79        | kube::Error::HttpError(_)
80        | kube::Error::Discovery(_) => RetryError::transient(e),
81        _ => RetryError::permanent(e),
82    }
83}
84
85/// Middleware function to perform bearer token auth against the service's API
86/// key.
87async fn auth(
88    axum::extract::State(state): axum::extract::State<Arc<State>>,
89    TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
90    request: Request,
91    next: Next,
92) -> Response {
93    if authorization.token() != state.service_api_key.expose_secret() {
94        return Error::forbidden().into_response();
95    }
96
97    next.run(request).await
98}
99
100/// The state for the server.
101struct State {
102    /// The API key of the service.
103    service_api_key: SecretString,
104    /// The task orchestrator.
105    orchestrator: TaskOrchestrator,
106}
107
108/// The task orchestrator server.
109#[derive(Clone, Builder)]
110pub struct Server {
111    /// The address to bind the server to.
112    #[builder(into, default = DEFAULT_ADDRESS)]
113    address: String,
114
115    /// The port to bind the server to.
116    #[builder(into, default = DEFAULT_PORT)]
117    port: u16,
118
119    /// The pod name of the orchestrator.
120    #[builder(into)]
121    pod_name: String,
122
123    /// The URL of the orchestrator service.
124    #[builder(into)]
125    service_url: Url,
126
127    /// The API key of the orchestrator service.
128    #[builder(into)]
129    service_api_key: SecretString,
130
131    /// The TES database to use for the server.
132    #[builder(name = "shared_database")]
133    database: Arc<dyn Database>,
134
135    /// The Kubernetes storage class to use for tasks.
136    #[builder(into)]
137    storage_class: Option<String>,
138
139    /// The transporter image to use.
140    ///
141    /// Defaults to `stjude-rust-labs/planetary-transporter:latest`.
142    #[builder(into)]
143    transporter_image: Option<String>,
144
145    /// The Kubernetes namespace to use for TES task resources.
146    ///
147    /// Defaults to `planetary-tasks`.
148    #[builder(into)]
149    tasks_namespace: Option<String>,
150
151    /// The number of CPU cores to request for transporter pods.
152    ///
153    /// Defaults to `4` CPU cores.
154    #[builder(into)]
155    transporter_cpu: Option<i32>,
156
157    /// The amount of memory (in GB) to request for transporter pods.
158    ///
159    /// Defaults to `1.07374182` GB (i.e 1 GiB).
160    #[builder(into)]
161    transporter_memory: Option<f64>,
162
163    /// The node selector to apply to preemptible tasks.
164    #[builder(into)]
165    preemptible_node_selector: Option<String>,
166
167    /// The taint to apply to preemptible tasks.
168    #[builder(into)]
169    preemptible_taint: Option<String>,
170}
171
172impl<S: server_builder::State> ServerBuilder<S> {
173    /// The TES database to use for the server.
174    ///
175    /// This is a convenience method for setting the shared database server
176    /// from any type that implements `Database`.
177    pub fn database(
178        self,
179        database: impl Database + 'static,
180    ) -> ServerBuilder<server_builder::SetSharedDatabase<S>>
181    where
182        S::SharedDatabase: server_builder::IsUnset,
183    {
184        self.shared_database(Arc::new(database))
185    }
186}
187
188impl Server {
189    /// Runs the server.
190    pub async fn run<F>(self, shutdown: F) -> anyhow::Result<()>
191    where
192        F: Future<Output = ()> + Send + 'static,
193    {
194        // Build the preemptible config if both fields are present
195        let preemptible_config = match (self.preemptible_node_selector, self.preemptible_taint) {
196            (Some(node_selector), Some(taint)) => {
197                Some(PreemptibleConfig::new(node_selector, taint)?)
198            }
199            (None, None) => None,
200            _ => anyhow::bail!(
201                "preemptive task execution requires both the node selector and taint to be \
202                 configured"
203            ),
204        };
205
206        let state = Arc::new(State {
207            service_api_key: self.service_api_key,
208            orchestrator: TaskOrchestrator::new(
209                self.database,
210                self.pod_name,
211                self.service_url,
212                self.tasks_namespace,
213                self.storage_class,
214                TransporterInfo {
215                    image: self.transporter_image,
216                    cpu: self.transporter_cpu,
217                    memory: self.transporter_memory,
218                },
219                preemptible_config,
220            )
221            .await?,
222        });
223
224        let server = planetary_server::Server::builder()
225            .address(self.address)
226            .port(self.port)
227            .routers(bon::vec![
228                Router::new()
229                    .route("/v1/tasks/{tes_id}", post(Self::start_task))
230                    .route("/v1/tasks/{tes_id}", delete(Self::cancel_task))
231                    .route("/v1/tasks/{tes_id}/io", get(Self::get_task_io))
232                    .route("/v1/tasks/{tes_id}/outputs", put(Self::put_task_outputs))
233                    .route("/v1/pods/{name}", patch(Self::adopt_pod))
234                    .layer(middleware::from_fn_with_state(state.clone(), auth))
235            ])
236            .build();
237
238        // Spawn the monitor
239        let monitor = Monitor::spawn(state.clone());
240
241        // Run the server to completion
242        server.run(state, shutdown).await?;
243
244        // Finally, shutdown the monitor
245        monitor.shutdown().await;
246        Ok(())
247    }
248
249    /// Implements the API endpoint for starting a task.
250    ///
251    /// This endpoint is used by the TES API service.
252    async fn start_task(
253        ExtractState(state): ExtractState<Arc<State>>,
254        Path(tes_id): Path<String>,
255    ) -> ServerResponse<()> {
256        tokio::spawn(async move {
257            state.orchestrator.start_task(&tes_id).await;
258        });
259
260        Ok(())
261    }
262
263    /// Implements the API endpoint for canceling a task.
264    ///
265    /// This endpoint is used by the TES API service.
266    async fn cancel_task(
267        ExtractState(state): ExtractState<Arc<State>>,
268        Path(tes_id): Path<String>,
269    ) -> ServerResponse<()> {
270        tokio::spawn(async move {
271            state.orchestrator.cancel_task(&tes_id).await;
272        });
273
274        Ok(())
275    }
276
277    /// Implements the API endpoint for getting a task's inputs and outputs.
278    ///
279    /// This endpoint is used by the transporter.
280    async fn get_task_io(
281        ExtractState(state): ExtractState<Arc<State>>,
282        Path(tes_id): Path<String>,
283    ) -> ServerResponse<Json<TaskIo>> {
284        Ok(Json(
285            state.orchestrator.database().get_task_io(&tes_id).await?,
286        ))
287    }
288
289    /// Implements the API endpoint for updating a task's output files.
290    ///
291    /// This endpoint is used by the transporter.
292    async fn put_task_outputs(
293        ExtractState(state): ExtractState<Arc<State>>,
294        Path(tes_id): Path<String>,
295        Json(files): Json<Vec<OutputFile>>,
296    ) -> ServerResponse<()> {
297        state
298            .orchestrator
299            .database()
300            .update_task_output_files(&tes_id, &files)
301            .await?;
302        Ok(())
303    }
304
305    /// Implements the API endpoint for adopting a pod to this orchestrator.
306    ///
307    /// This endpoint is used by the monitor.
308    async fn adopt_pod(
309        ExtractState(state): ExtractState<Arc<State>>,
310        Path(name): Path<String>,
311    ) -> ServerResponse<()> {
312        state.orchestrator.adopt_pod(&name).await?;
313        Ok(())
314    }
315}