1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use compute_runner_api::{ArtifactSink, ControlPlane, InputSource, LeaseEnvelope, Runner, TaskCtx};
4use rand::rngs::StdRng;
5use rand::SeedableRng;
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration as StdDuration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info, warn};
14use uuid::Uuid;
15
16use crate::{
17 dms::client::DmsClient,
18 heartbeat::{progress_channel, ProgressReceiver, ProgressSender},
19 poller::{jittered_delay_ms, PollerConfig},
20 session::{CapabilitySelector, HeartbeatPolicy, SessionManager},
21};
22
23#[derive(Default)]
25pub struct RunnerRegistry {
26 runners: HashMap<String, Arc<dyn Runner>>,
27}
28
29impl RunnerRegistry {
30 pub fn new() -> Self {
32 Self {
33 runners: HashMap::new(),
34 }
35 }
36
37 pub fn register<R: Runner + 'static>(mut self, runner: R) -> Self {
39 let key = runner.capability().to_string();
40 self.runners.insert(key, Arc::new(runner));
41 self
42 }
43
44 pub fn get(&self, capability: &str) -> Option<Arc<dyn Runner>> {
46 self.runners.get(capability).cloned()
47 }
48
49 pub fn capabilities(&self) -> Vec<String> {
51 let mut caps: Vec<_> = self.runners.keys().cloned().collect();
52 caps.sort();
53 caps
54 }
55
56 pub async fn run_for_lease(
58 &self,
59 lease: &LeaseEnvelope,
60 input: &dyn InputSource,
61 output: &dyn ArtifactSink,
62 ctrl: &dyn ControlPlane,
63 access_token: &dyn compute_runner_api::runner::AccessTokenProvider,
64 ) -> std::result::Result<(), crate::errors::ExecutorError> {
65 let cap = lease.task.capability.as_str();
66 let runner = self
67 .get(cap)
68 .ok_or_else(|| crate::errors::ExecutorError::NoRunner(cap.to_string()))?;
69 let ctx = TaskCtx {
70 lease,
71 input,
72 output,
73 ctrl,
74 access_token,
75 };
76 runner
77 .run(ctx)
78 .await
79 .map_err(|e| crate::errors::ExecutorError::Runner(e.to_string()))
80 }
81}
82
83pub async fn run_node(cfg: crate::config::NodeConfig, runners: RunnerRegistry) -> Result<()> {
85 let shutdown = CancellationToken::new();
86 let signal_token = shutdown.clone();
87 let signal_task = tokio::spawn(async move {
88 if tokio::signal::ctrl_c().await.is_ok() {
89 signal_token.cancel();
90 }
91 });
92
93 let result = run_node_with_shutdown(cfg, runners, shutdown.clone()).await;
94
95 shutdown.cancel();
96 let _ = signal_task.await;
97
98 result
99}
100
101pub async fn run_node_with_shutdown(
102 cfg: crate::config::NodeConfig,
103 runners: RunnerRegistry,
104 shutdown: CancellationToken,
105) -> Result<()> {
106 let siwe = crate::auth::SiweAfterRegistration::from_config(&cfg)?;
107 info!("DDS SIWE authentication configured; waiting for DDS registration callback");
108 let siwe_handle = siwe.start().await?;
109 info!("DDS SIWE token manager started");
110
111 let poll_cfg = PollerConfig {
112 backoff_ms_min: cfg.poll_backoff_ms_min,
113 backoff_ms_max: cfg.poll_backoff_ms_max,
114 };
115
116 loop {
117 if shutdown.is_cancelled() {
118 break;
119 }
120
121 if let Err(err) = siwe_handle.bearer().await {
123 warn!(error = %err, "Failed to obtain SIWE bearer token; backing off");
124 let delay_ms = jittered_delay_ms(poll_cfg);
125 tokio::select! {
126 _ = shutdown.cancelled() => break,
127 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
128 }
129 }
130
131 let timeout = StdDuration::from_secs(cfg.request_timeout_secs);
132 let dms_client = match crate::dms::client::DmsClient::new(
133 cfg.dms_base_url.clone(),
134 timeout,
135 std::sync::Arc::new(siwe_handle.clone()),
136 ) {
137 Ok(client) => client,
138 Err(err) => {
139 warn!(error = %err, "Failed to create DMS client; backing off");
140 let delay_ms = jittered_delay_ms(poll_cfg);
141 tokio::select! {
142 _ = shutdown.cancelled() => break,
143 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
144 }
145 }
146 };
147
148 match run_cycle_with_dms(&cfg, &dms_client, &runners).await {
149 Ok(true) => {
150 continue;
152 }
153 Ok(false) => {
154 let delay_ms = jittered_delay_ms(poll_cfg);
155 info!(delay_ms, "No lease available; backing off before next poll");
156 tokio::select! {
157 _ = shutdown.cancelled() => break,
158 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
159 }
160 }
161 Err(err) => {
162 warn!(error = %err, "DMS cycle failed; backing off");
163 let delay_ms = jittered_delay_ms(poll_cfg);
164 tokio::select! {
165 _ = shutdown.cancelled() => break,
166 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
167 }
168 }
169 }
170 }
171
172 siwe_handle.shutdown().await;
173 info!("Shutdown signal received; exiting run_node loop");
174
175 Ok(())
176}
177
178pub fn build_storage_for_lease(lease: &LeaseEnvelope) -> Result<crate::storage::Ports> {
181 let token = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
182 crate::storage::build_ports(lease, token)
183}
184
185pub fn apply_heartbeat_token_update(
188 token: &crate::storage::TokenRef,
189 hb: &crate::dms::types::HeartbeatResponse,
190) {
191 if let Some(new) = hb.access_token.clone() {
192 token.swap(new);
193 }
194}
195
196pub fn merge_heartbeat_into_lease(
198 lease: &mut LeaseEnvelope,
199 hb: &crate::dms::types::HeartbeatResponse,
200) {
201 if let Some(token) = hb.access_token.clone() {
202 lease.access_token = Some(token);
203 }
204 if let Some(expiry) = hb.access_token_expires_at {
205 lease.access_token_expires_at = Some(expiry);
206 }
207 if let Some(expiry) = hb.lease_expires_at {
208 lease.lease_expires_at = Some(expiry);
209 }
210 if let Some(cancel) = hb.cancel {
211 lease.cancel = cancel;
212 }
213 if let Some(status) = hb.status.clone() {
214 lease.status = Some(status);
215 }
216 if let Some(domain_id) = hb.domain_id {
217 lease.domain_id = Some(domain_id);
218 }
219 if let Some(url) = hb.domain_server_url.clone() {
220 lease.domain_server_url = Some(url);
221 }
222 if let Some(task) = hb.task.clone() {
223 lease.task = task;
224 } else {
225 if let Some(task_id) = hb.task_id {
226 lease.task.id = task_id;
227 }
228 if let Some(job_id) = hb.job_id {
229 lease.task.job_id = Some(job_id);
230 }
231 if let Some(attempts) = hb.attempts {
232 lease.task.attempts = Some(attempts);
233 }
234 if let Some(max_attempts) = hb.max_attempts {
235 lease.task.max_attempts = Some(max_attempts);
236 }
237 if let Some(deps_remaining) = hb.deps_remaining {
238 lease.task.deps_remaining = Some(deps_remaining);
239 }
240 }
241}
242
243pub async fn run_cycle_with_dms(
246 _cfg: &crate::config::NodeConfig,
247 dms: &DmsClient,
248 reg: &RunnerRegistry,
249) -> Result<bool> {
250 use crate::dms::types::{CompleteTaskRequest, FailTaskRequest, HeartbeatRequest};
251 use serde_json::json;
252
253 let capabilities = reg.capabilities();
254 let capability = capabilities
255 .first()
256 .cloned()
257 .ok_or_else(|| anyhow!("no runners registered"))?;
258
259 let mut lease = match dms.lease_by_capability(&capability).await? {
261 Some(lease) => lease,
262 None => {
263 return Ok(false);
264 }
265 };
266 if lease.access_token.is_none() {
267 tracing::warn!(
268 "Lease missing access token; storage client will fall back to legacy token flow"
269 );
270 }
271
272 let selector = CapabilitySelector::new(capabilities.clone());
274 let session = SessionManager::new(selector);
275 let policy = HeartbeatPolicy::default_policy();
276 let mut rng = StdRng::from_entropy();
277 let snapshot = session
278 .start_session(&lease, Instant::now(), &policy, &mut rng)
279 .await
280 .map_err(|err| anyhow!("failed to initialise session: {err}"))?;
281 if snapshot.cancel() {
282 warn!(
283 task_id = %snapshot.task_id(),
284 "Lease already marked as cancelled; skipping execution"
285 );
286 return Ok(true);
287 }
288
289 let token_ref = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
290
291 let heartbeat_initial = dms
292 .heartbeat(
293 lease.task.id,
294 &HeartbeatRequest {
295 progress: json!({}),
296 events: Vec::new(),
297 },
298 )
299 .await?;
300 apply_heartbeat_token_update(&token_ref, &heartbeat_initial);
301 merge_heartbeat_into_lease(&mut lease, &heartbeat_initial);
302 session
303 .apply_heartbeat(
304 &heartbeat_initial,
305 Some(json!({})),
306 Instant::now(),
307 &policy,
308 &mut rng,
309 )
310 .await
311 .map_err(|err| anyhow!("failed to refresh session after heartbeat: {err}"))?;
312
313 let ports = crate::storage::build_ports(&lease, token_ref.clone())?;
314
315 let (progress_tx, progress_rx) = progress_channel();
316 let control_state = Arc::new(Mutex::new(ControlState::default()));
317 {
318 let mut guard = control_state.lock().await;
319 guard.progress = json!({});
320 guard.events = Vec::new();
321 }
322
323 let runner_cancel = CancellationToken::new();
324 let heartbeat_shutdown = CancellationToken::new();
325
326 let ctrl = EngineControlPlane::new(
327 runner_cancel.clone(),
328 progress_tx.clone(),
329 control_state.clone(),
330 );
331
332 progress_tx.update(json!({}), Vec::new());
334
335 let heartbeat_driver = HeartbeatDriver::new(
336 dms.clone(),
337 HeartbeatDriverArgs {
338 session: session.clone(),
339 policy,
340 rng,
341 progress_rx,
342 state: control_state.clone(),
343 token_ref: token_ref.clone(),
344 runner_cancel: runner_cancel.clone(),
345 shutdown: heartbeat_shutdown.clone(),
346 task_id: lease.task.id,
347 },
348 );
349 let heartbeat_handle = tokio::spawn(async move { heartbeat_driver.run().await });
350
351 let run_res = reg
352 .run_for_lease(&lease, &*ports.input, &*ports.output, &ctrl, &token_ref)
353 .await;
354
355 {
359 let state = control_state.lock().await;
360 progress_tx.update(state.progress.clone(), state.events.clone());
361 }
362 sleep(StdDuration::from_millis(200)).await;
363
364 heartbeat_shutdown.cancel();
365 let heartbeat_result = match heartbeat_handle.await {
366 Ok(result) => result,
367 Err(err) => {
368 warn!(error = %err, "heartbeat loop task failed");
369 HeartbeatLoopResult::Completed
370 }
371 };
372
373 match heartbeat_result {
374 HeartbeatLoopResult::Completed => {}
375 HeartbeatLoopResult::Cancelled => {
376 info!(
377 task_id = %lease.task.id,
378 "Lease cancelled during execution; skipping completion"
379 );
380 runner_cancel.cancel();
381 return Ok(true);
382 }
383 HeartbeatLoopResult::LostLease(err) => {
384 warn!(
385 task_id = %lease.task.id,
386 error = %err,
387 "Lease lost during heartbeat; abandoning task"
388 );
389 runner_cancel.cancel();
390 return Ok(true);
391 }
392 }
393
394 let uploaded_artifacts = ports.uploaded_artifacts();
395 let artifacts_json: Vec<Value> = uploaded_artifacts
396 .iter()
397 .map(|artifact| {
398 json!({
399 "logical_path": artifact.logical_path,
400 "name": artifact.name,
401 "data_type": artifact.data_type,
402 "id": artifact.id,
403 })
404 })
405 .collect();
406 let output_cids: Vec<String> = uploaded_artifacts
407 .iter()
408 .filter_map(|artifact| artifact.id.clone())
409 .collect();
410 let job_info = json!({
411 "task_id": lease.task.id,
412 "job_id": lease.task.job_id,
413 "domain_id": lease.domain_id,
414 "capability": lease.task.capability,
415 });
416
417 match run_res {
419 Ok(()) => {
420 let body = CompleteTaskRequest {
421 output_cids,
422 meta: json!({
423 "job": job_info,
424 "artifacts": artifacts_json,
425 }),
426 };
427 dms.complete(lease.task.id, &body).await?;
428 }
429 Err(err) => {
430 error!(
431 task_id = %lease.task.id,
432 job_id = ?lease.task.job_id,
433 capability = %lease.task.capability,
434 error = %err,
435 debug = ?err,
436 "Runner execution failed; reporting failure to DMS"
437 );
438 let body = FailTaskRequest {
439 reason: err.to_string(),
440 details: json!({
441 "job": job_info,
442 "artifacts": artifacts_json,
443 }),
444 };
445 dms.fail(lease.task.id, &body)
446 .await
447 .with_context(|| format!("report fail for task {} to DMS", lease.task.id))?;
448 }
449 }
450
451 Ok(true)
452}
453
454#[derive(Default)]
455pub struct ControlState {
456 progress: Value,
457 events: Vec<Value>,
458}
459
460struct EngineControlPlane {
461 cancel: CancellationToken,
462 progress_tx: ProgressSender,
463 state: Arc<Mutex<ControlState>>,
464}
465
466impl EngineControlPlane {
467 pub fn new(
468 cancel: CancellationToken,
469 progress_tx: ProgressSender,
470 state: Arc<Mutex<ControlState>>,
471 ) -> Self {
472 Self {
473 cancel,
474 progress_tx,
475 state,
476 }
477 }
478}
479
480#[async_trait]
481impl ControlPlane for EngineControlPlane {
482 async fn is_cancelled(&self) -> bool {
483 self.cancel.is_cancelled()
484 }
485
486 async fn progress(&self, value: Value) -> Result<()> {
487 let events = {
488 let mut state = self.state.lock().await;
489 state.progress = value.clone();
490 state.events.clone()
491 };
492 self.progress_tx.update(value, events);
493 Ok(())
494 }
495
496 async fn log_event(&self, fields: Value) -> Result<()> {
497 let (progress, events) = {
498 let mut state = self.state.lock().await;
499 state.events.push(fields.clone());
500 (state.progress.clone(), state.events.clone())
501 };
502 self.progress_tx.update(progress, events);
503 Ok(())
504 }
505}
506
507pub enum HeartbeatLoopResult {
508 Completed,
509 Cancelled,
510 LostLease(anyhow::Error),
511}
512
513#[async_trait]
514pub trait HeartbeatTransport: Send + Sync + Clone + 'static {
515 async fn post_heartbeat(
516 &self,
517 task_id: Uuid,
518 body: &crate::dms::types::HeartbeatRequest,
519 ) -> Result<crate::dms::types::HeartbeatResponse>;
520}
521
522#[async_trait]
523impl HeartbeatTransport for DmsClient {
524 async fn post_heartbeat(
525 &self,
526 task_id: Uuid,
527 body: &crate::dms::types::HeartbeatRequest,
528 ) -> Result<crate::dms::types::HeartbeatResponse> {
529 self.heartbeat(task_id, body).await
530 }
531}
532
533pub struct HeartbeatDriverArgs {
534 pub session: SessionManager,
535 pub policy: HeartbeatPolicy,
536 pub rng: StdRng,
537 pub progress_rx: ProgressReceiver,
538 pub state: Arc<Mutex<ControlState>>,
539 pub token_ref: crate::storage::TokenRef,
540 pub runner_cancel: CancellationToken,
541 pub shutdown: CancellationToken,
542 pub task_id: Uuid,
543}
544
545pub struct HeartbeatDriver<T>
546where
547 T: HeartbeatTransport,
548{
549 transport: T,
550 session: SessionManager,
551 policy: HeartbeatPolicy,
552 rng: StdRng,
553 progress_rx: ProgressReceiver,
554 state: Arc<Mutex<ControlState>>,
555 token_ref: crate::storage::TokenRef,
556 runner_cancel: CancellationToken,
557 shutdown: CancellationToken,
558 task_id: Uuid,
559 last_progress: Value,
560}
561
562impl<T> HeartbeatDriver<T>
563where
564 T: HeartbeatTransport,
565{
566 pub fn new(transport: T, args: HeartbeatDriverArgs) -> Self {
567 Self {
568 transport,
569 session: args.session,
570 policy: args.policy,
571 rng: args.rng,
572 progress_rx: args.progress_rx,
573 state: args.state,
574 token_ref: args.token_ref,
575 runner_cancel: args.runner_cancel,
576 shutdown: args.shutdown,
577 task_id: args.task_id,
578 last_progress: Value::default(),
579 }
580 }
581
582 pub async fn run(mut self) -> HeartbeatLoopResult {
583 loop {
584 if self.shutdown.is_cancelled() || self.runner_cancel.is_cancelled() {
585 return HeartbeatLoopResult::Completed;
586 }
587
588 let snapshot = match self.session.snapshot().await {
589 Some(s) => s,
590 None => return HeartbeatLoopResult::Completed,
591 };
592
593 let ttl_delay = snapshot
594 .next_heartbeat_due()
595 .map(|due| due.saturating_duration_since(Instant::now()));
596
597 if let Some(delay) = ttl_delay {
598 tokio::select! {
599 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
600 progress = self.progress_rx.recv() => {
601 if let Some(data) = progress {
602 if let Some(outcome) = self.handle_progress(data).await {
603 return outcome;
604 }
605 } else {
606 return HeartbeatLoopResult::Completed;
607 }
608 }
609 _ = tokio::time::sleep(delay) => {
610 if let Some(outcome) = self.handle_ttl().await {
611 return outcome;
612 }
613 }
614 }
615 } else {
616 tokio::select! {
617 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
618 progress = self.progress_rx.recv() => {
619 if let Some(data) = progress {
620 if let Some(outcome) = self.handle_progress(data).await {
621 return outcome;
622 }
623 } else {
624 return HeartbeatLoopResult::Completed;
625 }
626 }
627 }
628 }
629 }
630 }
631
632 async fn handle_progress(
633 &mut self,
634 data: crate::heartbeat::HeartbeatData,
635 ) -> Option<HeartbeatLoopResult> {
636 self.last_progress = data.progress.clone();
637 let (progress, events) = self.snapshot_state().await;
638 self.send_and_update(progress, events).await
639 }
640
641 async fn handle_ttl(&mut self) -> Option<HeartbeatLoopResult> {
642 let (progress, events) = self.snapshot_state().await;
643 self.send_and_update(progress, events).await
644 }
645
646 async fn snapshot_state(&self) -> (Value, Vec<Value>) {
647 let state = self.state.lock().await;
648 (state.progress.clone(), state.events.clone())
649 }
650
651 async fn send_and_update(
652 &mut self,
653 progress: Value,
654 events: Vec<Value>,
655 ) -> Option<HeartbeatLoopResult> {
656 let request = crate::dms::types::HeartbeatRequest {
657 progress: progress.clone(),
658 events: events.clone(),
659 };
660
661 match self.transport.post_heartbeat(self.task_id, &request).await {
662 Ok(update) => {
663 if !events.is_empty() {
664 let mut state = self.state.lock().await;
665 if state.events.len() >= events.len()
666 && state.events[..events.len()] == events[..]
667 {
668 state.events.drain(0..events.len());
669 }
670 }
671 apply_heartbeat_token_update(&self.token_ref, &update);
672 if let Some(task) = &update.task {
673 self.task_id = task.id;
674 } else if let Some(task_id) = update.task_id {
675 self.task_id = task_id;
676 }
677 if let Err(err) = self
678 .session
679 .apply_heartbeat(
680 &update,
681 Some(progress.clone()),
682 Instant::now(),
683 &self.policy,
684 &mut self.rng,
685 )
686 .await
687 {
688 return Some(HeartbeatLoopResult::LostLease(anyhow::Error::new(err)));
689 }
690 if update.cancel.unwrap_or(false) {
691 self.runner_cancel.cancel();
692 return Some(HeartbeatLoopResult::Cancelled);
693 }
694 None
695 }
696 Err(err) => {
697 self.runner_cancel.cancel();
698 Some(HeartbeatLoopResult::LostLease(err))
699 }
700 }
701 }
702}