1use crate::mock_requests::{ChatCompletionRequest, EmbeddingRequest, Input};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum RequestKind {
42 ChatCompletion,
44 Embedding,
46 TextCompletion,
48}
49
50#[derive(Debug, Clone)]
54pub struct RoutedRequest {
55 pub id: u64,
56 pub kind: RequestKind,
57 pub inputs: Vec<Vec<u32>>,
61 pub max_tokens: u32,
62 pub temperature: f32,
63 pub stream: bool,
64 pub adapter: Option<String>,
67 pub model: String,
70}
71
72#[derive(Debug, Clone)]
73pub enum RouteError {
74 UnknownProtocol { name: String },
75 InvalidRequest { reason: String },
76 UnsupportedFeature { feature: &'static str },
77}
78
79impl std::fmt::Display for RouteError {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 Self::UnknownProtocol { name } => write!(f, "unknown protocol: {name}"),
83 Self::InvalidRequest { reason } => write!(f, "invalid request: {reason}"),
84 Self::UnsupportedFeature { feature } => write!(f, "unsupported feature: {feature}"),
85 }
86 }
87}
88
89impl std::error::Error for RouteError {}
90
91pub trait WireProtocol {
97 type Request;
98 fn name(&self) -> &'static str;
99 fn parse(&self, req: Self::Request) -> Result<RoutedRequest, RouteError>;
100}
101
102pub struct OpenAIProtocol;
105
106impl WireProtocol for OpenAIProtocol {
107 type Request = OpenAIRequest;
108 fn name(&self) -> &'static str {
109 "openai"
110 }
111 fn parse(&self, req: OpenAIRequest) -> Result<RoutedRequest, RouteError> {
112 match req {
113 OpenAIRequest::Chat(c) => parse_chat(c),
114 OpenAIRequest::Embedding(e) => parse_embed(e),
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
120pub enum OpenAIRequest {
121 Chat(ChatCompletionRequest),
122 Embedding(EmbeddingRequest),
123}
124
125fn parse_chat(req: ChatCompletionRequest) -> Result<RoutedRequest, RouteError> {
126 if req.messages.is_empty() {
127 return Err(RouteError::InvalidRequest {
128 reason: "messages cannot be empty".into(),
129 });
130 }
131 let flat: Vec<u32> = req
136 .messages
137 .iter()
138 .flat_map(|m| pseudo_tokenize(&m.role, &m.content))
139 .collect();
140 Ok(RoutedRequest {
141 id: hash_request_id(&req.model, &flat),
142 kind: RequestKind::ChatCompletion,
143 inputs: vec![flat],
144 max_tokens: req.max_tokens.unwrap_or(256),
145 temperature: req.temperature.unwrap_or(1.0),
146 stream: req.stream.unwrap_or(false),
147 adapter: None, model: req.model,
149 })
150}
151
152fn parse_embed(req: EmbeddingRequest) -> Result<RoutedRequest, RouteError> {
153 let inputs: Vec<Vec<u32>> = match req.input {
154 Input::Single(s) => vec![pseudo_tokenize("input", &s)],
155 Input::Batch(v) => v.iter().map(|s| pseudo_tokenize("input", s)).collect(),
156 };
157 if inputs.is_empty() {
158 return Err(RouteError::InvalidRequest {
159 reason: "embedding input cannot be empty".into(),
160 });
161 }
162 Ok(RoutedRequest {
163 id: hash_request_id(
164 &req.model,
165 inputs.first().map(|v| v.as_slice()).unwrap_or(&[]),
166 ),
167 kind: RequestKind::Embedding,
168 inputs,
169 max_tokens: 0, temperature: 0.0,
171 stream: false,
172 adapter: None,
173 model: req.model,
174 })
175}
176
177fn pseudo_tokenize(role: &str, text: &str) -> Vec<u32> {
182 let role_token = match role {
183 "system" => 1u32,
184 "user" => 2,
185 "assistant" => 3,
186 _ => 4,
187 };
188 let mut tokens = Vec::with_capacity(text.len() + 1);
189 tokens.push(role_token);
190 tokens.extend(text.chars().map(|c| c as u32));
191 tokens
192}
193
194fn hash_request_id(model: &str, tokens: &[u32]) -> u64 {
198 use std::hash::{Hash, Hasher};
199 let mut h = std::collections::hash_map::DefaultHasher::new();
200 model.hash(&mut h);
201 tokens.hash(&mut h);
202 h.finish()
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::mock_requests::*;
209
210 #[test]
211 fn openai_chat_routes_to_chat_completion() {
212 let req = ChatCompletionRequest {
213 model: "gpt-4o-mini".into(),
214 messages: vec![ChatMessage {
215 role: "user".into(),
216 content: "Hi".into(),
217 }],
218 max_tokens: Some(64),
219 temperature: Some(0.7),
220 stream: Some(false),
221 };
222 let routed = OpenAIProtocol.parse(OpenAIRequest::Chat(req)).unwrap();
223 assert_eq!(routed.kind, RequestKind::ChatCompletion);
224 assert_eq!(routed.inputs.len(), 1);
225 assert_eq!(routed.max_tokens, 64);
226 assert!((routed.temperature - 0.7).abs() < 1e-6);
227 assert_eq!(routed.model, "gpt-4o-mini");
228 }
229
230 #[test]
231 fn openai_embedding_single_string() {
232 let req = EmbeddingRequest {
233 model: "text-embedding-3-small".into(),
234 input: Input::Single("Hello".into()),
235 encoding_format: None,
236 };
237 let routed = OpenAIProtocol.parse(OpenAIRequest::Embedding(req)).unwrap();
238 assert_eq!(routed.kind, RequestKind::Embedding);
239 assert_eq!(routed.inputs.len(), 1);
240 assert_eq!(routed.inputs[0].len(), 6);
242 }
243
244 #[test]
245 fn openai_embedding_batch_input() {
246 let req = EmbeddingRequest {
247 model: "text-embedding-3-small".into(),
248 input: Input::Batch(vec!["a".into(), "bb".into(), "ccc".into()]),
249 encoding_format: None,
250 };
251 let routed = OpenAIProtocol.parse(OpenAIRequest::Embedding(req)).unwrap();
252 assert_eq!(routed.inputs.len(), 3);
253 assert_eq!(routed.inputs[1].len(), 3); }
255
256 #[test]
257 fn empty_chat_messages_rejected() {
258 let req = ChatCompletionRequest {
259 model: "x".into(),
260 messages: vec![],
261 max_tokens: None,
262 temperature: None,
263 stream: None,
264 };
265 let err = OpenAIProtocol.parse(OpenAIRequest::Chat(req)).unwrap_err();
266 assert!(matches!(err, RouteError::InvalidRequest { .. }));
267 }
268
269 #[test]
270 fn defaults_applied_when_optional_fields_missing() {
271 let req = ChatCompletionRequest {
272 model: "m".into(),
273 messages: vec![ChatMessage {
274 role: "user".into(),
275 content: "x".into(),
276 }],
277 max_tokens: None,
278 temperature: None,
279 stream: None,
280 };
281 let routed = OpenAIProtocol.parse(OpenAIRequest::Chat(req)).unwrap();
282 assert_eq!(routed.max_tokens, 256);
283 assert_eq!(routed.temperature, 1.0);
284 assert!(!routed.stream);
285 }
286
287 #[test]
288 fn protocol_name_introspectable() {
289 assert_eq!(OpenAIProtocol.name(), "openai");
290 }
291}