1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use compute_runner_api::{ArtifactSink, ControlPlane, InputSource, LeaseEnvelope, Runner, TaskCtx};
4use rand::rngs::StdRng;
5use rand::SeedableRng;
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration as StdDuration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info, warn};
14use uuid::Uuid;
15
16use crate::{
17 dms::client::DmsClient,
18 heartbeat::{progress_channel, ProgressReceiver, ProgressSender},
19 poller::{jittered_delay_ms, PollerConfig},
20 session::{CapabilitySelector, HeartbeatPolicy, SessionManager},
21};
22
23#[derive(Default)]
25pub struct RunnerRegistry {
26 runners: HashMap<String, Arc<dyn Runner>>,
27}
28
29impl RunnerRegistry {
30 pub fn new() -> Self {
32 Self {
33 runners: HashMap::new(),
34 }
35 }
36
37 pub fn register<R: Runner + 'static>(mut self, runner: R) -> Self {
39 let key = runner.capability().to_string();
40 self.runners.insert(key, Arc::new(runner));
41 self
42 }
43
44 pub fn get(&self, capability: &str) -> Option<Arc<dyn Runner>> {
46 self.runners.get(capability).cloned()
47 }
48
49 pub fn capabilities(&self) -> Vec<String> {
51 let mut caps: Vec<_> = self.runners.keys().cloned().collect();
52 caps.sort();
53 caps
54 }
55
56 pub async fn run_for_lease(
58 &self,
59 lease: &LeaseEnvelope,
60 input: &dyn InputSource,
61 output: &dyn ArtifactSink,
62 ctrl: &dyn ControlPlane,
63 ) -> std::result::Result<(), crate::errors::ExecutorError> {
64 let cap = lease.task.capability.as_str();
65 let runner = self
66 .get(cap)
67 .ok_or_else(|| crate::errors::ExecutorError::NoRunner(cap.to_string()))?;
68 let ctx = TaskCtx {
69 lease,
70 input,
71 output,
72 ctrl,
73 };
74 runner
75 .run(ctx)
76 .await
77 .map_err(|e| crate::errors::ExecutorError::Runner(e.to_string()))
78 }
79}
80
81pub async fn run_node(cfg: crate::config::NodeConfig, runners: RunnerRegistry) -> Result<()> {
83 let shutdown = CancellationToken::new();
84 let signal_token = shutdown.clone();
85 let signal_task = tokio::spawn(async move {
86 if tokio::signal::ctrl_c().await.is_ok() {
87 signal_token.cancel();
88 }
89 });
90
91 let result = run_node_with_shutdown(cfg, runners, shutdown.clone()).await;
92
93 shutdown.cancel();
94 let _ = signal_task.await;
95
96 result
97}
98
99pub async fn run_node_with_shutdown(
100 cfg: crate::config::NodeConfig,
101 runners: RunnerRegistry,
102 shutdown: CancellationToken,
103) -> Result<()> {
104 let siwe = crate::auth::SiweAfterRegistration::from_config(&cfg)?;
105 info!("DDS SIWE authentication configured; waiting for DDS registration callback");
106 let siwe_handle = siwe.start().await?;
107 info!("DDS SIWE token manager started");
108
109 let poll_cfg = PollerConfig {
110 backoff_ms_min: cfg.poll_backoff_ms_min,
111 backoff_ms_max: cfg.poll_backoff_ms_max,
112 };
113
114 loop {
115 if shutdown.is_cancelled() {
116 break;
117 }
118
119 if let Err(err) = siwe_handle.bearer().await {
121 warn!(error = %err, "Failed to obtain SIWE bearer token; backing off");
122 let delay_ms = jittered_delay_ms(poll_cfg);
123 tokio::select! {
124 _ = shutdown.cancelled() => break,
125 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
126 }
127 }
128
129 let timeout = StdDuration::from_secs(cfg.request_timeout_secs);
130 let dms_client = match crate::dms::client::DmsClient::new(
131 cfg.dms_base_url.clone(),
132 timeout,
133 std::sync::Arc::new(siwe_handle.clone()),
134 ) {
135 Ok(client) => client,
136 Err(err) => {
137 warn!(error = %err, "Failed to create DMS client; backing off");
138 let delay_ms = jittered_delay_ms(poll_cfg);
139 tokio::select! {
140 _ = shutdown.cancelled() => break,
141 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
142 }
143 }
144 };
145
146 match run_cycle_with_dms(&cfg, &dms_client, &runners).await {
147 Ok(true) => {
148 continue;
150 }
151 Ok(false) => {
152 let delay_ms = jittered_delay_ms(poll_cfg);
153 info!(delay_ms, "No lease available; backing off before next poll");
154 tokio::select! {
155 _ = shutdown.cancelled() => break,
156 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
157 }
158 }
159 Err(err) => {
160 warn!(error = %err, "DMS cycle failed; backing off");
161 let delay_ms = jittered_delay_ms(poll_cfg);
162 tokio::select! {
163 _ = shutdown.cancelled() => break,
164 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
165 }
166 }
167 }
168 }
169
170 siwe_handle.shutdown().await;
171 info!("Shutdown signal received; exiting run_node loop");
172
173 Ok(())
174}
175
176pub fn build_storage_for_lease(lease: &LeaseEnvelope) -> Result<crate::storage::Ports> {
179 let token = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
180 crate::storage::build_ports(lease, token)
181}
182
183pub fn apply_heartbeat_token_update(
186 token: &crate::storage::TokenRef,
187 hb: &crate::dms::types::HeartbeatResponse,
188) {
189 if let Some(new) = hb.access_token.clone() {
190 token.swap(new);
191 }
192}
193
194pub fn merge_heartbeat_into_lease(
196 lease: &mut LeaseEnvelope,
197 hb: &crate::dms::types::HeartbeatResponse,
198) {
199 if let Some(token) = hb.access_token.clone() {
200 lease.access_token = Some(token);
201 }
202 if let Some(expiry) = hb.access_token_expires_at {
203 lease.access_token_expires_at = Some(expiry);
204 }
205 if let Some(expiry) = hb.lease_expires_at {
206 lease.lease_expires_at = Some(expiry);
207 }
208 if let Some(cancel) = hb.cancel {
209 lease.cancel = cancel;
210 }
211 if let Some(status) = hb.status.clone() {
212 lease.status = Some(status);
213 }
214 if let Some(domain_id) = hb.domain_id {
215 lease.domain_id = Some(domain_id);
216 }
217 if let Some(url) = hb.domain_server_url.clone() {
218 lease.domain_server_url = Some(url);
219 }
220 if let Some(task) = hb.task.clone() {
221 lease.task = task;
222 } else {
223 if let Some(task_id) = hb.task_id {
224 lease.task.id = task_id;
225 }
226 if let Some(job_id) = hb.job_id {
227 lease.task.job_id = Some(job_id);
228 }
229 if let Some(attempts) = hb.attempts {
230 lease.task.attempts = Some(attempts);
231 }
232 if let Some(max_attempts) = hb.max_attempts {
233 lease.task.max_attempts = Some(max_attempts);
234 }
235 if let Some(deps_remaining) = hb.deps_remaining {
236 lease.task.deps_remaining = Some(deps_remaining);
237 }
238 }
239}
240
241pub async fn run_cycle_with_dms(
244 _cfg: &crate::config::NodeConfig,
245 dms: &DmsClient,
246 reg: &RunnerRegistry,
247) -> Result<bool> {
248 use crate::dms::types::{CompleteTaskRequest, FailTaskRequest, HeartbeatRequest};
249 use serde_json::json;
250
251 let capabilities = reg.capabilities();
252 let capability = capabilities
253 .first()
254 .cloned()
255 .ok_or_else(|| anyhow!("no runners registered"))?;
256
257 let mut lease = match dms.lease_by_capability(&capability).await? {
259 Some(lease) => lease,
260 None => {
261 return Ok(false);
262 }
263 };
264 if lease.access_token.is_none() {
265 tracing::warn!(
266 "Lease missing access token; storage client will fall back to legacy token flow"
267 );
268 }
269
270 let selector = CapabilitySelector::new(capabilities.clone());
272 let session = SessionManager::new(selector);
273 let policy = HeartbeatPolicy::default_policy();
274 let mut rng = StdRng::from_entropy();
275 let snapshot = session
276 .start_session(&lease, Instant::now(), &policy, &mut rng)
277 .await
278 .map_err(|err| anyhow!("failed to initialise session: {err}"))?;
279 if snapshot.cancel() {
280 warn!(
281 task_id = %snapshot.task_id(),
282 "Lease already marked as cancelled; skipping execution"
283 );
284 return Ok(true);
285 }
286
287 let token_ref = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
288
289 let heartbeat_initial = dms
290 .heartbeat(
291 lease.task.id,
292 &HeartbeatRequest {
293 progress: json!({}),
294 events: json!({}),
295 },
296 )
297 .await?;
298 apply_heartbeat_token_update(&token_ref, &heartbeat_initial);
299 merge_heartbeat_into_lease(&mut lease, &heartbeat_initial);
300 session
301 .apply_heartbeat(
302 &heartbeat_initial,
303 Some(json!({})),
304 Instant::now(),
305 &policy,
306 &mut rng,
307 )
308 .await
309 .map_err(|err| anyhow!("failed to refresh session after heartbeat: {err}"))?;
310
311 let ports = crate::storage::build_ports(&lease, token_ref.clone())?;
312
313 let (progress_tx, progress_rx) = progress_channel();
314 let control_state = Arc::new(Mutex::new(ControlState::default()));
315 {
316 let mut guard = control_state.lock().await;
317 guard.progress = json!({});
318 guard.events = json!({});
319 }
320
321 let runner_cancel = CancellationToken::new();
322 let heartbeat_shutdown = CancellationToken::new();
323
324 let ctrl = EngineControlPlane::new(
325 runner_cancel.clone(),
326 progress_tx.clone(),
327 control_state.clone(),
328 );
329
330 progress_tx.update(json!({}), json!({}));
332
333 let heartbeat_driver = HeartbeatDriver::new(
334 dms.clone(),
335 HeartbeatDriverArgs {
336 session: session.clone(),
337 policy,
338 rng,
339 progress_rx,
340 state: control_state.clone(),
341 token_ref: token_ref.clone(),
342 runner_cancel: runner_cancel.clone(),
343 shutdown: heartbeat_shutdown.clone(),
344 task_id: lease.task.id,
345 },
346 );
347 let heartbeat_handle = tokio::spawn(async move { heartbeat_driver.run().await });
348
349 let run_res = reg
350 .run_for_lease(&lease, &*ports.input, &*ports.output, &ctrl)
351 .await;
352
353 heartbeat_shutdown.cancel();
354 let heartbeat_result = match heartbeat_handle.await {
355 Ok(result) => result,
356 Err(err) => {
357 warn!(error = %err, "heartbeat loop task failed");
358 HeartbeatLoopResult::Completed
359 }
360 };
361
362 match heartbeat_result {
363 HeartbeatLoopResult::Completed => {}
364 HeartbeatLoopResult::Cancelled => {
365 info!(
366 task_id = %lease.task.id,
367 "Lease cancelled during execution; skipping completion"
368 );
369 runner_cancel.cancel();
370 return Ok(true);
371 }
372 HeartbeatLoopResult::LostLease(err) => {
373 warn!(
374 task_id = %lease.task.id,
375 error = %err,
376 "Lease lost during heartbeat; abandoning task"
377 );
378 runner_cancel.cancel();
379 return Ok(true);
380 }
381 }
382
383 let uploaded_artifacts = ports.uploaded_artifacts();
384 let artifacts_json: Vec<Value> = uploaded_artifacts
385 .iter()
386 .map(|artifact| {
387 json!({
388 "logical_path": artifact.logical_path,
389 "name": artifact.name,
390 "data_type": artifact.data_type,
391 "id": artifact.id,
392 })
393 })
394 .collect();
395 let job_info = json!({
396 "task_id": lease.task.id,
397 "job_id": lease.task.job_id,
398 "domain_id": lease.domain_id,
399 "capability": lease.task.capability,
400 });
401
402 match run_res {
404 Ok(()) => {
405 let body = CompleteTaskRequest {
406 outputs_index: json!({ "artifacts": artifacts_json.clone() }),
407 result: json!({
408 "job": job_info,
409 "artifacts": artifacts_json,
410 }),
411 };
412 dms.complete(lease.task.id, &body).await?;
413 }
414 Err(err) => {
415 error!(
416 task_id = %lease.task.id,
417 job_id = ?lease.task.job_id,
418 capability = %lease.task.capability,
419 error = %err,
420 debug = ?err,
421 "Runner execution failed; reporting failure to DMS"
422 );
423 let body = FailTaskRequest {
424 reason: err.to_string(),
425 details: json!({
426 "job": job_info,
427 "artifacts": artifacts_json,
428 }),
429 };
430 dms.fail(lease.task.id, &body)
431 .await
432 .with_context(|| format!("report fail for task {} to DMS", lease.task.id))?;
433 }
434 }
435
436 Ok(true)
437}
438
439#[derive(Default)]
440pub struct ControlState {
441 progress: Value,
442 events: Value,
443}
444
445struct EngineControlPlane {
446 cancel: CancellationToken,
447 progress_tx: ProgressSender,
448 state: Arc<Mutex<ControlState>>,
449}
450
451impl EngineControlPlane {
452 pub fn new(
453 cancel: CancellationToken,
454 progress_tx: ProgressSender,
455 state: Arc<Mutex<ControlState>>,
456 ) -> Self {
457 Self {
458 cancel,
459 progress_tx,
460 state,
461 }
462 }
463}
464
465#[async_trait]
466impl ControlPlane for EngineControlPlane {
467 async fn is_cancelled(&self) -> bool {
468 self.cancel.is_cancelled()
469 }
470
471 async fn progress(&self, value: Value) -> Result<()> {
472 let events = {
473 let mut state = self.state.lock().await;
474 state.progress = value.clone();
475 state.events.clone()
476 };
477 self.progress_tx.update(value, events);
478 Ok(())
479 }
480
481 async fn log_event(&self, fields: Value) -> Result<()> {
482 let progress = {
483 let mut state = self.state.lock().await;
484 state.events = fields.clone();
485 state.progress.clone()
486 };
487 self.progress_tx.update(progress, fields);
488 Ok(())
489 }
490}
491
492pub enum HeartbeatLoopResult {
493 Completed,
494 Cancelled,
495 LostLease(anyhow::Error),
496}
497
498#[async_trait]
499pub trait HeartbeatTransport: Send + Sync + Clone + 'static {
500 async fn post_heartbeat(
501 &self,
502 task_id: Uuid,
503 body: &crate::dms::types::HeartbeatRequest,
504 ) -> Result<crate::dms::types::HeartbeatResponse>;
505}
506
507#[async_trait]
508impl HeartbeatTransport for DmsClient {
509 async fn post_heartbeat(
510 &self,
511 task_id: Uuid,
512 body: &crate::dms::types::HeartbeatRequest,
513 ) -> Result<crate::dms::types::HeartbeatResponse> {
514 self.heartbeat(task_id, body).await
515 }
516}
517
518pub struct HeartbeatDriverArgs {
519 pub session: SessionManager,
520 pub policy: HeartbeatPolicy,
521 pub rng: StdRng,
522 pub progress_rx: ProgressReceiver,
523 pub state: Arc<Mutex<ControlState>>,
524 pub token_ref: crate::storage::TokenRef,
525 pub runner_cancel: CancellationToken,
526 pub shutdown: CancellationToken,
527 pub task_id: Uuid,
528}
529
530pub struct HeartbeatDriver<T>
531where
532 T: HeartbeatTransport,
533{
534 transport: T,
535 session: SessionManager,
536 policy: HeartbeatPolicy,
537 rng: StdRng,
538 progress_rx: ProgressReceiver,
539 state: Arc<Mutex<ControlState>>,
540 token_ref: crate::storage::TokenRef,
541 runner_cancel: CancellationToken,
542 shutdown: CancellationToken,
543 task_id: Uuid,
544 last_progress: Value,
545}
546
547impl<T> HeartbeatDriver<T>
548where
549 T: HeartbeatTransport,
550{
551 pub fn new(transport: T, args: HeartbeatDriverArgs) -> Self {
552 Self {
553 transport,
554 session: args.session,
555 policy: args.policy,
556 rng: args.rng,
557 progress_rx: args.progress_rx,
558 state: args.state,
559 token_ref: args.token_ref,
560 runner_cancel: args.runner_cancel,
561 shutdown: args.shutdown,
562 task_id: args.task_id,
563 last_progress: Value::default(),
564 }
565 }
566
567 pub async fn run(mut self) -> HeartbeatLoopResult {
568 loop {
569 if self.shutdown.is_cancelled() || self.runner_cancel.is_cancelled() {
570 return HeartbeatLoopResult::Completed;
571 }
572
573 let snapshot = match self.session.snapshot().await {
574 Some(s) => s,
575 None => return HeartbeatLoopResult::Completed,
576 };
577
578 let ttl_delay = snapshot
579 .next_heartbeat_due()
580 .map(|due| due.saturating_duration_since(Instant::now()));
581
582 if let Some(delay) = ttl_delay {
583 tokio::select! {
584 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
585 progress = self.progress_rx.recv() => {
586 if let Some(data) = progress {
587 if let Some(outcome) = self.handle_progress(data).await {
588 return outcome;
589 }
590 } else {
591 return HeartbeatLoopResult::Completed;
592 }
593 }
594 _ = tokio::time::sleep(delay) => {
595 if let Some(outcome) = self.handle_ttl().await {
596 return outcome;
597 }
598 }
599 }
600 } else {
601 tokio::select! {
602 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
603 progress = self.progress_rx.recv() => {
604 if let Some(data) = progress {
605 if let Some(outcome) = self.handle_progress(data).await {
606 return outcome;
607 }
608 } else {
609 return HeartbeatLoopResult::Completed;
610 }
611 }
612 }
613 }
614 }
615 }
616
617 async fn handle_progress(
618 &mut self,
619 data: crate::heartbeat::HeartbeatData,
620 ) -> Option<HeartbeatLoopResult> {
621 self.last_progress = data.progress.clone();
622 self.send_and_update(data.progress, data.events).await
623 }
624
625 async fn handle_ttl(&mut self) -> Option<HeartbeatLoopResult> {
626 let (progress, events) = self.snapshot_state().await;
627 self.send_and_update(progress, events).await
628 }
629
630 async fn snapshot_state(&self) -> (Value, Value) {
631 let state = self.state.lock().await;
632 (state.progress.clone(), state.events.clone())
633 }
634
635 async fn send_and_update(
636 &mut self,
637 progress: Value,
638 events: Value,
639 ) -> Option<HeartbeatLoopResult> {
640 let request = crate::dms::types::HeartbeatRequest {
641 progress: progress.clone(),
642 events: events.clone(),
643 };
644
645 match self.transport.post_heartbeat(self.task_id, &request).await {
646 Ok(update) => {
647 apply_heartbeat_token_update(&self.token_ref, &update);
648 if let Some(task) = &update.task {
649 self.task_id = task.id;
650 } else if let Some(task_id) = update.task_id {
651 self.task_id = task_id;
652 }
653 if let Err(err) = self
654 .session
655 .apply_heartbeat(
656 &update,
657 Some(progress.clone()),
658 Instant::now(),
659 &self.policy,
660 &mut self.rng,
661 )
662 .await
663 {
664 return Some(HeartbeatLoopResult::LostLease(anyhow::Error::new(err)));
665 }
666 if update.cancel.unwrap_or(false) {
667 self.runner_cancel.cancel();
668 return Some(HeartbeatLoopResult::Cancelled);
669 }
670 None
671 }
672 Err(err) => {
673 self.runner_cancel.cancel();
674 Some(HeartbeatLoopResult::LostLease(err))
675 }
676 }
677 }
678}