reflex_server/gateway/
streaming.rs1use async_openai::types::chat::{
2 ChatChoiceStream, ChatCompletionStreamResponseDelta, CreateChatCompletionRequest,
3 CreateChatCompletionStreamResponse,
4};
5use axum::response::sse::{Event, Sse};
6use futures_util::stream::{Stream, StreamExt};
7use genai::Client;
8use genai::chat::ChatStreamEvent;
9use std::convert::Infallible;
10use std::sync::Arc;
11use tokio::sync::Mutex;
12use tracing::error;
13
14use crate::gateway::adapter::adapt_openai_to_genai;
15use crate::gateway::error::GatewayError;
16use reflex::cache::BqSearchBackend;
17
18pub async fn handle_streaming_request<B>(
47 client: Client,
48 model: &str,
49 request: CreateChatCompletionRequest,
50 _tenant_id_hash: u64,
51 _context_hash: u64,
52 _semantic_text: String,
53) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>> + Send + 'static>, GatewayError>
54where
55 B: BqSearchBackend + Clone + Send + Sync + 'static,
56{
57 let genai_req = adapt_openai_to_genai(request.clone());
58 let model_owned = model.to_string();
59
60 let chat_stream_resp = client
61 .exec_chat_stream(&model_owned, genai_req, None)
62 .await
63 .map_err(|e| {
64 error!("Provider stream init error: {}", e);
65 GatewayError::ProviderError("Upstream service stream init failed".to_string())
66 })?;
67
68 let stream = chat_stream_resp.stream;
69
70 let _accumulated_content = Arc::new(Mutex::new(String::new()));
71
72 let event_stream = stream.map(move |result| match result {
73 Ok(ChatStreamEvent::Start) => Ok(Event::default().comment("start")),
74 Ok(ChatStreamEvent::Chunk(chunk)) => {
75 let text = chunk.content;
76 if !text.is_empty() {
77 let delta: ChatCompletionStreamResponseDelta = match serde_json::from_value(
78 serde_json::json!({ "role": "assistant", "content": text }),
79 ) {
80 Ok(d) => d,
81 Err(e) => {
82 error!("Failed to construct delta: {}", e);
83 return Ok(Event::default().comment("delta-error"));
84 }
85 };
86
87 let response: CreateChatCompletionStreamResponse =
88 match serde_json::from_value(serde_json::json!({
89 "id": format!("chatcmpl-{}", uuid::Uuid::new_v4()),
90 "object": "chat.completion.chunk",
91 "created": chrono::Utc::now().timestamp() as u32,
92 "model": model_owned.clone(),
93 "choices": vec![ChatChoiceStream {
94 index: 0,
95 delta,
96 finish_reason: None,
97 logprobs: None,
98 }],
99 "usage": serde_json::Value::Null,
100 })) {
101 Ok(r) => r,
102 Err(e) => {
103 error!("Failed to construct streaming response: {}", e);
104 return Ok(Event::default().comment("delta-error"));
105 }
106 };
107
108 match serde_json::to_string(&response) {
109 Ok(json) => Ok(Event::default().data(json)),
110 Err(e) => {
111 error!("Failed to serialize response: {}", e);
112 Ok(Event::default().comment("serialization-error"))
113 }
114 }
115 } else {
116 Ok(Event::default().comment("keep-alive"))
117 }
118 }
119 Ok(ChatStreamEvent::End(_)) => Ok(Event::default().data("[DONE]")),
120 Ok(_) => Ok(Event::default().comment("ignored-event")),
121 Err(e) => {
122 error!("Stream error: {}", e);
123 Ok(Event::default()
124 .event("error")
125 .data("Stream interrupted by upstream error"))
126 }
127 });
128
129 Ok(Sse::new(event_stream))
130}