use crate::catalog::backend::{BackendError, Row, SqlValue, TxOptions};
use crate::error::Result;
use crate::model_task::ModelTask;
use crate::tenant::TenantId;
use super::Catalog;
pub(crate) fn model_pk(tenant: Option<TenantId>, name: &str, version: i64) -> String {
match tenant {
Some(t) => format!("{t}::{name}::{version}"),
None => format!("{name}::{version}"),
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ModelRecord {
pub model_id: String,
pub catalog_pk: String,
pub version: i32,
pub model_type: String,
pub base_model_id: Option<String>,
pub backend: String,
pub task: ModelTask,
pub artifact_path: Option<String>,
pub config_json: Option<String>,
pub status: String,
pub created_at: String,
}
#[derive(Debug)]
pub struct RegisterModelParams<'a> {
pub model_id: &'a str,
pub version: i32,
pub model_type: &'a str,
pub backend: &'a str,
pub task: ModelTask,
pub base_model_id: Option<&'a str>,
pub artifact_path: Option<&'a str>,
pub config_json: Option<&'a str>,
}
const SELECT_COLS: &str =
"model_id, name, model_type, task, backend, version, status, metadata, artifact_path, \
created_at";
impl Catalog {
pub async fn register_model(&self, params: RegisterModelParams<'_>) -> Result<()> {
let tenant = self.current_tenant();
let pk = model_pk(tenant, params.model_id, params.version as i64);
let metadata = serde_json::json!({
"base_model_id": params.base_model_id,
"config_json": params.config_json,
})
.to_string();
let model_id = params.model_id.to_string();
let model_type = params.model_type.to_string();
let task = params.task.as_db_str();
let backend = params.backend.to_string();
let version = params.version as i64;
let artifact_path = params.artifact_path.map(str::to_string);
self.backend()
.transaction(TxOptions::default(), |tx| {
Box::pin(async move {
tx.set_tenant(tenant);
tx.assert_tenant_matches(tenant, "models")?;
tx.execute(
"INSERT INTO models (model_id, name, model_type, task, backend, version, status, metadata, artifact_path, tenant_id) \
VALUES ($1, $2, $3, $4, $5, $6, 'registered', $7, $8, $9) \
ON CONFLICT(model_id) DO UPDATE SET \
metadata = excluded.metadata, \
backend = excluded.backend, \
task = excluded.task, \
model_type = excluded.model_type, \
artifact_path = COALESCE(excluded.artifact_path, models.artifact_path), \
updated_at = CAST(CURRENT_TIMESTAMP AS TEXT)",
&[
SqlValue::TextOwned(pk),
SqlValue::TextOwned(model_id),
SqlValue::TextOwned(model_type),
SqlValue::Text(task),
SqlValue::TextOwned(backend),
SqlValue::Int(version),
SqlValue::TextOwned(metadata),
SqlValue::from(artifact_path),
SqlValue::from(tenant.map(|t| t.to_string())),
],
)
.await?;
Ok(())
})
})
.await?;
Ok(())
}
pub async fn get_model(&self, model_id: &str) -> Result<Option<ModelRecord>> {
let sql = format!(
"SELECT {SELECT_COLS} FROM models \
WHERE name = $1 AND (tenant_id = $2 OR tenant_id IS NULL) \
ORDER BY (tenant_id IS NOT NULL) DESC, version DESC LIMIT 1"
);
let mid = model_id.to_string();
let tenant = self.current_tenant();
Ok(self
.backend()
.transaction(
TxOptions {
read_only: true,
..Default::default()
},
|tx| {
Box::pin(async move {
tx.query_opt(
&sql,
&[
SqlValue::TextOwned(mid),
SqlValue::from(tenant.map(|t| t.to_string())),
],
parse_model_row,
)
.await
})
},
)
.await?)
}
pub async fn get_model_version(
&self,
model_id: &str,
version: i32,
) -> Result<Option<ModelRecord>> {
let sql = format!(
"SELECT {SELECT_COLS} FROM models \
WHERE name = $1 AND version = $2 \
AND (tenant_id = $3 OR tenant_id IS NULL) \
ORDER BY (tenant_id IS NOT NULL) DESC LIMIT 1"
);
let mid = model_id.to_string();
let v = version as i64;
let tenant = self.current_tenant();
Ok(self
.backend()
.transaction(
TxOptions {
read_only: true,
..Default::default()
},
|tx| {
Box::pin(async move {
tx.query_opt(
&sql,
&[
SqlValue::TextOwned(mid),
SqlValue::Int(v),
SqlValue::from(tenant.map(|t| t.to_string())),
],
parse_model_row,
)
.await
})
},
)
.await?)
}
pub async fn update_model_status(
&self,
model_id: &str,
status: super::status::ModelStatus,
) -> Result<()> {
let status_str = status.to_string();
let mid = model_id.to_string();
let tenant = self.current_tenant();
self.backend()
.transaction(TxOptions::default(), |tx| {
Box::pin(async move {
tx.set_tenant(tenant);
tx.execute(
"UPDATE models SET status = $1, updated_at = CAST(CURRENT_TIMESTAMP AS TEXT) \
WHERE name = $2 AND (tenant_id = $3 OR tenant_id IS NULL)",
&[
SqlValue::TextOwned(status_str),
SqlValue::TextOwned(mid),
SqlValue::from(tenant.map(|t| t.to_string())),
],
)
.await?;
Ok(())
})
})
.await?;
Ok(())
}
pub async fn list_models(&self) -> Result<Vec<ModelRecord>> {
let sql = format!(
"SELECT {SELECT_COLS} FROM models \
WHERE tenant_id = $1 OR tenant_id IS NULL \
ORDER BY created_at"
);
let tenant = self.current_tenant();
Ok(self
.backend()
.transaction(
TxOptions {
read_only: true,
..Default::default()
},
|tx| {
Box::pin(async move {
tx.query(
&sql,
&[SqlValue::from(tenant.map(|t| t.to_string()))],
parse_model_row,
)
.await
})
},
)
.await?)
}
}
fn parse_model_row(row: &Row<'_>) -> std::result::Result<ModelRecord, BackendError> {
let catalog_pk: String = row.get("model_id")?;
let name: String = row.get("name")?;
let model_type: String = row.get("model_type")?;
let task_raw: String = row.get("task")?;
let task = ModelTask::try_from_db_str(&task_raw).map_err(|e| BackendError::TypeConversion {
column: "task".into(),
detail: e.to_string(),
})?;
let backend: String = row.try_get("backend")?.unwrap_or_default();
let version: i32 = row.try_get("version")?.unwrap_or(1);
let status: String = row.try_get("status")?.unwrap_or_default();
let metadata: Option<String> = row.try_get("metadata")?;
let created_at: String = row.try_get("created_at")?.unwrap_or_default();
let artifact_path: Option<String> = row.try_get("artifact_path")?;
let (base_model_id, config_json) = metadata
.as_deref()
.and_then(|m| serde_json::from_str::<serde_json::Value>(m).ok())
.map(|v| {
(
v["base_model_id"].as_str().map(String::from),
v["config_json"].as_str().map(String::from),
)
})
.unwrap_or((None, None));
Ok(ModelRecord {
model_id: name,
catalog_pk,
version,
model_type,
base_model_id,
backend,
task,
artifact_path,
config_json,
status,
created_at,
})
}