use crate::db::runs::TrainingMetrics;
use axum::extract::ws::{Message, WebSocket};
use futures::{SinkExt, StreamExt};
use serde::Serialize;
use tokio::sync::broadcast;
#[derive(Debug, Clone, Serialize)]
pub struct MetricsMessage {
#[serde(rename = "type")]
pub msg_type: String,
pub data: MetricsData,
}
#[derive(Debug, Clone, Serialize)]
pub struct MetricsData {
pub epoch: u32,
pub step: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub loss: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub accuracy: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lr: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gpu_util: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub memory_mb: Option<f64>,
pub timestamp: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct StatusMessage {
#[serde(rename = "type")]
pub msg_type: String,
pub data: StatusData,
}
#[derive(Debug, Clone, Serialize)]
pub struct StatusData {
pub status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<String>,
}
pub struct MetricsStreamer;
impl MetricsStreamer {
pub async fn stream(socket: WebSocket, mut receiver: broadcast::Receiver<TrainingMetrics>) {
let (mut sender, mut ws_receiver) = socket.split();
let recv_handle = tokio::spawn(async move {
while let Some(msg) = ws_receiver.next().await {
if let Ok(msg) = msg {
if matches!(msg, Message::Close(_)) {
break;
}
} else {
break;
}
}
});
loop {
if recv_handle.is_finished() {
break;
}
match receiver.recv().await {
Ok(metrics) => {
let message = MetricsMessage {
msg_type: "metrics".to_string(),
data: MetricsData {
epoch: metrics.epoch,
step: metrics.step,
loss: metrics.loss,
accuracy: metrics.accuracy,
lr: metrics.lr,
gpu_util: metrics.gpu_util,
memory_mb: metrics.memory_mb,
timestamp: metrics.timestamp.to_rfc3339(),
},
};
let json = serde_json::to_string(&message).unwrap_or_default();
if sender.send(Message::Text(json)).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Closed) => {
let status = StatusMessage {
msg_type: "status".to_string(),
data: StatusData {
status: "completed".to_string(),
completed_at: Some(chrono::Utc::now().to_rfc3339()),
},
};
let json = serde_json::to_string(&status).unwrap_or_default();
let _ = sender.send(Message::Text(json)).await;
break;
}
Err(broadcast::error::RecvError::Lagged(_)) => {
continue;
}
}
}
recv_handle.abort();
}
pub fn format_metrics(metrics: &TrainingMetrics) -> String {
let message = MetricsMessage {
msg_type: "metrics".to_string(),
data: MetricsData {
epoch: metrics.epoch,
step: metrics.step,
loss: metrics.loss,
accuracy: metrics.accuracy,
lr: metrics.lr,
gpu_util: metrics.gpu_util,
memory_mb: metrics.memory_mb,
timestamp: metrics.timestamp.to_rfc3339(),
},
};
serde_json::to_string(&message).unwrap_or_default()
}
pub fn format_status(status: &str, completed_at: Option<&str>) -> String {
let message = StatusMessage {
msg_type: "status".to_string(),
data: StatusData {
status: status.to_string(),
completed_at: completed_at.map(String::from),
},
};
serde_json::to_string(&message).unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_metrics() {
let metrics = TrainingMetrics {
epoch: 5,
step: 1000,
loss: Some(0.5),
accuracy: Some(0.9),
lr: Some(0.001),
gpu_util: None,
memory_mb: None,
custom: serde_json::json!({}),
timestamp: chrono::Utc::now(),
};
let json = MetricsStreamer::format_metrics(&metrics);
assert!(json.contains("\"type\":\"metrics\""));
assert!(json.contains("\"epoch\":5"));
}
#[test]
fn test_format_status() {
let json = MetricsStreamer::format_status("completed", Some("2024-01-15T10:30:00Z"));
assert!(json.contains("\"type\":\"status\""));
assert!(json.contains("\"status\":\"completed\""));
}
}