use std::sync::Arc;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::extract::State;
use axum::response::Response;
use serde::{Deserialize, Serialize};
use crate::state::AppState;
#[derive(Debug, Deserialize)]
pub struct WsRequest {
pub model: Option<String>,
pub messages: Vec<WsMessage>,
#[serde(default = "default_max_tokens")]
pub max_tokens: u32,
#[serde(default = "default_temperature")]
pub temperature: f32,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct WsMessage {
pub role: String,
pub content: String,
}
fn default_max_tokens() -> u32 {
512
}
fn default_temperature() -> f32 {
0.7
}
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum WsEvent {
Token {
delta: String,
},
Done {
finish_reason: String,
usage: UsageSummary,
},
Error {
message: String,
},
}
#[derive(Debug, Serialize)]
pub struct UsageSummary {
pub prompt_tokens: u32,
pub completion_tokens: u32,
}
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> Response {
ws.on_upgrade(move |socket| handle_socket(socket, state))
}
async fn handle_socket(mut socket: WebSocket, _state: Arc<AppState>) {
let text = match receive_text(&mut socket).await {
Some(t) => t,
None => return,
};
let req: WsRequest = match serde_json::from_str(&text) {
Ok(r) => r,
Err(e) => {
send_error(&mut socket, &format!("Invalid JSON request: {e}")).await;
return;
}
};
let _ = req.model; let stub_tokens: &[&str] = &["Hello", " from", " OxiLLaMa", " via", " WebSocket"];
let mut sent = 0u32;
for token in stub_tokens {
let event = WsEvent::Token {
delta: (*token).to_string(),
};
if !send_event(&mut socket, &event).await {
return;
}
sent += 1;
}
let done = WsEvent::Done {
finish_reason: "stop".to_string(),
usage: UsageSummary {
prompt_tokens: 0,
completion_tokens: sent,
},
};
send_event(&mut socket, &done).await;
}
async fn receive_text(socket: &mut WebSocket) -> Option<String> {
match socket.recv().await {
Some(Ok(Message::Text(t))) => Some(t.to_string()),
Some(Ok(Message::Close(_))) | None => None,
Some(Ok(_)) => {
send_error(socket, "Expected a JSON text frame as the first message").await;
None
}
Some(Err(e)) => {
send_error(socket, &format!("WebSocket receive error: {e}")).await;
None
}
}
}
async fn send_event(socket: &mut WebSocket, event: &WsEvent) -> bool {
match serde_json::to_string(event) {
Ok(json) => socket.send(Message::Text(json.into())).await.is_ok(),
Err(_) => false,
}
}
async fn send_error(socket: &mut WebSocket, message: &str) {
let event = WsEvent::Error {
message: message.to_string(),
};
if let Ok(json) = serde_json::to_string(&event) {
let _ = socket.send(Message::Text(json.into())).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ws_event_token_serializes_correctly() {
let event = WsEvent::Token {
delta: "hello".into(),
};
let json = serde_json::to_string(&event).expect("serialize failed");
assert!(json.contains("\"type\":\"token\""));
assert!(json.contains("\"delta\":\"hello\""));
}
#[test]
fn ws_event_done_serializes_correctly() {
let event = WsEvent::Done {
finish_reason: "stop".into(),
usage: UsageSummary {
prompt_tokens: 5,
completion_tokens: 10,
},
};
let json = serde_json::to_string(&event).expect("serialize failed");
assert!(json.contains("\"type\":\"done\""));
assert!(json.contains("\"finish_reason\":\"stop\""));
assert!(json.contains("\"prompt_tokens\":5"));
assert!(json.contains("\"completion_tokens\":10"));
}
#[test]
fn ws_event_error_serializes_correctly() {
let event = WsEvent::Error {
message: "oops".into(),
};
let json = serde_json::to_string(&event).expect("serialize failed");
assert!(json.contains("\"type\":\"error\""));
assert!(json.contains("\"message\":\"oops\""));
}
#[test]
fn ws_request_deserializes_with_defaults() {
let json = r#"{"messages": [{"role": "user", "content": "hello"}]}"#;
let req: WsRequest = serde_json::from_str(json).expect("deserialize failed");
assert_eq!(req.max_tokens, 512);
assert!((req.temperature - 0.7).abs() < 0.001);
assert!(req.model.is_none());
}
#[test]
fn ws_request_deserializes_explicit_fields() {
let json = r#"{
"model": "local",
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 128,
"temperature": 0.5
}"#;
let req: WsRequest = serde_json::from_str(json).expect("deserialize failed");
assert_eq!(req.model.as_deref(), Some("local"));
assert_eq!(req.max_tokens, 128);
assert!((req.temperature - 0.5).abs() < 0.001);
}
}