Skip to main content

oxillama_server/
queue.rs

1//! Request queue types for the continuous-batching inference worker.
2//!
3//! Instead of each HTTP handler holding the engine mutex directly, every
4//! handler constructs a [`BatchRequest`] and sends it through a
5//! `tokio::sync::mpsc::Sender`.  A single background worker receives these
6//! requests one at a time and drives the `InferenceEngine`, eliminating
7//! mutex contention across concurrent requests.
8
9use std::sync::Arc;
10
11use oxillama_runtime::sampling::SamplerConfig;
12use tokio::sync::oneshot;
13
14/// Vocabulary byte table: maps token ID to its UTF-8 byte sequence.
15///
16/// Used for grammar-constrained sampling.  Wrapped in `Arc` so it can be
17/// cheaply shared between `AppState` and individual `SamplerConfig` instances.
18pub type VocabBytes = Arc<Vec<(u32, Vec<u8>)>>;
19
20/// Token usage statistics for a generation request.
21#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
22pub struct UsageStats {
23    /// Number of tokens in the prompt.
24    pub prompt_tokens: usize,
25    /// Number of tokens generated.
26    pub completion_tokens: usize,
27    /// Total tokens (prompt + completion).
28    pub total_tokens: usize,
29}
30
31/// Callback invoked for each generated token during streaming.
32///
33/// The closure runs inside the blocking worker thread, so calling
34/// `tokio::sync::mpsc::Sender::blocking_send` from within it is safe.
35pub type StreamCallback = Box<dyn FnMut(&str) + Send>;
36
37/// LoRA adapter selection for a single request.
38///
39/// Each entry is `(adapter_name, scale_multiplier)`.  The adapter must have
40/// been pre-loaded via `POST /admin/loras` and registered in `AppState::loras`.
41pub type LoraSelection = Vec<(String, f32)>;
42
43/// A single inference request dispatched to the worker task.
44pub enum BatchRequest {
45    /// Non-streaming generation: prompt → full response string.
46    Generate {
47        /// The formatted prompt to generate from.
48        prompt: String,
49        /// Maximum number of tokens to generate.
50        max_tokens: usize,
51        /// Per-request sampler configuration.
52        config: SamplerConfig,
53        /// Whether to look up and store the prompt's KV state in the prefix
54        /// cache.  When `true` (default), the worker checks for a matching
55        /// cached prefix and skips the redundant prefill if found.
56        cache_prompt: bool,
57        /// LoRA adapters to apply for this request.  Empty means no LoRA.
58        lora_selection: LoraSelection,
59        /// Channel to send the result back to the caller.
60        reply: oneshot::Sender<Result<(String, UsageStats), String>>,
61    },
62
63    /// Streaming generation: invokes `callback` for every decoded token.
64    GenerateStream {
65        /// The formatted prompt to generate from.
66        prompt: String,
67        /// Maximum number of tokens to generate.
68        max_tokens: usize,
69        /// Per-request sampler configuration.
70        config: SamplerConfig,
71        /// Whether to look up and store the prompt's KV state in the prefix
72        /// cache.
73        cache_prompt: bool,
74        /// LoRA adapters to apply for this request.  Empty means no LoRA.
75        lora_selection: LoraSelection,
76        /// Called with each token text inside the blocking worker thread.
77        callback: StreamCallback,
78        /// Channel that receives `Ok(UsageStats)` once generation is complete, or
79        /// `Err(message)` on failure.
80        reply: oneshot::Sender<Result<UsageStats, String>>,
81    },
82
83    /// Embedding computation: text → L2-normalised vector.
84    Embed {
85        /// The text to embed.
86        text: String,
87        /// Channel to return the embedding vector (or an error message).
88        reply: oneshot::Sender<Result<Vec<f32>, String>>,
89    },
90}
91
92// Implement Debug manually because StreamCallback is not Debug.
93impl std::fmt::Debug for BatchRequest {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            BatchRequest::Generate {
97                prompt,
98                max_tokens,
99                cache_prompt,
100                lora_selection,
101                ..
102            } => f
103                .debug_struct("Generate")
104                .field("prompt_len", &prompt.len())
105                .field("max_tokens", max_tokens)
106                .field("cache_prompt", cache_prompt)
107                .field("lora_count", &lora_selection.len())
108                .finish(),
109            BatchRequest::GenerateStream {
110                prompt,
111                max_tokens,
112                cache_prompt,
113                lora_selection,
114                ..
115            } => f
116                .debug_struct("GenerateStream")
117                .field("prompt_len", &prompt.len())
118                .field("max_tokens", max_tokens)
119                .field("cache_prompt", cache_prompt)
120                .field("lora_count", &lora_selection.len())
121                .finish(),
122            BatchRequest::Embed { text, .. } => f
123                .debug_struct("Embed")
124                .field("text_len", &text.len())
125                .finish(),
126        }
127    }
128}
129
130/// Metadata about the loaded model, cached at startup so route handlers do
131/// not need to hold a reference to the (now moved) engine.
132#[derive(Debug, Clone)]
133pub struct ModelMeta {
134    /// Default sampler configuration from the engine config.
135    pub default_sampler: SamplerConfig,
136    /// Vocabulary byte table for grammar-constrained sampling.
137    ///
138    /// `None` when no tokenizer is loaded (should not happen at serve time).
139    pub vocab_bytes: Option<VocabBytes>,
140    /// Hidden-state dimension for the embeddings endpoint.
141    pub hidden_size: usize,
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use tokio::sync::oneshot;
148
149    /// Round-trip a `BatchRequest::Generate` through an in-memory mpsc channel.
150    ///
151    /// This verifies that:
152    /// 1. The variant can be constructed and sent without compile errors.
153    /// 2. The oneshot reply channel delivers the result back to the caller.
154    #[tokio::test]
155    async fn test_generate_round_trip() {
156        let (tx, mut rx) = tokio::sync::mpsc::channel::<BatchRequest>(8);
157
158        let (reply_tx, reply_rx) = oneshot::channel::<Result<(String, UsageStats), String>>();
159
160        tx.send(BatchRequest::Generate {
161            prompt: "hello".to_string(),
162            max_tokens: 16,
163            config: SamplerConfig::default(),
164            cache_prompt: true,
165            lora_selection: vec![],
166            reply: reply_tx,
167        })
168        .await
169        .expect("channel should accept the request");
170
171        // Simulate a minimal worker: receive and immediately reply.
172        let req = rx.recv().await.expect("worker should receive request");
173        match req {
174            BatchRequest::Generate {
175                prompt,
176                max_tokens,
177                reply,
178                ..
179            } => {
180                assert_eq!(prompt, "hello");
181                assert_eq!(max_tokens, 16);
182                let usage = UsageStats {
183                    prompt_tokens: 1,
184                    completion_tokens: 1,
185                    total_tokens: 2,
186                };
187                reply
188                    .send(Ok(("world".to_string(), usage)))
189                    .expect("reply should succeed");
190            }
191            other => panic!("unexpected variant: {other:?}"),
192        }
193
194        let result = reply_rx.await.expect("reply future should resolve");
195        let (text, usage) = result.expect("should be Ok");
196        assert_eq!(text, "world");
197        assert_eq!(usage.total_tokens, 2);
198    }
199
200    /// Verify that the `Debug` implementation does not panic and includes
201    /// the prompt length rather than the full text (privacy / log hygiene).
202    #[test]
203    fn test_debug_does_not_expose_full_prompt() {
204        let (reply_tx, _reply_rx) = oneshot::channel::<Result<(String, UsageStats), String>>();
205        let req = BatchRequest::Generate {
206            prompt: "secret prompt contents".to_string(),
207            max_tokens: 32,
208            config: SamplerConfig::default(),
209            cache_prompt: true,
210            lora_selection: vec![],
211            reply: reply_tx,
212        };
213        let debug_str = format!("{req:?}");
214        assert!(debug_str.contains("prompt_len"));
215        assert!(!debug_str.contains("secret prompt contents"));
216    }
217}