#![allow(dead_code, unused_imports, unused_variables)]
use crate::{
InfernoError,
api::openai::{
ChatChunkChoice, ChatCompletionChunk, ChatCompletionRequest, ChatDelta, ChatMessage,
},
backends::{Backend, InferenceParams},
cli::serve::ServerState,
streaming::{StreamingConfig, StreamingManager},
upgrade::{ApplicationVersion, UpgradeEvent, UpgradeStatus},
};
use axum::{
extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::Response,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WSMessage {
#[serde(rename = "chat_request")]
ChatRequest {
id: String,
data: ChatCompletionRequest,
},
#[serde(rename = "chat_chunk")]
ChatChunk {
id: String,
data: ChatCompletionChunk,
},
#[serde(rename = "error")]
Error {
id: Option<String>,
message: String,
code: String,
},
#[serde(rename = "heartbeat")]
Heartbeat {
timestamp: chrono::DateTime<chrono::Utc>,
active_streams: usize,
},
#[serde(rename = "stream_metrics")]
StreamMetrics {
active_streams: usize,
total_tokens: u64,
average_latency: f32,
},
#[serde(rename = "connection_info")]
ConnectionInfo {
connection_id: String,
server_version: String,
capabilities: Vec<String>,
},
#[serde(rename = "upgrade_status")]
UpgradeStatus {
status: UpgradeStatus,
current_version: ApplicationVersion,
},
#[serde(rename = "upgrade_event")]
UpgradeEvent { event: UpgradeEvent },
#[serde(rename = "upgrade_check_request")]
UpgradeCheckRequest { id: String, force: bool },
#[serde(rename = "upgrade_install_request")]
UpgradeInstallRequest {
id: String,
version: Option<String>,
auto_backup: bool,
},
}
pub async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<ServerState>>,
) -> Response {
info!("New WebSocket connection for streaming");
ws.on_upgrade(move |socket| handle_websocket(socket, state))
}
async fn handle_websocket(socket: WebSocket, state: Arc<ServerState>) {
let connection_id = Uuid::new_v4().to_string();
info!("WebSocket connection established: {}", connection_id);
let streaming_config = StreamingConfig {
max_concurrent_streams: 5, enable_metrics: true,
heartbeat_interval_ms: 30000, ..Default::default()
};
let streaming_manager = Arc::new(StreamingManager::new(streaming_config));
if let Err(e) = streaming_manager.start().await {
error!("Failed to start streaming manager: {}", e);
return;
}
let (sender, mut receiver) = socket.split();
let sender = Arc::new(Mutex::new(sender));
let connection_info = WSMessage::ConnectionInfo {
connection_id: connection_id.clone(),
server_version: std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "0.1.0".to_string()),
capabilities: vec![
"streaming_chat".to_string(),
"real_time_metrics".to_string(),
"heartbeat".to_string(),
"upgrade_notifications".to_string(),
"upgrade_management".to_string(),
],
};
if let Ok(msg) = serde_json::to_string(&connection_info) {
if sender.lock().await.send(Message::Text(msg)).await.is_err() {
return;
}
}
let heartbeat_sender = sender.clone();
let heartbeat_manager = streaming_manager.clone();
let heartbeat_connection_id = connection_id.clone();
let heartbeat_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
loop {
interval.tick().await;
let metrics = heartbeat_manager.get_metrics();
let heartbeat = WSMessage::Heartbeat {
timestamp: chrono::Utc::now(),
active_streams: metrics.active_streams,
};
if let Ok(msg) = serde_json::to_string(&heartbeat) {
if heartbeat_sender
.lock()
.await
.send(Message::Text(msg))
.await
.is_err()
{
debug!(
"Heartbeat failed for connection: {}",
heartbeat_connection_id
);
break;
}
}
}
});
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Text(text)) => match serde_json::from_str::<WSMessage>(&text) {
Ok(ws_message) => {
if let Err(e) = handle_ws_message(
ws_message,
&state,
&streaming_manager,
&sender,
&connection_id,
)
.await
{
error!("Error handling WebSocket message: {}", e);
let error_msg = WSMessage::Error {
id: None,
message: format!("Message handling failed: {}", e),
code: "INTERNAL_ERROR".to_string(),
};
if let Ok(error_json) = serde_json::to_string(&error_msg) {
let _ = sender.lock().await.send(Message::Text(error_json)).await;
}
}
}
Err(e) => {
warn!("Invalid WebSocket message format: {}", e);
let error_msg = WSMessage::Error {
id: None,
message: format!("Invalid message format: {}", e),
code: "INVALID_FORMAT".to_string(),
};
if let Ok(error_json) = serde_json::to_string(&error_msg) {
let _ = sender.lock().await.send(Message::Text(error_json)).await;
}
}
},
Ok(Message::Binary(_)) => {
warn!("Binary messages not supported in streaming WebSocket");
}
Ok(Message::Ping(data)) => {
debug!("Received ping, sending pong");
let _ = sender.lock().await.send(Message::Pong(data)).await;
}
Ok(Message::Pong(_)) => {
debug!("Received pong");
}
Ok(Message::Close(_)) => {
info!("WebSocket connection closed: {}", connection_id);
break;
}
Err(e) => {
error!("WebSocket error: {}", e);
break;
}
}
}
heartbeat_handle.abort();
info!("WebSocket connection handler finished: {}", connection_id);
}
async fn handle_ws_message(
message: WSMessage,
state: &Arc<ServerState>,
streaming_manager: &Arc<StreamingManager>,
sender: &Arc<Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
connection_id: &str,
) -> Result<(), InfernoError> {
match message {
WSMessage::ChatRequest { id, data } => {
info!(
"Processing chat request {} for connection {}",
id, connection_id
);
let backend = get_or_load_backend_for_ws(state, &data.model).await?;
let prompt = format_chat_messages(&data.messages);
let inference_params = InferenceParams {
max_tokens: data.max_tokens,
temperature: data.temperature,
top_k: 40,
top_p: data.top_p,
stream: true, stop_sequences: data.stop.unwrap_or_default(),
seed: None,
};
let mut stream = streaming_manager
.create_enhanced_stream(&mut *backend.lock().await, &prompt, &inference_params)
.await
.map_err(|e| InfernoError::WebSocket(format!("Stream creation failed: {}", e)))?;
let sender_clone = sender.clone();
let request_id = id.clone();
let model_name = data.model.clone();
tokio::spawn(async move {
let initial_chunk = ChatCompletionChunk {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp(),
model: model_name.clone(),
choices: vec![ChatChunkChoice {
index: 0,
delta: ChatDelta {
role: Some("assistant".to_string()),
content: None,
},
finish_reason: None,
}],
};
let initial_ws_msg = WSMessage::ChatChunk {
id: request_id.clone(),
data: initial_chunk,
};
if let Ok(msg_json) = serde_json::to_string(&initial_ws_msg) {
let _ = sender_clone
.lock()
.await
.send(Message::Text(msg_json))
.await;
}
while let Some(token_result) = stream.next().await {
match token_result {
Ok(streaming_token) => {
if !streaming_token.is_heartbeat() {
let chunk = ChatCompletionChunk {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp(),
model: model_name.clone(),
choices: vec![ChatChunkChoice {
index: 0,
delta: ChatDelta {
role: None,
content: Some(streaming_token.content),
},
finish_reason: None,
}],
};
let ws_msg = WSMessage::ChatChunk {
id: request_id.clone(),
data: chunk,
};
if let Ok(msg_json) = serde_json::to_string(&ws_msg) {
if sender_clone
.lock()
.await
.send(Message::Text(msg_json))
.await
.is_err()
{
break;
}
}
}
}
Err(e) => {
error!("Streaming error: {}", e);
let error_msg = WSMessage::Error {
id: Some(request_id.clone()),
message: format!("Streaming failed: {}", e),
code: "STREAM_ERROR".to_string(),
};
if let Ok(error_json) = serde_json::to_string(&error_msg) {
let _ = sender_clone
.lock()
.await
.send(Message::Text(error_json))
.await;
}
break;
}
}
}
let final_chunk = ChatCompletionChunk {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp(),
model: model_name,
choices: vec![ChatChunkChoice {
index: 0,
delta: ChatDelta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
};
let final_ws_msg = WSMessage::ChatChunk {
id: request_id,
data: final_chunk,
};
if let Ok(msg_json) = serde_json::to_string(&final_ws_msg) {
let _ = sender_clone
.lock()
.await
.send(Message::Text(msg_json))
.await;
}
});
Ok(())
}
WSMessage::UpgradeCheckRequest { id, force } => {
info!(
"Processing upgrade check request {} for connection {}",
id, connection_id
);
let upgrade_manager = match &state.upgrade_manager {
Some(manager) => manager.clone(),
None => {
let error_msg = WSMessage::Error {
id: Some(id),
message: "Upgrade system not initialized".to_string(),
code: "UPGRADE_NOT_AVAILABLE".to_string(),
};
send_ws_message(sender, &error_msg).await?;
return Ok(());
}
};
let sender_clone = sender.clone();
let manager_clone = upgrade_manager.clone();
tokio::spawn(async move {
match manager_clone.check_for_updates().await {
Ok(update_info) => {
let status = if let Some(update) = update_info {
crate::upgrade::UpgradeStatus::Available(update)
} else {
crate::upgrade::UpgradeStatus::UpToDate
};
let status_msg = WSMessage::UpgradeStatus {
status,
current_version: ApplicationVersion::current(),
};
let _ = send_ws_message(&sender_clone, &status_msg).await;
}
Err(e) => {
let error_msg = WSMessage::Error {
id: Some(id),
message: format!("Update check failed: {}", e),
code: "UPDATE_CHECK_FAILED".to_string(),
};
let _ = send_ws_message(&sender_clone, &error_msg).await;
}
}
});
Ok(())
}
WSMessage::UpgradeInstallRequest {
id,
version,
auto_backup,
} => {
info!(
"Processing upgrade install request {} for connection {}",
id, connection_id
);
let upgrade_manager = match &state.upgrade_manager {
Some(manager) => manager.clone(),
None => {
let error_msg = WSMessage::Error {
id: Some(id),
message: "Upgrade system not initialized".to_string(),
code: "UPGRADE_NOT_AVAILABLE".to_string(),
};
send_ws_message(sender, &error_msg).await?;
return Ok(());
}
};
let sender_clone = sender.clone();
let manager_clone = upgrade_manager.clone();
tokio::spawn(async move {
match manager_clone.check_for_updates().await {
Ok(Some(update_info)) => {
if let Some(requested_version) = version {
if update_info.version.to_string() != requested_version {
let error_msg = WSMessage::Error {
id: Some(id),
message: format!(
"Requested version {} not available",
requested_version
),
code: "VERSION_NOT_FOUND".to_string(),
};
let _ = send_ws_message(&sender_clone, &error_msg).await;
return;
}
}
match manager_clone.install_update(&update_info).await {
Ok(_) => {
let status_msg = WSMessage::UpgradeStatus {
status: crate::upgrade::UpgradeStatus::Completed {
old_version: ApplicationVersion::current(),
new_version: update_info.version,
restart_required: true,
},
current_version: ApplicationVersion::current(),
};
let _ = send_ws_message(&sender_clone, &status_msg).await;
}
Err(e) => {
let error_msg = WSMessage::Error {
id: Some(id),
message: format!("Installation failed: {}", e),
code: "INSTALLATION_FAILED".to_string(),
};
let _ = send_ws_message(&sender_clone, &error_msg).await;
}
}
}
Ok(None) => {
let error_msg = WSMessage::Error {
id: Some(id),
message: "No updates available".to_string(),
code: "NO_UPDATES_AVAILABLE".to_string(),
};
let _ = send_ws_message(&sender_clone, &error_msg).await;
}
Err(e) => {
let error_msg = WSMessage::Error {
id: Some(id),
message: format!("Update check failed: {}", e),
code: "UPDATE_CHECK_FAILED".to_string(),
};
let _ = send_ws_message(&sender_clone, &error_msg).await;
}
}
});
Ok(())
}
_ => {
warn!("Unsupported WebSocket message type");
Err(InfernoError::WebSocket(
"Unsupported message type".to_string(),
))
}
}
}
async fn get_or_load_backend_for_ws(
state: &Arc<ServerState>,
model_name: &str,
) -> Result<Arc<tokio::sync::Mutex<Backend>>, InfernoError> {
if let Some(ref _distributed) = state.distributed {
return Err(InfernoError::WebSocket(
"WebSocket streaming not supported with distributed inference yet".to_string(),
));
}
if let Some(ref loaded_model) = state.loaded_model {
if loaded_model == model_name {
if let Some(ref backend) = state.backend {
return Ok(backend.inner().clone());
}
}
}
let model_info = state
.model_manager
.resolve_model(model_name)
.await
.map_err(|e| InfernoError::WebSocket(format!("Model resolution failed: {}", e)))?;
let backend_type =
crate::backends::BackendType::from_model_path(&model_info.path).ok_or_else(|| {
InfernoError::WebSocket(format!(
"No suitable backend found for model: {}",
model_info.path.display()
))
})?;
let mut backend = Backend::new(backend_type, &state.config.backend_config)
.map_err(|e| InfernoError::WebSocket(format!("Backend creation failed: {}", e)))?;
backend
.load_model(&model_info)
.await
.map_err(|e| InfernoError::WebSocket(format!("Model loading failed: {}", e)))?;
Ok(Arc::new(tokio::sync::Mutex::new(backend)))
}
fn format_chat_messages(messages: &[ChatMessage]) -> String {
messages
.iter()
.map(|msg| format!("{}: {}", msg.role, msg.content))
.collect::<Vec<_>>()
.join("\n")
}
async fn send_ws_message(
sender: &Arc<Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
message: &WSMessage,
) -> Result<(), InfernoError> {
let json = serde_json::to_string(message)
.map_err(|e| InfernoError::WebSocket(format!("Failed to serialize message: {}", e)))?;
sender
.lock()
.await
.send(Message::Text(json))
.await
.map_err(|e| InfernoError::WebSocket(format!("Failed to send message: {}", e)))?;
Ok(())
}