Skip to main content

oxibonsai_runtime/
async_engine.rs

1//! Async wrapper around the synchronous [`InferenceEngine`].
2//!
3//! Uses [`tokio::task::spawn_blocking`] for CPU-bound inference work,
4//! ensuring the Tokio runtime is not blocked. Bounded concurrency is
5//! enforced via a [`Semaphore`] to prevent resource exhaustion.
6//!
7//! This module is not available on WASM targets (`wasm32`) because tokio's
8//! full feature set (including threads and network I/O) is not supported there.
9
10#![cfg(not(target_arch = "wasm32"))]
11
12use std::sync::Arc;
13use tokio::sync::Mutex;
14use tokio::sync::Semaphore;
15
16use crate::engine::InferenceEngine;
17use crate::error::{RuntimeError, RuntimeResult};
18use crate::metrics::InferenceMetrics;
19
20/// Async inference engine with bounded concurrency.
21///
22/// Wraps a synchronous [`InferenceEngine`] and provides async methods
23/// that use `spawn_blocking` under the hood. The semaphore limits how
24/// many concurrent inference requests can be in flight, protecting
25/// both memory and CPU utilization.
26pub struct AsyncInferenceEngine {
27    engine: Arc<Mutex<InferenceEngine<'static>>>,
28    concurrency_limit: Arc<Semaphore>,
29    max_concurrent: usize,
30    metrics: Option<Arc<InferenceMetrics>>,
31}
32
33impl AsyncInferenceEngine {
34    /// Create a new async inference engine wrapping the given engine.
35    ///
36    /// `max_concurrent` controls how many inference requests may execute
37    /// concurrently. A value of 1 serializes all requests.
38    pub fn new(engine: InferenceEngine<'static>, max_concurrent: usize) -> Self {
39        let effective_max = max_concurrent.max(1);
40        Self {
41            engine: Arc::new(Mutex::new(engine)),
42            concurrency_limit: Arc::new(Semaphore::new(effective_max)),
43            max_concurrent: effective_max,
44            metrics: None,
45        }
46    }
47
48    /// Attach shared metrics for recording inference telemetry.
49    pub fn with_metrics(mut self, metrics: Arc<InferenceMetrics>) -> Self {
50        self.metrics = Some(metrics);
51        self
52    }
53
54    /// Generate tokens asynchronously.
55    ///
56    /// Blocks the caller until a semaphore permit is acquired, then
57    /// dispatches the CPU-bound generation to a blocking thread.
58    pub async fn generate(
59        &self,
60        prompt_tokens: Vec<u32>,
61        max_tokens: usize,
62    ) -> RuntimeResult<Vec<u32>> {
63        // Acquire concurrency permit
64        let _permit = self
65            .concurrency_limit
66            .acquire()
67            .await
68            .map_err(|_| RuntimeError::Server("semaphore closed".to_string()))?;
69
70        if let Some(m) = &self.metrics {
71            m.active_requests.inc();
72        }
73
74        let engine = Arc::clone(&self.engine);
75        let metrics = self.metrics.clone();
76
77        // Move CPU-bound work to a blocking thread
78        let result = tokio::task::spawn_blocking(move || {
79            let rt = tokio::runtime::Handle::current();
80            let mut engine_guard = rt.block_on(engine.lock());
81            engine_guard.generate(&prompt_tokens, max_tokens)
82        })
83        .await
84        .map_err(|e| RuntimeError::Server(format!("task join error: {e}")))?;
85
86        if let Some(m) = &metrics {
87            m.active_requests.dec();
88        }
89
90        result
91    }
92
93    /// Generate tokens with streaming via an unbounded channel.
94    ///
95    /// Returns a receiver that yields tokens as they are generated.
96    /// The generation happens on a blocking thread; the receiver can
97    /// be consumed asynchronously.
98    pub async fn generate_streaming(
99        &self,
100        prompt_tokens: Vec<u32>,
101        max_tokens: usize,
102    ) -> RuntimeResult<tokio::sync::mpsc::UnboundedReceiver<u32>> {
103        // Acquire concurrency permit
104        let permit = self
105            .concurrency_limit
106            .clone()
107            .acquire_owned()
108            .await
109            .map_err(|_| RuntimeError::Server("semaphore closed".to_string()))?;
110
111        if let Some(m) = &self.metrics {
112            m.active_requests.inc();
113        }
114
115        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
116        let engine = Arc::clone(&self.engine);
117        let metrics = self.metrics.clone();
118
119        tokio::task::spawn_blocking(move || {
120            let rt = tokio::runtime::Handle::current();
121            let mut engine_guard = rt.block_on(engine.lock());
122            let _result = engine_guard.generate_streaming(&prompt_tokens, max_tokens, &tx);
123            // tx is dropped here, closing the channel
124
125            if let Some(m) = &metrics {
126                m.active_requests.dec();
127            }
128
129            // Permit is dropped here, releasing the semaphore slot
130            drop(permit);
131        });
132
133        Ok(rx)
134    }
135
136    /// Current number of active (in-flight) requests.
137    ///
138    /// Computed as `max_concurrent - available_permits`.
139    pub fn active_requests(&self) -> usize {
140        self.max_concurrent - self.concurrency_limit.available_permits()
141    }
142
143    /// Maximum concurrent requests this engine allows.
144    pub fn max_concurrent(&self) -> usize {
145        self.max_concurrent
146    }
147
148    /// Check if the engine has capacity for at least one more request.
149    pub fn has_capacity(&self) -> bool {
150        self.concurrency_limit.available_permits() > 0
151    }
152
153    /// Get a reference to the underlying engine (behind a mutex).
154    pub fn engine(&self) -> &Arc<Mutex<InferenceEngine<'static>>> {
155        &self.engine
156    }
157}
158
159// ─── Tests ─────────────────────────────────────────────────────────────
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::sampling::SamplingParams;
165    use oxibonsai_core::config::Qwen3Config;
166
167    fn make_engine() -> InferenceEngine<'static> {
168        let config = Qwen3Config::bonsai_8b();
169        InferenceEngine::new(config, SamplingParams::default(), 42)
170    }
171
172    #[test]
173    fn async_engine_creation() {
174        let engine = make_engine();
175        let async_engine = AsyncInferenceEngine::new(engine, 4);
176        assert_eq!(async_engine.max_concurrent(), 4);
177        assert_eq!(async_engine.active_requests(), 0);
178        assert!(async_engine.has_capacity());
179    }
180
181    #[test]
182    fn async_engine_min_concurrency_is_one() {
183        let engine = make_engine();
184        let async_engine = AsyncInferenceEngine::new(engine, 0);
185        assert_eq!(async_engine.max_concurrent(), 1);
186    }
187
188    #[test]
189    fn async_engine_with_metrics() {
190        let engine = make_engine();
191        let metrics = Arc::new(InferenceMetrics::new());
192        let async_engine = AsyncInferenceEngine::new(engine, 2).with_metrics(Arc::clone(&metrics));
193        assert_eq!(async_engine.max_concurrent(), 2);
194        assert!(async_engine.has_capacity());
195    }
196
197    #[test]
198    fn async_engine_capacity_tracking() {
199        let engine = make_engine();
200        let async_engine = AsyncInferenceEngine::new(engine, 3);
201        // Initially at full capacity
202        assert_eq!(async_engine.active_requests(), 0);
203        assert!(async_engine.has_capacity());
204        assert_eq!(async_engine.max_concurrent(), 3);
205    }
206
207    #[tokio::test]
208    async fn async_engine_generate_empty_prompt() {
209        let engine = make_engine();
210        let async_engine = AsyncInferenceEngine::new(engine, 1);
211        let result = async_engine.generate(vec![], 10).await;
212        assert!(result.is_ok());
213        let tokens = result.expect("should succeed");
214        assert!(tokens.is_empty());
215    }
216
217    #[tokio::test]
218    async fn async_engine_streaming_empty_prompt() {
219        let engine = make_engine();
220        let async_engine = AsyncInferenceEngine::new(engine, 1);
221        let result = async_engine.generate_streaming(vec![], 10).await;
222        assert!(result.is_ok());
223        let mut rx = result.expect("should succeed");
224        // Channel should be closed immediately (empty prompt produces no tokens)
225        let token = rx.recv().await;
226        assert!(token.is_none());
227    }
228
229    #[tokio::test]
230    async fn async_engine_concurrency_respected() {
231        let engine = make_engine();
232        let async_engine = Arc::new(AsyncInferenceEngine::new(engine, 2));
233
234        // We can check capacity before any requests
235        assert!(async_engine.has_capacity());
236        assert_eq!(async_engine.active_requests(), 0);
237
238        // Generate with empty prompt should not exhaust permits
239        let r1 = async_engine.generate(vec![], 1).await;
240        assert!(r1.is_ok());
241        // After completion, permits should be returned
242        assert!(async_engine.has_capacity());
243        assert_eq!(async_engine.active_requests(), 0);
244    }
245}