use crate::config::CONFIG;
use crate::handlers::v1::buffers::{_create_buffer, _send_examples, _train_buffer, _update_buffer};
use crate::mutation::Mutation;
use crate::query::Query;
use crate::resources::v1::buffers::models::{
V1ReplayBuffer, V1ReplayBufferData, V1UpdateReplayBufferRequest,
};
use crate::resources::v1::llms::models::{
V1BufferOption, V1BufferOptionRequest, V1OnlineLLM, V1OnlineLLMRequest, V1OnlineLLMStatus,
V1OnlineLLMs, V1ResourceReference, V1ServerOption, V1ServerOptionRequest, V1UpdateBufferOption,
V1UpdateOnlineLLMRequest, V1UpdateServerOption,
};
use crate::state::AppState;
use anyhow::{anyhow, bail};
use axum::http::StatusCode;
use axum::{
extract::{Extension, Json, Path, State},
response::IntoResponse,
};
use chrono::Utc;
use nebulous::client::NebulousClient;
use nebulous::models::{V1ResourceMeta, V1ResourceMetaRequest};
use openai_api_rs::v1::api::OpenAIClient;
use openai_api_rs::v1::chat_completion::ChatCompletionRequest;
use sea_orm::ActiveModelTrait;
use serde_json::json;
use short_uuid::ShortUuid;
use tracing::{debug, error, info};
fn cleanup_buffer_if_created(
db: &sea_orm::DatabaseConnection,
buffer_id: &str,
was_created: bool,
error_context: &str,
) {
if was_created {
debug!("Cleaning up buffer: {:?}", buffer_id);
if let Err(cleanup_err) = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(Mutation::delete_buffer(db, buffer_id))
}) {
error!(
"Failed to clean up buffer after {}: {}",
error_context, cleanup_err
);
}
}
}
#[axum::debug_handler]
pub async fn create_llm(
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
Json(payload): Json<V1OnlineLLMRequest>,
) -> impl IntoResponse {
let db = state.db_pool.clone();
let id = ShortUuid::generate().to_string();
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let name = match payload.metadata.name {
Some(name) => name,
None => petname::petname(3, "-").unwrap(),
};
let namespace = match payload.metadata.namespace {
Some(namespace) => namespace,
None => user_profile.handle.clone().unwrap_or(
user_profile
.email
.clone()
.replace("@", "-")
.replace(".", "-"),
),
};
let owner = payload
.metadata
.owner
.clone()
.filter(|o| !o.is_empty())
.unwrap_or_else(|| user_profile.email.clone());
if !owner_ids.contains(&owner) {
return (
StatusCode::FORBIDDEN,
Json(json!({ "error": "Unauthorized owner specified" })),
);
}
let (buffer_id, buffer, buffer_was_created) = match payload.buffer.clone() {
V1BufferOptionRequest::Id(id) => {
debug!("Finding buffer by id: {:?}", id);
let owner_id_refs: Vec<&str> = owner_ids.iter().map(String::as_str).collect();
let buffer = match Query::find_buffer_by_id_and_owners(&db, &id, &owner_id_refs).await {
Ok(Some(buf)) => buf,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({ "error": "Buffer not found" })),
);
}
Err(err) => {
error!("Failed to find buffer: {:?}", err);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("DB error: {}", err) })),
);
}
};
debug!("Found buffer: {:?}", buffer);
(
buffer.id.clone(),
V1ReplayBuffer::from_model(&buffer.clone()),
false,
)
}
V1BufferOptionRequest::Buffer(mut buffer) => {
debug!("Creating buffer: {:?}", buffer);
let mut meta = V1ResourceMetaRequest::default();
meta.name = Some(name.clone());
meta.namespace = Some(namespace.clone());
meta.owner = Some(owner.clone());
meta.owner_ref = Some(format!("{}.{}.OnlineLLM", name, namespace));
buffer.metadata = meta;
let buffer = match _create_buffer(
&db,
id.clone(),
owner.clone(),
&user_profile,
buffer.clone(),
)
.await
{
Ok(buffer) => buffer,
Err(err) => {
error!("Failed to create buffer: {:?}", err);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Failed to create buffer: {}", err) })),
);
}
};
debug!("Created buffer: {:?}", buffer);
(buffer.metadata.id.clone(), buffer.clone(), true)
}
};
let client = match NebulousClient::new_from_config() {
Ok(client) => client,
Err(err) => {
error!("Nebulous client error: {:?}", err);
cleanup_buffer_if_created(&db, &buffer_id, buffer_was_created, "Nebulous client error");
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", err) })),
);
}
};
debug!("Nebulous client: {:?}", client);
let (server_ref, server) = match payload.server {
V1ServerOptionRequest::Ref(reference) => {
match reference.kind.as_str().to_lowercase().as_str() {
"container" => {
debug!("Checking if server exists: {:?}", reference);
let found = match client
.get_container(&reference.name, &reference.namespace)
.await
{
Ok(container) => container,
Err(err) => {
error!("Container lookup error: {:?}", err);
cleanup_buffer_if_created(
&db,
&buffer_id,
buffer_was_created,
"container lookup error",
);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", err) })),
);
}
};
(reference.clone(), found.clone())
}
_ => {
error!("Invalid server kind: {:?}", reference);
cleanup_buffer_if_created(
&db,
&buffer_id,
buffer_was_created,
"invalid server kind",
);
return (
StatusCode::BAD_REQUEST,
Json(json!({ "error": "Invalid server kind" })),
);
}
}
}
V1ServerOptionRequest::Server(mut server) => {
debug!("Creating server container: {:?}", server);
let mut meta = V1ResourceMetaRequest::default();
meta.name = Some(name.clone());
meta.namespace = Some(namespace.clone());
meta.owner = Some(owner.clone());
meta.owner_ref = Some(format!("{}.{}.OnlineLLM", name, namespace));
server.metadata = Some(meta);
let container = match client.create_container(&server).await {
Ok(container) => container.clone(),
Err(err) => {
error!("Container creation error: {:?}", err);
cleanup_buffer_if_created(
&db,
&buffer_id,
buffer_was_created,
"container creation error",
);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", err) })),
);
}
};
debug!("Server container: {:?}", container);
(
V1ResourceReference {
kind: "Container".to_string(),
name: container.metadata.name.clone(),
namespace: container.metadata.namespace.clone(),
},
container,
)
}
_ => {
error!("Invalid server option: {:?}", payload.server);
cleanup_buffer_if_created(&db, &buffer_id, buffer_was_created, "invalid server option");
return (
StatusCode::BAD_REQUEST,
Json(json!({ "error": "Invalid server kind" })),
);
}
};
debug!("Server ref: {:?}", server_ref);
let now = Utc::now();
let full_name = format!("{}/{}", buffer.metadata.namespace, buffer.metadata.name);
let llm = crate::entities::llm::ActiveModel {
id: sea_orm::Set(id.clone()),
namespace: sea_orm::Set(buffer.metadata.namespace.clone()),
name: sea_orm::Set(buffer.metadata.name.clone()),
full_name: sea_orm::Set(full_name),
owner_id: sea_orm::Set(owner),
buffer_id: sea_orm::Set(buffer_id),
model: sea_orm::Set(payload.model),
server_ref: sea_orm::Set(server_ref.to_string_encoded()),
status: sea_orm::Set(None),
labels: sea_orm::Set(None),
chat_schema: sea_orm::Set(payload.chat_schema),
created_at: sea_orm::Set(now.into()),
updated_at: sea_orm::Set(now.into()),
created_by: sea_orm::Set(user_profile.email.clone()),
..Default::default()
};
debug!("LLM: {:?}", llm);
let inserted_llm: crate::entities::llm::Model = match llm.insert(&db).await {
Ok(model) => model,
Err(err) => {
error!("Failed to insert LLM: {:?}", err);
cleanup_buffer_if_created(
&db,
&buffer.metadata.id,
buffer_was_created,
"Failed to insert LLM",
);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Failed to insert LLM: {}", err) })),
);
}
};
debug!("Inserted new LLM with id {:?}", inserted_llm.id);
(
StatusCode::CREATED,
Json(json!(V1OnlineLLM {
metadata: V1ResourceMeta {
id: inserted_llm.id,
name: inserted_llm.name,
namespace: inserted_llm.namespace,
owner: inserted_llm.owner_id,
created_at: inserted_llm.created_at.timestamp(),
updated_at: inserted_llm.updated_at.timestamp(),
created_by: inserted_llm.created_by,
labels: inserted_llm.labels.map(|json_value| {
serde_json::from_value::<std::collections::HashMap<String, String>>(json_value)
.unwrap_or_default()
}),
owner_ref: None,
},
model: inserted_llm.model,
buffer: buffer,
server: server,
chat_schema: inserted_llm.chat_schema,
status: V1OnlineLLMStatus {
is_online: None,
endpoint: None,
last_error: None,
},
})),
)
}
pub async fn list_llms(
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
) -> impl IntoResponse {
debug!("Listing LLMs for user {:?}", user_profile.email);
let db = state.db_pool.clone();
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(|s| s.as_str()).collect();
debug!("Owner IDs: {:?}", owner_id_refs);
let llm_models = match Query::find_llms_by_owners(&db, &owner_id_refs).await {
Ok(models) => models,
Err(err) => {
debug!("Failed to query LLMs: {}", err);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Failed to query LLMs: {}", err) })),
)
.into_response();
}
};
debug!("Found {} LLMs", llm_models.len());
if llm_models.is_empty() {
debug!("No LLMs found, returning empty list");
let response = V1OnlineLLMs { llms: Vec::new() };
return (StatusCode::OK, Json(response)).into_response();
}
let client = match NebulousClient::new_from_config() {
Ok(client) => client,
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", err) })),
)
.into_response();
}
};
debug!("Nebulous client initialized");
let mut llm_list = Vec::new();
for llm_model in llm_models {
debug!("Finding buffer for LLM {:?}", llm_model.id);
let buffer =
match Query::find_buffer_by_id_and_owners(&db, &llm_model.buffer_id, &owner_id_refs)
.await
{
Ok(Some(buf)) => V1ReplayBuffer::from_model(&buf),
_ => {
return (
StatusCode::NOT_FOUND,
Json(json!({ "error": format!("Buffer not found for LLM {}", llm_model.id) })),
).into_response();
}
};
debug!("Found buffer: {:?}", buffer);
let server_ref_str = llm_model.server_ref.clone();
let server_ref = match V1ResourceReference::from_str_encoded(&server_ref_str) {
Ok(sr) => sr,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Failed to parse server ref: {}", e) })),
)
.into_response();
}
};
debug!("Parsing server ref: {:?}", server_ref);
let server_container = match client
.get_container(&server_ref.name, &server_ref.namespace)
.await
{
Ok(container) => container,
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", err) })),
)
.into_response();
}
};
debug!("Found server container: {:?}", server_container);
let metadata = V1ResourceMeta {
id: llm_model.id.clone(),
name: llm_model.name.clone(),
namespace: llm_model.namespace.clone(),
owner: llm_model.owner_id.clone(),
created_at: llm_model.created_at.timestamp(),
updated_at: llm_model.updated_at.timestamp(),
created_by: llm_model.created_by,
labels: llm_model.labels.map(|json_value| {
serde_json::from_value::<std::collections::HashMap<String, String>>(json_value)
.unwrap_or_default()
}),
owner_ref: None,
};
debug!("Constructed metadata: {:?}", metadata);
let status = V1OnlineLLMStatus {
is_online: None,
endpoint: None,
last_error: None,
};
llm_list.push(V1OnlineLLM {
metadata,
model: llm_model.model,
buffer,
server: server_container,
chat_schema: llm_model.chat_schema,
status,
});
debug!("Pushed LLM into list");
}
let response = V1OnlineLLMs { llms: llm_list };
debug!("Returning response: {:?}", json!(response.clone()));
(StatusCode::OK, Json(response)).into_response()
}
async fn fetch_llm_v1(
db_pool: &sea_orm::DatabaseConnection,
client: &nebulous::client::NebulousClient,
namespace: &str,
name: &str,
owner_id_refs: &[&str],
) -> Result<V1OnlineLLM, anyhow::Error> {
let llm_model =
Query::find_llm_by_name_and_namespace_and_owners(db_pool, name, namespace, owner_id_refs)
.await?
.ok_or_else(|| anyhow!("LLM not found"))?;
let buffer_model =
Query::find_buffer_by_id_and_owners(db_pool, &llm_model.buffer_id, owner_id_refs)
.await?
.ok_or_else(|| anyhow!("Buffer not found for LLM {}", llm_model.id))?;
let buffer = V1ReplayBuffer::from_model(&buffer_model);
let server_ref_str = llm_model.server_ref.clone();
let server_ref = V1ResourceReference::from_str_encoded(&server_ref_str)
.map_err(|e| anyhow!("Failed to parse server ref: {}", e))?;
if server_ref.kind.to_lowercase() != "container" {
bail!("Unsupported server ref kind: {}", server_ref.kind);
}
let server_container = client
.get_container(&server_ref.name, &server_ref.namespace)
.await
.map_err(|e| anyhow!("Nebulous error: {}", e))?;
let metadata = V1ResourceMeta {
id: llm_model.id.clone(),
name: llm_model.name.clone(),
namespace: llm_model.namespace.clone(),
owner: llm_model.owner_id.clone(),
created_at: llm_model.created_at.timestamp(),
updated_at: llm_model.updated_at.timestamp(),
created_by: llm_model.created_by,
labels: llm_model.labels.map(|json_value| {
serde_json::from_value::<std::collections::HashMap<String, String>>(json_value)
.unwrap_or_default()
}),
owner_ref: None,
};
let status = V1OnlineLLMStatus {
is_online: None,
endpoint: None,
last_error: None,
};
Ok(V1OnlineLLM {
metadata,
model: llm_model.model,
buffer,
server: server_container,
chat_schema: llm_model.chat_schema,
status,
})
}
pub async fn get_llm(
Path((namespace, name)): Path<(String, String)>,
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
) -> impl IntoResponse {
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(|s| s.as_str()).collect();
let client = match NebulousClient::new_from_config() {
Ok(client) => client,
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", err) })),
)
.into_response();
}
};
match fetch_llm_v1(&state.db_pool, &client, &namespace, &name, &owner_id_refs).await {
Ok(llm) => (StatusCode::OK, Json(llm)).into_response(),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("not found") {
(StatusCode::NOT_FOUND, Json(json!({ "error": err_str }))).into_response()
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": err_str })),
)
.into_response()
}
}
}
}
#[axum::debug_handler]
pub async fn patch_llm(
Path((namespace, name)): Path<(String, String)>,
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
Json(payload): Json<V1UpdateOnlineLLMRequest>,
) -> impl IntoResponse {
debug!(
"Patching LLM {:?} in namespace {:?} with payload: {:?}",
name, namespace, payload
);
let db = state.db_pool.clone();
let mut owner_ids: Vec<String> = if let Some(orgs) = &user_profile.organizations {
orgs.keys().cloned().collect()
} else {
Vec::new()
};
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(String::as_str).collect();
let llm_model = match Query::find_llm_by_name_and_namespace_and_owners(
&db,
&name,
&namespace,
&owner_id_refs,
)
.await
{
Ok(Some(llm)) => llm,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": "LLM not found"})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to query LLM: {}", e)})),
)
.into_response();
}
};
if let Some(new_buffer) = &payload.buffer {
match new_buffer {
V1UpdateBufferOption::Id(new_id) => {
debug!("New buffer ID: {:?}", new_id);
if *new_id != llm_model.buffer_id {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": "Changing buffer ID not allowed via patch"})),
)
.into_response();
}
}
V1UpdateBufferOption::Buffer(updated_buffer) => {
debug!("Updating buffer: {:?}", updated_buffer);
_update_buffer(&db, &user_profile, &namespace, &name, updated_buffer).await;
}
}
}
if let Some(new_server) = &payload.server {
let server_ref_str = llm_model.server_ref.clone();
let old_ref = match V1ResourceReference::from_str_encoded(&server_ref_str) {
Ok(sr) => sr,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to parse existing server ref: {}", e)})),
)
.into_response();
}
};
let client = match NebulousClient::new_from_config() {
Ok(c) => c,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Nebulous error: {}", e)})),
)
.into_response();
}
};
let old_container = match client
.get_container(&old_ref.name, &old_ref.namespace)
.await
{
Ok(c) => c,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous read error: {}", e) })),
)
.into_response();
}
};
let no_delete = payload.no_delete.unwrap_or(false);
match new_server {
V1UpdateServerOption::Ref(new_ref) => {
let is_different = new_ref.name != old_container.metadata.name
|| new_ref.namespace != old_container.metadata.namespace
|| new_ref.kind.to_lowercase() != "container";
if is_different {
if payload.no_delete.unwrap_or(false) {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "Container changes require deletion, but no_delete=true"
})),
)
.into_response();
}
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": "Container changes require deletion, but no_delete=true"
})),
)
.into_response();
}
if let Err(e) = client
.delete_container(
&old_container.metadata.name,
&old_container.metadata.namespace,
)
.await
{
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to delete container: {}", e)})),
)
.into_response();
}
match client
.get_container(&new_ref.name, &new_ref.namespace)
.await
{
Ok(container) => container,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", e) })),
)
.into_response();
}
};
let updated_ref_str = new_ref.to_string_encoded();
if let Err(e) = Mutation::update_llm(
&db,
&crate::entities::llm::ActiveModel {
id: sea_orm::Set(llm_model.id.clone()),
server_ref: sea_orm::Set(updated_ref_str),
..Default::default()
},
)
.await
{
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("DB error: {}", e)})),
)
.into_response();
}
}
V1UpdateServerOption::Server(new_container) => {
debug!("Updating server: {:?}", new_container);
let created = match client
.patch_container(
&old_container.metadata.name,
&old_container.metadata.namespace,
&new_container,
)
.await
{
Ok(container) => container,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to create container: {}", e)})),
)
.into_response();
}
};
debug!("Created new container: {:?}", created);
let encoded_ref = V1ResourceReference {
kind: "Container".to_string(),
name: created.metadata.name.clone(),
namespace: created.metadata.namespace.clone(),
}
.to_string_encoded();
if let Err(e) = Mutation::update_llm(
&db,
&crate::entities::llm::ActiveModel {
id: sea_orm::Set(llm_model.id.clone()),
server_ref: sea_orm::Set(encoded_ref),
..Default::default()
},
)
.await
{
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("DB error: {}", e)})),
)
.into_response();
}
}
}
}
if let Some(new_schema) = &payload.chat_schema {
debug!("Updating chat schema: {:?}", new_schema);
if let Err(e) = Mutation::update_llm(
&db,
&crate::entities::llm::ActiveModel {
id: sea_orm::Set(llm_model.id.clone()),
chat_schema: sea_orm::Set(Some(new_schema.clone())),
..Default::default()
},
)
.await
{
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("DB error: {}", e)})),
)
.into_response();
}
}
let client = match NebulousClient::new_from_config() {
Ok(c) => c,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Nebulous error: {}", e)})),
)
.into_response();
}
};
debug!("Fetching updated LLM");
match fetch_llm_v1(&db, &client, &namespace, &name, &owner_id_refs).await {
Ok(updated_llm) => (StatusCode::OK, Json(updated_llm)).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
pub async fn delete_llm(
Path((namespace, name)): Path<(String, String)>,
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
) -> impl IntoResponse {
let db = state.db_pool.clone();
debug!("Deleting LLM {:?} in namespace {:?}", name, namespace);
let mut owner_ids: Vec<String> = user_profile
.organizations
.as_ref()
.map(|orgs| orgs.keys().cloned().collect())
.unwrap_or_default();
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(String::as_str).collect();
let llm_model = match Query::find_llm_by_name_and_namespace_and_owners(
&db,
&name,
&namespace,
&owner_id_refs,
)
.await
{
Ok(Some(llm)) => llm,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({ "error": "LLM not found" })),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Failed to query LLM: {}", e) })),
)
.into_response();
}
};
let buffer_model = match Query::find_buffer_by_id_and_owners(
&db,
&llm_model.buffer_id,
&owner_id_refs,
)
.await
{
Ok(Some(buf)) => buf,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({ "error": "Buffer not found" })),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Failed to query buffer: {}", e) })),
)
.into_response();
}
};
let server_ref_str = llm_model.server_ref.clone();
let server_ref = match V1ResourceReference::from_str_encoded(&server_ref_str) {
Ok(sr) => sr,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Failed to parse server ref: {}", e) })),
)
.into_response();
}
};
if server_ref.kind.to_lowercase() != "container" {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Unsupported server ref kind: {}", server_ref.kind) })),
)
.into_response();
}
let client = match NebulousClient::new_from_config() {
Ok(client) => client,
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", err) })),
)
.into_response();
}
};
debug!("Nebulous client initialized");
let server_container = match client
.get_container(&server_ref.name, &server_ref.namespace)
.await
{
Ok(container) => container,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", e) })),
)
.into_response();
}
};
debug!(
"Deleting container {:?} in namespace {:?}",
server_container.metadata.name, server_container.metadata.namespace
);
match client
.delete_container(
&server_container.metadata.name,
&server_container.metadata.namespace,
)
.await
{
Ok(_) => debug!("Container deleted successfully"),
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", e) })),
)
.into_response();
}
}
debug!("Container deleted");
debug!("Deleting buffer {:?} in namespace {:?}", name, namespace);
match Mutation::delete_buffer(&db, &buffer_model.id).await {
Ok(_) => (),
Err(e) => {
info!("Error deleting buffer: {:?}", e);
let error_response = json!({ "error": "Failed to delete buffer" });
return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response();
}
}
debug!("Buffer deleted");
debug!("Deleting LLM {:?} in namespace {:?}", name, namespace);
match Mutation::delete_llm_by_name_namespace_and_owners(
&state.db_pool,
&name,
&namespace,
&owner_id_refs,
)
.await
{
Ok(0) => {
return (
StatusCode::NOT_FOUND,
Json(json!({ "error": "LLM not found" })),
)
.into_response();
}
Ok(_) => {
debug!("LLM deleted");
return StatusCode::NO_CONTENT.into_response();
}
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": format!("DB error: {}", err)
})),
)
.into_response();
}
}
}
pub async fn chat_llm(
Path((namespace, name)): Path<(String, String)>,
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
Json(raw_json): Json<serde_json::Value>,
) -> impl IntoResponse {
debug!("Chatting with LLM {:?} in namespace {:?}", name, namespace);
let payload = match serde_json::from_value::<ChatCompletionRequest>(raw_json.clone()) {
Ok(parsed) => parsed,
Err(e) => {
debug!("Error parsing payload: {}", e);
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!("Invalid JSON for ChatCompletionRequest: {}", e)
})),
)
.into_response();
}
};
debug!("Payload: {:?}", payload);
let mut owner_ids: Vec<String> = user_profile
.organizations
.as_ref()
.map(|orgs| orgs.keys().cloned().collect())
.unwrap_or_default();
owner_ids.push(user_profile.email.clone());
let owner_id_refs: Vec<&str> = owner_ids.iter().map(String::as_str).collect();
let nebu_client = match NebulousClient::new_from_config() {
Ok(client) => client,
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Nebulous error: {}", err) })),
)
.into_response();
}
};
let llm = match fetch_llm_v1(
&state.db_pool,
&nebu_client,
&namespace,
&name,
&owner_id_refs,
)
.await
{
Ok(llm) => llm,
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": err.to_string() })),
)
.into_response();
}
};
debug!("LLM found: {:?}", llm.metadata.id);
let resource_ref = V1ResourceReference {
kind: "Container".to_string(),
name: llm.server.metadata.name.clone(),
namespace: llm.server.metadata.namespace.clone(),
};
let mut oai_client = match OpenAIClient::builder()
.with_api_key(nebu_client.api_key.clone())
.with_endpoint(format!("{}/v1", CONFIG.nebu_proxy_url))
.with_header("X-Resource", resource_ref.to_string_encoded())
.build()
{
Ok(client) => client,
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": err.to_string() })),
)
.into_response();
}
};
debug!("OAI client built");
let result = match oai_client.chat_completion(payload).await {
Ok(result) => result,
Err(err) => {
debug!("Error calling OAI client: {}", err);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": err.to_string() })),
)
.into_response();
}
};
println!("Content: {:?}", result.choices[0].message.content);
(StatusCode::OK, Json(result)).into_response()
}
pub async fn learn_llm(
Path((namespace, name)): Path<(String, String)>,
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
Json(payload): Json<V1ReplayBufferData>,
) -> impl IntoResponse {
debug!("Learning LLM {:?} in namespace {:?}", name, namespace);
let db_pool = state.db_pool.clone();
let mut owner_ids: Vec<String> = user_profile
.organizations
.as_ref()
.map(|orgs| orgs.keys().cloned().collect())
.unwrap_or_default();
owner_ids.push(user_profile.email.clone());
debug!(
"Sending examples to LLM {:?} in namespace {:?}",
name, namespace
);
let res =
match _send_examples(&db_pool, &state, &user_profile, &namespace, &name, &payload).await {
Ok(res) => res,
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": err.to_string() })),
)
.into_response();
}
};
(StatusCode::OK, Json(res)).into_response()
}
pub async fn train_llm(
Path((namespace, name)): Path<(String, String)>,
State(state): State<AppState>,
Extension(user_profile): Extension<crate::models::V1UserProfile>,
) -> impl IntoResponse {
let db_pool = state.db_pool.clone();
let mut owner_ids: Vec<String> = user_profile
.organizations
.as_ref()
.map(|orgs| orgs.keys().cloned().collect())
.unwrap_or_default();
owner_ids.push(user_profile.email.clone());
let res = match _train_buffer(&db_pool, &state, &user_profile, &namespace, &name).await {
Ok(res) => res,
Err(err) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": err.to_string() })),
)
.into_response();
}
};
(StatusCode::OK, Json(res)).into_response()
}