oxillama_server/queue.rs
1//! Request queue types for the continuous-batching inference worker.
2//!
3//! Instead of each HTTP handler holding the engine mutex directly, every
4//! handler constructs a [`BatchRequest`] and sends it through a
5//! `tokio::sync::mpsc::Sender`. A single background worker receives these
6//! requests one at a time and drives the `InferenceEngine`, eliminating
7//! mutex contention across concurrent requests.
8
9use std::sync::Arc;
10
11use oxillama_runtime::sampling::SamplerConfig;
12use tokio::sync::oneshot;
13
14/// Vocabulary byte table: maps token ID to its UTF-8 byte sequence.
15///
16/// Used for grammar-constrained sampling. Wrapped in `Arc` so it can be
17/// cheaply shared between `AppState` and individual `SamplerConfig` instances.
18pub type VocabBytes = Arc<Vec<(u32, Vec<u8>)>>;
19
20/// Token usage statistics for a generation request.
21#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
22pub struct UsageStats {
23 /// Number of tokens in the prompt.
24 pub prompt_tokens: usize,
25 /// Number of tokens generated.
26 pub completion_tokens: usize,
27 /// Total tokens (prompt + completion).
28 pub total_tokens: usize,
29}
30
31/// Callback invoked for each generated token during streaming.
32///
33/// The closure runs inside the blocking worker thread, so calling
34/// `tokio::sync::mpsc::Sender::blocking_send` from within it is safe.
35pub type StreamCallback = Box<dyn FnMut(&str) + Send>;
36
37/// LoRA adapter selection for a single request.
38///
39/// Each entry is `(adapter_name, scale_multiplier)`. The adapter must have
40/// been pre-loaded via `POST /admin/loras` and registered in `AppState::loras`.
41pub type LoraSelection = Vec<(String, f32)>;
42
43/// A single inference request dispatched to the worker task.
44pub enum BatchRequest {
45 /// Non-streaming generation: prompt → full response string.
46 Generate {
47 /// The formatted prompt to generate from.
48 prompt: String,
49 /// Maximum number of tokens to generate.
50 max_tokens: usize,
51 /// Per-request sampler configuration.
52 config: SamplerConfig,
53 /// Whether to look up and store the prompt's KV state in the prefix
54 /// cache. When `true` (default), the worker checks for a matching
55 /// cached prefix and skips the redundant prefill if found.
56 cache_prompt: bool,
57 /// LoRA adapters to apply for this request. Empty means no LoRA.
58 lora_selection: LoraSelection,
59 /// Channel to send the result back to the caller.
60 reply: oneshot::Sender<Result<(String, UsageStats), String>>,
61 },
62
63 /// Streaming generation: invokes `callback` for every decoded token.
64 GenerateStream {
65 /// The formatted prompt to generate from.
66 prompt: String,
67 /// Maximum number of tokens to generate.
68 max_tokens: usize,
69 /// Per-request sampler configuration.
70 config: SamplerConfig,
71 /// Whether to look up and store the prompt's KV state in the prefix
72 /// cache.
73 cache_prompt: bool,
74 /// LoRA adapters to apply for this request. Empty means no LoRA.
75 lora_selection: LoraSelection,
76 /// Called with each token text inside the blocking worker thread.
77 callback: StreamCallback,
78 /// Channel that receives `Ok(UsageStats)` once generation is complete, or
79 /// `Err(message)` on failure.
80 reply: oneshot::Sender<Result<UsageStats, String>>,
81 },
82
83 /// Embedding computation: text → L2-normalised vector.
84 Embed {
85 /// The text to embed.
86 text: String,
87 /// Channel to return the embedding vector (or an error message).
88 reply: oneshot::Sender<Result<Vec<f32>, String>>,
89 },
90}
91
92// Implement Debug manually because StreamCallback is not Debug.
93impl std::fmt::Debug for BatchRequest {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 BatchRequest::Generate {
97 prompt,
98 max_tokens,
99 cache_prompt,
100 lora_selection,
101 ..
102 } => f
103 .debug_struct("Generate")
104 .field("prompt_len", &prompt.len())
105 .field("max_tokens", max_tokens)
106 .field("cache_prompt", cache_prompt)
107 .field("lora_count", &lora_selection.len())
108 .finish(),
109 BatchRequest::GenerateStream {
110 prompt,
111 max_tokens,
112 cache_prompt,
113 lora_selection,
114 ..
115 } => f
116 .debug_struct("GenerateStream")
117 .field("prompt_len", &prompt.len())
118 .field("max_tokens", max_tokens)
119 .field("cache_prompt", cache_prompt)
120 .field("lora_count", &lora_selection.len())
121 .finish(),
122 BatchRequest::Embed { text, .. } => f
123 .debug_struct("Embed")
124 .field("text_len", &text.len())
125 .finish(),
126 }
127 }
128}
129
130/// Metadata about the loaded model, cached at startup so route handlers do
131/// not need to hold a reference to the (now moved) engine.
132#[derive(Debug, Clone)]
133pub struct ModelMeta {
134 /// Default sampler configuration from the engine config.
135 pub default_sampler: SamplerConfig,
136 /// Vocabulary byte table for grammar-constrained sampling.
137 ///
138 /// `None` when no tokenizer is loaded (should not happen at serve time).
139 pub vocab_bytes: Option<VocabBytes>,
140 /// Hidden-state dimension for the embeddings endpoint.
141 pub hidden_size: usize,
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use tokio::sync::oneshot;
148
149 /// Round-trip a `BatchRequest::Generate` through an in-memory mpsc channel.
150 ///
151 /// This verifies that:
152 /// 1. The variant can be constructed and sent without compile errors.
153 /// 2. The oneshot reply channel delivers the result back to the caller.
154 #[tokio::test]
155 async fn test_generate_round_trip() {
156 let (tx, mut rx) = tokio::sync::mpsc::channel::<BatchRequest>(8);
157
158 let (reply_tx, reply_rx) = oneshot::channel::<Result<(String, UsageStats), String>>();
159
160 tx.send(BatchRequest::Generate {
161 prompt: "hello".to_string(),
162 max_tokens: 16,
163 config: SamplerConfig::default(),
164 cache_prompt: true,
165 lora_selection: vec![],
166 reply: reply_tx,
167 })
168 .await
169 .expect("channel should accept the request");
170
171 // Simulate a minimal worker: receive and immediately reply.
172 let req = rx.recv().await.expect("worker should receive request");
173 match req {
174 BatchRequest::Generate {
175 prompt,
176 max_tokens,
177 reply,
178 ..
179 } => {
180 assert_eq!(prompt, "hello");
181 assert_eq!(max_tokens, 16);
182 let usage = UsageStats {
183 prompt_tokens: 1,
184 completion_tokens: 1,
185 total_tokens: 2,
186 };
187 reply
188 .send(Ok(("world".to_string(), usage)))
189 .expect("reply should succeed");
190 }
191 other => panic!("unexpected variant: {other:?}"),
192 }
193
194 let result = reply_rx.await.expect("reply future should resolve");
195 let (text, usage) = result.expect("should be Ok");
196 assert_eq!(text, "world");
197 assert_eq!(usage.total_tokens, 2);
198 }
199
200 /// Verify that the `Debug` implementation does not panic and includes
201 /// the prompt length rather than the full text (privacy / log hygiene).
202 #[test]
203 fn test_debug_does_not_expose_full_prompt() {
204 let (reply_tx, _reply_rx) = oneshot::channel::<Result<(String, UsageStats), String>>();
205 let req = BatchRequest::Generate {
206 prompt: "secret prompt contents".to_string(),
207 max_tokens: 32,
208 config: SamplerConfig::default(),
209 cache_prompt: true,
210 lora_selection: vec![],
211 reply: reply_tx,
212 };
213 let debug_str = format!("{req:?}");
214 assert!(debug_str.contains("prompt_len"));
215 assert!(!debug_str.contains("secret prompt contents"));
216 }
217}