use crate::core::models::openai::*;
use crate::utils::error::{GatewayError, Result};
use actix_web::http::header::{CACHE_CONTROL, CONTENT_TYPE};
use actix_web::{HttpResponse, web};
use futures::stream::{Stream, StreamExt};
use serde_json::json;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::error;
use uuid::Uuid;
#[derive(Debug, Clone, Default)]
pub struct Event {
pub event: Option<String>,
pub data: String,
}
impl Event {
pub fn new() -> Self {
Self {
event: None,
data: String::new(),
}
}
pub fn event(mut self, event: &str) -> Self {
self.event = Some(event.to_string());
self
}
pub fn data(mut self, data: &str) -> Self {
self.data = data.to_string();
self
}
pub fn to_bytes(&self) -> web::Bytes {
let mut result = String::new();
if let Some(event) = &self.event {
result.push_str(&format!("event: {}\n", event));
}
result.push_str(&format!("data: {}\n\n", self.data));
web::Bytes::from(result)
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: Option<String>,
pub choices: Vec<ChatCompletionChunkChoice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ChatCompletionChunkChoice {
pub index: u32,
pub delta: ChatCompletionDelta,
pub finish_reason: Option<String>,
pub logprobs: Option<serde_json::Value>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ChatCompletionDelta {
pub role: Option<MessageRole>,
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCallDelta>>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ToolCallDelta {
pub index: u32,
pub id: Option<String>,
#[serde(rename = "type")]
pub tool_type: Option<String>,
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FunctionCallDelta {
pub name: Option<String>,
pub arguments: Option<String>,
}
pub struct StreamingHandler {
request_id: String,
model: String,
is_first_chunk: bool,
accumulated_content: String,
start_time: std::time::Instant,
}
impl StreamingHandler {
pub fn new(model: String) -> Self {
Self {
request_id: format!("chatcmpl-{}", Uuid::new_v4()),
model,
is_first_chunk: true,
accumulated_content: String::new(),
start_time: std::time::Instant::now(),
}
}
pub fn create_sse_stream<S>(
mut self,
provider_stream: S,
) -> impl Stream<Item = Result<web::Bytes>>
where
S: Stream<Item = Result<String>> + Send + 'static,
{
let (tx, rx) = mpsc::channel(100);
tokio::spawn(async move {
tokio::pin!(provider_stream);
while let Some(chunk_result) = provider_stream.next().await {
match chunk_result {
Ok(chunk_data) => {
match self.process_chunk(&chunk_data).await {
Ok(Some(event)) => {
if tx.send(Ok(event.to_bytes())).await.is_err() {
break;
}
}
Ok(None) => continue, Err(e) => {
error!("Error processing chunk: {}", e);
let error_event = Event::default()
.event("error")
.data(&json!({"error": e.to_string()}).to_string());
let _ = tx.send(Ok(error_event.to_bytes())).await;
break;
}
}
}
Err(e) => {
error!("Provider stream error: {}", e);
let error_event = Event::default()
.event("error")
.data(&json!({"error": e.to_string()}).to_string());
let _ = tx.send(Ok(error_event.to_bytes())).await;
break;
}
}
}
if let Ok(final_event) = self.create_final_chunk().await {
let _ = tx.send(Ok(final_event.to_bytes())).await;
}
let done_event = Event::default().data("[DONE]");
let _ = tx.send(Ok(done_event.to_bytes())).await;
});
ReceiverStream::new(rx)
}
async fn process_chunk(&mut self, chunk_data: &str) -> Result<Option<Event>> {
let content = self.extract_content_from_chunk(chunk_data)?;
if content.is_empty() {
return Ok(None);
}
self.accumulated_content.push_str(&content);
let chunk = ChatCompletionChunk {
id: self.request_id.clone(),
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: self.model.clone(),
system_fingerprint: None,
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: ChatCompletionDelta {
role: if self.is_first_chunk {
Some(MessageRole::Assistant)
} else {
None
},
content: Some(content),
tool_calls: None,
},
finish_reason: None,
logprobs: None,
}],
usage: None,
};
self.is_first_chunk = false;
let event = Event::default().data(&serde_json::to_string(&chunk)?);
Ok(Some(event))
}
fn extract_content_from_chunk(&self, chunk_data: &str) -> Result<String> {
if chunk_data.starts_with("data: ") {
let data = chunk_data.strip_prefix("data: ").unwrap_or(chunk_data);
if data.trim() == "[DONE]" {
return Ok(String::new());
}
if let Ok(json_chunk) = serde_json::from_str::<serde_json::Value>(data) {
if let Some(choices) = json_chunk.get("choices").and_then(|c| c.as_array()) {
if let Some(choice) = choices.first() {
if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
return Ok(content.to_string());
}
}
}
}
if let Some(delta) = json_chunk.get("delta") {
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
return Ok(text.to_string());
}
}
if let Some(text) = json_chunk.get("text").and_then(|t| t.as_str()) {
return Ok(text.to_string());
}
}
}
Ok(chunk_data.to_string())
}
async fn create_final_chunk(&self) -> Result<Event> {
let token_counter = crate::utils::TokenCounter::new();
let completion_tokens = token_counter
.count_completion_tokens(&self.model, &self.accumulated_content)
.map(|estimate| estimate.input_tokens)
.unwrap_or_else(|_| self.estimate_token_count(&self.accumulated_content));
let prompt_tokens = self.estimate_prompt_tokens();
let total_tokens = prompt_tokens + completion_tokens;
let usage = Usage {
prompt_tokens,
completion_tokens,
total_tokens,
prompt_tokens_details: None,
completion_tokens_details: None,
};
let final_chunk = ChatCompletionChunk {
id: self.request_id.clone(),
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: self.model.clone(),
system_fingerprint: None,
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: ChatCompletionDelta {
role: None,
content: None,
tool_calls: None,
},
finish_reason: Some("stop".to_string()),
logprobs: None,
}],
usage: Some(usage),
};
let event = Event::default().data(&serde_json::to_string(&final_chunk)?);
Ok(event)
}
fn estimate_token_count(&self, text: &str) -> u32 {
(text.len() as f64 / 4.0).ceil() as u32
}
fn estimate_prompt_tokens(&self) -> u32 {
match self.model.as_str() {
m if m.contains("gpt-4") => 150,
m if m.contains("gpt-3.5") => 100,
m if m.contains("claude") => 200,
m if m.contains("gemini") => 120,
_ => 100,
}
}
}
pub fn create_sse_response<S>(stream: S) -> HttpResponse
where
S: Stream<Item = Result<web::Bytes>> + Send + 'static,
{
HttpResponse::Ok()
.insert_header((CONTENT_TYPE, "text/event-stream"))
.insert_header((CACHE_CONTROL, "no-cache"))
.insert_header(("Connection", "keep-alive"))
.streaming(stream)
}
pub mod providers {
use super::*;
use futures::stream::BoxStream;
pub struct OpenAIStreaming;
impl OpenAIStreaming {
pub fn create_stream(response: reqwest::Response) -> BoxStream<'static, Result<String>> {
let stream = response.bytes_stream().map(|chunk_result| {
chunk_result
.map_err(|e| GatewayError::Network(e.to_string()))
.and_then(|chunk| {
String::from_utf8(chunk.to_vec())
.map_err(|e| GatewayError::Parsing(e.to_string()))
})
});
Box::pin(stream)
}
}
pub struct AnthropicStreaming;
impl AnthropicStreaming {
pub fn create_stream(response: reqwest::Response) -> BoxStream<'static, Result<String>> {
let stream = response.bytes_stream().map(|chunk_result| {
chunk_result
.map_err(|e| GatewayError::network(e.to_string()))
.and_then(|chunk| {
String::from_utf8(chunk.to_vec())
.map_err(|e| GatewayError::internal(format!("Parsing error: {}", e)))
})
});
Box::pin(stream)
}
}
pub struct GenericStreaming;
impl GenericStreaming {
pub fn create_stream(response: reqwest::Response) -> BoxStream<'static, Result<String>> {
let stream = response.bytes_stream().map(|chunk_result| {
chunk_result
.map_err(|e| GatewayError::network(e.to_string()))
.and_then(|chunk| {
String::from_utf8(chunk.to_vec())
.map_err(|e| GatewayError::internal(format!("Parsing error: {}", e)))
})
});
Box::pin(stream)
}
}
}
pub mod utils {
use super::*;
pub fn parse_sse_line(line: &str) -> Option<String> {
if let Some(stripped) = line.strip_prefix("data: ") {
Some(stripped.to_string())
} else {
None
}
}
pub fn is_done_line(line: &str) -> bool {
line.trim() == "data: [DONE]" || line.trim() == "[DONE]"
}
pub fn create_error_event(error: &str) -> Event {
Event::default()
.event("error")
.data(&json!({"error": error}).to_string())
}
pub fn create_heartbeat_event() -> Event {
Event::default().event("heartbeat").data("ping")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_streaming_handler_creation() {
let handler = StreamingHandler::new("gpt-4".to_string());
assert_eq!(handler.model, "gpt-4");
assert!(handler.is_first_chunk);
assert!(handler.accumulated_content.is_empty());
}
#[test]
fn test_extract_content_from_chunk() {
let handler = StreamingHandler::new("gpt-4".to_string());
let openai_chunk = r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#;
let content = handler.extract_content_from_chunk(openai_chunk).unwrap();
assert_eq!(content, "Hello");
let anthropic_chunk = r#"data: {"delta":{"text":"World"}}"#;
let content = handler.extract_content_from_chunk(anthropic_chunk).unwrap();
assert_eq!(content, "World");
let done_chunk = "data: [DONE]";
let content = handler.extract_content_from_chunk(done_chunk).unwrap();
assert!(content.is_empty());
}
#[test]
fn test_token_estimation() {
let handler = StreamingHandler::new("gpt-4".to_string());
let text = "Hello world"; let tokens = handler.estimate_token_count(text);
assert_eq!(tokens, 3);
let longer_text = "This is a longer text for testing"; let tokens = handler.estimate_token_count(longer_text);
assert_eq!(tokens, 9);
}
#[tokio::test]
async fn test_sse_utils() {
let line = "data: Hello World";
let data = utils::parse_sse_line(line);
assert_eq!(data, Some("Hello World".to_string()));
assert!(utils::is_done_line("data: [DONE]"));
assert!(utils::is_done_line("[DONE]"));
assert!(!utils::is_done_line("data: Hello"));
let _error_event = utils::create_error_event("Test error");
}
}