oxibonsai_runtime/
async_engine.rs1#![cfg(not(target_arch = "wasm32"))]
11
12use std::sync::Arc;
13use tokio::sync::Mutex;
14use tokio::sync::Semaphore;
15
16use crate::engine::InferenceEngine;
17use crate::error::{RuntimeError, RuntimeResult};
18use crate::metrics::InferenceMetrics;
19
20pub struct AsyncInferenceEngine {
27 engine: Arc<Mutex<InferenceEngine<'static>>>,
28 concurrency_limit: Arc<Semaphore>,
29 max_concurrent: usize,
30 metrics: Option<Arc<InferenceMetrics>>,
31}
32
33impl AsyncInferenceEngine {
34 pub fn new(engine: InferenceEngine<'static>, max_concurrent: usize) -> Self {
39 let effective_max = max_concurrent.max(1);
40 Self {
41 engine: Arc::new(Mutex::new(engine)),
42 concurrency_limit: Arc::new(Semaphore::new(effective_max)),
43 max_concurrent: effective_max,
44 metrics: None,
45 }
46 }
47
48 pub fn with_metrics(mut self, metrics: Arc<InferenceMetrics>) -> Self {
50 self.metrics = Some(metrics);
51 self
52 }
53
54 pub async fn generate(
59 &self,
60 prompt_tokens: Vec<u32>,
61 max_tokens: usize,
62 ) -> RuntimeResult<Vec<u32>> {
63 let _permit = self
65 .concurrency_limit
66 .acquire()
67 .await
68 .map_err(|_| RuntimeError::Server("semaphore closed".to_string()))?;
69
70 if let Some(m) = &self.metrics {
71 m.active_requests.inc();
72 }
73
74 let engine = Arc::clone(&self.engine);
75 let metrics = self.metrics.clone();
76
77 let result = tokio::task::spawn_blocking(move || {
79 let rt = tokio::runtime::Handle::current();
80 let mut engine_guard = rt.block_on(engine.lock());
81 engine_guard.generate(&prompt_tokens, max_tokens)
82 })
83 .await
84 .map_err(|e| RuntimeError::Server(format!("task join error: {e}")))?;
85
86 if let Some(m) = &metrics {
87 m.active_requests.dec();
88 }
89
90 result
91 }
92
93 pub async fn generate_streaming(
99 &self,
100 prompt_tokens: Vec<u32>,
101 max_tokens: usize,
102 ) -> RuntimeResult<tokio::sync::mpsc::UnboundedReceiver<u32>> {
103 let permit = self
105 .concurrency_limit
106 .clone()
107 .acquire_owned()
108 .await
109 .map_err(|_| RuntimeError::Server("semaphore closed".to_string()))?;
110
111 if let Some(m) = &self.metrics {
112 m.active_requests.inc();
113 }
114
115 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
116 let engine = Arc::clone(&self.engine);
117 let metrics = self.metrics.clone();
118
119 tokio::task::spawn_blocking(move || {
120 let rt = tokio::runtime::Handle::current();
121 let mut engine_guard = rt.block_on(engine.lock());
122 let _result = engine_guard.generate_streaming(&prompt_tokens, max_tokens, &tx);
123 if let Some(m) = &metrics {
126 m.active_requests.dec();
127 }
128
129 drop(permit);
131 });
132
133 Ok(rx)
134 }
135
136 pub fn active_requests(&self) -> usize {
140 self.max_concurrent - self.concurrency_limit.available_permits()
141 }
142
143 pub fn max_concurrent(&self) -> usize {
145 self.max_concurrent
146 }
147
148 pub fn has_capacity(&self) -> bool {
150 self.concurrency_limit.available_permits() > 0
151 }
152
153 pub fn engine(&self) -> &Arc<Mutex<InferenceEngine<'static>>> {
155 &self.engine
156 }
157}
158
159#[cfg(test)]
162mod tests {
163 use super::*;
164 use crate::sampling::SamplingParams;
165 use oxibonsai_core::config::Qwen3Config;
166
167 fn make_engine() -> InferenceEngine<'static> {
168 let config = Qwen3Config::bonsai_8b();
169 InferenceEngine::new(config, SamplingParams::default(), 42)
170 }
171
172 #[test]
173 fn async_engine_creation() {
174 let engine = make_engine();
175 let async_engine = AsyncInferenceEngine::new(engine, 4);
176 assert_eq!(async_engine.max_concurrent(), 4);
177 assert_eq!(async_engine.active_requests(), 0);
178 assert!(async_engine.has_capacity());
179 }
180
181 #[test]
182 fn async_engine_min_concurrency_is_one() {
183 let engine = make_engine();
184 let async_engine = AsyncInferenceEngine::new(engine, 0);
185 assert_eq!(async_engine.max_concurrent(), 1);
186 }
187
188 #[test]
189 fn async_engine_with_metrics() {
190 let engine = make_engine();
191 let metrics = Arc::new(InferenceMetrics::new());
192 let async_engine = AsyncInferenceEngine::new(engine, 2).with_metrics(Arc::clone(&metrics));
193 assert_eq!(async_engine.max_concurrent(), 2);
194 assert!(async_engine.has_capacity());
195 }
196
197 #[test]
198 fn async_engine_capacity_tracking() {
199 let engine = make_engine();
200 let async_engine = AsyncInferenceEngine::new(engine, 3);
201 assert_eq!(async_engine.active_requests(), 0);
203 assert!(async_engine.has_capacity());
204 assert_eq!(async_engine.max_concurrent(), 3);
205 }
206
207 #[tokio::test]
208 async fn async_engine_generate_empty_prompt() {
209 let engine = make_engine();
210 let async_engine = AsyncInferenceEngine::new(engine, 1);
211 let result = async_engine.generate(vec![], 10).await;
212 assert!(result.is_ok());
213 let tokens = result.expect("should succeed");
214 assert!(tokens.is_empty());
215 }
216
217 #[tokio::test]
218 async fn async_engine_streaming_empty_prompt() {
219 let engine = make_engine();
220 let async_engine = AsyncInferenceEngine::new(engine, 1);
221 let result = async_engine.generate_streaming(vec![], 10).await;
222 assert!(result.is_ok());
223 let mut rx = result.expect("should succeed");
224 let token = rx.recv().await;
226 assert!(token.is_none());
227 }
228
229 #[tokio::test]
230 async fn async_engine_concurrency_respected() {
231 let engine = make_engine();
232 let async_engine = Arc::new(AsyncInferenceEngine::new(engine, 2));
233
234 assert!(async_engine.has_capacity());
236 assert_eq!(async_engine.active_requests(), 0);
237
238 let r1 = async_engine.generate(vec![], 1).await;
240 assert!(r1.is_ok());
241 assert!(async_engine.has_capacity());
243 assert_eq!(async_engine.active_requests(), 0);
244 }
245}