use std::sync::Arc;
use arc_swap::ArcSwap;
use async_trait::async_trait;
use futures::StreamExt;
use rakka_core::actor::{Actor, Context};
use tokio::sync::mpsc;
use inference_core::batch::ExecuteBatch;
use inference_core::error::InferenceError;
use inference_core::runner::ModelRunner;
use crate::circuit_breaker::CircuitBreakerHandle;
use crate::queue::PriorityRequest;
use crate::rate_limit::{AcquirePermit, RateLimiterHandle};
use crate::retry::{Attempt, RetryDecision, RetryEngine};
use crate::session::SessionSnapshot;
pub struct WorkerSlot {
pub runner: Box<dyn ModelRunner>,
pub circuit_breaker: Arc<CircuitBreakerHandle>,
pub rate_limiter: RateLimiterHandle,
pub session: Arc<ArcSwap<SessionSnapshot>>,
pub retry_engine: Arc<RetryEngine>,
}
#[derive(Debug)]
pub enum WorkerMsg {
Dispatch(PriorityRequest),
Shutdown,
}
pub struct RemoteWorkerActor {
slot: WorkerSlot,
idle_tx: mpsc::UnboundedSender<()>,
}
impl RemoteWorkerActor {
pub fn new(slot: WorkerSlot, idle_tx: mpsc::UnboundedSender<()>) -> Self {
Self { slot, idle_tx }
}
async fn dispatch(&mut self, req: PriorityRequest) {
let request_id = req.batch.request_id.clone();
let result = self.execute_with_retries(req.batch.clone(), &req.output).await;
if let Err(e) = result {
let _ = req.output.send(Err(e)).await;
}
let _ = self.idle_tx.send(());
tracing::trace!(request_id, "worker idle");
}
async fn execute_with_retries(
&mut self,
batch: ExecuteBatch,
output: &mpsc::Sender<Result<inference_core::tokens::TokenChunk, InferenceError>>,
) -> Result<(), InferenceError> {
let mut attempt = Attempt(0);
'outer: loop {
self.acquire_permit(&batch).await?;
self.slot.circuit_breaker.check()?;
let res = self.slot.runner.execute(batch.clone()).await;
match res {
Ok(handle) => {
let mut stream = handle.into_stream();
while let Some(item) = stream.next().await {
match item {
Ok(chunk) => {
if output.send(Ok(chunk)).await.is_err() {
return Ok(());
}
}
Err(err) => match self.slot.retry_engine.decide(attempt, &err) {
RetryDecision::Retry { after } => {
tokio::time::sleep(after).await;
attempt.0 += 1;
continue 'outer;
}
RetryDecision::GiveUp => return Err(err),
},
}
}
return Ok(());
}
Err(err) => {
if let RetryDecision::Retry { after } = self.slot.retry_engine.decide(attempt, &err) {
tokio::time::sleep(after).await;
attempt.0 += 1;
continue;
}
return Err(err);
}
}
}
}
async fn acquire_permit(&self, batch: &ExecuteBatch) -> Result<(), InferenceError> {
let _hint = self.slot.rate_limiter.snapshot();
let _ = AcquirePermit {
requests: 1,
tokens: batch.estimated_tokens(),
reply: dummy_permit_reply(),
};
Ok(())
}
}
#[async_trait]
impl Actor for RemoteWorkerActor {
type Msg = WorkerMsg;
async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
match msg {
WorkerMsg::Dispatch(req) => self.dispatch(req).await,
WorkerMsg::Shutdown => ctx.stop_self(),
}
}
}
fn dummy_permit_reply() -> tokio::sync::oneshot::Sender<Result<crate::rate_limit::Permit, InferenceError>> {
let (tx, rx) = tokio::sync::oneshot::channel();
drop(rx);
tx
}