Skip to main content

atomr_infer_runtime/
worker.rs

1//! Local-GPU worker — two-tier supervision adapter (doc §4, §5.3).
2//!
3//! `WorkerActor` is the **stable** parent: addressable, supervised by
4//! the engine-core, never restarts. Its child `ContextActor` is
5//! **restartable** and owns the runtime-specific resources (CUDA
6//! context, weights, etc). When the runner reports
7//! `CudaContextPoisoned` the parent panics with the
8//! `atomr_accel_cuda::error::CONTEXT_POISONED_TAG` marker so that
9//! `atomr_accel_cuda::error::device_supervisor_strategy` routes the
10//! failure to `Directive::Restart`.
11//!
12//! The supervision *policy* (3 retries / 60s, decider, marker tags) is
13//! re-used verbatim from atomr-accel's `error` module — that's the
14//! upstream substrate for the doc's §5.11 two-tier pattern. The
15//! *body* this crate adds is the runtime-polymorphic
16//! `Box<dyn ModelRunner>` slot, which is inference-specific.
17//!
18//! Per-runtime crates supply the runner via the `WorkerSlot` factory.
19//! Remote runtimes go through `inference-remote-core::RemoteWorkerActor`
20//! instead.
21
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use atomr_core::actor::{Actor, ActorRef, Context, Props};
26use atomr_core::supervision::SupervisorStrategy;
27use futures::StreamExt;
28use parking_lot::Mutex;
29use tokio::sync::Mutex as AsyncMutex;
30use tokio::sync::{mpsc, oneshot};
31
32use atomr_infer_core::batch::ExecuteBatch;
33use atomr_infer_core::error::InferenceError;
34use atomr_infer_core::runner::{ModelRunner, SessionRebuildCause};
35use atomr_infer_core::tokens::TokenChunk;
36
37/// What the parent hands to its child on construction. The runner
38/// owns the GPU context indirectly (via `cudarc::driver::CudaContext`,
39/// `atomr_accel_cuda::device::DeviceState`, or whatever the backend uses);
40/// when the parent decides to rebuild, it constructs a fresh
41/// `WorkerSlot` and the child cell starts anew.
42pub struct WorkerSlot {
43    pub runner: Box<dyn ModelRunner>,
44}
45
46pub enum WorkerMsg {
47    Execute(ExecuteBatch, mpsc::Sender<Result<TokenChunk, InferenceError>>),
48    /// Forwarded from the runner when a sticky CUDA error is detected.
49    /// Triggers a child restart.
50    ContextPoisoned(String),
51    /// Operator-triggered rebuild.
52    RebuildSession {
53        cause: SessionRebuildCause,
54        reply: oneshot::Sender<Result<(), InferenceError>>,
55    },
56}
57
58pub enum ContextMsg {
59    Execute(ExecuteBatch, mpsc::Sender<Result<TokenChunk, InferenceError>>),
60    Rebuild {
61        cause: SessionRebuildCause,
62        reply: oneshot::Sender<Result<(), InferenceError>>,
63    },
64}
65
66pub struct WorkerActor {
67    /// Slot factory — invoked once on initial child spawn and once per
68    /// rebuild. Per-runtime crates supply this.
69    slot_factory: Box<dyn Fn() -> WorkerSlot + Send + Sync>,
70    child: Option<ActorRef<ContextMsg>>,
71    parent_to_child_seq: u64,
72}
73
74impl WorkerActor {
75    pub fn new<F>(slot_factory: F) -> Self
76    where
77        F: Fn() -> WorkerSlot + Send + Sync + 'static,
78    {
79        Self {
80            slot_factory: Box::new(slot_factory),
81            child: None,
82            parent_to_child_seq: 0,
83        }
84    }
85
86    fn spawn_child(&mut self, ctx: &mut Context<Self>) {
87        // Factory is called once per spawn. ContextActor itself isn't
88        // restarted by atomr's supervisor — we tear it down and spawn
89        // a fresh one with a new slot when context poisoning happens.
90        self.parent_to_child_seq += 1;
91        let name = format!("ctx-{}", self.parent_to_child_seq);
92        let cell = Mutex::new(Some((self.slot_factory)()));
93        let props = Props::create(move || {
94            let s = cell.lock().take().expect("worker context factory invoked twice");
95            ContextActor::new(s)
96        });
97        match ctx.spawn(props, &name) {
98            Ok(addr) => self.child = Some(addr),
99            Err(e) => tracing::error!(?e, "spawn ContextActor failed"),
100        }
101    }
102}
103
104#[async_trait]
105impl Actor for WorkerActor {
106    type Msg = WorkerMsg;
107
108    async fn pre_start(&mut self, ctx: &mut Context<Self>) {
109        self.spawn_child(ctx);
110    }
111
112    fn supervisor_strategy(&self) -> SupervisorStrategy {
113        // With the `local-gpu` feature on, defer to the upstream
114        // supervisor strategy (3 retries / 60s window, decider over
115        // `ContextPoisoned` / `OutOfMemory` / `Unrecoverable` markers).
116        // Without the feature, fall back to a hand-rolled policy that
117        // restarts on the same string-tag panic-message — this keeps
118        // the workspace buildable for `remote-only` consumers that
119        // don't pull atomr-accel but still happen to mount a local
120        // ModelRunner (e.g. inference-testkit's MockRunner in tests).
121        #[cfg(feature = "local-gpu")]
122        {
123            // The NVIDIA CUDA backend lives in its own crate
124            // (atomr-accel-cuda) since atomr-accel 0.3; the umbrella
125            // crate only ships the trait surface.
126            atomr_accel_cuda::error::device_supervisor_strategy()
127        }
128        #[cfg(not(feature = "local-gpu"))]
129        {
130            use atomr_core::supervision::{Directive, OneForOneStrategy};
131            OneForOneStrategy::new()
132                .with_max_retries(3)
133                .with_within(std::time::Duration::from_secs(60))
134                .with_decider(|err| {
135                    // Mirror atomr_accel_cuda::error::decider's tag set.
136                    if err.contains("ContextPoisoned") {
137                        Directive::Restart
138                    } else if err.contains("OutOfMemory") {
139                        Directive::Resume
140                    } else if err.contains("Unrecoverable") {
141                        Directive::Stop
142                    } else {
143                        Directive::Escalate
144                    }
145                })
146                .into()
147        }
148    }
149
150    async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
151        match msg {
152            WorkerMsg::Execute(batch, output) => {
153                let Some(child) = self.child.as_ref() else { return };
154                child.tell(ContextMsg::Execute(batch, output));
155            }
156            WorkerMsg::ContextPoisoned(reason) => {
157                tracing::warn!(reason, "context poisoned — rebuilding child");
158                if let Some(child) = self.child.take() {
159                    child.stop();
160                }
161                self.spawn_child(ctx);
162            }
163            WorkerMsg::RebuildSession { cause, reply } => {
164                let Some(child) = self.child.as_ref() else {
165                    let _ = reply.send(Err(InferenceError::Internal("no child".into())));
166                    return;
167                };
168                child.tell(ContextMsg::Rebuild { cause, reply });
169            }
170        }
171    }
172}
173
174// ---------------------------------------------------------------------------
175
176/// `ContextActor` — restartable child holding the CUDA context (or the
177/// remote-network analogue). Distinct from
178/// `atomr_accel_cuda::device::ContextActor`: that one specialises to CUDA
179/// memory / streams; this one holds the polymorphic
180/// `Box<dyn ModelRunner>` so the same supervision shape covers
181/// remote-network runners too.
182pub struct ContextActor {
183    runner: Arc<AsyncMutex<Box<dyn ModelRunner>>>,
184}
185
186impl ContextActor {
187    pub fn new(slot: WorkerSlot) -> Self {
188        Self {
189            runner: Arc::new(AsyncMutex::new(slot.runner)),
190        }
191    }
192}
193
194#[async_trait]
195impl Actor for ContextActor {
196    type Msg = ContextMsg;
197
198    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: Self::Msg) {
199        match msg {
200            ContextMsg::Execute(batch, output) => {
201                let runner = self.runner.clone();
202                tokio::spawn(async move {
203                    let mut g = runner.lock().await;
204                    match g.execute(batch).await {
205                        Ok(handle) => {
206                            drop(g); // release runner mutex while we drain
207                            let mut s = handle.into_stream();
208                            while let Some(chunk) = s.next().await {
209                                if output.send(chunk).await.is_err() {
210                                    break;
211                                }
212                            }
213                        }
214                        Err(e) => {
215                            // Sticky CUDA errors propagate as panics
216                            // tagged with atomr_accel's CONTEXT_POISONED_TAG
217                            // so the parent's supervisor strategy can
218                            // route them to Restart.
219                            if matches!(e, InferenceError::CudaContextPoisoned(_)) {
220                                let _ = output.send(Err(e.clone())).await;
221                                #[cfg(feature = "local-gpu")]
222                                panic!("{}: {e}", atomr_accel_cuda::error::CONTEXT_POISONED_TAG);
223                                #[cfg(not(feature = "local-gpu"))]
224                                panic!("ContextPoisoned: {e}");
225                            }
226                            let _ = output.send(Err(e)).await;
227                        }
228                    }
229                });
230            }
231            ContextMsg::Rebuild { cause, reply } => {
232                let runner = self.runner.clone();
233                tokio::spawn(async move {
234                    let mut g = runner.lock().await;
235                    let r = g.rebuild_session(cause).await;
236                    let _ = reply.send(r);
237                });
238            }
239        }
240    }
241}