1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use compute_runner_api::{ArtifactSink, ControlPlane, InputSource, LeaseEnvelope, Runner, TaskCtx};
4use rand::rngs::StdRng;
5use rand::SeedableRng;
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration as StdDuration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info, warn};
14use uuid::Uuid;
15
16use crate::{
17 dms::client::DmsClient,
18 heartbeat::{progress_channel, ProgressReceiver, ProgressSender},
19 poller::{jittered_delay_ms, PollerConfig},
20 session::{CapabilitySelector, HeartbeatPolicy, SessionManager},
21};
22
23#[derive(Default)]
25pub struct RunnerRegistry {
26 runners: HashMap<String, Arc<dyn Runner>>,
27}
28
29impl RunnerRegistry {
30 pub fn new() -> Self {
32 Self {
33 runners: HashMap::new(),
34 }
35 }
36
37 pub fn register<R: Runner + 'static>(mut self, runner: R) -> Self {
39 let key = runner.capability().to_string();
40 self.runners.insert(key, Arc::new(runner));
41 self
42 }
43
44 pub fn get(&self, capability: &str) -> Option<Arc<dyn Runner>> {
46 self.runners.get(capability).cloned()
47 }
48
49 pub fn capabilities(&self) -> Vec<String> {
51 let mut caps: Vec<_> = self.runners.keys().cloned().collect();
52 caps.sort();
53 caps
54 }
55
56 pub async fn run_for_lease(
58 &self,
59 lease: &LeaseEnvelope,
60 input: &dyn InputSource,
61 output: &dyn ArtifactSink,
62 ctrl: &dyn ControlPlane,
63 ) -> std::result::Result<(), crate::errors::ExecutorError> {
64 let cap = lease.task.capability.as_str();
65 let runner = self
66 .get(cap)
67 .ok_or_else(|| crate::errors::ExecutorError::NoRunner(cap.to_string()))?;
68 let ctx = TaskCtx {
69 lease,
70 input,
71 output,
72 ctrl,
73 };
74 runner
75 .run(ctx)
76 .await
77 .map_err(|e| crate::errors::ExecutorError::Runner(e.to_string()))
78 }
79}
80
81pub async fn run_node(cfg: crate::config::NodeConfig, runners: RunnerRegistry) -> Result<()> {
83 let shutdown = CancellationToken::new();
84 let signal_token = shutdown.clone();
85 let signal_task = tokio::spawn(async move {
86 if tokio::signal::ctrl_c().await.is_ok() {
87 signal_token.cancel();
88 }
89 });
90
91 let result = run_node_with_shutdown(cfg, runners, shutdown.clone()).await;
92
93 shutdown.cancel();
94 let _ = signal_task.await;
95
96 result
97}
98
99pub async fn run_node_with_shutdown(
100 cfg: crate::config::NodeConfig,
101 runners: RunnerRegistry,
102 shutdown: CancellationToken,
103) -> Result<()> {
104 let siwe = crate::auth::SiweAfterRegistration::from_config(&cfg)?;
105 info!("DDS SIWE authentication configured; waiting for DDS registration callback");
106 let siwe_handle = siwe.start().await?;
107 info!("DDS SIWE token manager started");
108
109 let poll_cfg = PollerConfig {
110 backoff_ms_min: cfg.poll_backoff_ms_min,
111 backoff_ms_max: cfg.poll_backoff_ms_max,
112 };
113
114 loop {
115 if shutdown.is_cancelled() {
116 break;
117 }
118
119 let bearer = match siwe_handle.bearer().await {
120 Ok(token) => token,
121 Err(err) => {
122 warn!(error = %err, "Failed to obtain SIWE bearer token; backing off");
123 let delay_ms = jittered_delay_ms(poll_cfg);
124 tokio::select! {
125 _ = shutdown.cancelled() => break,
126 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
127 }
128 }
129 };
130
131 let timeout = StdDuration::from_secs(cfg.request_timeout_secs);
132 let dms_client = match crate::dms::client::DmsClient::new(
133 cfg.dms_base_url.clone(),
134 timeout,
135 Some(bearer),
136 ) {
137 Ok(client) => client,
138 Err(err) => {
139 warn!(error = %err, "Failed to create DMS client; backing off");
140 let delay_ms = jittered_delay_ms(poll_cfg);
141 tokio::select! {
142 _ = shutdown.cancelled() => break,
143 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
144 }
145 }
146 };
147
148 match run_cycle_with_dms(&cfg, &dms_client, &runners).await {
149 Ok(true) => {
150 continue;
152 }
153 Ok(false) => {
154 let delay_ms = jittered_delay_ms(poll_cfg);
155 info!(delay_ms, "No lease available; backing off before next poll");
156 tokio::select! {
157 _ = shutdown.cancelled() => break,
158 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
159 }
160 }
161 Err(err) => {
162 warn!(error = %err, "DMS cycle failed; backing off");
163 let delay_ms = jittered_delay_ms(poll_cfg);
164 tokio::select! {
165 _ = shutdown.cancelled() => break,
166 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
167 }
168 }
169 }
170 }
171
172 siwe_handle.shutdown().await;
173 info!("Shutdown signal received; exiting run_node loop");
174
175 Ok(())
176}
177
178pub fn build_storage_for_lease(lease: &LeaseEnvelope) -> Result<crate::storage::Ports> {
181 let token = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
182 crate::storage::build_ports(lease, token)
183}
184
185pub fn apply_heartbeat_token_update(
188 token: &crate::storage::TokenRef,
189 hb: &crate::dms::types::HeartbeatResponse,
190) {
191 if let Some(new) = hb.access_token.clone() {
192 token.swap(new);
193 }
194}
195
196pub async fn run_cycle_with_dms(
199 _cfg: &crate::config::NodeConfig,
200 dms: &DmsClient,
201 reg: &RunnerRegistry,
202) -> Result<bool> {
203 use crate::dms::types::{CompleteTaskRequest, FailTaskRequest, HeartbeatRequest};
204 use serde_json::json;
205
206 let capabilities = reg.capabilities();
207 let capability = capabilities
208 .first()
209 .cloned()
210 .ok_or_else(|| anyhow!("no runners registered"))?;
211
212 let mut lease = match dms.lease_by_capability(&capability).await? {
214 Some(lease) => lease,
215 None => {
216 return Ok(false);
217 }
218 };
219 if lease.access_token.is_none() {
220 tracing::warn!(
221 "Lease missing access token; storage client will fall back to legacy token flow"
222 );
223 }
224
225 let selector = CapabilitySelector::new(capabilities.clone());
227 let session = SessionManager::new(selector);
228 let policy = HeartbeatPolicy::default_policy();
229 let mut rng = StdRng::from_entropy();
230 let snapshot = session
231 .start_session(&lease, Instant::now(), &policy, &mut rng)
232 .await
233 .map_err(|err| anyhow!("failed to initialise session: {err}"))?;
234 if snapshot.cancel() {
235 warn!(
236 task_id = %snapshot.task_id(),
237 "Lease already marked as cancelled; skipping execution"
238 );
239 return Ok(true);
240 }
241
242 let token_ref = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
243
244 let heartbeat_initial = dms
245 .heartbeat(
246 lease.task.id,
247 &HeartbeatRequest {
248 progress: json!({}),
249 events: json!({}),
250 },
251 )
252 .await?;
253 apply_heartbeat_token_update(&token_ref, &heartbeat_initial);
254 session
255 .apply_heartbeat(
256 &heartbeat_initial,
257 Some(json!({})),
258 Instant::now(),
259 &policy,
260 &mut rng,
261 )
262 .await
263 .map_err(|err| anyhow!("failed to refresh session after heartbeat: {err}"))?;
264 lease.access_token = heartbeat_initial.access_token.clone();
265 lease.access_token_expires_at = heartbeat_initial.access_token_expires_at;
266 lease.lease_expires_at = heartbeat_initial.lease_expires_at;
267 lease.cancel = heartbeat_initial.cancel;
268
269 let ports = crate::storage::build_ports(&lease, token_ref.clone())?;
270
271 let (progress_tx, progress_rx) = progress_channel();
272 let control_state = Arc::new(Mutex::new(ControlState::default()));
273 {
274 let mut guard = control_state.lock().await;
275 guard.progress = json!({});
276 guard.events = json!({});
277 }
278
279 let runner_cancel = CancellationToken::new();
280 let heartbeat_shutdown = CancellationToken::new();
281
282 let ctrl = EngineControlPlane::new(
283 runner_cancel.clone(),
284 progress_tx.clone(),
285 control_state.clone(),
286 );
287
288 progress_tx.update(json!({}), json!({}));
290
291 let heartbeat_driver = HeartbeatDriver::new(
292 dms.clone(),
293 HeartbeatDriverArgs {
294 session: session.clone(),
295 policy,
296 rng,
297 progress_rx,
298 state: control_state.clone(),
299 token_ref: token_ref.clone(),
300 runner_cancel: runner_cancel.clone(),
301 shutdown: heartbeat_shutdown.clone(),
302 task_id: lease.task.id,
303 },
304 );
305 let heartbeat_handle = tokio::spawn(async move { heartbeat_driver.run().await });
306
307 let run_res = reg
308 .run_for_lease(&lease, &*ports.input, &*ports.output, &ctrl)
309 .await;
310
311 heartbeat_shutdown.cancel();
312 let heartbeat_result = match heartbeat_handle.await {
313 Ok(result) => result,
314 Err(err) => {
315 warn!(error = %err, "heartbeat loop task failed");
316 HeartbeatLoopResult::Completed
317 }
318 };
319
320 match heartbeat_result {
321 HeartbeatLoopResult::Completed => {}
322 HeartbeatLoopResult::Cancelled => {
323 info!(
324 task_id = %lease.task.id,
325 "Lease cancelled during execution; skipping completion"
326 );
327 runner_cancel.cancel();
328 return Ok(true);
329 }
330 HeartbeatLoopResult::LostLease(err) => {
331 warn!(
332 task_id = %lease.task.id,
333 error = %err,
334 "Lease lost during heartbeat; abandoning task"
335 );
336 runner_cancel.cancel();
337 return Ok(true);
338 }
339 }
340
341 let uploaded_artifacts = ports.uploaded_artifacts();
342 let artifacts_json: Vec<Value> = uploaded_artifacts
343 .iter()
344 .map(|artifact| {
345 json!({
346 "logical_path": artifact.logical_path,
347 "name": artifact.name,
348 "data_type": artifact.data_type,
349 "id": artifact.id,
350 })
351 })
352 .collect();
353 let job_info = json!({
354 "task_id": lease.task.id,
355 "job_id": lease.task.job_id,
356 "domain_id": lease.domain_id,
357 "capability": lease.task.capability,
358 });
359
360 match run_res {
362 Ok(()) => {
363 let body = CompleteTaskRequest {
364 outputs_index: json!({ "artifacts": artifacts_json.clone() }),
365 result: json!({
366 "job": job_info,
367 "artifacts": artifacts_json,
368 }),
369 };
370 dms.complete(lease.task.id, &body).await?;
371 }
372 Err(err) => {
373 error!(
374 task_id = %lease.task.id,
375 job_id = ?lease.task.job_id,
376 capability = %lease.task.capability,
377 error = %err,
378 debug = ?err,
379 "Runner execution failed; reporting failure to DMS"
380 );
381 let body = FailTaskRequest {
382 reason: err.to_string(),
383 details: json!({
384 "job": job_info,
385 "artifacts": artifacts_json,
386 }),
387 };
388 dms.fail(lease.task.id, &body)
389 .await
390 .with_context(|| format!("report fail for task {} to DMS", lease.task.id))?;
391 }
392 }
393
394 Ok(true)
395}
396
397#[derive(Default)]
398pub struct ControlState {
399 progress: Value,
400 events: Value,
401}
402
403struct EngineControlPlane {
404 cancel: CancellationToken,
405 progress_tx: ProgressSender,
406 state: Arc<Mutex<ControlState>>,
407}
408
409impl EngineControlPlane {
410 pub fn new(
411 cancel: CancellationToken,
412 progress_tx: ProgressSender,
413 state: Arc<Mutex<ControlState>>,
414 ) -> Self {
415 Self {
416 cancel,
417 progress_tx,
418 state,
419 }
420 }
421}
422
423#[async_trait]
424impl ControlPlane for EngineControlPlane {
425 async fn is_cancelled(&self) -> bool {
426 self.cancel.is_cancelled()
427 }
428
429 async fn progress(&self, value: Value) -> Result<()> {
430 let events = {
431 let mut state = self.state.lock().await;
432 state.progress = value.clone();
433 state.events.clone()
434 };
435 self.progress_tx.update(value, events);
436 Ok(())
437 }
438
439 async fn log_event(&self, fields: Value) -> Result<()> {
440 let progress = {
441 let mut state = self.state.lock().await;
442 state.events = fields.clone();
443 state.progress.clone()
444 };
445 self.progress_tx.update(progress, fields);
446 Ok(())
447 }
448}
449
450pub enum HeartbeatLoopResult {
451 Completed,
452 Cancelled,
453 LostLease(anyhow::Error),
454}
455
456#[async_trait]
457pub trait HeartbeatTransport: Send + Sync + Clone + 'static {
458 async fn post_heartbeat(
459 &self,
460 task_id: Uuid,
461 body: &crate::dms::types::HeartbeatRequest,
462 ) -> Result<crate::dms::types::HeartbeatResponse>;
463}
464
465#[async_trait]
466impl HeartbeatTransport for DmsClient {
467 async fn post_heartbeat(
468 &self,
469 task_id: Uuid,
470 body: &crate::dms::types::HeartbeatRequest,
471 ) -> Result<crate::dms::types::HeartbeatResponse> {
472 self.heartbeat(task_id, body).await
473 }
474}
475
476pub struct HeartbeatDriverArgs {
477 pub session: SessionManager,
478 pub policy: HeartbeatPolicy,
479 pub rng: StdRng,
480 pub progress_rx: ProgressReceiver,
481 pub state: Arc<Mutex<ControlState>>,
482 pub token_ref: crate::storage::TokenRef,
483 pub runner_cancel: CancellationToken,
484 pub shutdown: CancellationToken,
485 pub task_id: Uuid,
486}
487
488pub struct HeartbeatDriver<T>
489where
490 T: HeartbeatTransport,
491{
492 transport: T,
493 session: SessionManager,
494 policy: HeartbeatPolicy,
495 rng: StdRng,
496 progress_rx: ProgressReceiver,
497 state: Arc<Mutex<ControlState>>,
498 token_ref: crate::storage::TokenRef,
499 runner_cancel: CancellationToken,
500 shutdown: CancellationToken,
501 task_id: Uuid,
502 last_progress: Value,
503}
504
505impl<T> HeartbeatDriver<T>
506where
507 T: HeartbeatTransport,
508{
509 pub fn new(transport: T, args: HeartbeatDriverArgs) -> Self {
510 Self {
511 transport,
512 session: args.session,
513 policy: args.policy,
514 rng: args.rng,
515 progress_rx: args.progress_rx,
516 state: args.state,
517 token_ref: args.token_ref,
518 runner_cancel: args.runner_cancel,
519 shutdown: args.shutdown,
520 task_id: args.task_id,
521 last_progress: Value::default(),
522 }
523 }
524
525 pub async fn run(mut self) -> HeartbeatLoopResult {
526 loop {
527 if self.shutdown.is_cancelled() || self.runner_cancel.is_cancelled() {
528 return HeartbeatLoopResult::Completed;
529 }
530
531 let snapshot = match self.session.snapshot().await {
532 Some(s) => s,
533 None => return HeartbeatLoopResult::Completed,
534 };
535
536 let ttl_delay = snapshot
537 .next_heartbeat_due()
538 .map(|due| due.saturating_duration_since(Instant::now()));
539
540 if let Some(delay) = ttl_delay {
541 tokio::select! {
542 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
543 progress = self.progress_rx.recv() => {
544 if let Some(data) = progress {
545 if let Some(outcome) = self.handle_progress(data).await {
546 return outcome;
547 }
548 } else {
549 return HeartbeatLoopResult::Completed;
550 }
551 }
552 _ = tokio::time::sleep(delay) => {
553 if let Some(outcome) = self.handle_ttl().await {
554 return outcome;
555 }
556 }
557 }
558 } else {
559 tokio::select! {
560 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
561 progress = self.progress_rx.recv() => {
562 if let Some(data) = progress {
563 if let Some(outcome) = self.handle_progress(data).await {
564 return outcome;
565 }
566 } else {
567 return HeartbeatLoopResult::Completed;
568 }
569 }
570 }
571 }
572 }
573 }
574
575 async fn handle_progress(
576 &mut self,
577 data: crate::heartbeat::HeartbeatData,
578 ) -> Option<HeartbeatLoopResult> {
579 self.last_progress = data.progress.clone();
580 self.send_and_update(data.progress, data.events).await
581 }
582
583 async fn handle_ttl(&mut self) -> Option<HeartbeatLoopResult> {
584 let (progress, events) = self.snapshot_state().await;
585 self.send_and_update(progress, events).await
586 }
587
588 async fn snapshot_state(&self) -> (Value, Value) {
589 let state = self.state.lock().await;
590 (state.progress.clone(), state.events.clone())
591 }
592
593 async fn send_and_update(
594 &mut self,
595 progress: Value,
596 events: Value,
597 ) -> Option<HeartbeatLoopResult> {
598 let request = crate::dms::types::HeartbeatRequest {
599 progress: progress.clone(),
600 events: events.clone(),
601 };
602
603 match self.transport.post_heartbeat(self.task_id, &request).await {
604 Ok(lease) => {
605 apply_heartbeat_token_update(&self.token_ref, &lease);
606 if let Err(err) = self
607 .session
608 .apply_heartbeat(
609 &lease,
610 Some(progress.clone()),
611 Instant::now(),
612 &self.policy,
613 &mut self.rng,
614 )
615 .await
616 {
617 return Some(HeartbeatLoopResult::LostLease(anyhow::Error::new(err)));
618 }
619 if lease.cancel {
620 self.runner_cancel.cancel();
621 return Some(HeartbeatLoopResult::Cancelled);
622 }
623 None
624 }
625 Err(err) => {
626 self.runner_cancel.cancel();
627 Some(HeartbeatLoopResult::LostLease(err))
628 }
629 }
630 }
631}