1use crate::catalog::{Catalog, db_err};
10use orbok_core::{ModelId, OrbokResult, now_iso8601};
11use rusqlite::params;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ModelRole {
16 Embedding,
17 Reranker,
18}
19
20impl ModelRole {
21 pub fn as_str(&self) -> &'static str {
22 match self {
23 ModelRole::Embedding => "embedding",
24 ModelRole::Reranker => "reranker",
25 }
26 }
27 pub fn parse(s: &str) -> Option<Self> {
28 match s {
29 "embedding" => Some(Self::Embedding),
30 "reranker" => Some(Self::Reranker),
31 _ => None,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum ModelStatus {
39 Available,
40 Missing,
41 Invalid,
42 Installing,
43 Disabled,
44}
45
46impl ModelStatus {
47 pub fn as_str(&self) -> &'static str {
48 match self {
49 ModelStatus::Available => "available",
50 ModelStatus::Missing => "missing",
51 ModelStatus::Invalid => "invalid",
52 ModelStatus::Installing => "installing",
53 ModelStatus::Disabled => "disabled",
54 }
55 }
56 pub fn parse(s: &str) -> Option<Self> {
57 match s {
58 "available" => Some(Self::Available),
59 "missing" => Some(Self::Missing),
60 "invalid" => Some(Self::Invalid),
61 "installing" => Some(Self::Installing),
62 "disabled" => Some(Self::Disabled),
63 _ => None,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct ModelRecord {
71 pub model_id: ModelId,
72 pub role: ModelRole,
73 pub model_name: String,
74 pub model_version: String,
75 pub local_path: Option<String>,
76 pub license_summary: Option<String>,
77 pub size_bytes: Option<u64>,
78 pub backend: Option<String>,
79 pub dimension: Option<u32>,
80 pub status: ModelStatus,
81 pub last_validated_at: Option<String>,
82}
83
84#[derive(Debug, Clone)]
86pub struct NewModel {
87 pub role: ModelRole,
88 pub model_name: String,
89 pub model_version: String,
90 pub local_path: Option<String>,
91 pub license_summary: Option<String>,
92 pub size_bytes: Option<u64>,
93 pub backend: Option<String>,
94 pub dimension: Option<u32>,
95 pub status: ModelStatus,
96}
97
98pub struct ModelRepository<'a> {
99 catalog: &'a Catalog,
100}
101
102impl<'a> ModelRepository<'a> {
103 pub fn new(catalog: &'a Catalog) -> Self {
104 Self { catalog }
105 }
106
107 pub fn insert(&self, new: NewModel) -> OrbokResult<ModelRecord> {
109 let id = ModelId::generate();
110 let now = now_iso8601();
111 let conn = self.catalog.lock();
112 conn.execute(
113 "INSERT INTO models \
114 (model_id, role, model_name, model_version, local_path, license_summary, \
115 size_bytes, backend, dimension, status, created_at, updated_at) \
116 VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?11)",
117 params![
118 id.as_str(),
119 new.role.as_str(),
120 new.model_name,
121 new.model_version,
122 new.local_path,
123 new.license_summary,
124 new.size_bytes.map(|v| v as i64),
125 new.backend,
126 new.dimension.map(|v| v as i64),
127 new.status.as_str(),
128 now,
129 ],
130 )
131 .map_err(db_err)?;
132 drop(conn);
133 self.get(&id)?.ok_or(orbok_core::OrbokError::SourceNotFound)
134 }
135
136 pub fn get(&self, id: &ModelId) -> OrbokResult<Option<ModelRecord>> {
138 let conn = self.catalog.lock();
139 let result = conn.query_row(
140 "SELECT model_id, role, model_name, model_version, local_path, license_summary, \
141 size_bytes, backend, dimension, status, last_validated_at \
142 FROM models WHERE model_id = ?1",
143 params![id.as_str()],
144 row_to_record,
145 );
146 match result {
147 Ok(r) => Ok(Some(r)),
148 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
149 Err(e) => Err(db_err(e)),
150 }
151 }
152
153 pub fn list_by_role(&self, role: ModelRole) -> OrbokResult<Vec<ModelRecord>> {
155 let conn = self.catalog.lock();
156 let mut stmt = conn
157 .prepare(
158 "SELECT model_id, role, model_name, model_version, local_path, license_summary, \
159 size_bytes, backend, dimension, status, last_validated_at \
160 FROM models WHERE role = ?1 ORDER BY model_name, model_version",
161 )
162 .map_err(db_err)?;
163 let rows = stmt
164 .query_map(params![role.as_str()], row_to_record)
165 .map_err(db_err)?;
166 let mut out = Vec::new();
167 for row in rows {
168 out.push(row.map_err(db_err)?);
169 }
170 Ok(out)
171 }
172
173 pub fn list_all(&self) -> OrbokResult<Vec<ModelRecord>> {
175 let conn = self.catalog.lock();
176 let mut stmt = conn
177 .prepare(
178 "SELECT model_id, role, model_name, model_version, local_path, license_summary, \
179 size_bytes, backend, dimension, status, last_validated_at \
180 FROM models ORDER BY role, model_name",
181 )
182 .map_err(db_err)?;
183 let rows = stmt.query_map([], row_to_record).map_err(db_err)?;
184 let mut out = Vec::new();
185 for row in rows {
186 out.push(row.map_err(db_err)?);
187 }
188 Ok(out)
189 }
190
191 pub fn set_status(&self, id: &ModelId, status: ModelStatus) -> OrbokResult<()> {
193 let conn = self.catalog.lock();
194 conn.execute(
195 "UPDATE models SET status = ?2, updated_at = ?3 WHERE model_id = ?1",
196 params![id.as_str(), status.as_str(), now_iso8601()],
197 )
198 .map_err(db_err)?;
199 Ok(())
200 }
201
202 pub fn validate(&self, id: &ModelId, expected_dim: Option<u32>) -> OrbokResult<ModelStatus> {
205 let record = match self.get(id)? {
206 Some(r) => r,
207 None => return Ok(ModelStatus::Missing),
208 };
209 let status = if let Some(path) = &record.local_path {
210 if std::path::Path::new(path).exists() {
211 if let (Some(expected), Some(actual)) = (expected_dim, record.dimension) {
213 if expected != actual {
214 ModelStatus::Invalid
215 } else {
216 ModelStatus::Available
217 }
218 } else {
219 ModelStatus::Available
220 }
221 } else {
222 ModelStatus::Missing
223 }
224 } else {
225 ModelStatus::Missing
226 };
227 let now = now_iso8601();
228 {
229 let conn = self.catalog.lock();
230 conn.execute(
231 "UPDATE models SET status = ?2, last_validated_at = ?3, updated_at = ?3 \
232 WHERE model_id = ?1",
233 params![id.as_str(), status.as_str(), now],
234 )
235 .map_err(db_err)?;
236 }
237 Ok(status)
238 }
239
240 pub fn locate(
243 &self,
244 path: &str,
245 role: ModelRole,
246 name: &str,
247 version: &str,
248 dimension: Option<u32>,
249 ) -> OrbokResult<ModelRecord> {
250 let size_bytes = std::fs::metadata(path).map(|m| m.len()).ok();
251 let record = self.insert(NewModel {
252 role,
253 model_name: name.to_string(),
254 model_version: version.to_string(),
255 local_path: Some(path.to_string()),
256 license_summary: None,
257 size_bytes,
258 backend: None,
259 dimension,
260 status: if size_bytes.is_some() {
261 ModelStatus::Available
262 } else {
263 ModelStatus::Missing
264 },
265 })?;
266 Ok(record)
267 }
268
269 pub fn mark_embedding_dependents_stale(&self, model_id: &ModelId) -> OrbokResult<u64> {
272 let conn = self.catalog.lock();
273 let n = conn
274 .execute(
275 "UPDATE embeddings SET status = 'stale', updated_at = ?2 WHERE model_id = ?1",
276 params![model_id.as_str(), now_iso8601()],
277 )
278 .map_err(db_err)?;
279 Ok(n as u64)
280 }
281}
282
283fn row_to_record(row: &rusqlite::Row<'_>) -> rusqlite::Result<ModelRecord> {
284 Ok(ModelRecord {
285 model_id: ModelId::from_string(row.get::<_, String>(0)?),
286 role: {
287 let s: String = row.get(1)?;
288 ModelRole::parse(&s).unwrap_or(ModelRole::Embedding)
289 },
290 model_name: row.get(2)?,
291 model_version: row.get(3)?,
292 local_path: row.get(4)?,
293 license_summary: row.get(5)?,
294 size_bytes: row.get::<_, Option<i64>>(6)?.map(|v| v as u64),
295 backend: row.get(7)?,
296 dimension: row.get::<_, Option<i64>>(8)?.map(|v| v as u32),
297 status: {
298 let s: String = row.get(9)?;
299 ModelStatus::parse(&s).unwrap_or(ModelStatus::Missing)
300 },
301 last_validated_at: row.get(10)?,
302 })
303}
304
305pub fn verify_model_sha256(path: &str, expected_hash: &str) -> OrbokResult<bool> {
312 use sha2::{Digest, Sha256};
313 use std::io::Read;
314 let mut file = std::fs::File::open(path).map_err(|e| orbok_core::OrbokError::Io(e))?;
315 let mut hasher = Sha256::new();
316 let mut buf = [0u8; 64 * 1024];
317 loop {
318 let n = file.read(&mut buf).map_err(orbok_core::OrbokError::Io)?;
319 if n == 0 {
320 break;
321 }
322 hasher.update(&buf[..n]);
323 }
324 let actual: String = hasher
325 .finalize()
326 .iter()
327 .map(|b| format!("{b:02x}"))
328 .collect();
329 Ok(actual == expected_hash)
330}