atomr_infer_runtime/
request.rs1use async_trait::async_trait;
11use atomr_core::actor::{Actor, ActorRef, Context};
12use tokio::sync::{mpsc, oneshot};
13
14use atomr_infer_core::batch::ExecuteBatch;
15use atomr_infer_core::error::InferenceError;
16use atomr_infer_core::tokens::{TokenChunk, Tokens};
17
18use crate::dp_coordinator::{DpCoordinatorMsg, RouteTarget};
19
20pub enum RequestMsg {
21 Dispatch { deployment: String, batch: ExecuteBatch },
24 Chunk(Result<TokenChunk, InferenceError>),
26 Cancel,
28}
29
30pub type StreamingResponse = mpsc::Receiver<Result<TokenChunk, InferenceError>>;
33
34pub struct RequestActor {
35 coordinator: ActorRef<DpCoordinatorMsg>,
36 output: mpsc::Sender<Result<TokenChunk, InferenceError>>,
39 dispatched: bool,
41 accumulator: Tokens,
43 done: Option<oneshot::Sender<Tokens>>,
44}
45
46impl RequestActor {
47 pub fn new(
48 coordinator: ActorRef<DpCoordinatorMsg>,
49 output: mpsc::Sender<Result<TokenChunk, InferenceError>>,
50 done: oneshot::Sender<Tokens>,
51 ) -> Self {
52 Self {
53 coordinator,
54 output,
55 dispatched: false,
56 accumulator: Tokens::default(),
57 done: Some(done),
58 }
59 }
60
61 async fn dispatch(&mut self, ctx: &mut Context<Self>, deployment: String, batch: ExecuteBatch) {
62 if self.dispatched {
63 return;
64 }
65 self.dispatched = true;
66 self.accumulator.request_id = batch.request_id.clone();
67
68 let target = match self
69 .coordinator
70 .ask_with(
71 |reply| DpCoordinatorMsg::RouteTo {
72 deployment: deployment.clone(),
73 reply,
74 },
75 std::time::Duration::from_secs(2),
76 )
77 .await
78 {
79 Ok(Ok(t)) => t,
80 Ok(Err(e)) => {
81 let _ = self.output.send(Err(e)).await;
82 self.finish().await;
83 ctx.stop_self();
84 return;
85 }
86 Err(_) => {
87 let _ = self
88 .output
89 .send(Err(InferenceError::Internal("coordinator timeout".into())))
90 .await;
91 self.finish().await;
92 ctx.stop_self();
93 return;
94 }
95 };
96
97 let (chunk_tx, mut chunk_rx) = mpsc::channel::<Result<TokenChunk, InferenceError>>(64);
101 let self_ref = ctx.self_ref().clone();
102 tokio::spawn(async move {
103 while let Some(c) = chunk_rx.recv().await {
104 self_ref.tell(RequestMsg::Chunk(c));
105 }
106 });
107
108 let _ = target;
118 let _ = batch;
119 let _ = chunk_tx;
120 }
121
122 async fn finish(&mut self) {
123 if let Some(d) = self.done.take() {
124 let _ = d.send(std::mem::take(&mut self.accumulator));
125 }
126 }
127}
128
129#[async_trait]
130impl Actor for RequestActor {
131 type Msg = RequestMsg;
132
133 async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
134 match msg {
135 RequestMsg::Dispatch { deployment, batch } => {
136 self.dispatch(ctx, deployment, batch).await;
137 }
138 RequestMsg::Chunk(item) => {
139 let is_terminal = match &item {
140 Ok(c) => c.finish_reason.is_some(),
141 Err(_) => true,
142 };
143 if let Ok(c) = &item {
144 self.accumulator.append(c);
145 }
146 let _ = self.output.send(item).await;
147 if is_terminal {
148 self.finish().await;
149 ctx.stop_self();
150 }
151 }
152 RequestMsg::Cancel => {
153 self.finish().await;
154 ctx.stop_self();
155 }
156 }
157 }
158}
159
160pub type Route = RouteTarget;