use crate::api::AppState;
use crate::auth::{AuthError, AuthUser};
use crate::db::runs::{NewTrainingRun, RunConfig, RunRepository, RunStatus};
use crate::training::websocket::MetricsStreamer;
use axum::{
Json,
extract::{
Path, Query, State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
http::StatusCode,
response::IntoResponse,
};
use chrono::Utc;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Deserialize)]
pub struct ListRunsQuery {
#[serde(default)]
pub status: Option<String>,
#[serde(default = "default_limit")]
pub limit: u32,
#[serde(default)]
pub offset: u32,
}
fn default_limit() -> u32 {
100
}
#[derive(Debug, Deserialize)]
pub struct CreateRunRequest {
pub name: String,
pub model_type: String,
#[serde(default)]
pub model_version_id: Option<String>,
#[serde(default)]
pub dataset_id: Option<String>,
pub config: RunConfigRequest,
}
#[derive(Debug, Deserialize)]
pub struct RunConfigRequest {
pub epochs: u32,
pub batch_size: u32,
pub learning_rate: f64,
#[serde(default)]
pub steps_per_epoch: Option<u32>,
#[serde(default)]
pub optimizer: String,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct RunResponse {
pub id: String,
pub user_id: String,
pub name: String,
pub model_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_version_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dataset_id: Option<String>,
pub status: String,
pub config: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub latest_metrics: Option<MetricsResponse>,
pub started_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<String>,
pub created_at: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MetricsResponse {
pub epoch: u32,
pub step: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub loss: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub accuracy: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lr: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub gpu_util: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub memory_mb: Option<f64>,
#[serde(default)]
pub custom: serde_json::Value,
pub timestamp: String,
}
#[derive(Debug, Deserialize)]
pub struct RecordMetricsRequest {
pub epoch: u32,
pub step: u32,
#[serde(default)]
pub loss: Option<f64>,
#[serde(default)]
pub accuracy: Option<f64>,
#[serde(default)]
pub lr: Option<f64>,
#[serde(default)]
pub gpu_util: Option<f64>,
#[serde(default)]
pub memory_mb: Option<f64>,
#[serde(default)]
pub custom: serde_json::Value,
}
#[derive(Debug, Deserialize)]
pub struct AppendLogRequest {
pub message: String,
#[serde(default)]
pub level: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub timestamp: String,
pub level: String,
pub message: String,
}
#[derive(Debug, Serialize)]
pub struct LogsResponse {
pub logs: Vec<LogEntry>,
}
pub async fn list_runs(
State(state): State<AppState>,
user: AuthUser,
Query(query): Query<ListRunsQuery>,
) -> Result<Json<Vec<RunResponse>>, AuthError> {
let repo = RunRepository::new(&state.db);
let status = query.status.as_ref().and_then(|s| match s.as_str() {
"pending" => Some(RunStatus::Pending),
"running" => Some(RunStatus::Running),
"completed" => Some(RunStatus::Completed),
"failed" => Some(RunStatus::Failed),
"stopped" => Some(RunStatus::Stopped),
_ => None,
});
let runs = if user.role == "admin" {
repo.list_all(status, Some(query.limit), Some(query.offset))
.await
} else {
repo.list_by_user(&user.id, status, Some(query.limit), Some(query.offset))
.await
}
.map_err(|e| AuthError::Internal(e.to_string()))?;
let response: Vec<RunResponse> = runs
.into_iter()
.map(|r| RunResponse {
id: r.id,
user_id: r.user_id,
name: r.name,
model_type: r.model_type,
model_version_id: r.model_version_id,
dataset_id: r.dataset_id,
status: format!("{:?}", r.status).to_lowercase(),
config: serde_json::to_value(&r.config).unwrap_or_default(),
latest_metrics: r.latest_metrics.map(|m| MetricsResponse {
epoch: m.epoch,
step: m.step,
loss: m.loss,
accuracy: m.accuracy,
lr: m.lr,
gpu_util: m.gpu_util,
memory_mb: m.memory_mb,
custom: m.custom,
timestamp: m.timestamp.to_rfc3339(),
}),
started_at: r.started_at.to_rfc3339(),
completed_at: r.completed_at.map(|t| t.to_rfc3339()),
created_at: r.created_at.to_rfc3339(),
})
.collect();
Ok(Json(response))
}
pub async fn create_run(
State(state): State<AppState>,
user: AuthUser,
Json(req): Json<CreateRunRequest>,
) -> Result<(StatusCode, Json<RunResponse>), AuthError> {
let repo = RunRepository::new(&state.db);
let config = RunConfig {
epochs: req.config.epochs,
batch_size: req.config.batch_size,
learning_rate: req.config.learning_rate,
steps_per_epoch: req.config.steps_per_epoch.unwrap_or(100),
optimizer: if req.config.optimizer.is_empty() {
"adam".to_string()
} else {
req.config.optimizer
},
extra: serde_json::to_value(&req.config.extra).unwrap_or_default(),
};
let run = repo
.create(NewTrainingRun {
user_id: user.id,
name: req.name,
model_type: req.model_type,
model_version_id: req.model_version_id,
dataset_id: req.dataset_id,
config,
})
.await
.map_err(|e| AuthError::Internal(e.to_string()))?;
let logs_dir = state.config.runs_dir().join(&run.id);
std::fs::create_dir_all(&logs_dir).ok();
if let Err(e) = state.tracker.start_run(&run.id).await {
tracing::warn!(run_id = %run.id, error = %e, "Failed to start run tracking");
}
if let Err(e) = state.executor.start_training(run.clone()).await {
tracing::error!(run_id = %run.id, error = %e, "Failed to start training execution");
} else {
tracing::info!(run_id = %run.id, "Training execution started");
}
Ok((
StatusCode::CREATED,
Json(RunResponse {
id: run.id,
user_id: run.user_id,
name: run.name,
model_type: run.model_type,
model_version_id: run.model_version_id,
dataset_id: run.dataset_id,
status: "running".to_string(), config: serde_json::to_value(&run.config).unwrap_or_default(),
latest_metrics: None,
started_at: run.started_at.to_rfc3339(),
completed_at: None,
created_at: run.created_at.to_rfc3339(),
}),
))
}
pub async fn get_run(
State(state): State<AppState>,
user: AuthUser,
Path(id): Path<String>,
) -> Result<Json<RunResponse>, AuthError> {
let repo = RunRepository::new(&state.db);
let run = repo
.find_by_id(&id)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?
.ok_or(AuthError::NotFound("Run not found".to_string()))?;
if run.user_id != user.id && user.role != "admin" {
return Err(AuthError::Unauthorized);
}
Ok(Json(RunResponse {
id: run.id,
user_id: run.user_id,
name: run.name,
model_type: run.model_type,
model_version_id: run.model_version_id,
dataset_id: run.dataset_id,
status: format!("{:?}", run.status).to_lowercase(),
config: serde_json::to_value(&run.config).unwrap_or_default(),
latest_metrics: run.latest_metrics.map(|m| MetricsResponse {
epoch: m.epoch,
step: m.step,
loss: m.loss,
accuracy: m.accuracy,
lr: m.lr,
gpu_util: m.gpu_util,
memory_mb: m.memory_mb,
custom: m.custom,
timestamp: m.timestamp.to_rfc3339(),
}),
started_at: run.started_at.to_rfc3339(),
completed_at: run.completed_at.map(|t| t.to_rfc3339()),
created_at: run.created_at.to_rfc3339(),
}))
}
pub async fn delete_run(
State(state): State<AppState>,
user: AuthUser,
Path(id): Path<String>,
) -> Result<StatusCode, AuthError> {
let repo = RunRepository::new(&state.db);
let run = repo
.find_by_id(&id)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?
.ok_or(AuthError::NotFound("Run not found".to_string()))?;
if run.user_id != user.id && user.role != "admin" {
return Err(AuthError::Unauthorized);
}
if state.tracker.is_tracking(&id).await {
let _ = state.tracker.stop_run(&id).await;
}
let logs_dir = state.config.runs_dir().join(&id);
std::fs::remove_dir_all(&logs_dir).ok();
repo.delete(&id)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?;
Ok(StatusCode::NO_CONTENT)
}
#[derive(Debug, Deserialize)]
pub struct CompleteRunRequest {
#[serde(default)]
pub success: bool,
}
pub async fn complete_run(
State(state): State<AppState>,
user: AuthUser,
Path(id): Path<String>,
Json(req): Json<CompleteRunRequest>,
) -> Result<Json<RunResponse>, AuthError> {
let repo = RunRepository::new(&state.db);
let run = repo
.find_by_id(&id)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?
.ok_or(AuthError::NotFound("Run not found".to_string()))?;
if run.user_id != user.id && user.role != "admin" {
return Err(AuthError::Unauthorized);
}
if let Err(e) = state.tracker.complete_run(&id, req.success).await {
tracing::warn!(run_id = %id, error = %e, "Failed to complete run tracking");
}
let run = repo
.find_by_id(&id)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?
.ok_or(AuthError::NotFound("Run not found".to_string()))?;
Ok(Json(RunResponse {
id: run.id,
user_id: run.user_id,
name: run.name,
model_type: run.model_type,
model_version_id: run.model_version_id,
dataset_id: run.dataset_id,
status: format!("{:?}", run.status).to_lowercase(),
config: serde_json::to_value(&run.config).unwrap_or_default(),
latest_metrics: run.latest_metrics.map(|m| MetricsResponse {
epoch: m.epoch,
step: m.step,
loss: m.loss,
accuracy: m.accuracy,
lr: m.lr,
gpu_util: m.gpu_util,
memory_mb: m.memory_mb,
custom: m.custom,
timestamp: m.timestamp.to_rfc3339(),
}),
started_at: run.started_at.to_rfc3339(),
completed_at: run.completed_at.map(|t| t.to_rfc3339()),
created_at: run.created_at.to_rfc3339(),
}))
}
pub async fn stop_run(
State(state): State<AppState>,
user: AuthUser,
Path(id): Path<String>,
) -> Result<Json<RunResponse>, AuthError> {
let repo = RunRepository::new(&state.db);
let run = repo
.find_by_id(&id)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?
.ok_or(AuthError::NotFound("Run not found".to_string()))?;
if run.user_id != user.id && user.role != "admin" {
return Err(AuthError::Unauthorized);
}
if let Err(e) = state.executor.stop_run(&id).await {
tracing::warn!(run_id = %id, error = %e, "Failed to stop training execution (may have already completed)");
}
if let Err(e) = state.tracker.stop_run(&id).await {
tracing::warn!(run_id = %id, error = %e, "Failed to stop run tracking");
}
let run = repo
.update_status(&id, RunStatus::Stopped)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?;
Ok(Json(RunResponse {
id: run.id,
user_id: run.user_id,
name: run.name,
model_type: run.model_type,
model_version_id: run.model_version_id,
dataset_id: run.dataset_id,
status: format!("{:?}", run.status).to_lowercase(),
config: serde_json::to_value(&run.config).unwrap_or_default(),
latest_metrics: run.latest_metrics.map(|m| MetricsResponse {
epoch: m.epoch,
step: m.step,
loss: m.loss,
accuracy: m.accuracy,
lr: m.lr,
gpu_util: m.gpu_util,
memory_mb: m.memory_mb,
custom: m.custom,
timestamp: m.timestamp.to_rfc3339(),
}),
started_at: run.started_at.to_rfc3339(),
completed_at: run.completed_at.map(|t| t.to_rfc3339()),
created_at: run.created_at.to_rfc3339(),
}))
}
pub async fn get_metrics(
State(state): State<AppState>,
user: AuthUser,
Path(id): Path<String>,
) -> Result<Json<Vec<MetricsResponse>>, AuthError> {
let repo = RunRepository::new(&state.db);
let run = repo
.find_by_id(&id)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?
.ok_or(AuthError::NotFound("Run not found".to_string()))?;
if run.user_id != user.id && user.role != "admin" {
return Err(AuthError::Unauthorized);
}
let metrics = repo
.get_metrics_history(&id, Some(10000))
.await
.map_err(|e| AuthError::Internal(e.to_string()))?;
let response: Vec<MetricsResponse> = metrics
.into_iter()
.map(|m| MetricsResponse {
epoch: m.epoch,
step: m.step,
loss: m.loss,
accuracy: m.accuracy,
lr: m.lr,
gpu_util: m.gpu_util,
memory_mb: m.memory_mb,
custom: m.custom,
timestamp: m.timestamp.to_rfc3339(),
})
.collect();
Ok(Json(response))
}
pub async fn record_metrics(
State(state): State<AppState>,
user: AuthUser,
Path(id): Path<String>,
Json(req): Json<RecordMetricsRequest>,
) -> Result<StatusCode, AuthError> {
let repo = RunRepository::new(&state.db);
let run = repo
.find_by_id(&id)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?
.ok_or(AuthError::NotFound("Run not found".to_string()))?;
if run.user_id != user.id && user.role != "admin" {
return Err(AuthError::Unauthorized);
}
state
.tracker
.record_metrics(
&id,
req.epoch,
req.step,
req.loss,
req.accuracy,
req.lr,
req.gpu_util,
req.memory_mb,
req.custom.clone(),
)
.await
.map_err(AuthError::Internal)?;
if run.status == RunStatus::Pending {
repo.update_status(&id, RunStatus::Running)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?;
}
Ok(StatusCode::CREATED)
}
pub async fn get_logs(
State(state): State<AppState>,
user: AuthUser,
Path(id): Path<String>,
) -> Result<Json<LogsResponse>, AuthError> {
let repo = RunRepository::new(&state.db);
let run = repo
.find_by_id(&id)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?
.ok_or(AuthError::NotFound("Run not found".to_string()))?;
if run.user_id != user.id && user.role != "admin" {
return Err(AuthError::Unauthorized);
}
let logs_path = state.config.runs_dir().join(&id).join("logs.jsonl");
let logs = if logs_path.exists() {
std::fs::read_to_string(&logs_path)
.map_err(|e| AuthError::Internal(e.to_string()))?
.lines()
.filter_map(|line| serde_json::from_str::<LogEntry>(line).ok())
.collect()
} else {
vec![]
};
Ok(Json(LogsResponse { logs }))
}
pub async fn append_log(
State(state): State<AppState>,
user: AuthUser,
Path(id): Path<String>,
Json(req): Json<AppendLogRequest>,
) -> Result<StatusCode, AuthError> {
let repo = RunRepository::new(&state.db);
let run = repo
.find_by_id(&id)
.await
.map_err(|e| AuthError::Internal(e.to_string()))?
.ok_or(AuthError::NotFound("Run not found".to_string()))?;
if run.user_id != user.id && user.role != "admin" {
return Err(AuthError::Unauthorized);
}
let entry = LogEntry {
timestamp: Utc::now().to_rfc3339(),
level: if req.level.is_empty() {
"INFO".to_string()
} else {
req.level.to_uppercase()
},
message: req.message,
};
let logs_dir = state.config.runs_dir().join(&id);
std::fs::create_dir_all(&logs_dir).ok();
let logs_path = logs_dir.join("logs.jsonl");
let line = serde_json::to_string(&entry).unwrap_or_default() + "\n";
use std::io::Write;
let mut file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&logs_path)
.map_err(|e| AuthError::Internal(e.to_string()))?;
file.write_all(line.as_bytes())
.map_err(|e| AuthError::Internal(e.to_string()))?;
Ok(StatusCode::CREATED)
}
#[derive(Debug, Deserialize)]
pub struct WsStreamQuery {
pub token: Option<String>,
}
pub async fn stream_metrics(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Path(id): Path<String>,
Query(query): Query<WsStreamQuery>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let token = query.token.ok_or((
StatusCode::UNAUTHORIZED,
"Missing token parameter. Connect with ?token=<jwt>".to_string(),
))?;
let _claims = state.jwt.validate_access_token(&token).map_err(|e| {
tracing::warn!(run_id = %id, "Unauthorized WebSocket connection attempt");
(StatusCode::UNAUTHORIZED, format!("Invalid token: {}", e))
})?;
Ok(ws.on_upgrade(move |socket| handle_metrics_stream(socket, state, id)))
}
async fn handle_metrics_stream(socket: WebSocket, state: AppState, run_id: String) {
if let Some(receiver) = state.tracker.subscribe(&run_id).await {
MetricsStreamer::stream(socket, receiver).await;
} else {
handle_metrics_stream_polling(socket, state, run_id).await;
}
}
async fn handle_metrics_stream_polling(socket: WebSocket, state: AppState, run_id: String) {
let (mut sender, mut receiver) = socket.split();
let poll_state = state.clone();
let poll_id = run_id.clone();
let poll_handle = tokio::spawn(async move {
let repo = RunRepository::new(&poll_state.db);
let mut last_step = 0u32;
loop {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
if let Ok(Some(run)) = repo.find_by_id(&poll_id).await {
if run.status != RunStatus::Running && run.status != RunStatus::Pending {
let status_json = MetricsStreamer::format_status(
&format!("{:?}", run.status).to_lowercase(),
run.completed_at.as_ref().map(|t| t.to_rfc3339()).as_deref(),
);
let _ = sender.send(Message::Text(status_json)).await;
break;
}
if let Some(metrics) = &run.latest_metrics {
if metrics.step > last_step {
last_step = metrics.step;
let json = MetricsStreamer::format_metrics(metrics);
if sender.send(Message::Text(json)).await.is_err() {
break;
}
}
}
} else {
break;
}
}
});
while let Some(msg) = receiver.next().await {
if let Ok(msg) = msg {
if matches!(msg, Message::Close(_)) {
break;
}
} else {
break;
}
}
poll_handle.abort();
}