1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use compute_runner_api::{ArtifactSink, ControlPlane, InputSource, LeaseEnvelope, Runner, TaskCtx};
4use rand::rngs::StdRng;
5use rand::SeedableRng;
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration as StdDuration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12use tokio_util::sync::CancellationToken;
13use tracing::{error, info, warn};
14use uuid::Uuid;
15
16use crate::{
17 dms::client::DmsClient,
18 heartbeat::{progress_channel, ProgressReceiver, ProgressSender},
19 poller::{jittered_delay_ms, PollerConfig},
20 session::{CapabilitySelector, HeartbeatPolicy, SessionManager},
21};
22
23#[derive(Default)]
25pub struct RunnerRegistry {
26 runners: HashMap<String, Arc<dyn Runner>>,
27}
28
29impl RunnerRegistry {
30 pub fn new() -> Self {
32 Self {
33 runners: HashMap::new(),
34 }
35 }
36
37 pub fn register<R: Runner + 'static>(mut self, runner: R) -> Self {
39 let key = runner.capability().to_string();
40 self.runners.insert(key, Arc::new(runner));
41 self
42 }
43
44 pub fn get(&self, capability: &str) -> Option<Arc<dyn Runner>> {
46 self.runners.get(capability).cloned()
47 }
48
49 pub fn capabilities(&self) -> Vec<String> {
51 let mut caps: Vec<_> = self.runners.keys().cloned().collect();
52 caps.sort();
53 caps
54 }
55
56 pub async fn run_for_lease(
58 &self,
59 lease: &LeaseEnvelope,
60 input: &dyn InputSource,
61 output: &dyn ArtifactSink,
62 ctrl: &dyn ControlPlane,
63 access_token: &dyn compute_runner_api::runner::AccessTokenProvider,
64 ) -> std::result::Result<(), crate::errors::ExecutorError> {
65 let cap = lease.task.capability.as_str();
66 let runner = self
67 .get(cap)
68 .ok_or_else(|| crate::errors::ExecutorError::NoRunner(cap.to_string()))?;
69 let ctx = TaskCtx {
70 lease,
71 input,
72 output,
73 ctrl,
74 access_token,
75 };
76 runner
77 .run(ctx)
78 .await
79 .map_err(|e| crate::errors::ExecutorError::Runner(e.to_string()))
80 }
81}
82
83pub async fn run_node(cfg: crate::config::NodeConfig, runners: RunnerRegistry) -> Result<()> {
85 let shutdown = CancellationToken::new();
86 let signal_token = shutdown.clone();
87 let signal_task = tokio::spawn(async move {
88 if tokio::signal::ctrl_c().await.is_ok() {
89 signal_token.cancel();
90 }
91 });
92
93 let result = run_node_with_shutdown(cfg, runners, shutdown.clone()).await;
94
95 shutdown.cancel();
96 let _ = signal_task.await;
97
98 result
99}
100
101pub async fn run_node_with_shutdown(
102 cfg: crate::config::NodeConfig,
103 runners: RunnerRegistry,
104 shutdown: CancellationToken,
105) -> Result<()> {
106 let siwe = crate::auth::SiweAfterRegistration::from_config(&cfg)?;
107 info!("DDS SIWE authentication configured; waiting for DDS registration callback");
108 let siwe_handle = siwe.start().await?;
109 info!("DDS SIWE token manager started");
110
111 let poll_cfg = PollerConfig {
112 backoff_ms_min: cfg.poll_backoff_ms_min,
113 backoff_ms_max: cfg.poll_backoff_ms_max,
114 };
115
116 loop {
117 if shutdown.is_cancelled() {
118 break;
119 }
120
121 if let Err(err) = siwe_handle.bearer().await {
123 warn!(error = %err, "Failed to obtain SIWE bearer token; backing off");
124 let delay_ms = jittered_delay_ms(poll_cfg);
125 tokio::select! {
126 _ = shutdown.cancelled() => break,
127 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
128 }
129 }
130
131 let timeout = StdDuration::from_secs(cfg.request_timeout_secs);
132 let dms_client = match crate::dms::client::DmsClient::new(
133 cfg.dms_base_url.clone(),
134 timeout,
135 std::sync::Arc::new(siwe_handle.clone()),
136 ) {
137 Ok(client) => client,
138 Err(err) => {
139 warn!(error = %err, "Failed to create DMS client; backing off");
140 let delay_ms = jittered_delay_ms(poll_cfg);
141 tokio::select! {
142 _ = shutdown.cancelled() => break,
143 _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
144 }
145 }
146 };
147
148 match run_cycle_with_dms(&cfg, &dms_client, &runners).await {
149 Ok(true) => {
150 continue;
152 }
153 Ok(false) => {
154 let delay_ms = jittered_delay_ms(poll_cfg);
155 info!(delay_ms, "No lease available; backing off before next poll");
156 tokio::select! {
157 _ = shutdown.cancelled() => break,
158 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
159 }
160 }
161 Err(err) => {
162 warn!(error = %err, "DMS cycle failed; backing off");
163 let delay_ms = jittered_delay_ms(poll_cfg);
164 tokio::select! {
165 _ = shutdown.cancelled() => break,
166 _ = sleep(StdDuration::from_millis(delay_ms)) => {}
167 }
168 }
169 }
170 }
171
172 siwe_handle.shutdown().await;
173 info!("Shutdown signal received; exiting run_node loop");
174
175 Ok(())
176}
177
178pub fn build_storage_for_lease(lease: &LeaseEnvelope) -> Result<crate::storage::Ports> {
181 let token = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
182 crate::storage::build_ports(lease, token)
183}
184
185pub fn apply_heartbeat_token_update(
188 token: &crate::storage::TokenRef,
189 hb: &crate::dms::types::HeartbeatResponse,
190) {
191 if let Some(new) = hb.access_token.clone() {
192 token.swap(new);
193 }
194}
195
196pub fn merge_heartbeat_into_lease(
198 lease: &mut LeaseEnvelope,
199 hb: &crate::dms::types::HeartbeatResponse,
200) {
201 if let Some(token) = hb.access_token.clone() {
202 lease.access_token = Some(token);
203 }
204 if let Some(expiry) = hb.access_token_expires_at {
205 lease.access_token_expires_at = Some(expiry);
206 }
207 if let Some(expiry) = hb.lease_expires_at {
208 lease.lease_expires_at = Some(expiry);
209 }
210 if let Some(cancel) = hb.cancel {
211 lease.cancel = cancel;
212 }
213 if let Some(status) = hb.status.clone() {
214 lease.status = Some(status);
215 }
216 if let Some(domain_id) = hb.domain_id {
217 lease.domain_id = Some(domain_id);
218 }
219 if let Some(url) = hb.domain_server_url.clone() {
220 lease.domain_server_url = Some(url);
221 }
222 if let Some(task) = hb.task.clone() {
223 lease.task = task;
224 } else {
225 if let Some(task_id) = hb.task_id {
226 lease.task.id = task_id;
227 }
228 if let Some(job_id) = hb.job_id {
229 lease.task.job_id = Some(job_id);
230 }
231 if let Some(attempts) = hb.attempts {
232 lease.task.attempts = Some(attempts);
233 }
234 if let Some(max_attempts) = hb.max_attempts {
235 lease.task.max_attempts = Some(max_attempts);
236 }
237 if let Some(deps_remaining) = hb.deps_remaining {
238 lease.task.deps_remaining = Some(deps_remaining);
239 }
240 }
241}
242
243pub async fn run_cycle_with_dms(
246 _cfg: &crate::config::NodeConfig,
247 dms: &DmsClient,
248 reg: &RunnerRegistry,
249) -> Result<bool> {
250 use crate::dms::types::{CompleteTaskRequest, FailTaskRequest, HeartbeatRequest};
251 use serde_json::json;
252
253 let capabilities = reg.capabilities();
254 let capability = capabilities
255 .first()
256 .cloned()
257 .ok_or_else(|| anyhow!("no runners registered"))?;
258
259 let mut lease = match dms.lease_by_capability(&capability).await? {
261 Some(lease) => lease,
262 None => {
263 return Ok(false);
264 }
265 };
266 if lease.access_token.is_none() {
267 tracing::warn!(
268 "Lease missing access token; storage client will fall back to legacy token flow"
269 );
270 }
271
272 let selector = CapabilitySelector::new(capabilities.clone());
274 let session = SessionManager::new(selector);
275 let policy = HeartbeatPolicy::default_policy();
276 let mut rng = StdRng::from_entropy();
277 let snapshot = session
278 .start_session(&lease, Instant::now(), &policy, &mut rng)
279 .await
280 .map_err(|err| anyhow!("failed to initialise session: {err}"))?;
281 if snapshot.cancel() {
282 warn!(
283 task_id = %snapshot.task_id(),
284 "Lease already marked as cancelled; skipping execution"
285 );
286 return Ok(true);
287 }
288
289 let token_ref = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
290
291 let heartbeat_initial = dms
292 .heartbeat(
293 lease.task.id,
294 &HeartbeatRequest {
295 progress: json!({}),
296 events: json!({}),
297 },
298 )
299 .await?;
300 apply_heartbeat_token_update(&token_ref, &heartbeat_initial);
301 merge_heartbeat_into_lease(&mut lease, &heartbeat_initial);
302 session
303 .apply_heartbeat(
304 &heartbeat_initial,
305 Some(json!({})),
306 Instant::now(),
307 &policy,
308 &mut rng,
309 )
310 .await
311 .map_err(|err| anyhow!("failed to refresh session after heartbeat: {err}"))?;
312
313 let ports = crate::storage::build_ports(&lease, token_ref.clone())?;
314
315 let (progress_tx, progress_rx) = progress_channel();
316 let control_state = Arc::new(Mutex::new(ControlState::default()));
317 {
318 let mut guard = control_state.lock().await;
319 guard.progress = json!({});
320 guard.events = json!({});
321 }
322
323 let runner_cancel = CancellationToken::new();
324 let heartbeat_shutdown = CancellationToken::new();
325
326 let ctrl = EngineControlPlane::new(
327 runner_cancel.clone(),
328 progress_tx.clone(),
329 control_state.clone(),
330 );
331
332 progress_tx.update(json!({}), json!({}));
334
335 let heartbeat_driver = HeartbeatDriver::new(
336 dms.clone(),
337 HeartbeatDriverArgs {
338 session: session.clone(),
339 policy,
340 rng,
341 progress_rx,
342 state: control_state.clone(),
343 token_ref: token_ref.clone(),
344 runner_cancel: runner_cancel.clone(),
345 shutdown: heartbeat_shutdown.clone(),
346 task_id: lease.task.id,
347 },
348 );
349 let heartbeat_handle = tokio::spawn(async move { heartbeat_driver.run().await });
350
351 let run_res = reg
352 .run_for_lease(&lease, &*ports.input, &*ports.output, &ctrl, &token_ref)
353 .await;
354
355 {
359 let state = control_state.lock().await;
360 progress_tx.update(state.progress.clone(), state.events.clone());
361 }
362 sleep(StdDuration::from_millis(200)).await;
363
364 heartbeat_shutdown.cancel();
365 let heartbeat_result = match heartbeat_handle.await {
366 Ok(result) => result,
367 Err(err) => {
368 warn!(error = %err, "heartbeat loop task failed");
369 HeartbeatLoopResult::Completed
370 }
371 };
372
373 match heartbeat_result {
374 HeartbeatLoopResult::Completed => {}
375 HeartbeatLoopResult::Cancelled => {
376 info!(
377 task_id = %lease.task.id,
378 "Lease cancelled during execution; skipping completion"
379 );
380 runner_cancel.cancel();
381 return Ok(true);
382 }
383 HeartbeatLoopResult::LostLease(err) => {
384 warn!(
385 task_id = %lease.task.id,
386 error = %err,
387 "Lease lost during heartbeat; abandoning task"
388 );
389 runner_cancel.cancel();
390 return Ok(true);
391 }
392 }
393
394 let uploaded_artifacts = ports.uploaded_artifacts();
395 let artifacts_json: Vec<Value> = uploaded_artifacts
396 .iter()
397 .map(|artifact| {
398 json!({
399 "logical_path": artifact.logical_path,
400 "name": artifact.name,
401 "data_type": artifact.data_type,
402 "id": artifact.id,
403 })
404 })
405 .collect();
406 let job_info = json!({
407 "task_id": lease.task.id,
408 "job_id": lease.task.job_id,
409 "domain_id": lease.domain_id,
410 "capability": lease.task.capability,
411 });
412
413 match run_res {
415 Ok(()) => {
416 let body = CompleteTaskRequest {
417 outputs_index: json!({ "artifacts": artifacts_json.clone() }),
418 result: json!({
419 "job": job_info,
420 "artifacts": artifacts_json,
421 }),
422 };
423 dms.complete(lease.task.id, &body).await?;
424 }
425 Err(err) => {
426 error!(
427 task_id = %lease.task.id,
428 job_id = ?lease.task.job_id,
429 capability = %lease.task.capability,
430 error = %err,
431 debug = ?err,
432 "Runner execution failed; reporting failure to DMS"
433 );
434 let body = FailTaskRequest {
435 reason: err.to_string(),
436 details: json!({
437 "job": job_info,
438 "artifacts": artifacts_json,
439 }),
440 };
441 dms.fail(lease.task.id, &body)
442 .await
443 .with_context(|| format!("report fail for task {} to DMS", lease.task.id))?;
444 }
445 }
446
447 Ok(true)
448}
449
450#[derive(Default)]
451pub struct ControlState {
452 progress: Value,
453 events: Value,
454}
455
456struct EngineControlPlane {
457 cancel: CancellationToken,
458 progress_tx: ProgressSender,
459 state: Arc<Mutex<ControlState>>,
460}
461
462impl EngineControlPlane {
463 pub fn new(
464 cancel: CancellationToken,
465 progress_tx: ProgressSender,
466 state: Arc<Mutex<ControlState>>,
467 ) -> Self {
468 Self {
469 cancel,
470 progress_tx,
471 state,
472 }
473 }
474}
475
476#[async_trait]
477impl ControlPlane for EngineControlPlane {
478 async fn is_cancelled(&self) -> bool {
479 self.cancel.is_cancelled()
480 }
481
482 async fn progress(&self, value: Value) -> Result<()> {
483 let events = {
484 let mut state = self.state.lock().await;
485 state.progress = value.clone();
486 state.events.clone()
487 };
488 self.progress_tx.update(value, events);
489 Ok(())
490 }
491
492 async fn log_event(&self, fields: Value) -> Result<()> {
493 let progress = {
494 let mut state = self.state.lock().await;
495 state.events = fields.clone();
496 state.progress.clone()
497 };
498 self.progress_tx.update(progress, fields);
499 Ok(())
500 }
501}
502
503pub enum HeartbeatLoopResult {
504 Completed,
505 Cancelled,
506 LostLease(anyhow::Error),
507}
508
509#[async_trait]
510pub trait HeartbeatTransport: Send + Sync + Clone + 'static {
511 async fn post_heartbeat(
512 &self,
513 task_id: Uuid,
514 body: &crate::dms::types::HeartbeatRequest,
515 ) -> Result<crate::dms::types::HeartbeatResponse>;
516}
517
518#[async_trait]
519impl HeartbeatTransport for DmsClient {
520 async fn post_heartbeat(
521 &self,
522 task_id: Uuid,
523 body: &crate::dms::types::HeartbeatRequest,
524 ) -> Result<crate::dms::types::HeartbeatResponse> {
525 self.heartbeat(task_id, body).await
526 }
527}
528
529pub struct HeartbeatDriverArgs {
530 pub session: SessionManager,
531 pub policy: HeartbeatPolicy,
532 pub rng: StdRng,
533 pub progress_rx: ProgressReceiver,
534 pub state: Arc<Mutex<ControlState>>,
535 pub token_ref: crate::storage::TokenRef,
536 pub runner_cancel: CancellationToken,
537 pub shutdown: CancellationToken,
538 pub task_id: Uuid,
539}
540
541pub struct HeartbeatDriver<T>
542where
543 T: HeartbeatTransport,
544{
545 transport: T,
546 session: SessionManager,
547 policy: HeartbeatPolicy,
548 rng: StdRng,
549 progress_rx: ProgressReceiver,
550 state: Arc<Mutex<ControlState>>,
551 token_ref: crate::storage::TokenRef,
552 runner_cancel: CancellationToken,
553 shutdown: CancellationToken,
554 task_id: Uuid,
555 last_progress: Value,
556}
557
558impl<T> HeartbeatDriver<T>
559where
560 T: HeartbeatTransport,
561{
562 pub fn new(transport: T, args: HeartbeatDriverArgs) -> Self {
563 Self {
564 transport,
565 session: args.session,
566 policy: args.policy,
567 rng: args.rng,
568 progress_rx: args.progress_rx,
569 state: args.state,
570 token_ref: args.token_ref,
571 runner_cancel: args.runner_cancel,
572 shutdown: args.shutdown,
573 task_id: args.task_id,
574 last_progress: Value::default(),
575 }
576 }
577
578 pub async fn run(mut self) -> HeartbeatLoopResult {
579 loop {
580 if self.shutdown.is_cancelled() || self.runner_cancel.is_cancelled() {
581 return HeartbeatLoopResult::Completed;
582 }
583
584 let snapshot = match self.session.snapshot().await {
585 Some(s) => s,
586 None => return HeartbeatLoopResult::Completed,
587 };
588
589 let ttl_delay = snapshot
590 .next_heartbeat_due()
591 .map(|due| due.saturating_duration_since(Instant::now()));
592
593 if let Some(delay) = ttl_delay {
594 tokio::select! {
595 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
596 progress = self.progress_rx.recv() => {
597 if let Some(data) = progress {
598 if let Some(outcome) = self.handle_progress(data).await {
599 return outcome;
600 }
601 } else {
602 return HeartbeatLoopResult::Completed;
603 }
604 }
605 _ = tokio::time::sleep(delay) => {
606 if let Some(outcome) = self.handle_ttl().await {
607 return outcome;
608 }
609 }
610 }
611 } else {
612 tokio::select! {
613 _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
614 progress = self.progress_rx.recv() => {
615 if let Some(data) = progress {
616 if let Some(outcome) = self.handle_progress(data).await {
617 return outcome;
618 }
619 } else {
620 return HeartbeatLoopResult::Completed;
621 }
622 }
623 }
624 }
625 }
626 }
627
628 async fn handle_progress(
629 &mut self,
630 data: crate::heartbeat::HeartbeatData,
631 ) -> Option<HeartbeatLoopResult> {
632 self.last_progress = data.progress.clone();
633 self.send_and_update(data.progress, data.events).await
634 }
635
636 async fn handle_ttl(&mut self) -> Option<HeartbeatLoopResult> {
637 let (progress, events) = self.snapshot_state().await;
638 self.send_and_update(progress, events).await
639 }
640
641 async fn snapshot_state(&self) -> (Value, Value) {
642 let state = self.state.lock().await;
643 (state.progress.clone(), state.events.clone())
644 }
645
646 async fn send_and_update(
647 &mut self,
648 progress: Value,
649 events: Value,
650 ) -> Option<HeartbeatLoopResult> {
651 let request = crate::dms::types::HeartbeatRequest {
652 progress: progress.clone(),
653 events: events.clone(),
654 };
655
656 match self.transport.post_heartbeat(self.task_id, &request).await {
657 Ok(update) => {
658 apply_heartbeat_token_update(&self.token_ref, &update);
659 if let Some(task) = &update.task {
660 self.task_id = task.id;
661 } else if let Some(task_id) = update.task_id {
662 self.task_id = task_id;
663 }
664 if let Err(err) = self
665 .session
666 .apply_heartbeat(
667 &update,
668 Some(progress.clone()),
669 Instant::now(),
670 &self.policy,
671 &mut self.rng,
672 )
673 .await
674 {
675 return Some(HeartbeatLoopResult::LostLease(anyhow::Error::new(err)));
676 }
677 if update.cancel.unwrap_or(false) {
678 self.runner_cancel.cancel();
679 return Some(HeartbeatLoopResult::Cancelled);
680 }
681 None
682 }
683 Err(err) => {
684 self.runner_cancel.cancel();
685 Some(HeartbeatLoopResult::LostLease(err))
686 }
687 }
688 }
689}