1use async_openai::types::chat::*;
2use axum::{
3 Json,
4 extract::State,
5 http::{HeaderMap, HeaderValue, StatusCode},
6 response::{
7 IntoResponse, Response,
8 sse::{Event, Sse},
9 },
10};
11use futures_util::stream;
12use std::convert::Infallible;
13use tracing::{debug, error, info, instrument};
14
15use crate::gateway::error::GatewayError;
16use crate::gateway::payload::CachePayload;
17use crate::gateway::state::HandlerState;
18use crate::gateway::streaming::handle_streaming_request;
19use reflex::cache::{
20 BqSearchBackend, REFLEX_STATUS_HEADER, ReflexStatus, StorageLoader, TieredLookupResult,
21};
22use reflex::payload::TauqEncoder;
23use reflex::scoring::VerificationResult;
24use reflex::storage::{ArchivedCacheEntry, CacheEntry, StorageWriter};
25use reflex::vectordb::{VectorPoint, generate_point_id};
26
27#[instrument(skip(state, request, headers), fields(model = tracing::field::Empty))]
28pub async fn chat_completions_handler<B, S>(
29 State(state): State<HandlerState<B, S>>,
30 headers: HeaderMap,
31 Json(request): Json<serde_json::Value>,
32) -> Result<Response, GatewayError>
33where
34 B: BqSearchBackend + Clone + Send + Sync + 'static,
35 S: StorageLoader + StorageWriter + Clone + Send + Sync + 'static,
36{
37 validate_no_legacy_fields(&request)?;
38 let request: CreateChatCompletionRequest = serde_json::from_value(request)
39 .map_err(|e| GatewayError::InvalidRequest(format!("Invalid request schema: {}", e)))?;
40 tracing::Span::current().record("model", tracing::field::display(&request.model));
41
42 let _auth_token = headers
43 .get("Authorization")
44 .and_then(|val| val.to_str().ok())
45 .and_then(|s| s.strip_prefix("Bearer "))
46 .map(|s| s.trim().to_string());
47
48 let request_bytes = serde_json::to_vec(&request)
49 .map_err(|e| GatewayError::InvalidRequest(format!("Serialization failed: {}", e)))?;
50 let request_hash = blake3::hash(&request_bytes);
51 let request_hash_u64 = reflex::hashing::hash_to_u64(request_hash.as_bytes());
52
53 debug!(hash = %request_hash, "Processing chat completion request");
54
55 let semantic_text = semantic_text_from_request(&request);
56
57 let token = _auth_token.unwrap_or_else(|| "default".to_string());
58 let tenant_id_hash = reflex::hashing::hash_tenant_id(&token);
59
60 let stream_requested = request.stream.unwrap_or(false);
61
62 if stream_requested {
63 debug!("Streaming request received - bypassing cache");
64 if state.mock_provider {
65 debug!("Mock provider enabled - returning mock streaming response");
66 let mock_sse =
67 create_mock_streaming_response(request.model.clone(), semantic_text.clone());
68 return Ok(mock_sse.into_response());
69 }
70 let sse = handle_streaming_request::<B>(
71 state.genai_client.clone(),
72 &request.model,
73 request.clone(),
74 tenant_id_hash,
75 request_hash_u64,
76 semantic_text,
77 )
78 .await?;
79 return Ok(sse.into_response());
80 }
81
82 let request_hash_str = request_hash.to_string();
83 let tiered_result = state
84 .tiered_cache
85 .lookup_with_semantic_query(&request_hash_str, &semantic_text, tenant_id_hash)
86 .await
87 .map_err(|e| GatewayError::CacheLookupFailed(e.to_string()))?;
88
89 let cached_response = match tiered_result {
90 TieredLookupResult::HitL1(l1_result) => {
91 info!("L1 Cache Hit");
92 let archived = l1_result
93 .handle()
94 .access_archived::<ArchivedCacheEntry>()
95 .map_err(|e| GatewayError::CacheLookupFailed(e.to_string()))?;
96
97 let raw_payload = String::from_utf8_lossy(&archived.payload_blob);
98 match serde_json::from_str::<CachePayload>(&raw_payload) {
99 Ok(cache_payload) => Some((cache_payload, ReflexStatus::HitL1Exact)),
100 Err(e) => {
101 tracing::warn!("Failed to parse L1 payload: {}. Treating as miss.", e);
102 None
103 }
104 }
105 }
106 TieredLookupResult::HitL2(l2_result) => {
107 debug!(
108 candidates = l2_result.candidates().len(),
109 "L2 semantic hit, verifying..."
110 );
111
112 let mut valid_candidates = Vec::new();
113 for c in l2_result.candidates() {
114 let raw_payload = String::from_utf8_lossy(&c.entry.payload_blob);
115 if let Ok(payload) = serde_json::from_str::<CachePayload>(&raw_payload) {
116 let mut temp_entry = c.entry.clone();
117 temp_entry.payload_blob = payload.semantic_request.as_bytes().to_vec();
118 valid_candidates.push((temp_entry, c.score, payload));
119 }
120 }
121
122 let candidates_for_scoring: Vec<(CacheEntry, f32)> = valid_candidates
123 .iter()
124 .map(|(e, s, _)| (e.clone(), *s))
125 .collect();
126
127 let (verified_entry, verification_result) = state
128 .scorer
129 .verify_candidates(&semantic_text, candidates_for_scoring)
130 .map_err(GatewayError::ScoringFailed)?;
131
132 match verification_result {
133 VerificationResult::Verified { score } => {
134 info!(score = score, "L3 verification passed");
135 let entry = verified_entry.ok_or_else(|| {
136 GatewayError::InternalError(
137 "L3 verification returned Verified without an entry".to_string(),
138 )
139 })?;
140 let payload = valid_candidates
141 .iter()
142 .find(|(e, _, _)| e.context_hash == entry.context_hash)
143 .map(|(_, _, p)| p.clone())
144 .ok_or_else(|| {
145 GatewayError::InternalError(
146 "Lost track of verified payload".to_string(),
147 )
148 })?;
149
150 Some((payload, ReflexStatus::HitL3Verified))
151 }
152 VerificationResult::Rejected { top_score } => {
153 debug!(score = top_score, "L3 verification rejected");
154 None
155 }
156 VerificationResult::NoCandidates => {
157 debug!("L3 verification - no candidates");
158 None
159 }
160 }
161 }
162 TieredLookupResult::Miss => None,
163 };
164
165 if let Some((resp, status)) = cached_response {
166 return make_response(resp, status);
167 }
168
169 debug!("Cache Miss - Calling Provider");
170
171 let model = request.model.clone();
172
173 let response = if state.mock_provider {
174 let content = format!("Mock response for: {}", semantic_text);
175 let response_value = serde_json::json!({
176 "id": format!("chatcmpl-{}", uuid::Uuid::new_v4()),
177 "object": "chat.completion",
178 "created": chrono::Utc::now().timestamp() as u32,
179 "model": model.clone(),
180 "choices": [{
181 "index": 0,
182 "message": { "role": "assistant", "content": content },
183 "finish_reason": "stop"
184 }],
185 "usage": {
186 "prompt_tokens": 10,
187 "completion_tokens": 10,
188 "total_tokens": 20
189 }
190 });
191
192 serde_json::from_value::<CreateChatCompletionResponse>(response_value)
193 .map_err(|e| GatewayError::SerializationFailed(e.to_string()))?
194 } else {
195 let genai_req = crate::gateway::adapter::adapt_openai_to_genai(request.clone());
196
197 let genai_resp = state
198 .genai_client
199 .exec_chat(&model, genai_req, None)
200 .await
201 .map_err(|e| {
202 error!("Provider error: {}", e);
203 GatewayError::ProviderError("Upstream service request failed".to_string())
204 })?;
205
206 crate::gateway::adapter::adapt_genai_to_openai(genai_resp, model.clone())
207 };
208
209 let timestamp = chrono::Utc::now().timestamp();
210
211 let payload = CachePayload {
212 semantic_request: semantic_text.clone(),
213 response: response.clone(),
214 };
215 let payload_json = serde_json::to_string(&payload)
216 .map_err(|e| GatewayError::SerializationFailed(e.to_string()))?;
217
218 let embedding_f16 = state
219 .tiered_cache
220 .l2()
221 .embedder()
222 .embed(&semantic_text)
223 .map_err(|e| GatewayError::EmbeddingFailed(e.to_string()))?;
224
225 let embedding_bytes: Vec<u8> = embedding_f16.iter().flat_map(|v| v.to_le_bytes()).collect();
226
227 let cache_entry = CacheEntry {
228 tenant_id: tenant_id_hash,
229 context_hash: request_hash_u64,
230 timestamp,
231 embedding: embedding_bytes,
232 payload_blob: payload_json.into_bytes(),
233 };
234
235 let entry_id = format!("{:016x}", request_hash_u64);
236 let storage_key = format!("{}/{}.rkyv", tenant_id_hash, entry_id);
237
238 let serialized_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&cache_entry)
239 .map_err(|e| GatewayError::SerializationFailed(e.to_string()))?;
240
241 let storage = state.tiered_cache.l2().storage().clone();
242 let storage_key_for_write = storage_key.clone();
243 let mmap_handle = tokio::task::spawn_blocking(move || {
244 storage.write(&storage_key_for_write, serialized_bytes.as_ref())
245 })
246 .await
247 .map_err(|e| GatewayError::StorageError(format!("Storage write task failed: {}", e)))?
248 .map_err(|e| GatewayError::StorageError(e.to_string()))?;
249
250 let l1_key = request_hash.to_string();
251 state
252 .tiered_cache
253 .insert_l1(&l1_key, tenant_id_hash, mmap_handle);
254
255 let embedding_f32: Vec<f32> = embedding_f16.iter().map(|v| v.to_f32()).collect();
256 let vector_dim = state.tiered_cache.l2().config().vector_size;
257 spawn_index_update(
258 state.bq_client.clone(),
259 state.collection_name.clone(),
260 tenant_id_hash,
261 request_hash_u64,
262 timestamp,
263 embedding_f32,
264 storage_key,
265 vector_dim,
266 );
267
268 make_response(payload, ReflexStatus::Miss)
269}
270
271pub(crate) fn make_response(
272 payload: CachePayload,
273 status: ReflexStatus,
274) -> Result<Response, GatewayError> {
275 let payload_json = serde_json::to_value(&payload).unwrap_or_default();
276 let tauq_content = TauqEncoder::encode(&payload_json);
277
278 let message: ChatCompletionResponseMessage = serde_json::from_value(serde_json::json!({
279 "role": "assistant",
280 "content": tauq_content,
281 }))
282 .map_err(|e| GatewayError::SerializationFailed(e.to_string()))?;
283
284 let choice = ChatChoice {
285 index: 0,
286 message,
287 finish_reason: Some(FinishReason::Stop),
288 logprobs: None,
289 };
290
291 let mut wrapper = payload.response;
292 wrapper.object = "chat.completion".to_string();
293 wrapper.choices = vec![choice];
294
295 let mut headers = HeaderMap::new();
296 headers.insert(
297 REFLEX_STATUS_HEADER,
298 HeaderValue::from_static(status.as_header_value()),
299 );
300 Ok((StatusCode::OK, headers, Json(wrapper)).into_response())
301}
302
303pub(crate) fn validate_no_legacy_fields(req: &serde_json::Value) -> Result<(), GatewayError> {
304 if req.get("functions").is_some() || req.get("function_call").is_some() {
305 return Err(GatewayError::InvalidRequest(
306 "Legacy function-calling fields are not supported; use `tools`/`tool_choice`."
307 .to_string(),
308 ));
309 }
310
311 let messages = req
312 .get("messages")
313 .and_then(|m| m.as_array())
314 .ok_or_else(|| GatewayError::InvalidRequest("Missing or invalid `messages`".to_string()))?;
315
316 for m in messages {
317 if m.get("role").and_then(|r| r.as_str()) == Some("function") {
318 return Err(GatewayError::InvalidRequest(
319 "Unsupported message role: `function`.".to_string(),
320 ));
321 }
322 if m.get("role").and_then(|r| r.as_str()) == Some("assistant")
323 && m.get("function_call").is_some()
324 {
325 return Err(GatewayError::InvalidRequest(
326 "Legacy `function_call` on assistant messages is not supported; use `tool_calls`."
327 .to_string(),
328 ));
329 }
330 }
331
332 Ok(())
333}
334
335pub(crate) fn semantic_text_from_request(req: &CreateChatCompletionRequest) -> String {
397 let mut root = serde_json::Map::new();
398 root.insert("model".to_string(), serde_json::json!(req.model));
399 root.insert(
400 "messages".to_string(),
401 serde_json::to_value(&req.messages).unwrap_or_else(|_| serde_json::json!([])),
402 );
403
404 if let Some(tools) = &req.tools {
405 root.insert(
406 "tools".to_string(),
407 serde_json::to_value(tools).unwrap_or_else(|_| serde_json::json!([])),
408 );
409 }
410
411 if let Some(tool_choice) = &req.tool_choice {
412 root.insert(
413 "tool_choice".to_string(),
414 serde_json::to_value(tool_choice).unwrap_or(serde_json::Value::Null),
415 );
416 }
417
418 if let Some(response_format) = &req.response_format {
419 root.insert(
420 "response_format".to_string(),
421 serde_json::to_value(response_format).unwrap_or(serde_json::Value::Null),
422 );
423 }
424
425 serde_json::to_string(&serde_json::Value::Object(root))
426 .unwrap_or_else(|_| format!("model={} messages={}", req.model, req.messages.len()))
427}
428
429#[allow(clippy::too_many_arguments)]
430pub(crate) fn spawn_index_update<B>(
431 bq_client: B,
432 collection_name: String,
433 tenant_id: u64,
434 context_hash: u64,
435 timestamp: i64,
436 vector: Vec<f32>,
437 storage_key: String,
438 vector_dim: u64,
439) -> bool
440where
441 B: BqSearchBackend + Send + Sync + 'static,
442{
443 let point_id = generate_point_id(tenant_id, context_hash);
444
445 let point = VectorPoint {
446 id: point_id,
447 vector,
448 tenant_id,
449 context_hash,
450 timestamp,
451 storage_key: Some(storage_key),
452 };
453
454 tokio::spawn(async move {
455 if let Err(e) = bq_client
456 .ensure_collection(&collection_name, vector_dim)
457 .await
458 {
459 error!(error = %e, "Failed to ensure BQ collection");
460 return;
461 }
462
463 if let Err(e) = bq_client
464 .upsert_points(
465 &collection_name,
466 vec![point],
467 reflex::vectordb::WriteConsistency::Eventual,
468 )
469 .await
470 {
471 error!(error = %e, "Failed to upsert point to BQ index");
472 return;
473 }
474
475 debug!(
476 point_id = point_id,
477 "Successfully indexed point in BQ collection"
478 );
479 });
480
481 true
482}
483
484fn create_mock_streaming_response(
489 model: String,
490 semantic_text: String,
491) -> Sse<impl futures_util::Stream<Item = Result<Event, Infallible>> + Send> {
492 let content = format!("Mock streaming response for: {}", semantic_text);
493
494 let chunk_response = serde_json::json!({
495 "id": format!("chatcmpl-{}", uuid::Uuid::new_v4()),
496 "object": "chat.completion.chunk",
497 "created": chrono::Utc::now().timestamp() as u32,
498 "model": model,
499 "choices": [{
500 "index": 0,
501 "delta": { "role": "assistant", "content": content },
502 "finish_reason": null
503 }]
504 });
505
506 let done_response = serde_json::json!({
507 "id": format!("chatcmpl-{}", uuid::Uuid::new_v4()),
508 "object": "chat.completion.chunk",
509 "created": chrono::Utc::now().timestamp() as u32,
510 "model": model,
511 "choices": [{
512 "index": 0,
513 "delta": {},
514 "finish_reason": "stop"
515 }]
516 });
517
518 let events = vec![
519 Ok(Event::default().data(chunk_response.to_string())),
520 Ok(Event::default().data(done_response.to_string())),
521 Ok(Event::default().data("[DONE]")),
522 ];
523
524 Sse::new(stream::iter(events))
525}