inference_remote_core/
engine.rs1use std::sync::Arc;
9
10use async_trait::async_trait;
11use parking_lot::Mutex;
12use rakka_core::actor::{Actor, ActorRef, Context, Props};
13use tokio::sync::{mpsc, oneshot};
14
15use inference_core::batch::ExecuteBatch;
16use inference_core::deployment::CapacityPolicy;
17use inference_core::error::InferenceError;
18use inference_core::tokens::TokenChunk;
19
20use crate::queue::{Priority, PriorityRequest, RequestQueue};
21use crate::worker::{RemoteWorkerActor, WorkerMsg, WorkerSlot};
22
23#[derive(Clone)]
24pub struct RemoteEngineConfig {
25 pub queue_capacity: usize,
26 pub worker_count: u32,
27 pub on_capacity_exhausted: CapacityPolicy,
28}
29
30impl Default for RemoteEngineConfig {
31 fn default() -> Self {
32 Self {
33 queue_capacity: 1024,
34 worker_count: 8,
35 on_capacity_exhausted: CapacityPolicy::Queue,
36 }
37 }
38}
39
40#[derive(Default, Clone)]
41pub struct EngineMetrics {
42 pub queued: u64,
43 pub in_flight: u64,
44 pub completed: u64,
45 pub rejected_backpressure: u64,
46}
47
48pub struct AddRequest {
49 pub priority: Priority,
50 pub batch: ExecuteBatch,
51 pub output: mpsc::Sender<Result<TokenChunk, InferenceError>>,
52 pub admission: oneshot::Sender<Result<(), InferenceError>>,
53}
54
55pub enum EngineMsg {
56 Add(AddRequest),
57 WorkerIdle,
58}
59
60pub type WorkerSlotFactory = Box<dyn FnMut() -> WorkerSlot + Send>;
64
65struct WorkerEntry {
66 addr: ActorRef<WorkerMsg>,
67 idle: bool,
69}
70
71pub struct RemoteEngineCoreActor {
72 #[allow(dead_code)] config: RemoteEngineConfig,
74 queue: RequestQueue,
75 workers: Vec<WorkerEntry>,
76 metrics: Arc<Mutex<EngineMetrics>>,
77 factory: Option<WorkerSlotFactory>,
80 worker_count: u32,
81 idle_tx: mpsc::UnboundedSender<()>,
85 idle_rx: Option<mpsc::UnboundedReceiver<()>>,
86}
87
88impl RemoteEngineCoreActor {
89 pub fn new(config: RemoteEngineConfig, factory: WorkerSlotFactory) -> Self {
90 let (idle_tx, idle_rx) = mpsc::unbounded_channel();
91 let queue = RequestQueue::new(config.queue_capacity);
92 let worker_count = config.worker_count;
93 Self {
94 config,
95 queue,
96 workers: Vec::new(),
97 metrics: Arc::new(Mutex::new(EngineMetrics::default())),
98 factory: Some(factory),
99 worker_count,
100 idle_tx,
101 idle_rx: Some(idle_rx),
102 }
103 }
104
105 pub fn metrics_handle(&self) -> Arc<Mutex<EngineMetrics>> {
106 self.metrics.clone()
107 }
108
109 fn enqueue(&mut self, req: AddRequest) {
110 let priority_request = PriorityRequest {
111 priority: req.priority,
112 arrival_seq: 0,
113 batch: req.batch,
114 output: req.output,
115 };
116 match self.queue.push(priority_request) {
117 Ok(()) => {
118 self.metrics.lock().queued += 1;
119 let _ = req.admission.send(Ok(()));
120 }
121 Err(_rejected) => {
122 self.metrics.lock().rejected_backpressure += 1;
123 let _ = req
124 .admission
125 .send(Err(InferenceError::Backpressure("engine queue full".into())));
126 }
127 }
128 }
129
130 fn try_dispatch(&mut self) {
131 while !self.queue.is_empty() {
132 let Some(idx) = self.workers.iter().position(|w| w.idle) else {
133 break;
134 };
135 let Some(req) = self.queue.pop() else { break };
136 self.workers[idx].idle = false;
137 self.workers[idx].addr.tell(WorkerMsg::Dispatch(req));
138 let mut m = self.metrics.lock();
139 m.queued = m.queued.saturating_sub(1);
140 m.in_flight += 1;
141 }
142 }
143}
144
145#[async_trait]
146impl Actor for RemoteEngineCoreActor {
147 type Msg = EngineMsg;
148
149 async fn pre_start(&mut self, ctx: &mut Context<Self>) {
150 let mut factory = match self.factory.take() {
158 Some(f) => f,
159 None => {
160 tracing::error!("RemoteEngineCoreActor pre_start with no factory");
161 return;
162 }
163 };
164 for i in 0..self.worker_count {
165 let slot = factory();
166 let idle_tx = self.idle_tx.clone();
167 let cell = parking_lot::Mutex::new(Some(slot));
168 let props = Props::create(move || {
169 let s = cell
170 .lock()
171 .take()
172 .expect("worker factory invoked twice — restart not yet supported");
173 RemoteWorkerActor::new(s, idle_tx.clone())
174 });
175 let name = format!("worker-{i}");
176 match ctx.spawn(props, &name) {
177 Ok(addr) => self.workers.push(WorkerEntry { addr, idle: true }),
178 Err(e) => tracing::error!(?e, "spawn worker {i} failed"),
179 }
180 }
181
182 let self_ref = ctx.self_ref().clone();
185 let mut rx = self.idle_rx.take().expect("idle_rx set in new()");
186 tokio::spawn(async move {
187 while rx.recv().await.is_some() {
188 self_ref.tell(EngineMsg::WorkerIdle);
189 }
190 });
191 }
192
193 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: Self::Msg) {
194 match msg {
195 EngineMsg::Add(req) => {
196 self.enqueue(req);
197 self.try_dispatch();
198 }
199 EngineMsg::WorkerIdle => {
200 if let Some(w) = self.workers.iter_mut().find(|w| !w.idle) {
207 w.idle = true;
208 let mut m = self.metrics.lock();
209 m.in_flight = m.in_flight.saturating_sub(1);
210 m.completed += 1;
211 }
212 self.try_dispatch();
213 }
214 }
215 }
216}