use crate::catalog::backend::{
BackendError, IsolationLevel, Row, SqlValue, Transaction, TxOptions,
};
use crate::error::{JammiError, Result};
use crate::model_task::ModelTask;
use crate::tenant::TenantId;
use super::Catalog;
pub(crate) fn model_pk(tenant: Option<TenantId>, name: &str, version: i64) -> String {
match tenant {
Some(t) => format!("{t}::{name}::{version}"),
None => format!("{name}::{version}"),
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ModelRecord {
pub model_id: String,
pub catalog_pk: String,
pub version: i32,
pub model_type: String,
pub base_model_id: Option<String>,
pub backend: String,
pub task: ModelTask,
pub artifact_path: Option<String>,
pub config_json: Option<String>,
pub status: String,
pub created_at: String,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ModelDescriptor {
pub model_id: String,
pub backend: String,
pub task: ModelTask,
pub status: String,
}
impl From<&ModelRecord> for ModelDescriptor {
fn from(record: &ModelRecord) -> Self {
Self {
model_id: record.model_id.clone(),
backend: record.backend.clone(),
task: record.task,
status: record.status.clone(),
}
}
}
#[derive(Debug)]
pub struct RegisterModelParams<'a> {
pub model_id: &'a str,
pub version: i32,
pub model_type: &'a str,
pub backend: &'a str,
pub task: ModelTask,
pub base_model_id: Option<&'a str>,
pub artifact_path: Option<&'a str>,
pub config_json: Option<&'a str>,
}
const SELECT_COLS: &str =
"model_id, name, model_type, task, backend, version, status, metadata, artifact_path, \
created_at";
impl Catalog {
pub async fn register_model(&self, params: RegisterModelParams<'_>) -> Result<()> {
let tenant = self.current_tenant();
let pk = model_pk(tenant, params.model_id, params.version as i64);
let metadata = serde_json::json!({
"base_model_id": params.base_model_id,
"config_json": params.config_json,
})
.to_string();
let model_id = params.model_id.to_string();
let model_type = params.model_type.to_string();
let task = params.task.as_db_str();
let backend = params.backend.to_string();
let version = params.version as i64;
let artifact_path = params.artifact_path.map(str::to_string);
self.backend()
.transaction(TxOptions::default(), |tx| {
Box::pin(async move {
tx.set_tenant(tenant);
tx.assert_tenant_matches(tenant, "models")?;
tx.execute(
"INSERT INTO models (model_id, name, model_type, task, backend, version, status, metadata, artifact_path, tenant_id) \
VALUES ($1, $2, $3, $4, $5, $6, 'registered', $7, $8, $9) \
ON CONFLICT(model_id) DO UPDATE SET \
metadata = excluded.metadata, \
backend = excluded.backend, \
task = excluded.task, \
model_type = excluded.model_type, \
artifact_path = COALESCE(excluded.artifact_path, models.artifact_path), \
updated_at = CAST(CURRENT_TIMESTAMP AS TEXT)",
&[
SqlValue::TextOwned(pk),
SqlValue::TextOwned(model_id),
SqlValue::TextOwned(model_type),
SqlValue::Text(task),
SqlValue::TextOwned(backend),
SqlValue::Int(version),
SqlValue::TextOwned(metadata),
SqlValue::from(artifact_path),
SqlValue::from(tenant.map(|t| t.to_string())),
],
)
.await?;
Ok(())
})
})
.await?;
Ok(())
}
pub async fn get_model(&self, model_id: &str) -> Result<Option<ModelRecord>> {
let sql = format!(
"SELECT {SELECT_COLS} FROM models \
WHERE name = $1 AND (tenant_id = $2 OR tenant_id IS NULL) \
ORDER BY (tenant_id IS NOT NULL) DESC, version DESC LIMIT 1"
);
let mid = model_id.to_string();
let tenant = self.current_tenant();
Ok(self
.backend()
.transaction(
TxOptions {
read_only: true,
..Default::default()
},
|tx| {
Box::pin(async move {
tx.query_opt(
&sql,
&[
SqlValue::TextOwned(mid),
SqlValue::from(tenant.map(|t| t.to_string())),
],
parse_model_row,
)
.await
})
},
)
.await?)
}
pub async fn get_model_version(
&self,
model_id: &str,
version: i32,
) -> Result<Option<ModelRecord>> {
let sql = format!(
"SELECT {SELECT_COLS} FROM models \
WHERE name = $1 AND version = $2 \
AND (tenant_id = $3 OR tenant_id IS NULL) \
ORDER BY (tenant_id IS NOT NULL) DESC LIMIT 1"
);
let mid = model_id.to_string();
let v = version as i64;
let tenant = self.current_tenant();
Ok(self
.backend()
.transaction(
TxOptions {
read_only: true,
..Default::default()
},
|tx| {
Box::pin(async move {
tx.query_opt(
&sql,
&[
SqlValue::TextOwned(mid),
SqlValue::Int(v),
SqlValue::from(tenant.map(|t| t.to_string())),
],
parse_model_row,
)
.await
})
},
)
.await?)
}
pub async fn delete_model(
&self,
model_id: &str,
version: Option<i32>,
if_exists: bool,
) -> Result<()> {
let record = match version {
Some(v) => self.get_model_version(model_id, v).await?,
None => self.get_model(model_id).await?,
};
let record = match record {
Some(r) => r,
None if if_exists => return Ok(()),
None => {
return Err(JammiError::ModelNotFound {
model_id: model_id.to_string(),
})
}
};
let pk = record.catalog_pk;
let name = record.model_id;
let tenant = self.current_tenant();
let outcome = self
.backend()
.transaction(
TxOptions {
isolation: IsolationLevel::Serializable,
..Default::default()
},
|tx| {
Box::pin(async move {
tx.set_tenant(tenant);
tx.assert_tenant_matches(tenant, "models")?;
let tenant_val = SqlValue::from(tenant.map(|t| t.to_string()));
let referenced_by =
scan_model_references(tx, &name, &pk, &tenant_val).await?;
if !referenced_by.is_empty() {
return Ok(DeleteOutcome::Referenced(referenced_by));
}
let affected = tx
.execute(
"DELETE FROM models \
WHERE model_id = $1 \
AND (tenant_id = $2 OR (tenant_id IS NULL AND $2 IS NULL))",
&[SqlValue::TextOwned(pk), tenant_val],
)
.await?;
Ok(DeleteOutcome::Deleted(affected))
})
},
)
.await?;
match outcome {
DeleteOutcome::Referenced(referenced_by) => Err(JammiError::ModelReferenced {
model_id: model_id.to_string(),
referenced_by,
}),
DeleteOutcome::Deleted(0) => Err(JammiError::ModelNotFound {
model_id: model_id.to_string(),
}),
DeleteOutcome::Deleted(_) => Ok(()),
}
}
pub async fn list_models(&self) -> Result<Vec<ModelRecord>> {
let sql = format!(
"SELECT {SELECT_COLS} FROM models \
WHERE (tenant_id = $1 OR tenant_id IS NULL) \
ORDER BY created_at"
);
let tenant = self.current_tenant();
Ok(self
.backend()
.transaction(
TxOptions {
read_only: true,
..Default::default()
},
|tx| {
Box::pin(async move {
tx.query(
&sql,
&[SqlValue::from(tenant.map(|t| t.to_string()))],
parse_model_row,
)
.await
})
},
)
.await?)
}
}
enum DeleteOutcome {
Referenced(Vec<String>),
Deleted(u64),
}
struct ReferenceEdge {
name: &'static str,
sql: &'static str,
keyed_by_pk: bool,
}
const REFERENCE_EDGES: [ReferenceEdge; 4] = [
ReferenceEdge {
name: "result_tables",
sql: "SELECT COUNT(*) AS n FROM result_tables \
WHERE model_id = $1 AND (tenant_id = $2 OR (tenant_id IS NULL AND $2 IS NULL))",
keyed_by_pk: false,
},
ReferenceEdge {
name: "training_jobs.output_model_id",
sql:
"SELECT COUNT(*) AS n FROM training_jobs \
WHERE output_model_id = $1 AND (tenant_id = $2 OR (tenant_id IS NULL AND $2 IS NULL))",
keyed_by_pk: false,
},
ReferenceEdge {
name: "training_jobs.base_model_id",
sql: "SELECT COUNT(*) AS n FROM training_jobs \
WHERE base_model_id = $1 AND (tenant_id = $2 OR (tenant_id IS NULL AND $2 IS NULL))",
keyed_by_pk: true,
},
ReferenceEdge {
name: "eval_runs",
sql: "SELECT COUNT(*) AS n FROM eval_runs \
WHERE model_id = $1 AND (tenant_id = $2 OR (tenant_id IS NULL AND $2 IS NULL))",
keyed_by_pk: true,
},
];
async fn scan_model_references(
tx: &mut Transaction<'_>,
name: &str,
pk: &str,
tenant_val: &SqlValue<'_>,
) -> std::result::Result<Vec<String>, BackendError> {
let mut referenced_by = Vec::new();
for edge in &REFERENCE_EDGES {
let key = if edge.keyed_by_pk { pk } else { name };
let count = tx
.query_opt(
edge.sql,
&[SqlValue::TextOwned(key.to_string()), tenant_val.clone()],
|row| row.get::<i64>("n"),
)
.await?
.unwrap_or(0);
if count > 0 {
referenced_by.push(edge.name.to_string());
}
}
Ok(referenced_by)
}
fn parse_model_row(row: &Row<'_>) -> std::result::Result<ModelRecord, BackendError> {
let catalog_pk: String = row.get("model_id")?;
let name: String = row.get("name")?;
let model_type: String = row.get("model_type")?;
let task_raw: String = row.get("task")?;
let task = ModelTask::try_from_db_str(&task_raw).map_err(|e| BackendError::TypeConversion {
column: "task".into(),
detail: e.to_string(),
})?;
let backend: String = row.try_get("backend")?.unwrap_or_default();
let version: i32 = row.try_get("version")?.unwrap_or(1);
let status: String = row.try_get("status")?.unwrap_or_default();
let metadata: Option<String> = row.try_get("metadata")?;
let created_at: String = row.try_get("created_at")?.unwrap_or_default();
let artifact_path: Option<String> = row.try_get("artifact_path")?;
let (base_model_id, config_json) = metadata
.as_deref()
.and_then(|m| serde_json::from_str::<serde_json::Value>(m).ok())
.map(|v| {
(
v["base_model_id"].as_str().map(String::from),
v["config_json"].as_str().map(String::from),
)
})
.unwrap_or((None, None));
Ok(ModelRecord {
model_id: name,
catalog_pk,
version,
model_type,
base_model_id,
backend,
task,
artifact_path,
config_json,
status,
created_at,
})
}