use crate::entities::model_deployment;
use crate::models::{
ChatRequest, EmbeddingRequest, ModelReadyResponse, OCRRequest, Request, V1UserProfile,
};
use crate::query::Query as DbQuery;
use crate::state::AppState;
use crate::state::MessageQueue;
use crate::CONFIG;
use axum::extract::Request as AxumRequest;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Query as AxumQuery, State,
},
response::IntoResponse,
};
use futures::stream::{SplitSink, SplitStream};
use futures::{sink::SinkExt, stream::StreamExt};
use k8s_openapi::api::apps::v1::Deployment;
use kube::{Api, Client};
use rdkafka::consumer::{Consumer, StreamConsumer};
use rdkafka::producer::{FutureProducer, FutureRecord};
use rdkafka::ClientConfig;
use rdkafka::Message as KafkaMessage;
use rdkafka::TopicPartitionList;
use redis::AsyncCommands;
use sea_orm::DatabaseConnection;
use serde::Serializer;
use serde::{Deserialize, Serialize};
use serde_json::ser::CompactFormatter;
use short_uuid::ShortUuid;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use tracing::{debug, error, info, warn};
#[derive(Deserialize, Clone, Debug)]
pub struct StreamParams {
provider: Option<String>,
model: Option<String>,
}
async fn get_effective_model<'a>(
model: Option<&'a str>,
provider: Option<&str>,
) -> Option<&'a str> {
if provider == Some("litellm") {
None
} else {
model
}
}
async fn validate_model_deployment(
db: &DatabaseConnection,
model: Option<&str>,
provider: Option<&str>,
kind: &str,
) -> Result<model_deployment::Model, String> {
let effective_model = get_effective_model(model, provider).await;
let deployments = DbQuery::find_matching_deployments(db, effective_model, provider, kind)
.await
.map_err(|e| format!("Database error: {}", e))?;
let deployment = deployments.first().cloned().ok_or_else(|| {
if let Some(provider) = provider {
format!(
"No matching {} model '{}' found for provider '{}'",
kind,
model.unwrap_or(""),
provider
)
} else {
format!("No matching {} model '{}' found", kind, model.unwrap_or(""))
}
})?;
let k8s_client = Client::try_default()
.await
.map_err(|e| format!("Failed to create Kubernetes client: {}", e))?;
let k8s_deployments: Api<Deployment> =
Api::namespaced(k8s_client, &deployment.resource_namespace);
let k8s_deployment = k8s_deployments
.get_status(&deployment.resource_name)
.await
.map_err(|e| format!("Failed to get deployment status from Kubernetes: {}", e))?;
let ready_replicas = k8s_deployment
.status
.as_ref()
.and_then(|status| status.ready_replicas)
.unwrap_or(0);
if ready_replicas <= 0 {
return Err(format!(
"Model '{}' has no ready replicas",
model.unwrap_or("")
));
}
Ok(deployment)
}
pub async fn chat_stream_handler(
AxumQuery(params): AxumQuery<StreamParams>,
ws: WebSocketUpgrade,
State(state): State<AppState>,
request: AxumRequest,
) -> impl IntoResponse {
info!(
model = ?params.model,
provider = ?params.provider,
"Received WebSocket connection request"
);
let user_profile = request
.extensions()
.get::<V1UserProfile>()
.expect("User profile should be present after middleware")
.clone();
ws.on_upgrade(move |socket| async move {
let (mut sender, receiver) = socket.split();
let ready_response = match validate_model_deployment(
&state.db_pool,
params.model.as_deref(),
params.provider.as_deref(),
"ChatModel",
)
.await
{
Ok(_) => ModelReadyResponse {
response_type: "ModelReadyResponse".to_string(),
request_id: ShortUuid::generate().to_string(),
ready: true,
error: None,
},
Err(e) => ModelReadyResponse {
response_type: "ModelReadyResponse".to_string(),
request_id: ShortUuid::generate().to_string(),
ready: false,
error: Some(e),
},
};
if let Err(e) = sender
.send(Message::Text(
serde_json::to_string(&ready_response).unwrap(),
))
.await
{
error!("Failed to send ready response: {}", e);
return;
}
if ready_response.ready {
match &state.message_queue {
MessageQueue::Redis { client } => {
handle_socket_redis::<ChatRequest>(
(sender, receiver),
client.clone(),
"ChatModel".to_string(),
params.model.clone().unwrap_or_default(),
params,
user_profile.clone(),
)
.await
}
MessageQueue::Kafka { producer, .. } => {
handle_socket_kafka::<ChatRequest>(
(sender, receiver),
producer.clone(),
params.model.clone().unwrap_or_default(),
params,
user_profile.clone(),
)
.await
}
}
}
})
}
pub async fn embedding_stream_handler(
AxumQuery(params): AxumQuery<StreamParams>,
ws: WebSocketUpgrade,
State(state): State<AppState>,
request: AxumRequest,
) -> impl IntoResponse {
info!(
model = ?params.model,
provider = ?params.provider,
"Received WebSocket connection request"
);
let user_profile = request
.extensions()
.get::<V1UserProfile>()
.expect("User profile should be present after middleware")
.clone();
ws.on_upgrade(move |socket| async move {
let (mut sender, receiver) = socket.split();
let ready_response = match validate_model_deployment(
&state.db_pool,
params.model.as_deref(),
params.provider.as_deref(),
"EmbeddingModel",
)
.await
{
Ok(_) => ModelReadyResponse {
response_type: "ModelReadyResponse".to_string(),
request_id: ShortUuid::generate().to_string(),
ready: true,
error: None,
},
Err(e) => ModelReadyResponse {
response_type: "ModelReadyResponse".to_string(),
request_id: ShortUuid::generate().to_string(),
ready: false,
error: Some(e),
},
};
if let Err(e) = sender
.send(Message::Text(
serde_json::to_string(&ready_response).unwrap(),
))
.await
{
error!("Failed to send ready response: {}", e);
return;
}
if ready_response.ready {
match &state.message_queue {
MessageQueue::Redis { client } => {
handle_socket_redis::<EmbeddingRequest>(
(sender, receiver),
client.clone(),
"EmbeddingModel".to_string(),
params.model.clone().unwrap_or_default(),
params,
user_profile.clone(),
)
.await
}
MessageQueue::Kafka { producer, .. } => {
handle_socket_kafka::<EmbeddingRequest>(
(sender, receiver),
producer.clone(),
params.model.clone().unwrap_or_default(),
params,
user_profile.clone(),
)
.await
}
}
}
})
}
pub async fn ocr_stream_handler(
AxumQuery(params): AxumQuery<StreamParams>,
ws: WebSocketUpgrade,
State(state): State<AppState>,
request: AxumRequest,
) -> impl IntoResponse {
info!(
model = ?params.model,
provider = ?params.provider,
"Received WebSocket connection request"
);
let user_profile = request
.extensions()
.get::<V1UserProfile>()
.expect("User profile should be present after middleware")
.clone();
ws.on_upgrade(move |socket| async move {
let (mut sender, receiver) = socket.split();
let ready_response = match validate_model_deployment(
&state.db_pool,
params.model.as_deref(),
params.provider.as_deref(),
"OCRModel",
)
.await
{
Ok(_) => ModelReadyResponse {
response_type: "ModelReadyResponse".to_string(),
request_id: ShortUuid::generate().to_string(),
ready: true,
error: None,
},
Err(e) => ModelReadyResponse {
response_type: "ModelReadyResponse".to_string(),
request_id: ShortUuid::generate().to_string(),
ready: false,
error: Some(e),
},
};
if let Err(e) = sender
.send(Message::Text(
serde_json::to_string(&ready_response).unwrap(),
))
.await
{
error!("Failed to send ready response: {}", e);
return;
}
if ready_response.ready {
match &state.message_queue {
MessageQueue::Redis { client } => {
handle_socket_redis::<OCRRequest>(
(sender, receiver),
client.clone(),
"OCRModel".to_string(),
params.model.clone().unwrap_or_default(),
params,
user_profile.clone(),
)
.await
}
MessageQueue::Kafka { producer, .. } => {
handle_socket_kafka::<OCRRequest>(
(sender, receiver),
producer.clone(),
params.model.clone().unwrap_or_default(),
params,
user_profile.clone(),
)
.await
}
}
}
})
}
async fn handle_socket_kafka<T>(
socket: (SplitSink<WebSocket, Message>, SplitStream<WebSocket>),
producer: Arc<FutureProducer>,
name: String,
params: StreamParams,
user_profile: V1UserProfile,
) where
T: Request + Serialize + for<'de> Deserialize<'de> + Send,
{
let params = params.clone();
let provider = params.provider.clone();
let (mut sender, mut receiver) = socket;
let partition_id = {
let mut hasher = DefaultHasher::new();
user_profile.email.hash(&mut hasher);
(hasher.finish() % 1000) as i32
};
let output_topic = if !provider.as_ref().map_or(false, |s| !s.is_empty()) {
format!("{}.{}-output", provider.unwrap(), name)
} else {
format!("{}-output", name)
};
let consumer: StreamConsumer = ClientConfig::new()
.set("bootstrap.servers", CONFIG.kafka_bootstrap_servers.as_str())
.set("enable.auto.commit", "false")
.create()
.expect("Consumer creation failed");
let mut topic_partition = TopicPartitionList::new();
topic_partition.add_partition(&output_topic, partition_id);
consumer
.assign(&topic_partition)
.expect("Topic partition assignment failed");
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
let mut kafka_task = tokio::spawn(async move {
loop {
match consumer.recv().await {
Ok(message) => {
if let Some(payload) = message.payload() {
if let Ok(text) = String::from_utf8(payload.to_vec()) {
if tx.send(text).await.is_err() {
break;
}
}
}
}
Err(e) => {
eprintln!("Error receiving Kafka message: {}", e);
break;
}
}
}
});
let mut send_task = tokio::spawn(async move {
while let Some(text) = rx.recv().await {
if sender.send(Message::Text(text)).await.is_err() {
break;
}
}
});
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
match msg {
Message::Text(text) => {
match serde_json::from_str::<T>(&text) {
Ok(mut request) => {
request.set_user_id(user_profile.email.clone());
request.set_output_topic(output_topic.clone());
request.set_request_id(ShortUuid::generate().to_string());
let validated_json = serde_json::to_string(&request)
.expect("Failed to serialize validated request");
let backend = params.provider.clone();
let topic = if let Some(ref b) = backend {
if !b.is_empty() {
format!("{}.{}", b, name.clone())
} else {
name.clone()
}
} else {
name.clone()
};
if let Err((e, _)) = producer
.send::<str, [u8], _>(
FutureRecord::to(&topic)
.payload(validated_json.as_bytes())
.partition(partition_id),
std::time::Duration::from_secs(5),
)
.await
{
eprintln!("Failed to deliver message: {:?}", e);
}
}
Err(e) => {
eprintln!("Invalid message format: {}", e);
}
}
}
Message::Close(_) => break,
_ => {} }
}
});
tokio::select! {
_ = &mut kafka_task => {
send_task.abort();
recv_task.abort();
},
_ = &mut send_task => {
kafka_task.abort();
recv_task.abort();
},
_ = &mut recv_task => {
kafka_task.abort();
send_task.abort();
}
}
println!("WebSocket connection closed");
}
async fn handle_socket_redis<T>(
socket: (SplitSink<WebSocket, Message>, SplitStream<WebSocket>),
redis: Arc<redis::Client>,
kind: String,
name: String,
params: StreamParams,
user_profile: V1UserProfile,
) where
T: Request + Serialize + for<'de> Deserialize<'de> + Send,
{
info!(
model = %name,
provider = ?params.provider,
user = %user_profile.email,
"Starting Redis WebSocket handler"
);
let effective_name = if params.provider == Some("litellm".to_string()) {
"".to_string()
} else {
name.clone()
};
let (mut sender, mut receiver) = socket;
let provider = params.provider.clone();
let model = params.model.clone();
let session_id = ShortUuid::generate().to_string();
let input_stream = format!(
"{}:in:{}:{}",
kind.clone(),
provider.clone().unwrap_or_default(),
effective_name.clone(),
);
let output_stream = format!(
"{}:out:{}:{}:{}:{}",
kind,
provider.unwrap_or_default(),
model.unwrap_or_default(),
user_profile.email,
session_id,
);
info!(
input_stream = %input_stream,
output_stream = %output_stream,
"Redis streams created"
);
let redis_conn = match redis.get_multiplexed_async_connection().await {
Ok(conn) => conn,
Err(e) => {
error!("Failed to connect to Redis: {}", e);
return;
}
};
info!("Redis connection established");
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
let mut redis_consumer = redis_conn.clone();
let consumer_stream = output_stream.clone();
let mut last_id = "0".to_string();
let mut redis_task = tokio::spawn(async move {
info!("Starting Redis consumer task");
loop {
debug!(last_id = %last_id, "Reading from Redis stream '{}'", consumer_stream);
let streams = vec![consumer_stream.as_str()];
let ids = vec![last_id.as_str()];
let result: Result<redis::streams::StreamReadReply, _> = redis_consumer
.xread_options(
&streams,
&ids,
&redis::streams::StreamReadOptions::default().block(1000),
)
.await;
match result {
Ok(reply) => {
if reply.keys.is_empty() {
debug!("No new messages in stream");
continue;
}
for stream in reply.keys {
for message in stream.ids {
let id = message.id.clone();
last_id = id.clone();
debug!(message_id = %id, "Processing Redis message");
if let Some(payload) = message.map.get("message") {
let text = payload.clone();
debug!("Forwarding message to WebSocket sender");
if tx.send(text).await.is_err() {
warn!("WebSocket sender dropped, stopping consumer task");
return;
}
} else {
warn!("Message payload missing 'message' field");
}
}
}
}
Err(e) => {
error!(
"Error reading from Redis stream '{}': {}",
consumer_stream, e
);
return;
}
}
}
});
let mut send_task = tokio::spawn(async move {
info!("Starting WebSocket sender task");
while let Some(text) = rx.recv().await {
info!("Received message to send to WebSocket: {:?}", text);
let text_string = match text {
redis::Value::BulkString(bytes) => match String::from_utf8(bytes) {
Ok(s) => s,
Err(e) => {
error!("Failed to convert Redis value bytes to string: {}", e);
continue;
}
},
redis::Value::SimpleString(s) => s,
redis::Value::VerbatimString { text, .. } => text,
_ => {
error!("Unexpected Redis value type: {:?}", text);
continue;
}
};
if let Err(e) = sender.send(Message::Text(text_string)).await {
error!("Failed to send message through WebSocket: {}", e);
break;
}
debug!("Successfully sent message through WebSocket");
}
info!("WebSocket sender task ended");
});
let mut redis_producer = redis_conn.clone();
let producer_stream = input_stream.clone();
let mut recv_task = tokio::spawn(async move {
info!("Starting WebSocket receiver task");
while let Some(Ok(msg)) = receiver.next().await {
match msg {
Message::Text(text) => {
info!("Received WebSocket message: {}", text);
match serde_json::from_str::<T>(&text) {
Ok(mut request) => {
debug!("Successfully parsed request");
request.set_organizations(
user_profile.organizations.clone().unwrap_or_default(),
);
request.set_handle(user_profile.handle.clone().unwrap_or_default());
request.set_user_id(user_profile.email.clone());
request.set_output_topic(output_stream.clone());
request.set_request_id(ShortUuid::generate().to_string());
let validated_json = {
let mut writer = Vec::new();
let formatter = CompactFormatter;
let mut ser =
serde_json::Serializer::with_formatter(&mut writer, formatter);
ser.serialize_some(&request)
.expect("Failed to serialize request");
String::from_utf8(writer).expect("Failed to convert to string")
};
info!("Adding message to Redis stream: {}", producer_stream);
match redis_producer
.xadd::<String, &str, &str, String, String>(
producer_stream.clone(),
"*",
&[("message", validated_json.clone())],
)
.await
{
Ok(id) => info!(
"Successfully added message to Redis stream '{}' with ID: {}",
producer_stream, id
),
Err(e) => error!(
"Failed to add message to Redis stream '{}': {}",
producer_stream, e
),
}
}
Err(e) => {
error!("Invalid message format: {}", e);
}
}
}
Message::Close(_) => {
info!("Received WebSocket close message");
break;
}
other => {
warn!("Received unexpected WebSocket message: {:?}", other);
}
}
}
info!("WebSocket receiver task ended");
});
info!("All tasks spawned, waiting for completion");
tokio::select! {
_ = &mut redis_task => {
warn!("Redis consumer task exited");
send_task.abort();
recv_task.abort();
},
_ = &mut send_task => {
warn!("WebSocket sender task exited");
redis_task.abort();
recv_task.abort();
},
_ = &mut recv_task => {
warn!("WebSocket receiver task exited");
redis_task.abort();
send_task.abort();
}
}
info!("WebSocket connection closed");
}