1use chrono::{DateTime, Utc};
5use modelexpress_common::models::{ModelProvider, ModelStatus};
6use rusqlite::{Connection, Result as SqliteResult, params};
7use std::sync::{Arc, Mutex};
8use tracing::info;
9
10#[derive(Debug, Clone)]
12pub struct ModelRecord {
13 pub model_name: String,
14 pub provider: ModelProvider,
15 pub status: ModelStatus,
16 pub created_at: DateTime<Utc>,
17 pub last_used_at: DateTime<Utc>,
18 pub message: Option<String>,
19}
20
21#[derive(Debug, Clone)]
23pub struct ModelDatabase {
24 connection: Arc<Mutex<Connection>>,
25}
26
27impl ModelDatabase {
28 fn acquire_connection(&self) -> SqliteResult<std::sync::MutexGuard<'_, Connection>> {
30 self.connection.lock().map_err(|_| {
31 rusqlite::Error::SqliteFailure(
32 rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_LOCKED),
33 Some("Rust mutex protecting database connection is poisoned".to_string()),
34 )
35 })
36 }
37
38 pub fn new(database_path: &str) -> Result<Self, Box<dyn std::error::Error>> {
40 let conn = Connection::open(database_path)?;
41
42 conn.execute(
44 r"
45 CREATE TABLE IF NOT EXISTS models (
46 model_name TEXT PRIMARY KEY,
47 provider TEXT NOT NULL,
48 status TEXT NOT NULL,
49 created_at TEXT NOT NULL,
50 last_used_at TEXT NOT NULL,
51 message TEXT
52 )
53 ",
54 [],
55 )?;
56
57 conn.execute(
59 "CREATE INDEX IF NOT EXISTS idx_last_used_at ON models(last_used_at)",
60 [],
61 )?;
62
63 info!("Model database initialized at: {}", database_path);
64
65 Ok(Self {
66 connection: Arc::new(Mutex::new(conn)),
67 })
68 }
69
70 pub fn get_status(&self, model_name: &str) -> SqliteResult<Option<ModelStatus>> {
72 let conn = self.acquire_connection()?;
73 let mut stmt = conn.prepare("SELECT status FROM models WHERE model_name = ?1")?;
74
75 let mut rows = stmt.query_map([model_name], |row| {
76 let status_str: String = row.get(0)?;
77 Ok(status_str)
78 })?;
79
80 if let Some(row) = rows.next() {
81 let status_str = row?;
82 let status = match status_str.as_str() {
83 "DOWNLOADING" => ModelStatus::DOWNLOADING,
84 "DOWNLOADED" => ModelStatus::DOWNLOADED,
85 "ERROR" => ModelStatus::ERROR,
86 _ => ModelStatus::ERROR,
87 };
88 Ok(Some(status))
89 } else {
90 Ok(None)
91 }
92 }
93
94 pub fn get_model_record(&self, model_name: &str) -> SqliteResult<Option<ModelRecord>> {
96 let conn = self.acquire_connection()?;
97 let mut stmt = conn.prepare(
98 "SELECT model_name, provider, status, created_at, last_used_at, message FROM models WHERE model_name = ?1"
99 )?;
100
101 let mut rows = stmt.query_map([model_name], |row| {
102 let provider_str: String = row.get(1)?;
103 let status_str: String = row.get(2)?;
104 let created_at_str: String = row.get(3)?;
105 let last_used_at_str: String = row.get(4)?;
106 let message: Option<String> = row.get(5)?;
107
108 let provider = match provider_str.as_str() {
109 "HuggingFace" => ModelProvider::HuggingFace,
110 _ => ModelProvider::HuggingFace, };
112
113 let status = match status_str.as_str() {
114 "DOWNLOADING" => ModelStatus::DOWNLOADING,
115 "DOWNLOADED" => ModelStatus::DOWNLOADED,
116 "ERROR" => ModelStatus::ERROR,
117 _ => ModelStatus::ERROR,
118 };
119
120 let created_at = DateTime::parse_from_rfc3339(&created_at_str)
121 .map_err(|_| {
122 rusqlite::Error::InvalidColumnType(
123 3,
124 "created_at".to_string(),
125 rusqlite::types::Type::Text,
126 )
127 })?
128 .with_timezone(&Utc);
129
130 let last_used_at = DateTime::parse_from_rfc3339(&last_used_at_str)
131 .map_err(|_| {
132 rusqlite::Error::InvalidColumnType(
133 4,
134 "last_used_at".to_string(),
135 rusqlite::types::Type::Text,
136 )
137 })?
138 .with_timezone(&Utc);
139
140 Ok(ModelRecord {
141 model_name: row.get(0)?,
142 provider,
143 status,
144 created_at,
145 last_used_at,
146 message,
147 })
148 })?;
149
150 if let Some(row) = rows.next() {
151 Ok(Some(row?))
152 } else {
153 Ok(None)
154 }
155 }
156
157 pub fn set_status(
159 &self,
160 model_name: &str,
161 provider: ModelProvider,
162 status: ModelStatus,
163 message: Option<String>,
164 ) -> SqliteResult<()> {
165 let conn = self.acquire_connection()?;
166 let now = Utc::now();
167
168 let provider_str = match provider {
169 ModelProvider::HuggingFace => "HuggingFace",
170 };
171
172 let status_str = match status {
173 ModelStatus::DOWNLOADING => "DOWNLOADING",
174 ModelStatus::DOWNLOADED => "DOWNLOADED",
175 ModelStatus::ERROR => "ERROR",
176 };
177
178 conn.execute(
180 r"
181 INSERT OR REPLACE INTO models (model_name, provider, status, created_at, last_used_at, message)
182 VALUES (?1, ?2, ?3,
183 COALESCE((SELECT created_at FROM models WHERE model_name = ?1), ?4),
184 ?4, ?5)
185 ",
186 params![
187 model_name,
188 provider_str,
189 status_str,
190 now.to_rfc3339(),
191 message
192 ],
193 )?;
194
195 Ok(())
196 }
197
198 pub fn touch_model(&self, model_name: &str) -> SqliteResult<()> {
200 let conn = self.acquire_connection()?;
201 let now = Utc::now();
202
203 conn.execute(
204 "UPDATE models SET last_used_at = ?1 WHERE model_name = ?2",
205 params![now.to_rfc3339(), model_name],
206 )?;
207
208 Ok(())
209 }
210
211 pub fn delete_model(&self, model_name: &str) -> SqliteResult<()> {
213 let conn = self.acquire_connection()?;
214 conn.execute("DELETE FROM models WHERE model_name = ?1", [model_name])?;
215 Ok(())
216 }
217
218 pub fn get_models_by_last_used(&self, limit: Option<u32>) -> SqliteResult<Vec<ModelRecord>> {
220 let conn = self.acquire_connection()?;
221
222 let query = if let Some(limit) = limit {
223 format!(
224 "SELECT model_name, provider, status, created_at, last_used_at, message FROM models ORDER BY last_used_at ASC LIMIT {limit}"
225 )
226 } else {
227 "SELECT model_name, provider, status, created_at, last_used_at, message FROM models ORDER BY last_used_at ASC".to_string()
228 };
229
230 let mut stmt = conn.prepare(&query)?;
231 let rows = stmt.query_map([], |row| {
232 let provider_str: String = row.get(1)?;
233 let status_str: String = row.get(2)?;
234 let created_at_str: String = row.get(3)?;
235 let last_used_at_str: String = row.get(4)?;
236 let message: Option<String> = row.get(5)?;
237
238 let provider = match provider_str.as_str() {
239 "HuggingFace" => ModelProvider::HuggingFace,
240 _ => ModelProvider::HuggingFace,
241 };
242
243 let status = match status_str.as_str() {
244 "DOWNLOADING" => ModelStatus::DOWNLOADING,
245 "DOWNLOADED" => ModelStatus::DOWNLOADED,
246 "ERROR" => ModelStatus::ERROR,
247 _ => ModelStatus::ERROR,
248 };
249
250 let created_at = DateTime::parse_from_rfc3339(&created_at_str)
251 .map_err(|_| {
252 rusqlite::Error::InvalidColumnType(
253 3,
254 "created_at".to_string(),
255 rusqlite::types::Type::Text,
256 )
257 })?
258 .with_timezone(&Utc);
259
260 let last_used_at = DateTime::parse_from_rfc3339(&last_used_at_str)
261 .map_err(|_| {
262 rusqlite::Error::InvalidColumnType(
263 4,
264 "last_used_at".to_string(),
265 rusqlite::types::Type::Text,
266 )
267 })?
268 .with_timezone(&Utc);
269
270 Ok(ModelRecord {
271 model_name: row.get(0)?,
272 provider,
273 status,
274 created_at,
275 last_used_at,
276 message,
277 })
278 })?;
279
280 let mut models = Vec::new();
281 for row in rows {
282 models.push(row?);
283 }
284
285 Ok(models)
286 }
287
288 pub fn get_status_counts(&self) -> SqliteResult<(u32, u32, u32)> {
290 let conn = self.acquire_connection()?;
291
292 let mut downloading = 0u32;
293 let mut downloaded = 0u32;
294 let mut error = 0u32;
295
296 let mut stmt = conn.prepare("SELECT status, COUNT(*) FROM models GROUP BY status")?;
297 let rows = stmt.query_map([], |row| {
298 let status: String = row.get(0)?;
299 let count: u32 = row.get(1)?;
300 Ok((status, count))
301 })?;
302
303 for row in rows {
304 let (status, count) = row?;
305 match status.as_str() {
306 "DOWNLOADING" => downloading = count,
307 "DOWNLOADED" => downloaded = count,
308 "ERROR" => error = count,
309 _ => {}
310 }
311 }
312
313 Ok((downloading, downloaded, error))
314 }
315
316 pub fn try_claim_for_download(
322 &self,
323 model_name: &str,
324 provider: ModelProvider,
325 ) -> SqliteResult<ModelStatus> {
326 let conn = self.acquire_connection()?;
327 let now = Utc::now();
328
329 let provider_str = match provider {
330 ModelProvider::HuggingFace => "HuggingFace",
331 };
332
333 let rows_affected = conn.execute(
336 r#"
337 INSERT OR IGNORE INTO models (model_name, provider, status, created_at, last_used_at, message)
338 VALUES (?1, ?2, 'DOWNLOADING', ?3, ?3, 'Starting download...')
339 "#,
340 params![model_name, provider_str, now.to_rfc3339()],
341 )?;
342
343 if rows_affected > 0 {
344 Ok(ModelStatus::DOWNLOADING)
346 } else {
347 let mut stmt = conn.prepare("SELECT status FROM models WHERE model_name = ?1")?;
349 let mut rows = stmt.query_map([model_name], |row| {
350 let status_str: String = row.get(0)?;
351 Ok(status_str)
352 })?;
353
354 if let Some(row) = rows.next() {
355 let status_str = row?;
356 let status = match status_str.as_str() {
357 "DOWNLOADING" => ModelStatus::DOWNLOADING,
358 "DOWNLOADED" => ModelStatus::DOWNLOADED,
359 "ERROR" => ModelStatus::ERROR,
360 _ => ModelStatus::ERROR,
361 };
362 Ok(status)
363 } else {
364 Err(rusqlite::Error::QueryReturnedNoRows)
366 }
367 }
368 }
369}
370
371#[cfg(test)]
372#[allow(clippy::expect_used)]
373mod tests {
374 use super::*;
375 use tempfile::TempDir;
376
377 fn create_test_database() -> (ModelDatabase, TempDir) {
378 let temp_dir = TempDir::new().expect("Failed to create temporary directory");
379 let db_path = temp_dir.path().join("test_models.db");
380 let db = ModelDatabase::new(db_path.to_str().expect("Invalid path"))
381 .expect("Failed to create test database");
382 (db, temp_dir)
383 }
384
385 #[test]
386 fn test_database_creation() {
387 let (db, _temp_dir) = create_test_database();
388 let result = db.get_status("non-existent-model");
392 assert!(result.is_ok());
393 assert!(result.expect("Failed to get status").is_none());
394 }
395
396 #[test]
397 fn test_set_and_get_status() {
398 let (db, _temp_dir) = create_test_database();
399 let model_name = "test-model";
400 let provider = ModelProvider::HuggingFace;
401 let status = ModelStatus::DOWNLOADING;
402
403 let result = db.set_status(model_name, provider, status, None);
405 assert!(result.is_ok());
406
407 let retrieved_status = db.get_status(model_name).expect("Failed to get status");
409 assert!(retrieved_status.is_some());
410 assert_eq!(retrieved_status.expect("Status should be present"), status);
411 }
412
413 #[test]
414 fn test_update_status() {
415 let (db, _temp_dir) = create_test_database();
416 let model_name = "test-model";
417 let provider = ModelProvider::HuggingFace;
418
419 db.set_status(model_name, provider, ModelStatus::DOWNLOADING, None)
421 .expect("Failed to set initial status");
422
423 db.set_status(
425 model_name,
426 provider,
427 ModelStatus::DOWNLOADED,
428 Some("Success".to_string()),
429 )
430 .expect("Failed to update status");
431
432 let status = db
434 .get_status(model_name)
435 .expect("Failed to get status")
436 .expect("Status should be present");
437 assert_eq!(status, ModelStatus::DOWNLOADED);
438 }
439
440 #[test]
441 fn test_get_full_model_record() {
442 let (db, _temp_dir) = create_test_database();
443 let model_name = "test-model";
444 let provider = ModelProvider::HuggingFace;
445 let message = "Test message";
446
447 db.set_status(
449 model_name,
450 provider,
451 ModelStatus::DOWNLOADED,
452 Some(message.to_string()),
453 )
454 .expect("Failed to set status");
455
456 let record = db
458 .get_model_record(model_name)
459 .expect("Failed to get model record");
460 assert!(record.is_some());
461
462 let record = record.expect("Record should be present");
463 assert_eq!(record.model_name, model_name);
464 assert_eq!(record.provider, provider);
465 assert_eq!(record.status, ModelStatus::DOWNLOADED);
466 assert_eq!(
467 record.message.as_ref().expect("Message should be present"),
468 message
469 );
470 }
471
472 #[test]
473 fn test_touch_model() {
474 let (db, _temp_dir) = create_test_database();
475 let model_name = "test-model";
476 let provider = ModelProvider::HuggingFace;
477
478 db.set_status(model_name, provider, ModelStatus::DOWNLOADED, None)
480 .expect("Failed to create model record");
481
482 let initial_record = db
484 .get_model_record(model_name)
485 .expect("Failed to get initial record")
486 .expect("Record should be present");
487
488 std::thread::sleep(std::time::Duration::from_millis(10));
490
491 db.touch_model(model_name).expect("Failed to touch model");
493
494 let updated_record = db
496 .get_model_record(model_name)
497 .expect("Failed to get updated record")
498 .expect("Record should be present");
499
500 assert!(updated_record.last_used_at > initial_record.last_used_at);
502 assert_eq!(updated_record.created_at, initial_record.created_at);
504 }
505
506 #[test]
507 fn test_delete_model() {
508 let (db, _temp_dir) = create_test_database();
509 let model_name = "test-model";
510 let provider = ModelProvider::HuggingFace;
511
512 db.set_status(model_name, provider, ModelStatus::DOWNLOADED, None)
514 .expect("Failed to create model record");
515
516 assert!(
518 db.get_status(model_name)
519 .expect("Failed to get status")
520 .is_some()
521 );
522
523 db.delete_model(model_name).expect("Failed to delete model");
525
526 assert!(
528 db.get_status(model_name)
529 .expect("Failed to get status")
530 .is_none()
531 );
532 }
533
534 #[test]
535 fn test_get_models_by_last_used() {
536 let (db, _temp_dir) = create_test_database();
537 let provider = ModelProvider::HuggingFace;
538
539 db.set_status("model1", provider, ModelStatus::DOWNLOADED, None)
541 .expect("Failed to create model1");
542 std::thread::sleep(std::time::Duration::from_millis(10));
543
544 db.set_status("model2", provider, ModelStatus::DOWNLOADED, None)
545 .expect("Failed to create model2");
546 std::thread::sleep(std::time::Duration::from_millis(10));
547
548 db.set_status("model3", provider, ModelStatus::DOWNLOADED, None)
549 .expect("Failed to create model3");
550
551 let models = db
553 .get_models_by_last_used(None)
554 .expect("Failed to get models");
555 assert_eq!(models.len(), 3);
556
557 assert_eq!(models[0].model_name, "model1");
559 assert_eq!(models[1].model_name, "model2");
560 assert_eq!(models[2].model_name, "model3");
561
562 let limited_models = db
564 .get_models_by_last_used(Some(2))
565 .expect("Failed to get limited models");
566 assert_eq!(limited_models.len(), 2);
567 assert_eq!(limited_models[0].model_name, "model1");
568 assert_eq!(limited_models[1].model_name, "model2");
569 }
570
571 #[test]
572 fn test_get_status_counts() {
573 let (db, _temp_dir) = create_test_database();
574 let provider = ModelProvider::HuggingFace;
575
576 let (downloading, downloaded, error) =
578 db.get_status_counts().expect("Failed to get status counts");
579 assert_eq!(downloading, 0);
580 assert_eq!(downloaded, 0);
581 assert_eq!(error, 0);
582
583 db.set_status("model1", provider, ModelStatus::DOWNLOADING, None)
585 .expect("Failed to set model1 status");
586 db.set_status("model2", provider, ModelStatus::DOWNLOADING, None)
587 .expect("Failed to set model2 status");
588 db.set_status("model3", provider, ModelStatus::DOWNLOADED, None)
589 .expect("Failed to set model3 status");
590 db.set_status("model4", provider, ModelStatus::ERROR, None)
591 .expect("Failed to set model4 status");
592
593 let (downloading, downloaded, error) =
595 db.get_status_counts().expect("Failed to get status counts");
596 assert_eq!(downloading, 2);
597 assert_eq!(downloaded, 1);
598 assert_eq!(error, 1);
599 }
600
601 #[test]
602 fn test_model_provider_string_conversion() {
603 let (db, _temp_dir) = create_test_database();
604 let model_name = "test-model";
605
606 db.set_status(
608 model_name,
609 ModelProvider::HuggingFace,
610 ModelStatus::DOWNLOADED,
611 None,
612 )
613 .expect("Failed to set status");
614
615 let record = db
616 .get_model_record(model_name)
617 .expect("Failed to get record")
618 .expect("Record should be present");
619 assert_eq!(record.provider, ModelProvider::HuggingFace);
620 }
621
622 #[test]
623 fn test_model_status_string_conversion() {
624 let (db, _temp_dir) = create_test_database();
625 let provider = ModelProvider::HuggingFace;
626
627 let statuses = [
629 ModelStatus::DOWNLOADING,
630 ModelStatus::DOWNLOADED,
631 ModelStatus::ERROR,
632 ];
633
634 for (i, status) in statuses.iter().enumerate() {
635 let model_name = format!("model{i}");
636 db.set_status(&model_name, provider, *status, None)
637 .expect("Failed to set status");
638
639 let retrieved_status = db
640 .get_status(&model_name)
641 .expect("Failed to get status")
642 .expect("Status should be present");
643 assert_eq!(retrieved_status, *status);
644 }
645 }
646
647 #[test]
648 fn test_concurrent_access() {
649 let (db, _temp_dir) = create_test_database();
650 let provider = ModelProvider::HuggingFace;
651
652 for i in 0..10 {
654 let model_name = format!("model{i}");
655 db.set_status(&model_name, provider, ModelStatus::DOWNLOADED, None)
656 .expect("Failed to set status");
657 let _status = db.get_status(&model_name).expect("Failed to get status");
658 db.touch_model(&model_name).expect("Failed to touch model");
659 }
660
661 let models = db
662 .get_models_by_last_used(None)
663 .expect("Failed to get models");
664 assert_eq!(models.len(), 10);
665 }
666
667 #[test]
668 fn test_try_claim_for_download_new_model() {
669 let (db, _temp_dir) = create_test_database();
670 let model_name = "new-model";
671 let provider = ModelProvider::HuggingFace;
672
673 let status = db
675 .try_claim_for_download(model_name, provider)
676 .expect("Failed to claim for download");
677 assert_eq!(status, ModelStatus::DOWNLOADING);
678
679 let record = db
681 .get_model_record(model_name)
682 .expect("Failed to get record")
683 .expect("Record should be present");
684 assert_eq!(record.model_name, model_name);
685 assert_eq!(record.provider, provider);
686 assert_eq!(record.status, ModelStatus::DOWNLOADING);
687 }
688
689 #[test]
690 fn test_try_claim_for_download_existing_model() {
691 let (db, _temp_dir) = create_test_database();
692 let model_name = "existing-model";
693 let provider = ModelProvider::HuggingFace;
694
695 db.set_status(model_name, provider, ModelStatus::DOWNLOADED, None)
697 .expect("Failed to set initial status");
698
699 let status = db
701 .try_claim_for_download(model_name, provider)
702 .expect("Failed to claim for download");
703 assert_eq!(status, ModelStatus::DOWNLOADED);
704
705 let record = db
707 .get_model_record(model_name)
708 .expect("Failed to get record")
709 .expect("Record should be present");
710 assert_eq!(record.model_name, model_name);
711 assert_eq!(record.provider, provider);
712 assert_eq!(record.status, ModelStatus::DOWNLOADED);
713 }
714
715 #[test]
716 fn test_try_claim_for_download_race_condition() {
717 let (db, _temp_dir) = create_test_database();
718 let model_name = "race-condition-model";
719 let provider = ModelProvider::HuggingFace;
720
721 let status1 = db
723 .try_claim_for_download(model_name, provider)
724 .expect("Failed to claim for download 1");
725 let status2 = db
726 .try_claim_for_download(model_name, provider)
727 .expect("Failed to claim for download 2");
728
729 assert_eq!(status1, ModelStatus::DOWNLOADING);
731 assert_eq!(status2, ModelStatus::DOWNLOADING);
732
733 let record = db
735 .get_model_record(model_name)
736 .expect("Failed to get record")
737 .expect("Record should be present");
738 assert_eq!(record.model_name, model_name);
739 assert_eq!(record.provider, provider);
740 assert_eq!(record.status, ModelStatus::DOWNLOADING);
741 }
742
743 #[test]
744 fn test_try_claim_for_download_compare_and_swap() {
745 let (db, _temp_dir) = create_test_database();
746 let model_name = "test-cas-model";
747 let provider = ModelProvider::HuggingFace;
748
749 let status1 = db
751 .try_claim_for_download(model_name, provider)
752 .expect("Failed to claim for download 1");
753 assert_eq!(status1, ModelStatus::DOWNLOADING);
754
755 let status2 = db
757 .try_claim_for_download(model_name, provider)
758 .expect("Failed to claim for download 2");
759 assert_eq!(status2, ModelStatus::DOWNLOADING);
760
761 db.set_status(model_name, provider, ModelStatus::DOWNLOADED, None)
763 .expect("Failed to update status");
764
765 let status3 = db
767 .try_claim_for_download(model_name, provider)
768 .expect("Failed to claim for download 3");
769 assert_eq!(status3, ModelStatus::DOWNLOADED);
770 }
771}