Skip to main content

atomr_infer_runtime/
engine_core.rs

1//! `EngineCoreActor` — local-GPU per-replica orchestrator. Doc §4, §5.1.
2//!
3//! Wraps a `Box<dyn ModelRunner>` whose `transport_kind() ==
4//! LocalGpu`. The continuous-batch scheduler and KV-cache manager are
5//! per-runtime *modules* (vLLM has them; TensorRT/ORT batch by
6//! stacking inputs); this actor just owns the runner, dispatches
7//! `ExecuteBatch` requests through it, and pumps the resulting chunk
8//! stream into the per-request output channel.
9//!
10//! `RemoteEngineCoreActor` (in `inference-remote-core`) is the
11//! network-shaped sibling.
12
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use atomr_core::actor::{Actor, Context};
17use futures::StreamExt;
18use parking_lot::Mutex;
19use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex};
20
21use atomr_infer_core::batch::ExecuteBatch;
22use atomr_infer_core::error::InferenceError;
23use atomr_infer_core::runner::ModelRunner;
24use atomr_infer_core::tokens::TokenChunk;
25
26#[derive(Clone)]
27pub struct LocalEngineConfig {
28    pub max_concurrent: u32,
29    pub queue_capacity: usize,
30}
31
32impl Default for LocalEngineConfig {
33    fn default() -> Self {
34        Self {
35            max_concurrent: 32,
36            queue_capacity: 1024,
37        }
38    }
39}
40
41pub struct AddRequest {
42    pub batch: ExecuteBatch,
43    pub output: mpsc::Sender<Result<TokenChunk, InferenceError>>,
44    pub admission: oneshot::Sender<Result<(), InferenceError>>,
45}
46
47pub enum EngineCoreMsg {
48    Add(AddRequest),
49    /// Request a load-score snapshot. Used by `DpCoordinatorActor`'s
50    /// periodic poll.
51    GetLoad {
52        reply: oneshot::Sender<f64>,
53    },
54}
55
56pub struct EngineCoreActor {
57    /// Async mutex because `ModelRunner::execute` is held across an
58    /// await; a `parking_lot::Mutex` guard would not be `Send` over
59    /// the await boundary.
60    runner: Arc<AsyncMutex<Box<dyn ModelRunner>>>,
61    config: LocalEngineConfig,
62    in_flight: Arc<Mutex<u32>>,
63}
64
65impl EngineCoreActor {
66    pub fn new(runner: Box<dyn ModelRunner>, config: LocalEngineConfig) -> Self {
67        Self {
68            runner: Arc::new(AsyncMutex::new(runner)),
69            config,
70            in_flight: Arc::new(Mutex::new(0)),
71        }
72    }
73
74    fn try_admit(&self) -> Result<(), InferenceError> {
75        let mut g = self.in_flight.lock();
76        if *g >= self.config.max_concurrent {
77            return Err(InferenceError::Backpressure("engine at capacity".into()));
78        }
79        *g += 1;
80        Ok(())
81    }
82
83    fn release(&self) {
84        let mut g = self.in_flight.lock();
85        *g = g.saturating_sub(1);
86    }
87}
88
89#[async_trait]
90impl Actor for EngineCoreActor {
91    type Msg = EngineCoreMsg;
92
93    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: Self::Msg) {
94        match msg {
95            EngineCoreMsg::Add(req) => match self.try_admit() {
96                Err(e) => {
97                    let _ = req.admission.send(Err(e));
98                }
99                Ok(()) => {
100                    let _ = req.admission.send(Ok(()));
101                    let runner = self.runner.clone();
102                    let in_flight = self.in_flight.clone();
103                    let output = req.output;
104                    let batch = req.batch;
105                    tokio::spawn(async move {
106                        // Hold the async mutex across the execute()
107                        // await — single runner owns the GPU context
108                        // exclusively for the duration of a batched
109                        // step. For runtimes that batch across
110                        // requests (vLLM), `execute` returns quickly
111                        // after enqueueing onto the engine's internal
112                        // step loop.
113                        let mut g = runner.lock().await;
114                        match g.execute(batch).await {
115                            Ok(handle) => {
116                                let mut s = handle.into_stream();
117                                while let Some(chunk) = s.next().await {
118                                    if output.send(chunk).await.is_err() {
119                                        break;
120                                    }
121                                }
122                            }
123                            Err(e) => {
124                                let _ = output.send(Err(e)).await;
125                            }
126                        }
127                        let mut g = in_flight.lock();
128                        *g = g.saturating_sub(1);
129                    });
130                    self.release();
131                }
132            },
133            EngineCoreMsg::GetLoad { reply } => {
134                let load = *self.in_flight.lock() as f64 / self.config.max_concurrent as f64;
135                let _ = reply.send(load);
136            }
137        }
138    }
139}