Skip to main content

atomr_infer_runtime/
request.rs

1//! `RequestActor` — one per active client request. Doc §6.1, §6.2.
2//!
3//! Owns the per-request `Tokens` accumulation and the streaming channel
4//! back to the gateway's HTTP response. The actor is created by the
5//! gateway, asks the `DpCoordinatorActor` for a route, then `tell`s
6//! the chosen engine an `AddRequest`. The engine writes chunks into
7//! the `mpsc::Sender` we provided; we forward them to the `Tokens`
8//! sender held by the gateway.
9
10use async_trait::async_trait;
11use atomr_core::actor::{Actor, ActorRef, Context};
12use tokio::sync::{mpsc, oneshot};
13
14use atomr_infer_core::batch::ExecuteBatch;
15use atomr_infer_core::error::InferenceError;
16use atomr_infer_core::tokens::{TokenChunk, Tokens};
17
18use crate::dp_coordinator::{DpCoordinatorMsg, RouteTarget};
19
20pub enum RequestMsg {
21    /// Kick off the request: routes via the coordinator and dispatches
22    /// to the chosen engine.
23    Dispatch { deployment: String, batch: ExecuteBatch },
24    /// Forwarded chunk from the engine.
25    Chunk(Result<TokenChunk, InferenceError>),
26    /// Gateway gave up on the response (client disconnected). Cancel.
27    Cancel,
28}
29
30/// Streaming response handed back to the gateway for forwarding into
31/// the HTTP body. The gateway pulls `next()` until it sees `None`.
32pub type StreamingResponse = mpsc::Receiver<Result<TokenChunk, InferenceError>>;
33
34pub struct RequestActor {
35    coordinator: ActorRef<DpCoordinatorMsg>,
36    /// The gateway-facing channel. Each chunk we receive from the
37    /// engine is mirrored here.
38    output: mpsc::Sender<Result<TokenChunk, InferenceError>>,
39    /// Whether `Dispatch` has happened — guards against double-dispatch.
40    dispatched: bool,
41    /// Aggregate accumulator (exposed at end via `done` channel).
42    accumulator: Tokens,
43    done: Option<oneshot::Sender<Tokens>>,
44}
45
46impl RequestActor {
47    pub fn new(
48        coordinator: ActorRef<DpCoordinatorMsg>,
49        output: mpsc::Sender<Result<TokenChunk, InferenceError>>,
50        done: oneshot::Sender<Tokens>,
51    ) -> Self {
52        Self {
53            coordinator,
54            output,
55            dispatched: false,
56            accumulator: Tokens::default(),
57            done: Some(done),
58        }
59    }
60
61    async fn dispatch(&mut self, ctx: &mut Context<Self>, deployment: String, batch: ExecuteBatch) {
62        if self.dispatched {
63            return;
64        }
65        self.dispatched = true;
66        self.accumulator.request_id = batch.request_id.clone();
67
68        let target = match self
69            .coordinator
70            .ask_with(
71                |reply| DpCoordinatorMsg::RouteTo {
72                    deployment: deployment.clone(),
73                    reply,
74                },
75                std::time::Duration::from_secs(2),
76            )
77            .await
78        {
79            Ok(Ok(t)) => t,
80            Ok(Err(e)) => {
81                let _ = self.output.send(Err(e)).await;
82                self.finish().await;
83                ctx.stop_self();
84                return;
85            }
86            Err(_) => {
87                let _ = self
88                    .output
89                    .send(Err(InferenceError::Internal("coordinator timeout".into())))
90                    .await;
91                self.finish().await;
92                ctx.stop_self();
93                return;
94            }
95        };
96
97        // Bridge the engine→our chunk channel into RequestMsg::Chunk so
98        // we observe each chunk on our own mailbox and update the
99        // accumulator in actor context (no shared state).
100        let (chunk_tx, mut chunk_rx) = mpsc::channel::<Result<TokenChunk, InferenceError>>(64);
101        let self_ref = ctx.self_ref().clone();
102        tokio::spawn(async move {
103            while let Some(c) = chunk_rx.recv().await {
104                self_ref.tell(RequestMsg::Chunk(c));
105            }
106        });
107
108        // Send through the appropriate transport — local engine cores
109        // and remote engine cores have different message types, so we
110        // bridge via a small typed adapter at the placement site.
111        // Here in v0 we accept either by using a closure boxed by the
112        // caller; the gateway constructs the closure with knowledge of
113        // the engine kind.
114        // For now we simply ignore the routed `target` for the actual
115        // dispatch — that wiring is the gateway's job (see `gateway`
116        // module). We only retain the route for observability.
117        let _ = target;
118        let _ = batch;
119        let _ = chunk_tx;
120    }
121
122    async fn finish(&mut self) {
123        if let Some(d) = self.done.take() {
124            let _ = d.send(std::mem::take(&mut self.accumulator));
125        }
126    }
127}
128
129#[async_trait]
130impl Actor for RequestActor {
131    type Msg = RequestMsg;
132
133    async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
134        match msg {
135            RequestMsg::Dispatch { deployment, batch } => {
136                self.dispatch(ctx, deployment, batch).await;
137            }
138            RequestMsg::Chunk(item) => {
139                let is_terminal = match &item {
140                    Ok(c) => c.finish_reason.is_some(),
141                    Err(_) => true,
142                };
143                if let Ok(c) = &item {
144                    self.accumulator.append(c);
145                }
146                let _ = self.output.send(item).await;
147                if is_terminal {
148                    self.finish().await;
149                    ctx.stop_self();
150                }
151            }
152            RequestMsg::Cancel => {
153                self.finish().await;
154                ctx.stop_self();
155            }
156        }
157    }
158}
159
160/// Public alias: `RouteTarget` exposed under the `Route` name so
161/// callers can hold typed targets without reaching into
162/// `dp_coordinator`.
163pub type Route = RouteTarget;