planetary_orchestrator/
lib.rs1use std::future::Future;
4use std::sync::Arc;
5use std::time::Duration;
6
7use axum::Router;
8use axum::extract::Request;
9use axum::extract::State as ExtractState;
10use axum::middleware;
11use axum::middleware::Next;
12use axum::response::IntoResponse;
13use axum::response::Response;
14use axum::routing::delete;
15use axum::routing::get;
16use axum::routing::patch;
17use axum::routing::post;
18use axum::routing::put;
19use axum_extra::TypedHeader;
20use axum_extra::headers::Authorization;
21use axum_extra::headers::authorization::Bearer;
22use bon::Builder;
23use planetary_db::Database;
24use planetary_db::TaskIo;
25use planetary_server::DEFAULT_ADDRESS;
26use planetary_server::DEFAULT_PORT;
27use planetary_server::Error;
28use planetary_server::Json;
29use planetary_server::Path;
30use planetary_server::ServerResponse;
31use secrecy::ExposeSecret;
32use secrecy::SecretString;
33use tes::v1::types::responses::OutputFile;
34use tokio_retry2::RetryError;
35use tokio_retry2::strategy::ExponentialFactorBackoff;
36use tokio_retry2::strategy::MaxInterval;
37use tracing::warn;
38use url::Url;
39
40use crate::orchestrator::Monitor;
41use crate::orchestrator::PreemptibleConfig;
42use crate::orchestrator::TaskOrchestrator;
43use crate::orchestrator::TransporterInfo;
44
45mod orchestrator;
46
47fn retry_durations() -> impl Iterator<Item = Duration> {
52 const INITIAL_DELAY_MILLIS: u64 = 1000;
53 const BASE_FACTOR: f64 = 2.0;
54 const MAX_DURATION: Duration = Duration::from_secs(60);
55 const RETRIES: usize = 5;
56
57 ExponentialFactorBackoff::from_millis(INITIAL_DELAY_MILLIS, BASE_FACTOR)
58 .max_duration(MAX_DURATION)
59 .take(RETRIES)
60}
61
62fn notify_retry(e: &kube::Error, duration: Duration) {
64 warn!(
65 "network operation failed: {e} (retrying after {duration} seconds)",
66 duration = duration.as_secs()
67 );
68}
69
70fn into_retry_error(e: kube::Error) -> RetryError<kube::Error> {
72 match e {
73 kube::Error::Api(kube::core::ErrorResponse { code, .. }) if code >= 500 => {
74 RetryError::transient(e)
75 }
76 kube::Error::HyperError(_)
77 | kube::Error::Service(_)
78 | kube::Error::ReadEvents(_)
79 | kube::Error::HttpError(_)
80 | kube::Error::Discovery(_) => RetryError::transient(e),
81 _ => RetryError::permanent(e),
82 }
83}
84
85async fn auth(
88 axum::extract::State(state): axum::extract::State<Arc<State>>,
89 TypedHeader(authorization): TypedHeader<Authorization<Bearer>>,
90 request: Request,
91 next: Next,
92) -> Response {
93 if authorization.token() != state.service_api_key.expose_secret() {
94 return Error::forbidden().into_response();
95 }
96
97 next.run(request).await
98}
99
100struct State {
102 service_api_key: SecretString,
104 orchestrator: TaskOrchestrator,
106}
107
108#[derive(Clone, Builder)]
110pub struct Server {
111 #[builder(into, default = DEFAULT_ADDRESS)]
113 address: String,
114
115 #[builder(into, default = DEFAULT_PORT)]
117 port: u16,
118
119 #[builder(into)]
121 pod_name: String,
122
123 #[builder(into)]
125 service_url: Url,
126
127 #[builder(into)]
129 service_api_key: SecretString,
130
131 #[builder(name = "shared_database")]
133 database: Arc<dyn Database>,
134
135 #[builder(into)]
137 storage_class: Option<String>,
138
139 #[builder(into)]
143 transporter_image: Option<String>,
144
145 #[builder(into)]
149 tasks_namespace: Option<String>,
150
151 #[builder(into)]
155 transporter_cpu: Option<i32>,
156
157 #[builder(into)]
161 transporter_memory: Option<f64>,
162
163 #[builder(into)]
165 preemptible_node_selector: Option<String>,
166
167 #[builder(into)]
169 preemptible_taint: Option<String>,
170}
171
172impl<S: server_builder::State> ServerBuilder<S> {
173 pub fn database(
178 self,
179 database: impl Database + 'static,
180 ) -> ServerBuilder<server_builder::SetSharedDatabase<S>>
181 where
182 S::SharedDatabase: server_builder::IsUnset,
183 {
184 self.shared_database(Arc::new(database))
185 }
186}
187
188impl Server {
189 pub async fn run<F>(self, shutdown: F) -> anyhow::Result<()>
191 where
192 F: Future<Output = ()> + Send + 'static,
193 {
194 let preemptible_config = match (self.preemptible_node_selector, self.preemptible_taint) {
196 (Some(node_selector), Some(taint)) => {
197 Some(PreemptibleConfig::new(node_selector, taint)?)
198 }
199 (None, None) => None,
200 _ => anyhow::bail!(
201 "preemptive task execution requires both the node selector and taint to be \
202 configured"
203 ),
204 };
205
206 let state = Arc::new(State {
207 service_api_key: self.service_api_key,
208 orchestrator: TaskOrchestrator::new(
209 self.database,
210 self.pod_name,
211 self.service_url,
212 self.tasks_namespace,
213 self.storage_class,
214 TransporterInfo {
215 image: self.transporter_image,
216 cpu: self.transporter_cpu,
217 memory: self.transporter_memory,
218 },
219 preemptible_config,
220 )
221 .await?,
222 });
223
224 let server = planetary_server::Server::builder()
225 .address(self.address)
226 .port(self.port)
227 .routers(bon::vec![
228 Router::new()
229 .route("/v1/tasks/{tes_id}", post(Self::start_task))
230 .route("/v1/tasks/{tes_id}", delete(Self::cancel_task))
231 .route("/v1/tasks/{tes_id}/io", get(Self::get_task_io))
232 .route("/v1/tasks/{tes_id}/outputs", put(Self::put_task_outputs))
233 .route("/v1/pods/{name}", patch(Self::adopt_pod))
234 .layer(middleware::from_fn_with_state(state.clone(), auth))
235 ])
236 .build();
237
238 let monitor = Monitor::spawn(state.clone());
240
241 server.run(state, shutdown).await?;
243
244 monitor.shutdown().await;
246 Ok(())
247 }
248
249 async fn start_task(
253 ExtractState(state): ExtractState<Arc<State>>,
254 Path(tes_id): Path<String>,
255 ) -> ServerResponse<()> {
256 tokio::spawn(async move {
257 state.orchestrator.start_task(&tes_id).await;
258 });
259
260 Ok(())
261 }
262
263 async fn cancel_task(
267 ExtractState(state): ExtractState<Arc<State>>,
268 Path(tes_id): Path<String>,
269 ) -> ServerResponse<()> {
270 tokio::spawn(async move {
271 state.orchestrator.cancel_task(&tes_id).await;
272 });
273
274 Ok(())
275 }
276
277 async fn get_task_io(
281 ExtractState(state): ExtractState<Arc<State>>,
282 Path(tes_id): Path<String>,
283 ) -> ServerResponse<Json<TaskIo>> {
284 Ok(Json(
285 state.orchestrator.database().get_task_io(&tes_id).await?,
286 ))
287 }
288
289 async fn put_task_outputs(
293 ExtractState(state): ExtractState<Arc<State>>,
294 Path(tes_id): Path<String>,
295 Json(files): Json<Vec<OutputFile>>,
296 ) -> ServerResponse<()> {
297 state
298 .orchestrator
299 .database()
300 .update_task_output_files(&tes_id, &files)
301 .await?;
302 Ok(())
303 }
304
305 async fn adopt_pod(
309 ExtractState(state): ExtractState<Arc<State>>,
310 Path(name): Path<String>,
311 ) -> ServerResponse<()> {
312 state.orchestrator.adopt_pod(&name).await?;
313 Ok(())
314 }
315}