Skip to main content

posemesh_compute_node/
engine.rs

1use anyhow::{anyhow, Context, Result};
2use async_trait::async_trait;
3use compute_runner_api::{ArtifactSink, ControlPlane, InputSource, LeaseEnvelope, Runner, TaskCtx};
4use rand::rngs::StdRng;
5use rand::SeedableRng;
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration as StdDuration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12use tokio_util::sync::CancellationToken;
13use tracing::{debug, error, info, warn};
14use uuid::Uuid;
15
16use crate::{
17    dms::client::DmsClient,
18    heartbeat::{progress_channel, ProgressReceiver, ProgressSender},
19    poller::{jittered_delay_ms, PollerConfig},
20    session::{CapabilitySelector, HeartbeatPolicy, SessionManager},
21};
22
23/// Registry mapping capability strings to runner instances.
24#[derive(Default)]
25pub struct RunnerRegistry {
26    runners: HashMap<String, Arc<dyn Runner>>,
27}
28
29impl RunnerRegistry {
30    /// Create an empty registry.
31    pub fn new() -> Self {
32        Self {
33            runners: HashMap::new(),
34        }
35    }
36
37    /// Register a runner by its capability. Last registration wins on duplicates.
38    pub fn register<R: Runner + 'static>(mut self, runner: R) -> Self {
39        let key = runner.capability().to_string();
40        self.runners.insert(key, Arc::new(runner));
41        self
42    }
43
44    /// Retrieve a runner by capability.
45    pub fn get(&self, capability: &str) -> Option<Arc<dyn Runner>> {
46        self.runners.get(capability).cloned()
47    }
48
49    /// Snapshot of registered capability strings.
50    pub fn capabilities(&self) -> Vec<String> {
51        let mut caps: Vec<_> = self.runners.keys().cloned().collect();
52        caps.sort();
53        caps
54    }
55
56    /// Dispatch task to the appropriate runner based on `lease.task.capability`.
57    pub async fn run_for_lease(
58        &self,
59        lease: &LeaseEnvelope,
60        input: &dyn InputSource,
61        output: &dyn ArtifactSink,
62        ctrl: &dyn ControlPlane,
63        access_token: &dyn compute_runner_api::runner::AccessTokenProvider,
64    ) -> std::result::Result<(), crate::errors::ExecutorError> {
65        let cap = lease.task.capability.as_str();
66        let runner = self
67            .get(cap)
68            .ok_or_else(|| crate::errors::ExecutorError::NoRunner(cap.to_string()))?;
69        let ctx = TaskCtx {
70            lease,
71            input,
72            output,
73            ctrl,
74            access_token,
75        };
76        runner
77            .run(ctx)
78            .await
79            .map_err(|e| crate::errors::ExecutorError::Runner(e.to_string()))
80    }
81}
82
83/// Run the node main loop. Networking and storage are wired in later prompts.
84pub async fn run_node(cfg: crate::config::NodeConfig, runners: RunnerRegistry) -> Result<()> {
85    let shutdown = CancellationToken::new();
86    let signal_token = shutdown.clone();
87    let signal_task = tokio::spawn(async move {
88        if tokio::signal::ctrl_c().await.is_ok() {
89            signal_token.cancel();
90        }
91    });
92
93    let result = run_node_with_shutdown(cfg, runners, shutdown.clone()).await;
94
95    shutdown.cancel();
96    let _ = signal_task.await;
97
98    result
99}
100
101pub async fn run_node_with_shutdown(
102    cfg: crate::config::NodeConfig,
103    runners: RunnerRegistry,
104    shutdown: CancellationToken,
105) -> Result<()> {
106    let siwe = crate::auth::SiweAfterRegistration::from_config(&cfg)?;
107    info!("DDS SIWE authentication configured; waiting for DDS registration");
108    let siwe_handle = siwe.start().await?;
109    info!("DDS SIWE token manager started");
110
111    let poll_cfg = PollerConfig {
112        backoff_ms_min: cfg.poll_backoff_ms_min,
113        backoff_ms_max: cfg.poll_backoff_ms_max,
114    };
115
116    loop {
117        if shutdown.is_cancelled() {
118            break;
119        }
120
121        // Ensure SIWE token is available before attempting DMS operations
122        if let Err(err) = siwe_handle.bearer().await {
123            warn!(error = %err, "Failed to obtain SIWE bearer token; backing off");
124            let delay_ms = jittered_delay_ms(poll_cfg);
125            tokio::select! {
126                _ = shutdown.cancelled() => break,
127                _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
128            }
129        }
130
131        let timeout = StdDuration::from_secs(cfg.request_timeout_secs);
132        let dms_client = match crate::dms::client::DmsClient::new(
133            cfg.dms_base_url.clone(),
134            timeout,
135            std::sync::Arc::new(siwe_handle.clone()),
136        ) {
137            Ok(client) => client,
138            Err(err) => {
139                warn!(error = %err, "Failed to create DMS client; backing off");
140                let delay_ms = jittered_delay_ms(poll_cfg);
141                tokio::select! {
142                    _ = shutdown.cancelled() => break,
143                    _ = sleep(StdDuration::from_millis(delay_ms)) => continue,
144                }
145            }
146        };
147
148        match run_cycle_with_dms(&cfg, &dms_client, &runners).await {
149            Ok(true) => {
150                // Successful task execution; immediately attempt next poll.
151                continue;
152            }
153            Ok(false) => {
154                let delay_ms = jittered_delay_ms(poll_cfg);
155                debug!(delay_ms, "No lease available; backing off before next poll");
156                tokio::select! {
157                    _ = shutdown.cancelled() => break,
158                    _ = sleep(StdDuration::from_millis(delay_ms)) => {}
159                }
160            }
161            Err(err) => {
162                warn!(error = %err, "DMS cycle failed; backing off");
163                let delay_ms = jittered_delay_ms(poll_cfg);
164                tokio::select! {
165                    _ = shutdown.cancelled() => break,
166                    _ = sleep(StdDuration::from_millis(delay_ms)) => {}
167                }
168            }
169        }
170    }
171
172    siwe_handle.shutdown().await;
173    info!("Shutdown signal received; exiting run_node loop");
174
175    Ok(())
176}
177
178/// Build storage ports (input/output) for a given lease by constructing a TokenRef
179/// from the lease's access token and delegating to storage::build_ports.
180pub fn build_storage_for_lease(lease: &LeaseEnvelope) -> Result<crate::storage::Ports> {
181    let token = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
182    crate::storage::build_ports(lease, token)
183}
184
185/// Apply heartbeat token refresh: if HeartbeatResponse carries a new access token,
186/// swap it into the provided TokenRef so subsequent storage requests use it.
187pub fn apply_heartbeat_token_update(
188    token: &crate::storage::TokenRef,
189    hb: &crate::dms::types::HeartbeatResponse,
190) {
191    if let Some(new) = hb.access_token.clone() {
192        token.swap(new);
193    }
194}
195
196/// Merge fields from a heartbeat response into the cached lease.
197pub fn merge_heartbeat_into_lease(
198    lease: &mut LeaseEnvelope,
199    hb: &crate::dms::types::HeartbeatResponse,
200) {
201    if let Some(token) = hb.access_token.clone() {
202        lease.access_token = Some(token);
203    }
204    if let Some(expiry) = hb.access_token_expires_at {
205        lease.access_token_expires_at = Some(expiry);
206    }
207    if let Some(expiry) = hb.lease_expires_at {
208        lease.lease_expires_at = Some(expiry);
209    }
210    if let Some(cancel) = hb.cancel {
211        lease.cancel = cancel;
212    }
213    if let Some(status) = hb.status.clone() {
214        lease.status = Some(status);
215    }
216    if let Some(domain_id) = hb.domain_id {
217        lease.domain_id = Some(domain_id);
218    }
219    if let Some(url) = hb.domain_server_url.clone() {
220        lease.domain_server_url = Some(url);
221    }
222    if let Some(task) = hb.task.clone() {
223        lease.task = task;
224    } else {
225        if let Some(task_id) = hb.task_id {
226            lease.task.id = task_id;
227        }
228        if let Some(job_id) = hb.job_id {
229            lease.task.job_id = Some(job_id);
230        }
231        if let Some(attempts) = hb.attempts {
232            lease.task.attempts = Some(attempts);
233        }
234        if let Some(max_attempts) = hb.max_attempts {
235            lease.task.max_attempts = Some(max_attempts);
236        }
237        if let Some(deps_remaining) = hb.deps_remaining {
238            lease.task.deps_remaining = Some(deps_remaining);
239        }
240    }
241}
242
243/// Run a single poll→run→complete/fail cycle using DMS client and the runner registry.
244/// This is a minimal integration used by tests; `run_node` wiring remains separate.
245pub async fn run_cycle_with_dms(
246    cfg: &crate::config::NodeConfig,
247    dms: &DmsClient,
248    reg: &RunnerRegistry,
249) -> Result<bool> {
250    use crate::dms::types::{CompleteTaskRequest, FailTaskRequest, HeartbeatRequest};
251    use serde_json::json;
252
253    let capabilities = reg.capabilities();
254    let capability = capabilities
255        .first()
256        .cloned()
257        .ok_or_else(|| anyhow!("no runners registered"))?;
258
259    // Lease a task from DMS
260    let mut lease = match dms.lease_by_capability(&capability).await? {
261        Some(lease) => lease,
262        None => {
263            return Ok(false);
264        }
265    };
266    if lease.access_token.is_none() {
267        tracing::warn!(
268            "Lease missing access token; storage client will fall back to legacy token flow"
269        );
270    }
271
272    // Initialise session state for heartbeats and token rotation.
273    let selector = CapabilitySelector::new(capabilities.clone());
274    let session = SessionManager::new(selector);
275    let policy = HeartbeatPolicy::new(cfg.heartbeat_min_ratio, cfg.heartbeat_max_ratio);
276    let mut rng = StdRng::from_entropy();
277    let task_id = lease.task.id;
278    let report_setup_failure = |stage: &'static str, err: &anyhow::Error| {
279        let details = json!({
280            "stage": stage,
281            "error": err.to_string(),
282        });
283        async move {
284            let body = FailTaskRequest {
285                reason: "node_setup_failed".into(),
286                details,
287            };
288            dms.fail(task_id, &body).await
289        }
290    };
291
292    let snapshot = match session
293        .start_session(&lease, Instant::now(), &policy, &mut rng)
294        .await
295    {
296        Ok(snapshot) => snapshot,
297        Err(err) => {
298            let original = anyhow!("failed to initialise session: {err}");
299            if let Err(fail_err) = report_setup_failure("start_session", &original).await {
300                warn!(
301                    error = %fail_err,
302                    task_id = %task_id,
303                    "failed to report setup failure"
304                );
305                return Err(original);
306            }
307            return Ok(true);
308        }
309    };
310    if snapshot.cancel() {
311        warn!(
312            task_id = %snapshot.task_id(),
313            "Lease already marked as cancelled; skipping execution"
314        );
315        return Ok(true);
316    }
317
318    let token_ref = crate::storage::TokenRef::new(lease.access_token.clone().unwrap_or_default());
319
320    let heartbeat_initial = match dms
321        .heartbeat(
322            lease.task.id,
323            &HeartbeatRequest {
324                progress: json!({}),
325                events: Vec::new(),
326            },
327        )
328        .await
329    {
330        Ok(response) => response,
331        Err(err) => {
332            if let Err(fail_err) = report_setup_failure("initial_heartbeat", &err).await {
333                warn!(
334                    error = %fail_err,
335                    task_id = %task_id,
336                    "failed to report setup failure"
337                );
338                return Err(err);
339            }
340            return Ok(true);
341        }
342    };
343    apply_heartbeat_token_update(&token_ref, &heartbeat_initial);
344    merge_heartbeat_into_lease(&mut lease, &heartbeat_initial);
345    session
346        .apply_heartbeat(
347            &heartbeat_initial,
348            Some(json!({})),
349            Instant::now(),
350            &policy,
351            &mut rng,
352        )
353        .await
354        .map_err(|err| anyhow!("failed to refresh session after heartbeat: {err}"))?;
355
356    let ports = match crate::storage::build_ports(&lease, token_ref.clone()) {
357        Ok(ports) => ports,
358        Err(err) => {
359            if let Err(fail_err) = report_setup_failure("build_ports", &err).await {
360                warn!(
361                    error = %fail_err,
362                    task_id = %task_id,
363                    "failed to report setup failure"
364                );
365                return Err(err);
366            }
367            return Ok(true);
368        }
369    };
370
371    let (progress_tx, progress_rx) = progress_channel();
372    let control_state = Arc::new(Mutex::new(ControlState::default()));
373    {
374        let mut guard = control_state.lock().await;
375        guard.progress = json!({});
376        guard.events = Vec::new();
377    }
378
379    let runner_cancel = CancellationToken::new();
380    let heartbeat_shutdown = CancellationToken::new();
381
382    let ctrl = EngineControlPlane::new(
383        runner_cancel.clone(),
384        progress_tx.clone(),
385        control_state.clone(),
386    );
387
388    // Trigger an immediate heartbeat once the loop starts to refresh tokens.
389    progress_tx.update(json!({}), Vec::new());
390
391    let heartbeat_driver = HeartbeatDriver::new(
392        dms.clone(),
393        HeartbeatDriverArgs {
394            session: session.clone(),
395            policy,
396            rng,
397            progress_rx,
398            state: control_state.clone(),
399            token_ref: token_ref.clone(),
400            runner_cancel: runner_cancel.clone(),
401            shutdown: heartbeat_shutdown.clone(),
402            task_id: lease.task.id,
403        },
404    );
405    let heartbeat_handle = tokio::spawn(async move { heartbeat_driver.run().await });
406
407    let run_res = reg
408        .run_for_lease(&lease, &*ports.input, &*ports.output, &ctrl, &token_ref)
409        .await;
410
411    // Re-broadcast the latest progress/events so the heartbeat loop can flush
412    // them before shutdown. Without this, very short tasks may complete before
413    // the final heartbeat is delivered, leaving stale progress in DMS.
414    {
415        let state = control_state.lock().await;
416        progress_tx.update(state.progress.clone(), state.events.clone());
417    }
418    sleep(StdDuration::from_millis(200)).await;
419
420    heartbeat_shutdown.cancel();
421    let heartbeat_result = match heartbeat_handle.await {
422        Ok(result) => result,
423        Err(err) => {
424            warn!(error = %err, "heartbeat loop task failed");
425            HeartbeatLoopResult::Completed
426        }
427    };
428
429    match heartbeat_result {
430        HeartbeatLoopResult::Completed => {}
431        HeartbeatLoopResult::Cancelled => {
432            info!(
433                task_id = %lease.task.id,
434                "Lease cancelled during execution; skipping completion"
435            );
436            runner_cancel.cancel();
437            return Ok(true);
438        }
439        HeartbeatLoopResult::LostLease(err) => {
440            warn!(
441                task_id = %lease.task.id,
442                error = %err,
443                "Lease lost during heartbeat; abandoning task"
444            );
445            runner_cancel.cancel();
446            return Ok(true);
447        }
448    }
449
450    let uploaded_artifacts = ports.uploaded_artifacts();
451    let artifacts_json: Vec<Value> = uploaded_artifacts
452        .iter()
453        .map(|artifact| {
454            json!({
455                "logical_path": artifact.logical_path,
456                "name": artifact.name,
457                "data_type": artifact.data_type,
458                "id": artifact.id,
459            })
460        })
461        .collect();
462    let output_cids: Vec<String> = uploaded_artifacts
463        .iter()
464        .filter_map(|artifact| artifact.id.clone())
465        .collect();
466    let job_info = json!({
467        "task_id": lease.task.id,
468        "job_id": lease.task.job_id,
469        "domain_id": lease.domain_id,
470        "capability": lease.task.capability,
471    });
472
473    // Complete or fail the task depending on runner outcome.
474    match run_res {
475        Ok(()) => {
476            let body = CompleteTaskRequest {
477                output_cids,
478                meta: json!({
479                    "job": job_info,
480                    "artifacts": artifacts_json,
481                }),
482            };
483            dms.complete(lease.task.id, &body).await?;
484        }
485        Err(err) => {
486            error!(
487                task_id = %lease.task.id,
488                job_id = ?lease.task.job_id,
489                capability = %lease.task.capability,
490                error = %err,
491                debug = ?err,
492                "Runner execution failed; reporting failure to DMS"
493            );
494            let body = FailTaskRequest {
495                reason: err.to_string(),
496                details: json!({
497                    "job": job_info,
498                    "artifacts": artifacts_json,
499                }),
500            };
501            dms.fail(lease.task.id, &body)
502                .await
503                .with_context(|| format!("report fail for task {} to DMS", lease.task.id))?;
504        }
505    }
506
507    Ok(true)
508}
509
510#[derive(Default)]
511pub struct ControlState {
512    progress: Value,
513    events: Vec<Value>,
514}
515
516struct EngineControlPlane {
517    cancel: CancellationToken,
518    progress_tx: ProgressSender,
519    state: Arc<Mutex<ControlState>>,
520}
521
522impl EngineControlPlane {
523    pub fn new(
524        cancel: CancellationToken,
525        progress_tx: ProgressSender,
526        state: Arc<Mutex<ControlState>>,
527    ) -> Self {
528        Self {
529            cancel,
530            progress_tx,
531            state,
532        }
533    }
534}
535
536#[async_trait]
537impl ControlPlane for EngineControlPlane {
538    async fn is_cancelled(&self) -> bool {
539        self.cancel.is_cancelled()
540    }
541
542    async fn progress(&self, value: Value) -> Result<()> {
543        let events = {
544            let mut state = self.state.lock().await;
545            state.progress = value.clone();
546            state.events.clone()
547        };
548        self.progress_tx.update(value, events);
549        Ok(())
550    }
551
552    async fn log_event(&self, fields: Value) -> Result<()> {
553        let (progress, events) = {
554            let mut state = self.state.lock().await;
555            state.events.push(fields.clone());
556            (state.progress.clone(), state.events.clone())
557        };
558        self.progress_tx.update(progress, events);
559        Ok(())
560    }
561}
562
563pub enum HeartbeatLoopResult {
564    Completed,
565    Cancelled,
566    LostLease(anyhow::Error),
567}
568
569#[async_trait]
570pub trait HeartbeatTransport: Send + Sync + Clone + 'static {
571    async fn post_heartbeat(
572        &self,
573        task_id: Uuid,
574        body: &crate::dms::types::HeartbeatRequest,
575    ) -> Result<crate::dms::types::HeartbeatResponse>;
576}
577
578#[async_trait]
579impl HeartbeatTransport for DmsClient {
580    async fn post_heartbeat(
581        &self,
582        task_id: Uuid,
583        body: &crate::dms::types::HeartbeatRequest,
584    ) -> Result<crate::dms::types::HeartbeatResponse> {
585        self.heartbeat(task_id, body).await
586    }
587}
588
589pub struct HeartbeatDriverArgs {
590    pub session: SessionManager,
591    pub policy: HeartbeatPolicy,
592    pub rng: StdRng,
593    pub progress_rx: ProgressReceiver,
594    pub state: Arc<Mutex<ControlState>>,
595    pub token_ref: crate::storage::TokenRef,
596    pub runner_cancel: CancellationToken,
597    pub shutdown: CancellationToken,
598    pub task_id: Uuid,
599}
600
601pub struct HeartbeatDriver<T>
602where
603    T: HeartbeatTransport,
604{
605    transport: T,
606    session: SessionManager,
607    policy: HeartbeatPolicy,
608    rng: StdRng,
609    progress_rx: ProgressReceiver,
610    state: Arc<Mutex<ControlState>>,
611    token_ref: crate::storage::TokenRef,
612    runner_cancel: CancellationToken,
613    shutdown: CancellationToken,
614    task_id: Uuid,
615    last_progress: Value,
616}
617
618impl<T> HeartbeatDriver<T>
619where
620    T: HeartbeatTransport,
621{
622    pub fn new(transport: T, args: HeartbeatDriverArgs) -> Self {
623        Self {
624            transport,
625            session: args.session,
626            policy: args.policy,
627            rng: args.rng,
628            progress_rx: args.progress_rx,
629            state: args.state,
630            token_ref: args.token_ref,
631            runner_cancel: args.runner_cancel,
632            shutdown: args.shutdown,
633            task_id: args.task_id,
634            last_progress: Value::default(),
635        }
636    }
637
638    pub async fn run(mut self) -> HeartbeatLoopResult {
639        loop {
640            if self.shutdown.is_cancelled() || self.runner_cancel.is_cancelled() {
641                return HeartbeatLoopResult::Completed;
642            }
643
644            let snapshot = match self.session.snapshot().await {
645                Some(s) => s,
646                None => return HeartbeatLoopResult::Completed,
647            };
648
649            let ttl_delay = snapshot
650                .next_heartbeat_due()
651                .map(|due| due.saturating_duration_since(Instant::now()));
652
653            if let Some(delay) = ttl_delay {
654                tokio::select! {
655                    _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
656                    progress = self.progress_rx.recv() => {
657                        if let Some(data) = progress {
658                            if let Some(outcome) = self.handle_progress(data).await {
659                                return outcome;
660                            }
661                        } else {
662                            return HeartbeatLoopResult::Completed;
663                        }
664                    }
665                    _ = tokio::time::sleep(delay) => {
666                        if let Some(outcome) = self.handle_ttl().await {
667                            return outcome;
668                        }
669                    }
670                }
671            } else {
672                tokio::select! {
673                    _ = self.shutdown.cancelled() => return HeartbeatLoopResult::Completed,
674                    progress = self.progress_rx.recv() => {
675                        if let Some(data) = progress {
676                            if let Some(outcome) = self.handle_progress(data).await {
677                                return outcome;
678                            }
679                        } else {
680                            return HeartbeatLoopResult::Completed;
681                        }
682                    }
683                }
684            }
685        }
686    }
687
688    async fn handle_progress(
689        &mut self,
690        data: crate::heartbeat::HeartbeatData,
691    ) -> Option<HeartbeatLoopResult> {
692        self.last_progress = data.progress.clone();
693        let (progress, events) = self.snapshot_state().await;
694        self.send_and_update(progress, events).await
695    }
696
697    async fn handle_ttl(&mut self) -> Option<HeartbeatLoopResult> {
698        let (progress, events) = self.snapshot_state().await;
699        self.send_and_update(progress, events).await
700    }
701
702    async fn snapshot_state(&self) -> (Value, Vec<Value>) {
703        let state = self.state.lock().await;
704        (state.progress.clone(), state.events.clone())
705    }
706
707    async fn send_and_update(
708        &mut self,
709        progress: Value,
710        events: Vec<Value>,
711    ) -> Option<HeartbeatLoopResult> {
712        let request = crate::dms::types::HeartbeatRequest {
713            progress: progress.clone(),
714            events: events.clone(),
715        };
716
717        match self.transport.post_heartbeat(self.task_id, &request).await {
718            Ok(update) => {
719                if !events.is_empty() {
720                    let mut state = self.state.lock().await;
721                    if state.events.len() >= events.len()
722                        && state.events[..events.len()] == events[..]
723                    {
724                        state.events.drain(0..events.len());
725                    }
726                }
727                apply_heartbeat_token_update(&self.token_ref, &update);
728                if let Some(task) = &update.task {
729                    self.task_id = task.id;
730                } else if let Some(task_id) = update.task_id {
731                    self.task_id = task_id;
732                }
733                if let Err(err) = self
734                    .session
735                    .apply_heartbeat(
736                        &update,
737                        Some(progress.clone()),
738                        Instant::now(),
739                        &self.policy,
740                        &mut self.rng,
741                    )
742                    .await
743                {
744                    return Some(HeartbeatLoopResult::LostLease(anyhow::Error::new(err)));
745                }
746                if update.cancel.unwrap_or(false) {
747                    self.runner_cancel.cancel();
748                    return Some(HeartbeatLoopResult::Cancelled);
749                }
750                None
751            }
752            Err(err) => {
753                self.runner_cancel.cancel();
754                Some(HeartbeatLoopResult::LostLease(err))
755            }
756        }
757    }
758}