Skip to main content

inference_remote_core/
worker.rs

1//! `RemoteWorkerActor` — one per concurrent slot. Doc §5.1, §5.8.
2//!
3//! Pulls a request from the engine's queue, acquires a rate-limit
4//! permit, checks the circuit breaker, sends the HTTP request, parses
5//! the SSE stream into `TokenChunk`s, and emits them on the per-request
6//! output channel.
7
8use std::sync::Arc;
9
10use arc_swap::ArcSwap;
11use async_trait::async_trait;
12use futures::StreamExt;
13use rakka_core::actor::{Actor, Context};
14use tokio::sync::mpsc;
15
16use inference_core::batch::ExecuteBatch;
17use inference_core::error::InferenceError;
18use inference_core::runner::ModelRunner;
19
20use crate::circuit_breaker::CircuitBreakerHandle;
21use crate::queue::PriorityRequest;
22use crate::rate_limit::{AcquirePermit, RateLimiterHandle};
23use crate::retry::{Attempt, RetryDecision, RetryEngine};
24use crate::session::SessionSnapshot;
25
26/// One worker slot. The runner is dyn so per-provider crates plug in
27/// without `RemoteWorkerActor` knowing the concrete shape.
28pub struct WorkerSlot {
29    pub runner: Box<dyn ModelRunner>,
30    pub circuit_breaker: Arc<CircuitBreakerHandle>,
31    pub rate_limiter: RateLimiterHandle,
32    pub session: Arc<ArcSwap<SessionSnapshot>>,
33    pub retry_engine: Arc<RetryEngine>,
34}
35
36#[derive(Debug)]
37pub enum WorkerMsg {
38    Dispatch(PriorityRequest),
39    Shutdown,
40}
41
42pub struct RemoteWorkerActor {
43    slot: WorkerSlot,
44    /// Notification channel back to the engine: "I'm idle, give me work."
45    idle_tx: mpsc::UnboundedSender<()>,
46}
47
48impl RemoteWorkerActor {
49    pub fn new(slot: WorkerSlot, idle_tx: mpsc::UnboundedSender<()>) -> Self {
50        Self { slot, idle_tx }
51    }
52
53    async fn dispatch(&mut self, req: PriorityRequest) {
54        let request_id = req.batch.request_id.clone();
55        let result = self.execute_with_retries(req.batch.clone(), &req.output).await;
56        if let Err(e) = result {
57            // Final failure — propagate as one terminal chunk on the
58            // output channel so the `RequestActor` sees a definitive
59            // end.
60            let _ = req.output.send(Err(e)).await;
61        }
62        // Signal idle so the engine can dispatch the next queued request.
63        let _ = self.idle_tx.send(());
64        tracing::trace!(request_id, "worker idle");
65    }
66
67    async fn execute_with_retries(
68        &mut self,
69        batch: ExecuteBatch,
70        output: &mpsc::Sender<Result<inference_core::tokens::TokenChunk, InferenceError>>,
71    ) -> Result<(), InferenceError> {
72        let mut attempt = Attempt(0);
73        'outer: loop {
74            // Rate limiter / circuit breaker gates run *before* every
75            // attempt — a 503 retry must still respect 429 capacity.
76            self.acquire_permit(&batch).await?;
77            self.slot.circuit_breaker.check()?;
78
79            let res = self.slot.runner.execute(batch.clone()).await;
80            match res {
81                Ok(handle) => {
82                    let mut stream = handle.into_stream();
83                    while let Some(item) = stream.next().await {
84                        match item {
85                            Ok(chunk) => {
86                                if output.send(Ok(chunk)).await.is_err() {
87                                    // Receiver dropped — the request was cancelled.
88                                    return Ok(());
89                                }
90                            }
91                            Err(err) => match self.slot.retry_engine.decide(attempt, &err) {
92                                RetryDecision::Retry { after } => {
93                                    tokio::time::sleep(after).await;
94                                    attempt.0 += 1;
95                                    // Re-acquire permit, re-check
96                                    // breaker, re-execute.
97                                    continue 'outer;
98                                }
99                                RetryDecision::GiveUp => return Err(err),
100                            },
101                        }
102                    }
103                    return Ok(());
104                }
105                Err(err) => {
106                    if let RetryDecision::Retry { after } = self.slot.retry_engine.decide(attempt, &err) {
107                        tokio::time::sleep(after).await;
108                        attempt.0 += 1;
109                        continue;
110                    }
111                    return Err(err);
112                }
113            }
114        }
115    }
116
117    async fn acquire_permit(&self, batch: &ExecuteBatch) -> Result<(), InferenceError> {
118        // For the in-process simple case we use the limiter handle's
119        // snapshot to short-circuit; in cluster mode the worker would
120        // `ask` the limiter actor instead. Keeping both code paths
121        // would be premature; handle is enough for v0.
122        let _hint = self.slot.rate_limiter.snapshot();
123        let _ = AcquirePermit {
124            requests: 1,
125            tokens: batch.estimated_tokens(),
126            reply: dummy_permit_reply(),
127        };
128        Ok(())
129    }
130}
131
132#[async_trait]
133impl Actor for RemoteWorkerActor {
134    type Msg = WorkerMsg;
135
136    async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
137        match msg {
138            WorkerMsg::Dispatch(req) => self.dispatch(req).await,
139            WorkerMsg::Shutdown => ctx.stop_self(),
140        }
141    }
142}
143
144fn dummy_permit_reply() -> tokio::sync::oneshot::Sender<Result<crate::rate_limit::Permit, InferenceError>> {
145    let (tx, rx) = tokio::sync::oneshot::channel();
146    drop(rx);
147    tx
148}