1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use compute_runner_api::{ArtifactSink, ControlPlane, InputSource, LeaseEnvelope, Runner, TaskCtx};
4use rand::rngs::StdRng;
5use rand::SeedableRng;
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration as StdDuration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12use tokio_util::sync::CancellationToken;
13use tracing::{debug, error, info, warn};
14use uuid::Uuid;
15
16use crate::{
17 dms::client::DmsClient,
18 heartbeat::{progress_channel, ProgressReceiver, ProgressSender},
19 poller::{jittered_delay_ms, PollerConfig},
20 session::{CapabilitySelector, HeartbeatPolicy, SessionManager},
21};
22
23#[derive(Default)]
25pub struct RunnerRegistry {
26 runners: HashMap<String, Arc<dyn Runner>>,
27}
28
29impl RunnerRegistry {
30 pub fn new() -> Self {
32 Self {
33 runners: HashMap::new(),
34 }
35 }
36
37 pub fn register<R: Runner + 'static>(mut self, runner: R) -> Self {
39 let key = runner.capability().to_string();
40 self.runners.insert(key, Arc::new(runner));
41 self
42 }
43
44 pub fn get(&self, capability: &str) -> Option<Arc<dyn Runner>> {
46 self.runners.get(capability).cloned()
47 }
48
49 pub fn capabilities(&self) -> Vec<String> {
51 let mut caps: Vec<_> = self.runners.keys().cloned().collect();
52 caps.sort();
53 caps
54 }
55
56 pub async fn run_for_lease(
58 &self,
59 lease: &LeaseEnvelope,
60 input: &dyn InputSource,
61 output: &dyn ArtifactSink,
62 ctrl: &dyn ControlPlane,
63 access_token: &dyn compute_runner_api::runner::AccessTokenProvider,
64 ) -> std::result::Result<(), crate::errors::ExecutorError> {
65 let cap = lease.task.capability.as_str();
66 let runner = self
67 .get(cap)
68 .ok_or_else(|| crate::errors::ExecutorError::NoRunner(cap.to_string()))?;
69 let ctx = TaskCtx {
70 lease,
71 input,
72 output,
73 ctrl,
74 access_token,
75 };
76 runner
77 .run(ctx)
78 .await
79 .map_err(|e| crate::errors::ExecutorError::Runner(e.to_string()))
80 }
81}
82
83pub async fn run_node(cfg: crate::config::NodeConfig, runners: RunnerRegistry) -> Result<()> {
85 let shutdown = CancellationToken::new();
86 let signal_token = shutdown.clone();
87 let signal_task = tokio::spawn(async move {
88 if tokio::signal::ctrl_c().await.is_ok() {
89 signal_token.cancel();
90 }
91 });
92
93 let result = run_node_with_shutdown(cfg, runners, shutdown.clone()).await;
94
95 shutdown.cancel();
96 let _ = signal_task.await;
97
98 result
99}
100
101pub async fn run_node_with_shutdown(
102 cfg: crate::config::NodeConfig,
103 runners: RunnerRegistry,
104 shutdown: CancellationToken,
105) -> Result<()> {
106 let siwe = crate::auth::SiweAfterRegistration::from_config(&cfg)?;
107 info!("DDS SIWE authentication configured; waiting for DDS registration callback");
108 let siwe_handle = siwe.start().await?;
109 info!("DDS SIWE token manager started");
110
111 let poll_cfg = PollerConfig {
112 backoff_ms_min: cfg.poll_backoff_ms_min,
113 backoff_ms_max: cfg.poll_backoff_ms_max,
114 };
115
116 loop {
117 if shutdown.is_cancelled() {
118 break;
119 }
120
121 if let Err(err) = siwe_handle.bearer().await {
123 warn!(error = %err, "Failed to obtain SIWE bearer token; backing off");
124 let delay_ms = jittered_delay_ms(poll_cfg);
125 tokio::select! {
126 _ = shutdown.cancelled() => break,
127 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
128 }
129 }
130
131 let timeout = StdDuration::from_secs(cfg.request_timeout_secs);
132 let dms_client = match crate::dms::client::DmsClient::new(
133 cfg.dms_base_url.clone(),
134 timeout,
135 std::sync::Arc::new(siwe_handle.clone()),
136 ) {
137 Ok(client) => client,
138 Err(err) => {
139 warn!(error = %err, "Failed to create DMS client; backing off");
140 let delay_ms = jittered_delay_ms(poll_cfg);
141 tokio::select! {
142 _ = shutdown.cancelled() => break,
143 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
144 }
145 }
146 };
147
148 match run_cycle_with_dms(&cfg, &dms_client, &runners).await {
149 Ok(true) => {
150 continue;
152 }
153 Ok(false) => {
154 let delay_ms = jittered_delay_ms(poll_cfg);
155 debug!(delay_ms, "No lease available; backing off before next poll");
156 tokio::select! {
157 _ = shutdown.cancelled() => break,
158 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
159 }
160 }
161 Err(err) => {
162 warn!(error = %err, "DMS cycle failed; backing off");
163 let delay_ms = jittered_delay_ms(poll_cfg);
164 tokio::select! {
165 _ = shutdown.cancelled() => break,
166 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
167 }
168 }
169 }
170 }
171
172 siwe_handle.shutdown().await;
173 info!("Shutdown signal received; exiting run_node loop");
174
175 Ok(())
176}
177
178pub fn build_storage_for_lease(lease: &LeaseEnvelope) -> Result<crate::storage::Ports> {
181 let token = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
182 crate::storage::build_ports(lease, token)
183}
184
185pub fn apply_heartbeat_token_update(
188 token: &crate::storage::TokenRef,
189 hb: &crate::dms::types::HeartbeatResponse,
190) {
191 if let Some(new) = hb.access_token.clone() {
192 token.swap(new);
193 }
194}
195
196pub fn merge_heartbeat_into_lease(
198 lease: &mut LeaseEnvelope,
199 hb: &crate::dms::types::HeartbeatResponse,
200) {
201 if let Some(token) = hb.access_token.clone() {
202 lease.access_token = Some(token);
203 }
204 if let Some(expiry) = hb.access_token_expires_at {
205 lease.access_token_expires_at = Some(expiry);
206 }
207 if let Some(expiry) = hb.lease_expires_at {
208 lease.lease_expires_at = Some(expiry);
209 }
210 if let Some(cancel) = hb.cancel {
211 lease.cancel = cancel;
212 }
213 if let Some(status) = hb.status.clone() {
214 lease.status = Some(status);
215 }
216 if let Some(domain_id) = hb.domain_id {
217 lease.domain_id = Some(domain_id);
218 }
219 if let Some(url) = hb.domain_server_url.clone() {
220 lease.domain_server_url = Some(url);
221 }
222 if let Some(task) = hb.task.clone() {
223 lease.task = task;
224 } else {
225 if let Some(task_id) = hb.task_id {
226 lease.task.id = task_id;
227 }
228 if let Some(job_id) = hb.job_id {
229 lease.task.job_id = Some(job_id);
230 }
231 if let Some(attempts) = hb.attempts {
232 lease.task.attempts = Some(attempts);
233 }
234 if let Some(max_attempts) = hb.max_attempts {
235 lease.task.max_attempts = Some(max_attempts);
236 }
237 if let Some(deps_remaining) = hb.deps_remaining {
238 lease.task.deps_remaining = Some(deps_remaining);
239 }
240 }
241}
242
243pub async fn run_cycle_with_dms(
246 _cfg: &crate::config::NodeConfig,
247 dms: &DmsClient,
248 reg: &RunnerRegistry,
249) -> Result<bool> {
250 use crate::dms::types::{CompleteTaskRequest, FailTaskRequest, HeartbeatRequest};
251 use serde_json::json;
252
253 let capabilities = reg.capabilities();
254 let capability = capabilities
255 .first()
256 .cloned()
257 .ok_or_else(|| anyhow!("no runners registered"))?;
258
259 let mut lease = match dms.lease_by_capability(&capability).await? {
261 Some(lease) => lease,
262 None => {
263 return Ok(false);
264 }
265 };
266 if lease.access_token.is_none() {
267 tracing::warn!(
268 "Lease missing access token; storage client will fall back to legacy token flow"
269 );
270 }
271
272 let selector = CapabilitySelector::new(capabilities.clone());
274 let session = SessionManager::new(selector);
275 let policy = HeartbeatPolicy::default_policy();
276 let mut rng = StdRng::from_entropy();
277 let task_id = lease.task.id;
278 let report_setup_failure = |stage: &'static str, err: &anyhow::Error| {
279 let details = json!({
280 "stage": stage,
281 "error": err.to_string(),
282 });
283 async move {
284 let body = FailTaskRequest {
285 reason: "node_setup_failed".into(),
286 details,
287 };
288 dms.fail(task_id, &body).await
289 }
290 };
291
292 let snapshot = match session
293 .start_session(&lease, Instant::now(), &policy, &mut rng)
294 .await
295 {
296 Ok(snapshot) => snapshot,
297 Err(err) => {
298 let original = anyhow!("failed to initialise session: {err}");
299 if let Err(fail_err) = report_setup_failure("start_session", &original).await {
300 warn!(
301 error = %fail_err,
302 task_id = %task_id,
303 "failed to report setup failure"
304 );
305 return Err(original);
306 }
307 return Ok(true);
308 }
309 };
310 if snapshot.cancel() {
311 warn!(
312 task_id = %snapshot.task_id(),
313 "Lease already marked as cancelled; skipping execution"
314 );
315 return Ok(true);
316 }
317
318 let token_ref = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
319
320 let heartbeat_initial = match dms
321 .heartbeat(
322 lease.task.id,
323 &HeartbeatRequest {
324 progress: json!({}),
325 events: Vec::new(),
326 },
327 )
328 .await
329 {
330 Ok(response) => response,
331 Err(err) => {
332 if let Err(fail_err) = report_setup_failure("initial_heartbeat", &err).await {
333 warn!(
334 error = %fail_err,
335 task_id = %task_id,
336 "failed to report setup failure"
337 );
338 return Err(err);
339 }
340 return Ok(true);
341 }
342 };
343 apply_heartbeat_token_update(&token_ref, &heartbeat_initial);
344 merge_heartbeat_into_lease(&mut lease, &heartbeat_initial);
345 session
346 .apply_heartbeat(
347 &heartbeat_initial,
348 Some(json!({})),
349 Instant::now(),
350 &policy,
351 &mut rng,
352 )
353 .await
354 .map_err(|err| anyhow!("failed to refresh session after heartbeat: {err}"))?;
355
356 let ports = match crate::storage::build_ports(&lease, token_ref.clone()) {
357 Ok(ports) => ports,
358 Err(err) => {
359 if let Err(fail_err) = report_setup_failure("build_ports", &err).await {
360 warn!(
361 error = %fail_err,
362 task_id = %task_id,
363 "failed to report setup failure"
364 );
365 return Err(err);
366 }
367 return Ok(true);
368 }
369 };
370
371 let (progress_tx, progress_rx) = progress_channel();
372 let control_state = Arc::new(Mutex::new(ControlState::default()));
373 {
374 let mut guard = control_state.lock().await;
375 guard.progress = json!({});
376 guard.events = Vec::new();
377 }
378
379 let runner_cancel = CancellationToken::new();
380 let heartbeat_shutdown = CancellationToken::new();
381
382 let ctrl = EngineControlPlane::new(
383 runner_cancel.clone(),
384 progress_tx.clone(),
385 control_state.clone(),
386 );
387
388 progress_tx.update(json!({}), Vec::new());
390
391 let heartbeat_driver = HeartbeatDriver::new(
392 dms.clone(),
393 HeartbeatDriverArgs {
394 session: session.clone(),
395 policy,
396 rng,
397 progress_rx,
398 state: control_state.clone(),
399 token_ref: token_ref.clone(),
400 runner_cancel: runner_cancel.clone(),
401 shutdown: heartbeat_shutdown.clone(),
402 task_id: lease.task.id,
403 },
404 );
405 let heartbeat_handle = tokio::spawn(async move { heartbeat_driver.run().await });
406
407 let run_res = reg
408 .run_for_lease(&lease, &*ports.input, &*ports.output, &ctrl, &token_ref)
409 .await;
410
411 {
415 let state = control_state.lock().await;
416 progress_tx.update(state.progress.clone(), state.events.clone());
417 }
418 sleep(StdDuration::from_millis(200)).await;
419
420 heartbeat_shutdown.cancel();
421 let heartbeat_result = match heartbeat_handle.await {
422 Ok(result) => result,
423 Err(err) => {
424 warn!(error = %err, "heartbeat loop task failed");
425 HeartbeatLoopResult::Completed
426 }
427 };
428
429 match heartbeat_result {
430 HeartbeatLoopResult::Completed => {}
431 HeartbeatLoopResult::Cancelled => {
432 info!(
433 task_id = %lease.task.id,
434 "Lease cancelled during execution; skipping completion"
435 );
436 runner_cancel.cancel();
437 return Ok(true);
438 }
439 HeartbeatLoopResult::LostLease(err) => {
440 warn!(
441 task_id = %lease.task.id,
442 error = %err,
443 "Lease lost during heartbeat; abandoning task"
444 );
445 runner_cancel.cancel();
446 return Ok(true);
447 }
448 }
449
450 let uploaded_artifacts = ports.uploaded_artifacts();
451 let artifacts_json: Vec<Value> = uploaded_artifacts
452 .iter()
453 .map(|artifact| {
454 json!({
455 "logical_path": artifact.logical_path,
456 "name": artifact.name,
457 "data_type": artifact.data_type,
458 "id": artifact.id,
459 })
460 })
461 .collect();
462 let output_cids: Vec<String> = uploaded_artifacts
463 .iter()
464 .filter_map(|artifact| artifact.id.clone())
465 .collect();
466 let job_info = json!({
467 "task_id": lease.task.id,
468 "job_id": lease.task.job_id,
469 "domain_id": lease.domain_id,
470 "capability": lease.task.capability,
471 });
472
473 match run_res {
475 Ok(()) => {
476 let body = CompleteTaskRequest {
477 output_cids,
478 meta: json!({
479 "job": job_info,
480 "artifacts": artifacts_json,
481 }),
482 };
483 dms.complete(lease.task.id, &body).await?;
484 }
485 Err(err) => {
486 error!(
487 task_id = %lease.task.id,
488 job_id = ?lease.task.job_id,
489 capability = %lease.task.capability,
490 error = %err,
491 debug = ?err,
492 "Runner execution failed; reporting failure to DMS"
493 );
494 let body = FailTaskRequest {
495 reason: err.to_string(),
496 details: json!({
497 "job": job_info,
498 "artifacts": artifacts_json,
499 }),
500 };
501 dms.fail(lease.task.id, &body)
502 .await
503 .with_context(|| format!("report fail for task {} to DMS", lease.task.id))?;
504 }
505 }
506
507 Ok(true)
508}
509
510#[derive(Default)]
511pub struct ControlState {
512 progress: Value,
513 events: Vec<Value>,
514}
515
516struct EngineControlPlane {
517 cancel: CancellationToken,
518 progress_tx: ProgressSender,
519 state: Arc<Mutex<ControlState>>,
520}
521
522impl EngineControlPlane {
523 pub fn new(
524 cancel: CancellationToken,
525 progress_tx: ProgressSender,
526 state: Arc<Mutex<ControlState>>,
527 ) -> Self {
528 Self {
529 cancel,
530 progress_tx,
531 state,
532 }
533 }
534}
535
536#[async_trait]
537impl ControlPlane for EngineControlPlane {
538 async fn is_cancelled(&self) -> bool {
539 self.cancel.is_cancelled()
540 }
541
542 async fn progress(&self, value: Value) -> Result<()> {
543 let events = {
544 let mut state = self.state.lock().await;
545 state.progress = value.clone();
546 state.events.clone()
547 };
548 self.progress_tx.update(value, events);
549 Ok(())
550 }
551
552 async fn log_event(&self, fields: Value) -> Result<()> {
553 let (progress, events) = {
554 let mut state = self.state.lock().await;
555 state.events.push(fields.clone());
556 (state.progress.clone(), state.events.clone())
557 };
558 self.progress_tx.update(progress, events);
559 Ok(())
560 }
561}
562
563pub enum HeartbeatLoopResult {
564 Completed,
565 Cancelled,
566 LostLease(anyhow::Error),
567}
568
569#[async_trait]
570pub trait HeartbeatTransport: Send + Sync + Clone + 'static {
571 async fn post_heartbeat(
572 &self,
573 task_id: Uuid,
574 body: &crate::dms::types::HeartbeatRequest,
575 ) -> Result<crate::dms::types::HeartbeatResponse>;
576}
577
578#[async_trait]
579impl HeartbeatTransport for DmsClient {
580 async fn post_heartbeat(
581 &self,
582 task_id: Uuid,
583 body: &crate::dms::types::HeartbeatRequest,
584 ) -> Result<crate::dms::types::HeartbeatResponse> {
585 self.heartbeat(task_id, body).await
586 }
587}
588
589pub struct HeartbeatDriverArgs {
590 pub session: SessionManager,
591 pub policy: HeartbeatPolicy,
592 pub rng: StdRng,
593 pub progress_rx: ProgressReceiver,
594 pub state: Arc<Mutex<ControlState>>,
595 pub token_ref: crate::storage::TokenRef,
596 pub runner_cancel: CancellationToken,
597 pub shutdown: CancellationToken,
598 pub task_id: Uuid,
599}
600
601pub struct HeartbeatDriver<T>
602where
603 T: HeartbeatTransport,
604{
605 transport: T,
606 session: SessionManager,
607 policy: HeartbeatPolicy,
608 rng: StdRng,
609 progress_rx: ProgressReceiver,
610 state: Arc<Mutex<ControlState>>,
611 token_ref: crate::storage::TokenRef,
612 runner_cancel: CancellationToken,
613 shutdown: CancellationToken,
614 task_id: Uuid,
615 last_progress: Value,
616}
617
618impl<T> HeartbeatDriver<T>
619where
620 T: HeartbeatTransport,
621{
622 pub fn new(transport: T, args: HeartbeatDriverArgs) -> Self {
623 Self {
624 transport,
625 session: args.session,
626 policy: args.policy,
627 rng: args.rng,
628 progress_rx: args.progress_rx,
629 state: args.state,
630 token_ref: args.token_ref,
631 runner_cancel: args.runner_cancel,
632 shutdown: args.shutdown,
633 task_id: args.task_id,
634 last_progress: Value::default(),
635 }
636 }
637
638 pub async fn run(mut self) -> HeartbeatLoopResult {
639 loop {
640 if self.shutdown.is_cancelled() || self.runner_cancel.is_cancelled() {
641 return HeartbeatLoopResult::Completed;
642 }
643
644 let snapshot = match self.session.snapshot().await {
645 Some(s) => s,
646 None => return HeartbeatLoopResult::Completed,
647 };
648
649 let ttl_delay = snapshot
650 .next_heartbeat_due()
651 .map(|due| due.saturating_duration_since(Instant::now()));
652
653 if let Some(delay) = ttl_delay {
654 tokio::select! {
655 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
656 progress = self.progress_rx.recv() => {
657 if let Some(data) = progress {
658 if let Some(outcome) = self.handle_progress(data).await {
659 return outcome;
660 }
661 } else {
662 return HeartbeatLoopResult::Completed;
663 }
664 }
665 _ = tokio::time::sleep(delay) => {
666 if let Some(outcome) = self.handle_ttl().await {
667 return outcome;
668 }
669 }
670 }
671 } else {
672 tokio::select! {
673 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
674 progress = self.progress_rx.recv() => {
675 if let Some(data) = progress {
676 if let Some(outcome) = self.handle_progress(data).await {
677 return outcome;
678 }
679 } else {
680 return HeartbeatLoopResult::Completed;
681 }
682 }
683 }
684 }
685 }
686 }
687
688 async fn handle_progress(
689 &mut self,
690 data: crate::heartbeat::HeartbeatData,
691 ) -> Option<HeartbeatLoopResult> {
692 self.last_progress = data.progress.clone();
693 let (progress, events) = self.snapshot_state().await;
694 self.send_and_update(progress, events).await
695 }
696
697 async fn handle_ttl(&mut self) -> Option<HeartbeatLoopResult> {
698 let (progress, events) = self.snapshot_state().await;
699 self.send_and_update(progress, events).await
700 }
701
702 async fn snapshot_state(&self) -> (Value, Vec<Value>) {
703 let state = self.state.lock().await;
704 (state.progress.clone(), state.events.clone())
705 }
706
707 async fn send_and_update(
708 &mut self,
709 progress: Value,
710 events: Vec<Value>,
711 ) -> Option<HeartbeatLoopResult> {
712 let request = crate::dms::types::HeartbeatRequest {
713 progress: progress.clone(),
714 events: events.clone(),
715 };
716
717 match self.transport.post_heartbeat(self.task_id, &request).await {
718 Ok(update) => {
719 if !events.is_empty() {
720 let mut state = self.state.lock().await;
721 if state.events.len() >= events.len()
722 && state.events[..events.len()] == events[..]
723 {
724 state.events.drain(0..events.len());
725 }
726 }
727 apply_heartbeat_token_update(&self.token_ref, &update);
728 if let Some(task) = &update.task {
729 self.task_id = task.id;
730 } else if let Some(task_id) = update.task_id {
731 self.task_id = task_id;
732 }
733 if let Err(err) = self
734 .session
735 .apply_heartbeat(
736 &update,
737 Some(progress.clone()),
738 Instant::now(),
739 &self.policy,
740 &mut self.rng,
741 )
742 .await
743 {
744 return Some(HeartbeatLoopResult::LostLease(anyhow::Error::new(err)));
745 }
746 if update.cancel.unwrap_or(false) {
747 self.runner_cancel.cancel();
748 return Some(HeartbeatLoopResult::Cancelled);
749 }
750 None
751 }
752 Err(err) => {
753 self.runner_cancel.cancel();
754 Some(HeartbeatLoopResult::LostLease(err))
755 }
756 }
757 }
758}