use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, ActorRef, Context, Props};
use atomr_core::supervision::SupervisorStrategy;
use futures::StreamExt;
use parking_lot::Mutex;
use tokio::sync::Mutex as AsyncMutex;
use tokio::sync::{mpsc, oneshot};
use atomr_infer_core::batch::ExecuteBatch;
use atomr_infer_core::error::InferenceError;
use atomr_infer_core::runner::{ModelRunner, SessionRebuildCause};
use atomr_infer_core::tokens::TokenChunk;
pub struct WorkerSlot {
pub runner: Box<dyn ModelRunner>,
}
pub enum WorkerMsg {
Execute(ExecuteBatch, mpsc::Sender<Result<TokenChunk, InferenceError>>),
ContextPoisoned(String),
RebuildSession {
cause: SessionRebuildCause,
reply: oneshot::Sender<Result<(), InferenceError>>,
},
}
pub enum ContextMsg {
Execute(ExecuteBatch, mpsc::Sender<Result<TokenChunk, InferenceError>>),
Rebuild {
cause: SessionRebuildCause,
reply: oneshot::Sender<Result<(), InferenceError>>,
},
}
pub struct WorkerActor {
slot_factory: Box<dyn Fn() -> WorkerSlot + Send + Sync>,
child: Option<ActorRef<ContextMsg>>,
parent_to_child_seq: u64,
}
impl WorkerActor {
pub fn new<F>(slot_factory: F) -> Self
where
F: Fn() -> WorkerSlot + Send + Sync + 'static,
{
Self {
slot_factory: Box::new(slot_factory),
child: None,
parent_to_child_seq: 0,
}
}
fn spawn_child(&mut self, ctx: &mut Context<Self>) {
self.parent_to_child_seq += 1;
let name = format!("ctx-{}", self.parent_to_child_seq);
let cell = Mutex::new(Some((self.slot_factory)()));
let props = Props::create(move || {
let s = cell.lock().take().expect("worker context factory invoked twice");
ContextActor::new(s)
});
match ctx.spawn(props, &name) {
Ok(addr) => self.child = Some(addr),
Err(e) => tracing::error!(?e, "spawn ContextActor failed"),
}
}
}
#[async_trait]
impl Actor for WorkerActor {
type Msg = WorkerMsg;
async fn pre_start(&mut self, ctx: &mut Context<Self>) {
self.spawn_child(ctx);
}
fn supervisor_strategy(&self) -> SupervisorStrategy {
#[cfg(feature = "local-gpu")]
{
atomr_accel_cuda::error::device_supervisor_strategy()
}
#[cfg(not(feature = "local-gpu"))]
{
use atomr_core::supervision::{Directive, OneForOneStrategy};
OneForOneStrategy::new()
.with_max_retries(3)
.with_within(std::time::Duration::from_secs(60))
.with_decider(|err| {
if err.contains("ContextPoisoned") {
Directive::Restart
} else if err.contains("OutOfMemory") {
Directive::Resume
} else if err.contains("Unrecoverable") {
Directive::Stop
} else {
Directive::Escalate
}
})
.into()
}
}
async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
match msg {
WorkerMsg::Execute(batch, output) => {
let Some(child) = self.child.as_ref() else { return };
child.tell(ContextMsg::Execute(batch, output));
}
WorkerMsg::ContextPoisoned(reason) => {
tracing::warn!(reason, "context poisoned — rebuilding child");
if let Some(child) = self.child.take() {
child.stop();
}
self.spawn_child(ctx);
}
WorkerMsg::RebuildSession { cause, reply } => {
let Some(child) = self.child.as_ref() else {
let _ = reply.send(Err(InferenceError::Internal("no child".into())));
return;
};
child.tell(ContextMsg::Rebuild { cause, reply });
}
}
}
}
pub struct ContextActor {
runner: Arc<AsyncMutex<Box<dyn ModelRunner>>>,
}
impl ContextActor {
pub fn new(slot: WorkerSlot) -> Self {
Self {
runner: Arc::new(AsyncMutex::new(slot.runner)),
}
}
}
#[async_trait]
impl Actor for ContextActor {
type Msg = ContextMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: Self::Msg) {
match msg {
ContextMsg::Execute(batch, output) => {
let runner = self.runner.clone();
tokio::spawn(async move {
let mut g = runner.lock().await;
match g.execute(batch).await {
Ok(handle) => {
drop(g); let mut s = handle.into_stream();
while let Some(chunk) = s.next().await {
if output.send(chunk).await.is_err() {
break;
}
}
}
Err(e) => {
if matches!(e, InferenceError::CudaContextPoisoned(_)) {
let _ = output.send(Err(e.clone())).await;
#[cfg(feature = "local-gpu")]
panic!("{}: {e}", atomr_accel_cuda::error::CONTEXT_POISONED_TAG);
#[cfg(not(feature = "local-gpu"))]
panic!("ContextPoisoned: {e}");
}
let _ = output.send(Err(e)).await;
}
}
});
}
ContextMsg::Rebuild { cause, reply } => {
let runner = self.runner.clone();
tokio::spawn(async move {
let mut g = runner.lock().await;
let r = g.rebuild_session(cause).await;
let _ = reply.send(r);
});
}
}
}
}