use crate::buffers::sample::{LatestAndRandomSampler, LatestDataSampler, RandomSampler, Sampler};
use crate::config::CONFIG;
use crate::entities::buffer;
use crate::models::V1UserProfile;
use crate::mutation::Mutation;
use crate::org::get_organization_names;
use crate::query::Query;
use crate::resources::v1::buffers::models::{
V1ReplayBuffer, V1ReplayBufferData, V1ReplayBufferRequest, V1ReplayBufferStatus,
V1UpdateReplayBufferRequest,
};
use crate::state::AppState;
use crate::validate::ValidatedJson;
use anyhow::{anyhow, Error, Result};
use aws_sdk_s3::types::ObjectCannedAcl;
use axum::{
extract::{Extension, Json, Path, Query as QueryExtractor, State},
http::StatusCode,
response::IntoResponse,
};
use chrono::Utc;
use nebulous::client::NebulousClient;
use nebulous::models::V1ResourceMeta;
use nebulous::resources::v1::containers::models::{V1ContainerRequest, V1EnvVar};
use nebulous::resources::v1::volumes::models::V1VolumeDriver;
use nebulous::resources::v1::volumes::models::V1VolumePath;
use rand::distributions::{Alphanumeric, DistString};
use sea_orm::IntoActiveModel;
use sea_orm::Set;
use sea_orm::*;
use serde_json::json;
use short_uuid::ShortUuid;
use std::collections::HashMap;
use tokio::fs::{create_dir_all, OpenOptions};
use tokio::io::AsyncWriteExt;
use tracing::{debug, error, info};
pub async fn create_buffer(
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
Json(payload): Json<V1ReplayBufferRequest>,
) -> impl IntoResponse {
info!("Creating buffer: {:?}", payload);
let db = state.db_pool.clone();
let id = ShortUuid::generate().to_string();
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(|s| s.as_str()).collect();
let owner = payload
.metadata
.owner
.clone()
.filter(|o| !o.is_empty())
.unwrap_or_else(|| user_profile.email.clone());
if !owner_ids.contains(&owner) {
return (
StatusCode::FORBIDDEN,
Json(json!({ "error": "Unauthorized owner specified" })),
);
}
match _create_buffer(&db, id, owner, &user_profile, payload).await {
Ok(replay_buffer) => {
let response = json!(replay_buffer);
(StatusCode::CREATED, Json(response))
}
Err(e) => {
info!("Error creating buffer: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Failed to create buffer" })),
)
}
}
}
pub async fn _create_buffer(
db: &DatabaseConnection,
id: String,
owner: String,
user_profile: &V1UserProfile,
payload: V1ReplayBufferRequest,
) -> Result<V1ReplayBuffer, Error> {
let namespace = payload.metadata.namespace.clone();
let name = payload
.metadata
.name
.clone()
.filter(|n| !n.is_empty())
.unwrap_or_else(|| petname::petname(3, "-").unwrap());
let namespace = namespace.unwrap_or_else(|| {
user_profile.handle.clone().unwrap_or(
user_profile
.email
.clone()
.replace("@", "-")
.replace(".", "-"),
)
});
let new_buffer = buffer::ActiveModel {
id: Set(id.clone()),
name: Set(name.clone()),
namespace: Set(namespace.clone()),
full_name: Set(format!("{}/{}", namespace.clone(), name.clone())),
owner_id: Set(owner.clone()),
train_every: Set(payload.train_every),
sample_n: Set(payload.sample_n),
sample_strategy: Set(payload.sample_strategy),
num_epochs: Set(payload.num_epochs),
train_job: Set(Some(
serde_json::to_value(&payload.train_job).unwrap_or_default(),
)),
labels: Set(payload
.metadata
.labels
.clone()
.map(|labels| serde_json::to_value(labels).unwrap())),
created_at: Set(Utc::now().into()),
updated_at: Set(Utc::now().into()),
created_by: Set(Some(user_profile.email.clone())),
..Default::default()
};
info!("Creating new buffer: {:?}", new_buffer);
match Mutation::create_buffer(&db, new_buffer).await {
Ok(buffer_model) => {
let replay_buffer = V1ReplayBuffer {
metadata: V1ResourceMeta {
id: buffer_model.id.clone(),
name: name.clone(),
namespace: namespace.clone(),
owner: owner.clone(),
labels: buffer_model
.labels
.clone()
.and_then(|v| serde_json::from_value(v).ok()),
created_at: buffer_model.created_at.timestamp(),
updated_at: buffer_model.updated_at.timestamp(),
created_by: user_profile.email.clone(),
owner_ref: None,
},
train_every: buffer_model.train_every.clone(),
sample_n: buffer_model.sample_n.clone(),
sample_strategy: buffer_model.sample_strategy.clone(),
status: V1ReplayBufferStatus {
num_records: None,
train_idx: None,
num_train_jobs: None,
last_train_job: None,
num_epochs: None,
},
train_job: buffer_model
.train_job
.clone()
.and_then(|v| serde_json::from_value(v).ok())
.unwrap_or_default(),
num_epochs: payload.num_epochs,
};
Ok(replay_buffer)
}
Err(e) => {
info!("Error creating buffer: {:?}", e);
Err(e.into())
}
}
}
pub async fn get_buffer(
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
Path((namespace, name)): Path<(String, String)>,
) -> impl IntoResponse {
let db = state.db_pool.clone();
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(|s| s.as_str()).collect();
match Query::find_buffer_by_name_and_owners(&db, &name, &namespace, &owner_id_refs).await {
Ok(Some(buffer_model)) => {
if !owner_ids.contains(&buffer_model.owner_id) {
let error_response = json!({ "error": "Buffer not found" });
return (StatusCode::NOT_FOUND, Json(error_response));
}
let replay_buffer = V1ReplayBuffer {
metadata: V1ResourceMeta {
id: buffer_model.id.clone(),
name: buffer_model.name.clone(),
namespace: buffer_model.namespace.clone(),
owner: buffer_model.owner_id.clone(),
labels: buffer_model
.labels
.clone()
.and_then(|v| serde_json::from_value(v).ok()),
created_at: buffer_model.created_at.timestamp(),
updated_at: buffer_model.updated_at.timestamp(),
created_by: user_profile.email.clone(),
owner_ref: None,
},
train_every: buffer_model.train_every.clone(),
sample_n: buffer_model.sample_n,
sample_strategy: buffer_model.sample_strategy.clone(),
num_epochs: buffer_model.num_epochs,
status: V1ReplayBufferStatus {
num_records: buffer_model.num_records,
train_idx: buffer_model.train_idx,
num_train_jobs: None,
last_train_job: None,
num_epochs: buffer_model.train_idx,
},
train_job: buffer_model
.train_job
.clone()
.and_then(|v| serde_json::from_value(v).ok())
.unwrap_or_default(),
};
let response = json!(replay_buffer);
(StatusCode::OK, Json(response))
}
Ok(None) => {
let error_response = json!({ "error": "Buffer not found" });
(StatusCode::NOT_FOUND, Json(error_response))
}
Err(e) => {
info!("Error fetching buffer: {:?}", e);
let error_response = json!({ "error": "Failed to retrieve buffer" });
(StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
}
}
}
pub async fn delete_buffer(
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
Path((namespace, name)): Path<(String, String)>,
) -> impl IntoResponse {
let db = state.db_pool.clone();
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(|s| s.as_str()).collect();
match Query::find_buffer_by_name_and_owners(&db, &name, &namespace, &owner_id_refs).await {
Ok(Some(buffer_model)) => {
if buffer_model.owner_id != user_profile.email
&& !owner_ids.contains(&buffer_model.owner_id)
{
let error_response = json!({ "error": "Buffer not found" });
return (StatusCode::NOT_FOUND, Json(error_response));
}
match Mutation::delete_buffer(&db, &buffer_model.id).await {
Ok(_) => (StatusCode::NO_CONTENT, Json(json!({}))),
Err(e) => {
info!("Error deleting buffer: {:?}", e);
let error_response = json!({ "error": "Failed to delete buffer" });
(StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
}
}
}
Ok(None) => {
let error_response = json!({ "error": "Buffer not found" });
(StatusCode::NOT_FOUND, Json(error_response))
}
Err(e) => {
info!("Error fetching buffer: {:?}", e);
let error_response = json!({ "error": "Failed to delete buffer" });
(StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
}
}
}
#[derive(Debug, serde::Deserialize)]
pub struct ListBuffersQuery {
pub labels: Option<HashMap<String, String>>,
}
pub async fn list_buffers(
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
QueryExtractor(query): QueryExtractor<ListBuffersQuery>,
) -> impl IntoResponse {
let db = state.db_pool.clone();
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(|s| s.as_str()).collect();
match Query::find_buffers_by_owners(&db, &owner_id_refs).await {
Ok(buffer_models) => {
let filtered_buffers = if let Some(ref requested_labels) = query.labels {
buffer_models
.into_iter()
.filter(|buffer_model| {
let db_labels: Option<HashMap<String, String>> = buffer_model
.labels
.clone()
.and_then(|val| serde_json::from_value(val).ok());
if let Some(ref actual_labels) = db_labels {
requested_labels
.iter()
.all(|(k, v)| actual_labels.get(k) == Some(v))
} else {
false
}
})
.collect::<Vec<_>>()
} else {
buffer_models
};
let replay_buffers: Vec<V1ReplayBuffer> = filtered_buffers
.into_iter()
.map(|buffer_model| V1ReplayBuffer {
metadata: V1ResourceMeta {
id: buffer_model.id.clone(),
name: buffer_model.name.clone(),
namespace: buffer_model.namespace.clone(),
owner: buffer_model.owner_id.clone(),
labels: buffer_model
.labels
.clone()
.and_then(|v| serde_json::from_value(v).ok()),
created_at: buffer_model.created_at.timestamp(),
updated_at: buffer_model.updated_at.timestamp(),
created_by: user_profile.email.clone(),
owner_ref: None,
},
train_every: buffer_model.train_every.clone(),
sample_n: buffer_model.sample_n,
sample_strategy: buffer_model.sample_strategy.clone(),
num_epochs: buffer_model.num_epochs,
status: V1ReplayBufferStatus {
num_records: buffer_model.num_records,
train_idx: buffer_model.train_idx,
num_train_jobs: None,
last_train_job: None,
num_epochs: Some(buffer_model.num_epochs),
},
train_job: buffer_model
.train_job
.clone()
.and_then(|v| serde_json::from_value(v).ok())
.unwrap_or_default(),
})
.collect();
let response = json!({ "buffers": replay_buffers });
(StatusCode::OK, Json(response))
}
Err(e) => {
info!("Error fetching buffers: {:?}", e);
let error_response = json!({ "error": "Failed to retrieve buffers" });
(StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
}
}
}
async fn trigger_training_job(
state: &AppState,
buffer_model: &buffer::Model,
current_idx: i32,
user_profile: &crate::models::V1UserProfile,
train_file_path: &str,
container_request: &V1ContainerRequest,
) -> anyhow::Result<()> {
use crate::buffers::sample::{
LatestAndRandomSampler, LatestDataSampler, RandomSampler, Sampler,
};
use rand::distributions::{Alphanumeric, DistString};
use tokio::fs::{create_dir_all, OpenOptions};
use tokio::io::AsyncWriteExt;
let mut container_request = container_request.clone();
let sample_strategy = buffer_model.sample_strategy.clone();
let sample_size = buffer_model.sample_n;
let sampler: Box<dyn Sampler + Send + Sync> = match sample_strategy.as_str() {
"Random" => {
let seed = u64::from_str_radix(&buffer_model.id[..16], 16).unwrap_or(42);
Box::new(RandomSampler::new(sample_size, seed))
}
"LatestWithRandom" => {
let seed = u64::from_str_radix(&buffer_model.id[..16], 16).unwrap_or(42);
Box::new(LatestAndRandomSampler::new(current_idx, sample_size, seed))
}
"Latest" => Box::new(LatestDataSampler {
last_index: current_idx,
}),
invalid => {
return Err(anyhow::anyhow!("Invalid sampling strategy: {}", invalid));
}
};
info!("Using sampler: {}", sampler.name());
let epoch_idx = buffer_model.epoch_idx.unwrap_or(0);
let epoch_delta = buffer_model.num_epochs;
let new_epoch_idx = epoch_idx + epoch_delta;
info!(
"Current epoch_idx: {}, adding {}, new: {}",
epoch_idx, epoch_delta, new_epoch_idx
);
let samples = sampler.sample(train_file_path).map_err(|e| {
anyhow::anyhow!(
"Sampling failed on file {} using {}: {:?}",
train_file_path,
sample_strategy,
e
)
})?;
info!("Sampled {} lines from {}", samples.len(), train_file_path);
let temp_dir = format!("/datasets/temp/{}", buffer_model.id);
create_dir_all(&temp_dir).await?;
let random_suffix = Alphanumeric.sample_string(&mut rand::thread_rng(), 5);
let temp_file_path = format!("{}/train-{}.jsonl", temp_dir, random_suffix);
{
let mut file = OpenOptions::new()
.create(true)
.write(true)
.open(&temp_file_path)
.await?;
for line in &samples {
file.write_all(line.as_bytes()).await?;
file.write_all(b"\n").await?;
}
}
info!("Wrote sampled data to {:?}", temp_file_path);
let file_content = tokio::fs::read(&temp_file_path).await?;
let bucket_name = std::env::var("S3_BUCKET_NAME")
.map_err(|_| anyhow::anyhow!("S3_BUCKET_NAME environment variable not set"))?;
let timestamp = chrono::Utc::now().timestamp();
let s3_dir = format!("buffers/{}/{}", buffer_model.id, timestamp);
let s3_key = format!("{}/train-{}.jsonl", s3_dir, timestamp);
let s3_client = aws_sdk_s3::Client::new(&aws_config::load_from_env().await);
s3_client
.put_object()
.bucket(&bucket_name)
.key(&s3_key)
.body(file_content.into())
.acl(aws_sdk_s3::types::ObjectCannedAcl::PublicRead)
.content_type("application/json")
.send()
.await?;
info!(
"Uploaded sample data to S3: s3://{}/{}",
bucket_name, s3_key
);
let s3_url = format!("https://{}.s3.amazonaws.com/{}", bucket_name, s3_key);
info!("Public URL for training file: {}", s3_url);
let nebulous_client = match NebulousClient::new_from_config() {
Ok(client) => client,
Err(e) => {
info!("Error creating nebulous client: {:?}", e);
return Err(anyhow::anyhow!("{:?}", e));
}
};
let mut current_env = container_request.env.clone().unwrap_or_default();
current_env.push(V1EnvVar {
key: "DATASET_URI".to_string(),
value: Some(s3_url.to_string()),
secret_name: None,
});
let mut current_volumes = container_request.volumes.clone().unwrap_or_default();
current_volumes.push(V1VolumePath {
source: format!("s3://{}", s3_dir),
dest: "/datasets/".to_string(),
resync: false,
continuous: false,
driver: V1VolumeDriver::RCLONE_SYNC,
});
current_env.push(V1EnvVar {
key: "DATASET_PATH".to_string(),
value: Some(temp_file_path.to_string()),
secret_name: None,
});
current_env.push(V1EnvVar {
key: "NUM_EPOCHS".to_string(),
value: Some(new_epoch_idx.to_string()),
secret_name: None,
});
container_request.env = Some(current_env);
container_request.volumes = Some(current_volumes);
debug!(
"Triggering training job with container_request: {:?}",
container_request
);
let container_response = nebulous_client
.create_container(&container_request)
.await
.map_err(|err| anyhow::anyhow!("{:?}", err))?;
debug!("Container response: {:?}", container_response);
let mut active_model = buffer_model.clone().into_active_model();
active_model.epoch_idx = sea_orm::Set(Some(new_epoch_idx));
Mutation::update_buffer(&state.db_pool, &active_model).await?;
info!("Updated buffer epoch_idx to {:?}", new_epoch_idx);
Ok(())
}
#[axum::debug_handler]
pub async fn send_examples(
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
Path((namespace, name)): Path<(String, String)>,
Json(payload): Json<V1ReplayBufferData>,
) -> impl IntoResponse {
let db = &state.db_pool;
match _send_examples(db, &state, &user_profile, &namespace, &name, &payload).await {
Ok(success_json) => (StatusCode::OK, Json(success_json)),
Err(e) => {
info!("send_examples encountered an error: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Failed to send examples: {:#}", e) })),
)
}
}
}
pub async fn _send_examples(
db: &DatabaseConnection,
state: &AppState,
user_profile: &crate::models::V1UserProfile,
namespace: &str,
name: &str,
payload: &V1ReplayBufferData,
) -> anyhow::Result<serde_json::Value> {
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(|s| s.as_str()).collect();
let buffer_model =
match Query::find_buffer_by_name_and_owners(db, name, namespace, &owner_id_refs).await {
Ok(Some(buffer_model)) => buffer_model,
Ok(None) => {
anyhow::bail!("Buffer not found for namespace={} name={}", namespace, name);
}
Err(e) => {
anyhow::bail!("Database error querying buffer: {:?}", e);
}
};
info!("buffer_model: {:?}", buffer_model);
let dataset_dir = format!("{}/{}", CONFIG.dataset_dir, buffer_model.id);
debug!("creating dataset dir {:?}", dataset_dir);
create_dir_all(&dataset_dir).await.map_err(|e| {
error!("Failed to create directory: {:?}", e);
anyhow::anyhow!("Failed to create directory: {:?}", e)
})?;
debug!("dataset_dir: {:?}", dataset_dir);
let train_file_path = format!("{}/train.jsonl", dataset_dir);
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&train_file_path)
.await
.map_err(|e| anyhow::anyhow!("Failed to open train.jsonl: {:?}", e))?;
debug!("train_file_path: {:?}", train_file_path);
let original_num_records = buffer_model.num_records.unwrap_or(0);
let new_num_records = original_num_records + payload.examples.len() as i32;
info!("num_records: {:?}", buffer_model.num_records);
info!("new examples: {:?}", payload.examples.len());
info!("new_num_records: {:?}", new_num_records);
for example in &payload.examples {
let json_str = serde_json::to_string(example)
.map_err(|e| anyhow::anyhow!("Error converting JSON to string: {:?}", e))?;
file.write_all(json_str.as_bytes())
.await
.map_err(|e| anyhow::anyhow!("Failed to write example: {:?}", e))?;
file.write_all(b"\n")
.await
.map_err(|e| anyhow::anyhow!("Failed to write newline: {:?}", e))?;
}
info!(
"Wrote {} examples to file: {:?}",
payload.examples.len(),
train_file_path
);
let mut buffer_active_model: buffer::ActiveModel = buffer_model.clone().into();
buffer_active_model.num_records = Set(Some(new_num_records));
buffer_active_model.updated_at = Set(Utc::now().into());
if let Err(e) = Mutation::update_buffer(db, &buffer_active_model).await {
info!("Error updating buffer record count: {:?}", e);
}
if let Some(train_every) = buffer_model.train_every {
let train_idx = buffer_model.train_idx.unwrap_or(0);
let diff = new_num_records - train_idx;
info!("train_every: {}, diff: {}", train_every, diff);
if diff >= train_every || payload.train.unwrap_or(false) {
let container_req_result = buffer_model
.train_job
.clone()
.map(|raw_json| serde_json::from_value::<V1ContainerRequest>(raw_json));
let container_request = match container_req_result {
Some(Ok(parsed)) => parsed,
Some(Err(e)) => {
anyhow::bail!("Invalid train_job JSON: {:?}", e);
}
None => {
anyhow::bail!("No train_job present in this buffer");
}
};
if let Err(err) = trigger_training_job(
state,
&buffer_model,
new_num_records,
user_profile,
&train_file_path,
&container_request,
)
.await
{
anyhow::bail!("Failed to trigger training job: {:?}", err);
}
let mut buffer_after_training: buffer::ActiveModel = buffer_model.into_active_model();
buffer_after_training.train_idx = Set(Some(new_num_records));
if let Err(e) = Mutation::update_buffer(db, &buffer_after_training).await {
anyhow::bail!("Failed to update train_idx: {:?}", e);
}
}
}
Ok(json!({ "message": "Examples saved successfully" }))
}
pub async fn train_buffer(
State(state): State<AppState>,
Extension(user_profile): Extension<V1UserProfile>,
Path((namespace, name)): Path<(String, String)>,
) -> impl IntoResponse {
match _train_buffer(&state.db_pool, &state, &user_profile, &namespace, &name).await {
Ok(_) => {
let success = json!({ "message": "Buffer triggered successfully" });
(StatusCode::OK, Json(success))
}
Err(e) => {
info!("Error in train_buffer: {:?}", e);
let status_code = if e.to_string().contains("Buffer not found") {
StatusCode::NOT_FOUND
} else if e.to_string().contains("No train_job present")
|| e.to_string().contains("Invalid train_job")
{
StatusCode::BAD_REQUEST
} else {
StatusCode::INTERNAL_SERVER_ERROR
};
(status_code, Json(json!({ "error": e.to_string() })))
}
}
}
pub async fn _train_buffer(
db: &DatabaseConnection,
state: &AppState,
user_profile: &V1UserProfile,
namespace: &str,
name: &str,
) -> anyhow::Result<()> {
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(|s| s.as_str()).collect();
let buffer_model =
match Query::find_buffer_by_name_and_owners(db, name, namespace, &owner_id_refs).await {
Ok(Some(buffer_model)) => buffer_model,
Ok(None) => {
anyhow::bail!("Buffer not found");
}
Err(e) => {
info!("Error fetching buffer: {:?}", e);
anyhow::bail!("Failed to query buffer: {}", e);
}
};
info!("buffer_model: {:?}", buffer_model);
let dataset_dir = format!("/datasets/buffers/{}", buffer_model.id.clone());
create_dir_all(&dataset_dir)
.await
.map_err(|e| anyhow::anyhow!("Failed to create directory: {}", e))?;
let train_file_path = format!("{}/train.jsonl", dataset_dir);
OpenOptions::new()
.create(true)
.append(true)
.open(&train_file_path)
.await
.map_err(|e| anyhow::anyhow!("Failed to open train.jsonl: {}", e))?;
let container_req_result = buffer_model
.train_job
.clone()
.map(|raw_json| serde_json::from_value::<V1ContainerRequest>(raw_json));
let container_request = match container_req_result {
Some(Ok(parsed)) => parsed,
Some(Err(e)) => {
anyhow::bail!("Invalid train_job JSON: {}", e);
}
None => {
anyhow::bail!("No train_job present in this buffer");
}
};
info!("triggering training job");
trigger_training_job(
state,
&buffer_model.clone(),
buffer_model.num_records.clone().unwrap_or(0),
user_profile,
&train_file_path,
&container_request,
)
.await
.map_err(|e| anyhow::anyhow!("Failed to trigger training job: {}", e))?;
let mut buffer_after_training: buffer::ActiveModel = buffer_model.clone().into_active_model();
buffer_after_training.train_idx = Set(Some(buffer_model.num_records.unwrap_or(0)));
Mutation::update_buffer(db, &buffer_after_training)
.await
.map_err(|e| anyhow::anyhow!("Failed to update train_idx: {}", e))?;
Ok(())
}
#[axum::debug_handler]
pub async fn update_buffer(
State(state): State<AppState>,
Extension(user_profile): Extension<V1UserProfile>,
Path((namespace, name)): Path<(String, String)>,
ValidatedJson(payload): ValidatedJson<V1UpdateReplayBufferRequest>,
) -> impl IntoResponse {
let db = state.db_pool.clone();
match _update_buffer(&db, &user_profile, &namespace, &name, &payload).await {
Ok(replay_buffer) => {
let response_body = json!({ "buffer": replay_buffer });
(StatusCode::OK, Json(response_body))
}
Err(e) => {
info!("Error updating buffer: {:?}", e);
let error_response = json!({ "error": "Failed to update buffer" });
(StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
}
}
}
pub async fn _update_buffer(
db: &DatabaseConnection,
user_profile: &V1UserProfile,
namespace: &str,
name: &str,
payload: &V1UpdateReplayBufferRequest,
) -> Result<V1ReplayBuffer, Error> {
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(|s| s.as_str()).collect();
let buffer_model =
match Query::find_buffer_by_name_and_owners(&db, &name, &namespace, &owner_id_refs).await {
Ok(Some(model)) => model,
Ok(None) => {
return Err(anyhow::anyhow!("Buffer not found or not authorized"));
}
Err(e) => {
info!("Error fetching buffer for update: {:?}", e);
return Err(anyhow::anyhow!("Failed to retrieve buffer"));
}
};
let mut buffer_active_model: buffer::ActiveModel = buffer_model.clone().into_active_model();
let mut something_changed = false;
if let Some(new_val) = &payload.train_every {
if *new_val != buffer_model.train_every.unwrap_or_default() {
buffer_active_model.train_every = Set(Some(*new_val));
something_changed = true;
}
}
if let Some(new_val) = &payload.sample_n {
if *new_val != buffer_model.sample_n {
buffer_active_model.sample_n = Set(*new_val);
something_changed = true;
}
}
if let Some(new_val) = &payload.sample_strategy {
if new_val != &buffer_model.sample_strategy {
buffer_active_model.sample_strategy = Set(new_val.clone());
something_changed = true;
}
}
if let Some(new_val) = &payload.train_job {
let new_val_json = serde_json::to_value(new_val).unwrap_or_default();
if Some(new_val_json.clone()) != buffer_model.train_job {
buffer_active_model.train_job = Set(Some(new_val_json));
something_changed = true;
}
}
if !something_changed {
debug!("No changes detected");
return Ok(V1ReplayBuffer {
metadata: V1ResourceMeta {
id: buffer_model.id.clone(),
name: buffer_model.name.clone(),
namespace: buffer_model.namespace.clone(),
owner: buffer_model.owner_id.clone(),
labels: buffer_model
.labels
.clone()
.and_then(|v| serde_json::from_value(v).ok()),
created_at: buffer_model.created_at.timestamp(),
updated_at: buffer_model.updated_at.timestamp(),
created_by: user_profile.email.clone(),
owner_ref: None,
},
train_every: buffer_model.train_every.clone(),
sample_n: buffer_model.sample_n.clone(),
sample_strategy: buffer_model.sample_strategy.clone(),
num_epochs: buffer_model.num_epochs,
train_job: buffer_model
.train_job
.clone()
.and_then(|v| serde_json::from_value(v).ok())
.unwrap_or_default(),
status: V1ReplayBufferStatus {
num_records: buffer_model.num_records,
train_idx: buffer_model.train_idx,
num_train_jobs: None,
last_train_job: None,
num_epochs: Some(buffer_model.num_epochs),
},
});
}
buffer_active_model.updated_at = Set(Utc::now().into());
match Mutation::update_buffer(&db, &buffer_active_model).await {
Ok(updated_model) => {
let response_body = V1ReplayBuffer {
metadata: V1ResourceMeta {
id: updated_model.id.clone(),
name: updated_model.name.clone(),
namespace: updated_model.namespace.clone(),
owner: updated_model.owner_id.clone(),
labels: updated_model
.labels
.clone()
.and_then(|v| serde_json::from_value(v).ok()),
created_at: updated_model.created_at.timestamp(),
updated_at: updated_model.updated_at.timestamp(),
created_by: user_profile.email.clone(),
owner_ref: None,
},
train_every: updated_model.train_every.clone(),
sample_n: updated_model.sample_n.clone(),
sample_strategy: updated_model.sample_strategy.clone(),
num_epochs: updated_model.num_epochs,
train_job: updated_model
.train_job
.clone()
.and_then(|v| serde_json::from_value(v).ok())
.unwrap_or_default(),
status: V1ReplayBufferStatus {
num_records: updated_model.num_records,
train_idx: updated_model.train_idx,
num_train_jobs: None,
last_train_job: None,
num_epochs: Some(updated_model.num_epochs),
},
};
Ok(response_body)
}
Err(e) => {
info!("Error updating buffer: {:?}", e);
Err(anyhow::anyhow!("Failed to update buffer"))
}
}
}