inference_remote_core/
worker.rs1use std::sync::Arc;
9
10use arc_swap::ArcSwap;
11use async_trait::async_trait;
12use futures::StreamExt;
13use rakka_core::actor::{Actor, Context};
14use tokio::sync::mpsc;
15
16use inference_core::batch::ExecuteBatch;
17use inference_core::error::InferenceError;
18use inference_core::runner::ModelRunner;
19
20use crate::circuit_breaker::CircuitBreakerHandle;
21use crate::queue::PriorityRequest;
22use crate::rate_limit::{AcquirePermit, RateLimiterHandle};
23use crate::retry::{Attempt, RetryDecision, RetryEngine};
24use crate::session::SessionSnapshot;
25
26pub struct WorkerSlot {
29 pub runner: Box<dyn ModelRunner>,
30 pub circuit_breaker: Arc<CircuitBreakerHandle>,
31 pub rate_limiter: RateLimiterHandle,
32 pub session: Arc<ArcSwap<SessionSnapshot>>,
33 pub retry_engine: Arc<RetryEngine>,
34}
35
36#[derive(Debug)]
37pub enum WorkerMsg {
38 Dispatch(PriorityRequest),
39 Shutdown,
40}
41
42pub struct RemoteWorkerActor {
43 slot: WorkerSlot,
44 idle_tx: mpsc::UnboundedSender<()>,
46}
47
48impl RemoteWorkerActor {
49 pub fn new(slot: WorkerSlot, idle_tx: mpsc::UnboundedSender<()>) -> Self {
50 Self { slot, idle_tx }
51 }
52
53 async fn dispatch(&mut self, req: PriorityRequest) {
54 let request_id = req.batch.request_id.clone();
55 let result = self.execute_with_retries(req.batch.clone(), &req.output).await;
56 if let Err(e) = result {
57 let _ = req.output.send(Err(e)).await;
61 }
62 let _ = self.idle_tx.send(());
64 tracing::trace!(request_id, "worker idle");
65 }
66
67 async fn execute_with_retries(
68 &mut self,
69 batch: ExecuteBatch,
70 output: &mpsc::Sender<Result<inference_core::tokens::TokenChunk, InferenceError>>,
71 ) -> Result<(), InferenceError> {
72 let mut attempt = Attempt(0);
73 'outer: loop {
74 self.acquire_permit(&batch).await?;
77 self.slot.circuit_breaker.check()?;
78
79 let res = self.slot.runner.execute(batch.clone()).await;
80 match res {
81 Ok(handle) => {
82 let mut stream = handle.into_stream();
83 while let Some(item) = stream.next().await {
84 match item {
85 Ok(chunk) => {
86 if output.send(Ok(chunk)).await.is_err() {
87 return Ok(());
89 }
90 }
91 Err(err) => match self.slot.retry_engine.decide(attempt, &err) {
92 RetryDecision::Retry { after } => {
93 tokio::time::sleep(after).await;
94 attempt.0 += 1;
95 continue 'outer;
98 }
99 RetryDecision::GiveUp => return Err(err),
100 },
101 }
102 }
103 return Ok(());
104 }
105 Err(err) => {
106 if let RetryDecision::Retry { after } = self.slot.retry_engine.decide(attempt, &err) {
107 tokio::time::sleep(after).await;
108 attempt.0 += 1;
109 continue;
110 }
111 return Err(err);
112 }
113 }
114 }
115 }
116
117 async fn acquire_permit(&self, batch: &ExecuteBatch) -> Result<(), InferenceError> {
118 let _hint = self.slot.rate_limiter.snapshot();
123 let _ = AcquirePermit {
124 requests: 1,
125 tokens: batch.estimated_tokens(),
126 reply: dummy_permit_reply(),
127 };
128 Ok(())
129 }
130}
131
132#[async_trait]
133impl Actor for RemoteWorkerActor {
134 type Msg = WorkerMsg;
135
136 async fn handle(&mut self, ctx: &mut Context<Self>, msg: Self::Msg) {
137 match msg {
138 WorkerMsg::Dispatch(req) => self.dispatch(req).await,
139 WorkerMsg::Shutdown => ctx.stop_self(),
140 }
141 }
142}
143
144fn dummy_permit_reply() -> tokio::sync::oneshot::Sender<Result<crate::rate_limit::Permit, InferenceError>> {
145 let (tx, rx) = tokio::sync::oneshot::channel();
146 drop(rx);
147 tx
148}