reflex_server/gateway/
handler.rs

1use async_openai::types::chat::*;
2use axum::{
3    Json,
4    extract::State,
5    http::{HeaderMap, HeaderValue, StatusCode},
6    response::{
7        IntoResponse, Response,
8        sse::{Event, Sse},
9    },
10};
11use futures_util::stream;
12use std::convert::Infallible;
13use tracing::{debug, error, info, instrument};
14
15use crate::gateway::error::GatewayError;
16use crate::gateway::payload::CachePayload;
17use crate::gateway::state::HandlerState;
18use crate::gateway::streaming::handle_streaming_request;
19use reflex::cache::{
20    BqSearchBackend, REFLEX_STATUS_HEADER, ReflexStatus, StorageLoader, TieredLookupResult,
21};
22use reflex::payload::TauqEncoder;
23use reflex::scoring::VerificationResult;
24use reflex::storage::{ArchivedCacheEntry, CacheEntry, StorageWriter};
25use reflex::vectordb::{VectorPoint, generate_point_id};
26
27#[instrument(skip(state, request, headers), fields(model = tracing::field::Empty))]
28pub async fn chat_completions_handler<B, S>(
29    State(state): State<HandlerState<B, S>>,
30    headers: HeaderMap,
31    Json(request): Json<serde_json::Value>,
32) -> Result<Response, GatewayError>
33where
34    B: BqSearchBackend + Clone + Send + Sync + 'static,
35    S: StorageLoader + StorageWriter + Clone + Send + Sync + 'static,
36{
37    validate_no_legacy_fields(&request)?;
38    let request: CreateChatCompletionRequest = serde_json::from_value(request)
39        .map_err(|e| GatewayError::InvalidRequest(format!("Invalid request schema: {}", e)))?;
40    tracing::Span::current().record("model", tracing::field::display(&request.model));
41
42    let _auth_token = headers
43        .get("Authorization")
44        .and_then(|val| val.to_str().ok())
45        .and_then(|s| s.strip_prefix("Bearer "))
46        .map(|s| s.trim().to_string());
47
48    let request_bytes = serde_json::to_vec(&request)
49        .map_err(|e| GatewayError::InvalidRequest(format!("Serialization failed: {}", e)))?;
50    let request_hash = blake3::hash(&request_bytes);
51    let request_hash_u64 = reflex::hashing::hash_to_u64(request_hash.as_bytes());
52
53    debug!(hash = %request_hash, "Processing chat completion request");
54
55    let semantic_text = semantic_text_from_request(&request);
56
57    let token = _auth_token.unwrap_or_else(|| "default".to_string());
58    let tenant_id_hash = reflex::hashing::hash_tenant_id(&token);
59
60    let stream_requested = request.stream.unwrap_or(false);
61
62    if stream_requested {
63        debug!("Streaming request received - bypassing cache");
64        if state.mock_provider {
65            debug!("Mock provider enabled - returning mock streaming response");
66            let mock_sse =
67                create_mock_streaming_response(request.model.clone(), semantic_text.clone());
68            return Ok(mock_sse.into_response());
69        }
70        let sse = handle_streaming_request::<B>(
71            state.genai_client.clone(),
72            &request.model,
73            request.clone(),
74            tenant_id_hash,
75            request_hash_u64,
76            semantic_text,
77        )
78        .await?;
79        return Ok(sse.into_response());
80    }
81
82    let request_hash_str = request_hash.to_string();
83    let tiered_result = state
84        .tiered_cache
85        .lookup_with_semantic_query(&request_hash_str, &semantic_text, tenant_id_hash)
86        .await
87        .map_err(|e| GatewayError::CacheLookupFailed(e.to_string()))?;
88
89    let cached_response = match tiered_result {
90        TieredLookupResult::HitL1(l1_result) => {
91            info!("L1 Cache Hit");
92            let archived = l1_result
93                .handle()
94                .access_archived::<ArchivedCacheEntry>()
95                .map_err(|e| GatewayError::CacheLookupFailed(e.to_string()))?;
96
97            let raw_payload = String::from_utf8_lossy(&archived.payload_blob);
98            match serde_json::from_str::<CachePayload>(&raw_payload) {
99                Ok(cache_payload) => Some((cache_payload, ReflexStatus::HitL1Exact)),
100                Err(e) => {
101                    tracing::warn!("Failed to parse L1 payload: {}. Treating as miss.", e);
102                    None
103                }
104            }
105        }
106        TieredLookupResult::HitL2(l2_result) => {
107            debug!(
108                candidates = l2_result.candidates().len(),
109                "L2 semantic hit, verifying..."
110            );
111
112            let mut valid_candidates = Vec::new();
113            for c in l2_result.candidates() {
114                let raw_payload = String::from_utf8_lossy(&c.entry.payload_blob);
115                if let Ok(payload) = serde_json::from_str::<CachePayload>(&raw_payload) {
116                    let mut temp_entry = c.entry.clone();
117                    temp_entry.payload_blob = payload.semantic_request.as_bytes().to_vec();
118                    valid_candidates.push((temp_entry, c.score, payload));
119                }
120            }
121
122            let candidates_for_scoring: Vec<(CacheEntry, f32)> = valid_candidates
123                .iter()
124                .map(|(e, s, _)| (e.clone(), *s))
125                .collect();
126
127            let (verified_entry, verification_result) = state
128                .scorer
129                .verify_candidates(&semantic_text, candidates_for_scoring)
130                .map_err(GatewayError::ScoringFailed)?;
131
132            match verification_result {
133                VerificationResult::Verified { score } => {
134                    info!(score = score, "L3 verification passed");
135                    let entry = verified_entry.ok_or_else(|| {
136                        GatewayError::InternalError(
137                            "L3 verification returned Verified without an entry".to_string(),
138                        )
139                    })?;
140                    let payload = valid_candidates
141                        .iter()
142                        .find(|(e, _, _)| e.context_hash == entry.context_hash)
143                        .map(|(_, _, p)| p.clone())
144                        .ok_or_else(|| {
145                            GatewayError::InternalError(
146                                "Lost track of verified payload".to_string(),
147                            )
148                        })?;
149
150                    Some((payload, ReflexStatus::HitL3Verified))
151                }
152                VerificationResult::Rejected { top_score } => {
153                    debug!(score = top_score, "L3 verification rejected");
154                    None
155                }
156                VerificationResult::NoCandidates => {
157                    debug!("L3 verification - no candidates");
158                    None
159                }
160            }
161        }
162        TieredLookupResult::Miss => None,
163    };
164
165    if let Some((resp, status)) = cached_response {
166        return make_response(resp, status);
167    }
168
169    debug!("Cache Miss - Calling Provider");
170
171    let model = request.model.clone();
172
173    let response = if state.mock_provider {
174        let content = format!("Mock response for: {}", semantic_text);
175        let response_value = serde_json::json!({
176            "id": format!("chatcmpl-{}", uuid::Uuid::new_v4()),
177            "object": "chat.completion",
178            "created": chrono::Utc::now().timestamp() as u32,
179            "model": model.clone(),
180            "choices": [{
181                "index": 0,
182                "message": { "role": "assistant", "content": content },
183                "finish_reason": "stop"
184            }],
185            "usage": {
186                "prompt_tokens": 10,
187                "completion_tokens": 10,
188                "total_tokens": 20
189            }
190        });
191
192        serde_json::from_value::<CreateChatCompletionResponse>(response_value)
193            .map_err(|e| GatewayError::SerializationFailed(e.to_string()))?
194    } else {
195        let genai_req = crate::gateway::adapter::adapt_openai_to_genai(request.clone());
196
197        let genai_resp = state
198            .genai_client
199            .exec_chat(&model, genai_req, None)
200            .await
201            .map_err(|e| {
202                error!("Provider error: {}", e);
203                GatewayError::ProviderError("Upstream service request failed".to_string())
204            })?;
205
206        crate::gateway::adapter::adapt_genai_to_openai(genai_resp, model.clone())
207    };
208
209    let timestamp = chrono::Utc::now().timestamp();
210
211    let payload = CachePayload {
212        semantic_request: semantic_text.clone(),
213        response: response.clone(),
214    };
215    let payload_json = serde_json::to_string(&payload)
216        .map_err(|e| GatewayError::SerializationFailed(e.to_string()))?;
217
218    let embedding_f16 = state
219        .tiered_cache
220        .l2()
221        .embedder()
222        .embed(&semantic_text)
223        .map_err(|e| GatewayError::EmbeddingFailed(e.to_string()))?;
224
225    let embedding_bytes: Vec<u8> = embedding_f16.iter().flat_map(|v| v.to_le_bytes()).collect();
226
227    let cache_entry = CacheEntry {
228        tenant_id: tenant_id_hash,
229        context_hash: request_hash_u64,
230        timestamp,
231        embedding: embedding_bytes,
232        payload_blob: payload_json.into_bytes(),
233    };
234
235    let entry_id = format!("{:016x}", request_hash_u64);
236    let storage_key = format!("{}/{}.rkyv", tenant_id_hash, entry_id);
237
238    let serialized_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&cache_entry)
239        .map_err(|e| GatewayError::SerializationFailed(e.to_string()))?;
240
241    let storage = state.tiered_cache.l2().storage().clone();
242    let storage_key_for_write = storage_key.clone();
243    let mmap_handle = tokio::task::spawn_blocking(move || {
244        storage.write(&storage_key_for_write, serialized_bytes.as_ref())
245    })
246    .await
247    .map_err(|e| GatewayError::StorageError(format!("Storage write task failed: {}", e)))?
248    .map_err(|e| GatewayError::StorageError(e.to_string()))?;
249
250    let l1_key = request_hash.to_string();
251    state
252        .tiered_cache
253        .insert_l1(&l1_key, tenant_id_hash, mmap_handle);
254
255    let embedding_f32: Vec<f32> = embedding_f16.iter().map(|v| v.to_f32()).collect();
256    let vector_dim = state.tiered_cache.l2().config().vector_size;
257    spawn_index_update(
258        state.bq_client.clone(),
259        state.collection_name.clone(),
260        tenant_id_hash,
261        request_hash_u64,
262        timestamp,
263        embedding_f32,
264        storage_key,
265        vector_dim,
266    );
267
268    make_response(payload, ReflexStatus::Miss)
269}
270
271pub(crate) fn make_response(
272    payload: CachePayload,
273    status: ReflexStatus,
274) -> Result<Response, GatewayError> {
275    let payload_json = serde_json::to_value(&payload).unwrap_or_default();
276    let tauq_content = TauqEncoder::encode(&payload_json);
277
278    let message: ChatCompletionResponseMessage = serde_json::from_value(serde_json::json!({
279        "role": "assistant",
280        "content": tauq_content,
281    }))
282    .map_err(|e| GatewayError::SerializationFailed(e.to_string()))?;
283
284    let choice = ChatChoice {
285        index: 0,
286        message,
287        finish_reason: Some(FinishReason::Stop),
288        logprobs: None,
289    };
290
291    let mut wrapper = payload.response;
292    wrapper.object = "chat.completion".to_string();
293    wrapper.choices = vec![choice];
294
295    let mut headers = HeaderMap::new();
296    headers.insert(
297        REFLEX_STATUS_HEADER,
298        HeaderValue::from_static(status.as_header_value()),
299    );
300    Ok((StatusCode::OK, headers, Json(wrapper)).into_response())
301}
302
303pub(crate) fn validate_no_legacy_fields(req: &serde_json::Value) -> Result<(), GatewayError> {
304    if req.get("functions").is_some() || req.get("function_call").is_some() {
305        return Err(GatewayError::InvalidRequest(
306            "Legacy function-calling fields are not supported; use `tools`/`tool_choice`."
307                .to_string(),
308        ));
309    }
310
311    let messages = req
312        .get("messages")
313        .and_then(|m| m.as_array())
314        .ok_or_else(|| GatewayError::InvalidRequest("Missing or invalid `messages`".to_string()))?;
315
316    for m in messages {
317        if m.get("role").and_then(|r| r.as_str()) == Some("function") {
318            return Err(GatewayError::InvalidRequest(
319                "Unsupported message role: `function`.".to_string(),
320            ));
321        }
322        if m.get("role").and_then(|r| r.as_str()) == Some("assistant")
323            && m.get("function_call").is_some()
324        {
325            return Err(GatewayError::InvalidRequest(
326                "Legacy `function_call` on assistant messages is not supported; use `tool_calls`."
327                    .to_string(),
328            ));
329        }
330    }
331
332    Ok(())
333}
334
335/// Builds a semantic key from a chat completion request for cache lookups.
336///
337/// # Semantic Cache Design Decision
338///
339/// This function deliberately extracts only a subset of request fields to form the
340/// cache key. This is an intentional product decision that prioritizes cache hit rates
341/// over exact output distribution matching.
342///
343/// ## Included Parameters (affect semantic meaning)
344///
345/// These parameters are included because they fundamentally change what the model
346/// should produce:
347///
348/// - **`model`**: Different models produce different outputs
349/// - **`messages`**: The conversation context determines the response content
350/// - **`tools`**: Available function definitions change model behavior
351/// - **`tool_choice`**: Controls whether/which tools the model should use
352/// - **`response_format`**: Structured output schemas (e.g., JSON mode) change output shape
353///
354/// ## Excluded Parameters (sampling/generation controls)
355///
356/// These parameters are deliberately omitted from the cache key:
357///
358/// - **`temperature`**: Controls randomness (0.0 = deterministic, 2.0 = creative)
359/// - **`max_tokens`** / **`max_completion_tokens`**: Limits response length
360/// - **`top_p`**: Nucleus sampling threshold
361/// - **`frequency_penalty`**: Penalizes repeated tokens
362/// - **`presence_penalty`**: Penalizes tokens that have appeared at all
363/// - **`n`**: Number of completions to generate
364/// - **`stop`**: Stop sequences
365/// - **`seed`**: Random seed for reproducibility
366/// - **`logprobs`** / **`top_logprobs`**: Log probability settings
367/// - **`logit_bias`**: Token probability adjustments
368/// - **`user`**: End-user identifier (for abuse tracking, not semantics)
369/// - **`stream`**: Delivery mechanism, not content
370/// - **`stream_options`**: Stream configuration
371/// - **`service_tier`**: Infrastructure routing
372/// - **`store`**: Whether to store for fine-tuning
373/// - **`metadata`**: Arbitrary metadata
374/// - **`parallel_tool_calls`**: Execution strategy
375///
376/// ## Tradeoff
377///
378/// **Benefit**: Higher cache hit rates. Requests like "summarize X" with temperature=0.7
379/// will return the same cached response as temperature=0.3, avoiding redundant LLM calls.
380///
381/// **Cost**: The cached response may not reflect the exact output distribution the caller
382/// requested. A request with `max_tokens=100` might return a cached 500-token response.
383/// Callers expecting low-temperature determinism may receive a response generated with
384/// higher temperature.
385///
386/// ## Rationale
387///
388/// For most use cases, the semantic content of the request (what is being asked) matters
389/// more than the sampling parameters (how the response should be generated). Two users
390/// asking the same question with different temperature settings are likely satisfied by
391/// the same high-quality cached answer. This design assumes that cache freshness and
392/// hit rates outweigh perfect sampling parameter fidelity.
393///
394/// If exact parameter matching is required for your use case, consider implementing
395/// a separate cache key strategy or bypassing the semantic cache entirely.
396pub(crate) fn semantic_text_from_request(req: &CreateChatCompletionRequest) -> String {
397    let mut root = serde_json::Map::new();
398    root.insert("model".to_string(), serde_json::json!(req.model));
399    root.insert(
400        "messages".to_string(),
401        serde_json::to_value(&req.messages).unwrap_or_else(|_| serde_json::json!([])),
402    );
403
404    if let Some(tools) = &req.tools {
405        root.insert(
406            "tools".to_string(),
407            serde_json::to_value(tools).unwrap_or_else(|_| serde_json::json!([])),
408        );
409    }
410
411    if let Some(tool_choice) = &req.tool_choice {
412        root.insert(
413            "tool_choice".to_string(),
414            serde_json::to_value(tool_choice).unwrap_or(serde_json::Value::Null),
415        );
416    }
417
418    if let Some(response_format) = &req.response_format {
419        root.insert(
420            "response_format".to_string(),
421            serde_json::to_value(response_format).unwrap_or(serde_json::Value::Null),
422        );
423    }
424
425    serde_json::to_string(&serde_json::Value::Object(root))
426        .unwrap_or_else(|_| format!("model={} messages={}", req.model, req.messages.len()))
427}
428
429#[allow(clippy::too_many_arguments)]
430pub(crate) fn spawn_index_update<B>(
431    bq_client: B,
432    collection_name: String,
433    tenant_id: u64,
434    context_hash: u64,
435    timestamp: i64,
436    vector: Vec<f32>,
437    storage_key: String,
438    vector_dim: u64,
439) -> bool
440where
441    B: BqSearchBackend + Send + Sync + 'static,
442{
443    let point_id = generate_point_id(tenant_id, context_hash);
444
445    let point = VectorPoint {
446        id: point_id,
447        vector,
448        tenant_id,
449        context_hash,
450        timestamp,
451        storage_key: Some(storage_key),
452    };
453
454    tokio::spawn(async move {
455        if let Err(e) = bq_client
456            .ensure_collection(&collection_name, vector_dim)
457            .await
458        {
459            error!(error = %e, "Failed to ensure BQ collection");
460            return;
461        }
462
463        if let Err(e) = bq_client
464            .upsert_points(
465                &collection_name,
466                vec![point],
467                reflex::vectordb::WriteConsistency::Eventual,
468            )
469            .await
470        {
471            error!(error = %e, "Failed to upsert point to BQ index");
472            return;
473        }
474
475        debug!(
476            point_id = point_id,
477            "Successfully indexed point in BQ collection"
478        );
479    });
480
481    true
482}
483
484/// Creates a mock SSE streaming response for testing purposes.
485///
486/// This function generates a simple SSE stream that emits a single mock chunk
487/// followed by a `[DONE]` marker, mimicking the OpenAI streaming response format.
488fn create_mock_streaming_response(
489    model: String,
490    semantic_text: String,
491) -> Sse<impl futures_util::Stream<Item = Result<Event, Infallible>> + Send> {
492    let content = format!("Mock streaming response for: {}", semantic_text);
493
494    let chunk_response = serde_json::json!({
495        "id": format!("chatcmpl-{}", uuid::Uuid::new_v4()),
496        "object": "chat.completion.chunk",
497        "created": chrono::Utc::now().timestamp() as u32,
498        "model": model,
499        "choices": [{
500            "index": 0,
501            "delta": { "role": "assistant", "content": content },
502            "finish_reason": null
503        }]
504    });
505
506    let done_response = serde_json::json!({
507        "id": format!("chatcmpl-{}", uuid::Uuid::new_v4()),
508        "object": "chat.completion.chunk",
509        "created": chrono::Utc::now().timestamp() as u32,
510        "model": model,
511        "choices": [{
512            "index": 0,
513            "delta": {},
514            "finish_reason": "stop"
515        }]
516    });
517
518    let events = vec![
519        Ok(Event::default().data(chunk_response.to_string())),
520        Ok(Event::default().data(done_response.to_string())),
521        Ok(Event::default().data("[DONE]")),
522    ];
523
524    Sse::new(stream::iter(events))
525}