inference_runtime/worker.rs
1//! Local-GPU worker — two-tier supervision adapter (doc §4, §5.3).
2//!
3//! `WorkerActor` is the **stable** parent: addressable, supervised by
4//! the engine-core, never restarts. Its child `ContextActor` is
5//! **restartable** and owns the runtime-specific resources (CUDA
6//! context, weights, etc). When the runner reports
7//! `CudaContextPoisoned` the parent panics with the
8//! `rakka_accel::cuda::error::CONTEXT_POISONED_TAG` marker so that
9//! `rakka_accel::cuda::error::device_supervisor_strategy` routes the
10//! failure to `Directive::Restart`.
11//!
12//! The supervision *policy* (3 retries / 60s, decider, marker tags) is
13//! re-used verbatim from rakka-accel's `error` module — that's the
14//! upstream substrate for the doc's §5.11 two-tier pattern. The
15//! *body* this crate adds is the runtime-polymorphic
16//! `Box<dyn ModelRunner>` slot, which is inference-specific.
17//!
18//! Per-runtime crates supply the runner via the `WorkerSlot` factory.
19//! Remote runtimes go through `inference-remote-core::RemoteWorkerActor`
20//! instead.
21
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use futures::StreamExt;
26use parking_lot::Mutex;
27use rakka_core::actor::{Actor, ActorRef, Context, Props};
28use rakka_core::supervision::SupervisorStrategy;
29use tokio::sync::Mutex as AsyncMutex;
30use tokio::sync::{mpsc, oneshot};
31
32use inference_core::batch::ExecuteBatch;
33use inference_core::error::InferenceError;
34use inference_core::runner::{ModelRunner, SessionRebuildCause};
35use inference_core::tokens::TokenChunk;
36
37/// What the parent hands to its child on construction. The runner
38/// owns the GPU context indirectly (via `cudarc::driver::CudaContext`,
39/// `rakka_accel::cuda::device::DeviceState`, or whatever the backend uses);
40/// when the parent decides to rebuild, it constructs a fresh
41/// `WorkerSlot` and the child cell starts anew.
42pub struct WorkerSlot {
43 pub runner: Box<dyn ModelRunner>,
44}
45
46pub enum WorkerMsg {
47 Execute(ExecuteBatch, mpsc::Sender<Result<TokenChunk, InferenceError>>),
48 /// Forwarded from the runner when a sticky CUDA error is detected.
49 /// Triggers a child restart.
50 ContextPoisoned(String),
51 /// Operator-triggered rebuild.
52 RebuildSession {
53 cause: SessionRebuildCause,
54 reply: oneshot::Sender<Result<(), InferenceError>>,
55 },
56}
57
58pub enum ContextMsg {
59 Execute(ExecuteBatch, mpsc::Sender<Result<TokenChunk, InferenceError>>),
60 Rebuild {
61 cause: SessionRebuildCause,
62 reply: oneshot::Sender<Result<(), InferenceError>>,
63 },
64}
65
66pub struct WorkerActor {
67 /// Slot factory — invoked once on initial child spawn and once per
68 /// rebuild. Per-runtime crates supply this.
69 slot_factory: Box<dyn Fn() -> WorkerSlot + Send + Sync>,
70 child: Option<ActorRef<ContextMsg>>,
71 parent_to_child_seq: u64,
72}
73
74impl WorkerActor {
75 pub fn new<F>(slot_factory: F) -> Self
76 where
77 F: Fn() -> WorkerSlot + Send + Sync + 'static,
78 {
79 Self {
80 slot_factory: Box::new(slot_factory),
81 child: None,
82 parent_to_child_seq: 0,
83 }
84 }
85
86 fn spawn_child(&mut self, ctx: &mut Context<Self>) {
87 // Factory is called once per spawn. ContextActor itself isn't
88 // restarted by rakka's supervisor — we tear it down and spawn
89 // a fresh one with a new slot when context poisoning happens.
90 self.parent_to_child_seq += 1;
91 let name = format!("ctx-{}", self.parent_to_child_seq);
92 let cell = Mutex::new(Some((self.slot_factory)()));
93 let props = Props::create(move || {
94 let s = cell.lock().take().expect("worker context factory invoked twice");
95 ContextActor::new(s)
96 });
97 match ctx.spawn(props, &name) {
98 Ok(addr) => self.child = Some(addr),
99 Err(e) => tracing::error!(?e, "spawn ContextActor failed"),
100 }
101 }
102}
103
104#[async_trait]
105impl Actor for WorkerActor {
106 type Msg = WorkerMsg;
107
108 async fn pre_start(&mut self, ctx: &mut Context<Self>) {
109 self.spawn_child(ctx);
110 }
111
112 fn supervisor_strategy(&self) -> SupervisorStrategy {
113 // With the `local-gpu` feature on, defer to the upstream
114 // supervisor strategy (3 retries / 60s window, decider over
115 // `ContextPoisoned` / `OutOfMemory` / `Unrecoverable` markers).
116 // Without the feature, fall back to a hand-rolled policy that
117 // restarts on the same string-tag panic-message — this keeps
118 // the workspace buildable for `remote-only` consumers that
119 // don't pull rakka-accel but still happen to mount a local
120 // ModelRunner (e.g. inference-testkit's MockRunner in tests).
121 #[cfg(feature = "local-gpu")]
122 {
123 // The CUDA backend is re-exported at `rakka_accel::cuda`
124 // when the `cuda` feature is on. We carry that feature
125 // forward via our own `local-gpu` feature.
126 rakka_accel::cuda::error::device_supervisor_strategy()
127 }
128 #[cfg(not(feature = "local-gpu"))]
129 {
130 use rakka_core::supervision::{Directive, OneForOneStrategy};
131 OneForOneStrategy::new()
132 .with_max_retries(3)
133 .with_within(std::time::Duration::from_secs(60))
134 .with_decider(|err| {
135 // Mirror rakka_accel::cuda::error::decider's tag set.
136 if err.contains("ContextPoisoned") {
137 Directive::Restart
138 } else if err.contains("OutOfMemory") {
139 Directive::Resume
140 } else if err.contains("Unrecoverable") {
141 Directive::Stop
142 } else {
143 Directive::Escalate
144 }
145 })
146 .into()
147 }
148 }
149
150 async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
151 match msg {
152 WorkerMsg::Execute(batch, output) => {
153 let Some(child) = self.child.as_ref() else { return };
154 child.tell(ContextMsg::Execute(batch, output));
155 }
156 WorkerMsg::ContextPoisoned(reason) => {
157 tracing::warn!(reason, "context poisoned — rebuilding child");
158 if let Some(child) = self.child.take() {
159 child.stop();
160 }
161 self.spawn_child(ctx);
162 }
163 WorkerMsg::RebuildSession { cause, reply } => {
164 let Some(child) = self.child.as_ref() else {
165 let _ = reply.send(Err(InferenceError::Internal("no child".into())));
166 return;
167 };
168 child.tell(ContextMsg::Rebuild { cause, reply });
169 }
170 }
171 }
172}
173
174// ---------------------------------------------------------------------------
175
176/// `ContextActor` — restartable child holding the CUDA context (or the
177/// remote-network analogue). Distinct from
178/// `rakka_accel::cuda::device::ContextActor`: that one specialises to CUDA
179/// memory / streams; this one holds the polymorphic
180/// `Box<dyn ModelRunner>` so the same supervision shape covers
181/// remote-network runners too.
182pub struct ContextActor {
183 runner: Arc<AsyncMutex<Box<dyn ModelRunner>>>,
184}
185
186impl ContextActor {
187 pub fn new(slot: WorkerSlot) -> Self {
188 Self {
189 runner: Arc::new(AsyncMutex::new(slot.runner)),
190 }
191 }
192}
193
194#[async_trait]
195impl Actor for ContextActor {
196 type Msg = ContextMsg;
197
198 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: Self::Msg) {
199 match msg {
200 ContextMsg::Execute(batch, output) => {
201 let runner = self.runner.clone();
202 tokio::spawn(async move {
203 let mut g = runner.lock().await;
204 match g.execute(batch).await {
205 Ok(handle) => {
206 drop(g); // release runner mutex while we drain
207 let mut s = handle.into_stream();
208 while let Some(chunk) = s.next().await {
209 if output.send(chunk).await.is_err() {
210 break;
211 }
212 }
213 }
214 Err(e) => {
215 // Sticky CUDA errors propagate as panics
216 // tagged with rakka_accel's CONTEXT_POISONED_TAG
217 // so the parent's supervisor strategy can
218 // route them to Restart.
219 if matches!(e, InferenceError::CudaContextPoisoned(_)) {
220 let _ = output.send(Err(e.clone())).await;
221 #[cfg(feature = "local-gpu")]
222 panic!("{}: {e}", rakka_accel::cuda::error::CONTEXT_POISONED_TAG);
223 #[cfg(not(feature = "local-gpu"))]
224 panic!("ContextPoisoned: {e}");
225 }
226 let _ = output.send(Err(e)).await;
227 }
228 }
229 });
230 }
231 ContextMsg::Rebuild { cause, reply } => {
232 let runner = self.runner.clone();
233 tokio::spawn(async move {
234 let mut g = runner.lock().await;
235 let r = g.rebuild_session(cause).await;
236 let _ = reply.send(r);
237 });
238 }
239 }
240 }
241}