Skip to main content

rlx_runtime/
router.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Multi-protocol request router (plan #32).
17//!
18//! Borrowed from MAX's `serve/router/{openai_routes, kserve_routes,
19//! sagemaker_routes, openresponses_routes}.py`. The shape: one
20//! inference engine, multiple wire protocols layered as thin
21//! per-protocol adapters. Adding a new protocol = implementing
22//! [`WireProtocol`] for its raw request type, not editing the hot
23//! path.
24//!
25//! Today's adapter is OpenAI-shaped (chat completions +
26//! embeddings) since [`crate::mock_requests`] already defines
27//! those structs. KServe / SageMaker / OpenResponses slot in by
28//! impl'ing `WireProtocol` for their respective request types.
29//!
30//! All conversion is pure-data: no I/O, no async. The actual HTTP
31//! parsing happens upstream (in the future serving crate); this
32//! module owns the translation between wire types and the
33//! internal [`RoutedRequest`].
34
35use crate::mock_requests::{ChatCompletionRequest, EmbeddingRequest, Input};
36
37/// What kind of inference the request is asking for. Drives
38/// which downstream pipeline serves it (text-gen vs embedding
39/// pool, etc.).
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum RequestKind {
42    /// Chat completion (autoregressive token generation).
43    ChatCompletion,
44    /// Embedding (one forward pass, return pooled output).
45    Embedding,
46    /// Plain text completion (legacy `/v1/completions` shape).
47    TextCompletion,
48}
49
50/// Internal canonical request shape. Every wire protocol parses
51/// into this; downstream schedulers / engines consume only this
52/// type.
53#[derive(Debug, Clone)]
54pub struct RoutedRequest {
55    pub id: u64,
56    pub kind: RequestKind,
57    /// Pre-tokenized input (one entry per text in a batched embed
58    /// request). For chat completion: the system + user history
59    /// flattened into a single token list.
60    pub inputs: Vec<Vec<u32>>,
61    pub max_tokens: u32,
62    pub temperature: f32,
63    pub stream: bool,
64    /// Optional LoRA adapter name; passed through to
65    /// [`crate::lora_scheduler::LoraRequest::adapter`].
66    pub adapter: Option<String>,
67    /// Model name from the wire request — kept for telemetry /
68    /// validation; the actual model selection happens upstream.
69    pub model: String,
70}
71
72#[derive(Debug, Clone)]
73pub enum RouteError {
74    UnknownProtocol { name: String },
75    InvalidRequest { reason: String },
76    UnsupportedFeature { feature: &'static str },
77}
78
79impl std::fmt::Display for RouteError {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        match self {
82            Self::UnknownProtocol { name } => write!(f, "unknown protocol: {name}"),
83            Self::InvalidRequest { reason } => write!(f, "invalid request: {reason}"),
84            Self::UnsupportedFeature { feature } => write!(f, "unsupported feature: {feature}"),
85        }
86    }
87}
88
89impl std::error::Error for RouteError {}
90
91/// Adapter trait — implement once per wire protocol.
92///
93/// Implementations are tiny by design: they parse / validate the
94/// wire shape and produce a `RoutedRequest`. They do NOT touch
95/// the engine.
96pub trait WireProtocol {
97    type Request;
98    fn name(&self) -> &'static str;
99    fn parse(&self, req: Self::Request) -> Result<RoutedRequest, RouteError>;
100}
101
102/// OpenAI-style adapter — handles ChatCompletionRequest and
103/// EmbeddingRequest from [`crate::mock_requests`].
104pub struct OpenAIProtocol;
105
106impl WireProtocol for OpenAIProtocol {
107    type Request = OpenAIRequest;
108    fn name(&self) -> &'static str {
109        "openai"
110    }
111    fn parse(&self, req: OpenAIRequest) -> Result<RoutedRequest, RouteError> {
112        match req {
113            OpenAIRequest::Chat(c) => parse_chat(c),
114            OpenAIRequest::Embedding(e) => parse_embed(e),
115        }
116    }
117}
118
119#[derive(Debug, Clone)]
120pub enum OpenAIRequest {
121    Chat(ChatCompletionRequest),
122    Embedding(EmbeddingRequest),
123}
124
125fn parse_chat(req: ChatCompletionRequest) -> Result<RoutedRequest, RouteError> {
126    if req.messages.is_empty() {
127        return Err(RouteError::InvalidRequest {
128            reason: "messages cannot be empty".into(),
129        });
130    }
131    // Tokenization is the consumer's job — here we synthesize a
132    // single placeholder token-list per message. Real serving
133    // hooks call into a tokenizer before this point and feeds
134    // the resulting vec<u32> in directly.
135    let flat: Vec<u32> = req
136        .messages
137        .iter()
138        .flat_map(|m| pseudo_tokenize(&m.role, &m.content))
139        .collect();
140    Ok(RoutedRequest {
141        id: hash_request_id(&req.model, &flat),
142        kind: RequestKind::ChatCompletion,
143        inputs: vec![flat],
144        max_tokens: req.max_tokens.unwrap_or(256),
145        temperature: req.temperature.unwrap_or(1.0),
146        stream: req.stream.unwrap_or(false),
147        adapter: None, // OpenAI shape doesn't carry adapter; future ext.
148        model: req.model,
149    })
150}
151
152fn parse_embed(req: EmbeddingRequest) -> Result<RoutedRequest, RouteError> {
153    let inputs: Vec<Vec<u32>> = match req.input {
154        Input::Single(s) => vec![pseudo_tokenize("input", &s)],
155        Input::Batch(v) => v.iter().map(|s| pseudo_tokenize("input", s)).collect(),
156    };
157    if inputs.is_empty() {
158        return Err(RouteError::InvalidRequest {
159            reason: "embedding input cannot be empty".into(),
160        });
161    }
162    Ok(RoutedRequest {
163        id: hash_request_id(
164            &req.model,
165            inputs.first().map(|v| v.as_slice()).unwrap_or(&[]),
166        ),
167        kind: RequestKind::Embedding,
168        inputs,
169        max_tokens: 0, // not meaningful for embeddings
170        temperature: 0.0,
171        stream: false,
172        adapter: None,
173        model: req.model,
174    })
175}
176
177/// Placeholder tokenizer: maps each char to its u32 code point,
178/// prefixed by a role-header pseudo-token (1=system, 2=user, 3=...).
179/// Real consumers replace this with a real tokenizer; the routing
180/// layer doesn't depend on which tokenizer.
181fn pseudo_tokenize(role: &str, text: &str) -> Vec<u32> {
182    let role_token = match role {
183        "system" => 1u32,
184        "user" => 2,
185        "assistant" => 3,
186        _ => 4,
187    };
188    let mut tokens = Vec::with_capacity(text.len() + 1);
189    tokens.push(role_token);
190    tokens.extend(text.chars().map(|c| c as u32));
191    tokens
192}
193
194/// Stable-ish u64 from `(model, first_input_tokens)`. The router
195/// itself doesn't need uniqueness — the consumer can override
196/// `RoutedRequest::id` with a real UUID after parsing.
197fn hash_request_id(model: &str, tokens: &[u32]) -> u64 {
198    use std::hash::{Hash, Hasher};
199    let mut h = std::collections::hash_map::DefaultHasher::new();
200    model.hash(&mut h);
201    tokens.hash(&mut h);
202    h.finish()
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::mock_requests::*;
209
210    #[test]
211    fn openai_chat_routes_to_chat_completion() {
212        let req = ChatCompletionRequest {
213            model: "gpt-4o-mini".into(),
214            messages: vec![ChatMessage {
215                role: "user".into(),
216                content: "Hi".into(),
217            }],
218            max_tokens: Some(64),
219            temperature: Some(0.7),
220            stream: Some(false),
221        };
222        let routed = OpenAIProtocol.parse(OpenAIRequest::Chat(req)).unwrap();
223        assert_eq!(routed.kind, RequestKind::ChatCompletion);
224        assert_eq!(routed.inputs.len(), 1);
225        assert_eq!(routed.max_tokens, 64);
226        assert!((routed.temperature - 0.7).abs() < 1e-6);
227        assert_eq!(routed.model, "gpt-4o-mini");
228    }
229
230    #[test]
231    fn openai_embedding_single_string() {
232        let req = EmbeddingRequest {
233            model: "text-embedding-3-small".into(),
234            input: Input::Single("Hello".into()),
235            encoding_format: None,
236        };
237        let routed = OpenAIProtocol.parse(OpenAIRequest::Embedding(req)).unwrap();
238        assert_eq!(routed.kind, RequestKind::Embedding);
239        assert_eq!(routed.inputs.len(), 1);
240        // role(=4 since "input" isn't a known role) + 5 chars
241        assert_eq!(routed.inputs[0].len(), 6);
242    }
243
244    #[test]
245    fn openai_embedding_batch_input() {
246        let req = EmbeddingRequest {
247            model: "text-embedding-3-small".into(),
248            input: Input::Batch(vec!["a".into(), "bb".into(), "ccc".into()]),
249            encoding_format: None,
250        };
251        let routed = OpenAIProtocol.parse(OpenAIRequest::Embedding(req)).unwrap();
252        assert_eq!(routed.inputs.len(), 3);
253        assert_eq!(routed.inputs[1].len(), 3); // role + 2 chars
254    }
255
256    #[test]
257    fn empty_chat_messages_rejected() {
258        let req = ChatCompletionRequest {
259            model: "x".into(),
260            messages: vec![],
261            max_tokens: None,
262            temperature: None,
263            stream: None,
264        };
265        let err = OpenAIProtocol.parse(OpenAIRequest::Chat(req)).unwrap_err();
266        assert!(matches!(err, RouteError::InvalidRequest { .. }));
267    }
268
269    #[test]
270    fn defaults_applied_when_optional_fields_missing() {
271        let req = ChatCompletionRequest {
272            model: "m".into(),
273            messages: vec![ChatMessage {
274                role: "user".into(),
275                content: "x".into(),
276            }],
277            max_tokens: None,
278            temperature: None,
279            stream: None,
280        };
281        let routed = OpenAIProtocol.parse(OpenAIRequest::Chat(req)).unwrap();
282        assert_eq!(routed.max_tokens, 256);
283        assert_eq!(routed.temperature, 1.0);
284        assert!(!routed.stream);
285    }
286
287    #[test]
288    fn protocol_name_introspectable() {
289        assert_eq!(OpenAIProtocol.name(), "openai");
290    }
291}