use async_trait::async_trait;
use atomr_core::actor::{Actor, ActorRef, Context};
use tokio::sync::{mpsc, oneshot};
use atomr_infer_core::batch::ExecuteBatch;
use atomr_infer_core::error::InferenceError;
use atomr_infer_core::tokens::{TokenChunk, Tokens};
use crate::dp_coordinator::{DpCoordinatorMsg, RouteTarget};
pub enum RequestMsg {
Dispatch { deployment: String, batch: ExecuteBatch },
Chunk(Result<TokenChunk, InferenceError>),
Cancel,
}
pub type StreamingResponse = mpsc::Receiver<Result<TokenChunk, InferenceError>>;
pub struct RequestActor {
coordinator: ActorRef<DpCoordinatorMsg>,
output: mpsc::Sender<Result<TokenChunk, InferenceError>>,
dispatched: bool,
accumulator: Tokens,
done: Option<oneshot::Sender<Tokens>>,
}
impl RequestActor {
pub fn new(
coordinator: ActorRef<DpCoordinatorMsg>,
output: mpsc::Sender<Result<TokenChunk, InferenceError>>,
done: oneshot::Sender<Tokens>,
) -> Self {
Self {
coordinator,
output,
dispatched: false,
accumulator: Tokens::default(),
done: Some(done),
}
}
async fn dispatch(&mut self, ctx: &mut Context<Self>, deployment: String, batch: ExecuteBatch) {
if self.dispatched {
return;
}
self.dispatched = true;
self.accumulator.request_id = batch.request_id.clone();
let target = match self
.coordinator
.ask_with(
|reply| DpCoordinatorMsg::RouteTo {
deployment: deployment.clone(),
reply,
},
std::time::Duration::from_secs(2),
)
.await
{
Ok(Ok(t)) => t,
Ok(Err(e)) => {
let _ = self.output.send(Err(e)).await;
self.finish().await;
ctx.stop_self();
return;
}
Err(_) => {
let _ = self
.output
.send(Err(InferenceError::Internal("coordinator timeout".into())))
.await;
self.finish().await;
ctx.stop_self();
return;
}
};
let (chunk_tx, mut chunk_rx) = mpsc::channel::<Result<TokenChunk, InferenceError>>(64);
let self_ref = ctx.self_ref().clone();
tokio::spawn(async move {
while let Some(c) = chunk_rx.recv().await {
self_ref.tell(RequestMsg::Chunk(c));
}
});
let _ = target;
let _ = batch;
let _ = chunk_tx;
}
async fn finish(&mut self) {
if let Some(d) = self.done.take() {
let _ = d.send(std::mem::take(&mut self.accumulator));
}
}
}
#[async_trait]
impl Actor for RequestActor {
type Msg = RequestMsg;
async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
match msg {
RequestMsg::Dispatch { deployment, batch } => {
self.dispatch(ctx, deployment, batch).await;
}
RequestMsg::Chunk(item) => {
let is_terminal = match &item {
Ok(c) => c.finish_reason.is_some(),
Err(_) => true,
};
if let Ok(c) = &item {
self.accumulator.append(c);
}
let _ = self.output.send(item).await;
if is_terminal {
self.finish().await;
ctx.stop_self();
}
}
RequestMsg::Cancel => {
self.finish().await;
ctx.stop_self();
}
}
}
}
pub type Route = RouteTarget;