inference_runtime/
engine_core.rs1use std::sync::Arc;
14
15use async_trait::async_trait;
16use futures::StreamExt;
17use parking_lot::Mutex;
18use rakka_core::actor::{Actor, Context};
19use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex};
20
21use inference_core::batch::ExecuteBatch;
22use inference_core::error::InferenceError;
23use inference_core::runner::ModelRunner;
24use inference_core::tokens::TokenChunk;
25
26#[derive(Clone)]
27pub struct LocalEngineConfig {
28 pub max_concurrent: u32,
29 pub queue_capacity: usize,
30}
31
32impl Default for LocalEngineConfig {
33 fn default() -> Self {
34 Self {
35 max_concurrent: 32,
36 queue_capacity: 1024,
37 }
38 }
39}
40
41pub struct AddRequest {
42 pub batch: ExecuteBatch,
43 pub output: mpsc::Sender<Result<TokenChunk, InferenceError>>,
44 pub admission: oneshot::Sender<Result<(), InferenceError>>,
45}
46
47pub enum EngineCoreMsg {
48 Add(AddRequest),
49 GetLoad {
52 reply: oneshot::Sender<f64>,
53 },
54}
55
56pub struct EngineCoreActor {
57 runner: Arc<AsyncMutex<Box<dyn ModelRunner>>>,
61 config: LocalEngineConfig,
62 in_flight: Arc<Mutex<u32>>,
63}
64
65impl EngineCoreActor {
66 pub fn new(runner: Box<dyn ModelRunner>, config: LocalEngineConfig) -> Self {
67 Self {
68 runner: Arc::new(AsyncMutex::new(runner)),
69 config,
70 in_flight: Arc::new(Mutex::new(0)),
71 }
72 }
73
74 fn try_admit(&self) -> Result<(), InferenceError> {
75 let mut g = self.in_flight.lock();
76 if *g >= self.config.max_concurrent {
77 return Err(InferenceError::Backpressure("engine at capacity".into()));
78 }
79 *g += 1;
80 Ok(())
81 }
82
83 fn release(&self) {
84 let mut g = self.in_flight.lock();
85 *g = g.saturating_sub(1);
86 }
87}
88
89#[async_trait]
90impl Actor for EngineCoreActor {
91 type Msg = EngineCoreMsg;
92
93 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: Self::Msg) {
94 match msg {
95 EngineCoreMsg::Add(req) => match self.try_admit() {
96 Err(e) => {
97 let _ = req.admission.send(Err(e));
98 }
99 Ok(()) => {
100 let _ = req.admission.send(Ok(()));
101 let runner = self.runner.clone();
102 let in_flight = self.in_flight.clone();
103 let output = req.output;
104 let batch = req.batch;
105 tokio::spawn(async move {
106 let mut g = runner.lock().await;
114 match g.execute(batch).await {
115 Ok(handle) => {
116 let mut s = handle.into_stream();
117 while let Some(chunk) = s.next().await {
118 if output.send(chunk).await.is_err() {
119 break;
120 }
121 }
122 }
123 Err(e) => {
124 let _ = output.send(Err(e)).await;
125 }
126 }
127 let mut g = in_flight.lock();
128 *g = g.saturating_sub(1);
129 });
130 self.release();
131 }
132 },
133 EngineCoreMsg::GetLoad { reply } => {
134 let load = *self.in_flight.lock() as f64 / self.config.max_concurrent as f64;
135 let _ = reply.send(load);
136 }
137 }
138 }
139}