Skip to main content

modelexpress_server/
database.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use chrono::{DateTime, Utc};
5use modelexpress_common::models::{ModelProvider, ModelStatus};
6use rusqlite::{Connection, Result as SqliteResult, params};
7use std::sync::{Arc, Mutex};
8use tracing::info;
9
10/// Database model for tracking model download status
11#[derive(Debug, Clone)]
12pub struct ModelRecord {
13    pub model_name: String,
14    pub provider: ModelProvider,
15    pub status: ModelStatus,
16    pub created_at: DateTime<Utc>,
17    pub last_used_at: DateTime<Utc>,
18    pub message: Option<String>,
19}
20
21/// SQLite-based model status tracker for distributed systems
22#[derive(Debug, Clone)]
23pub struct ModelDatabase {
24    connection: Arc<Mutex<Connection>>,
25}
26
27impl ModelDatabase {
28    /// Helper method to acquire the database connection with proper poison recovery
29    fn acquire_connection(&self) -> SqliteResult<std::sync::MutexGuard<'_, Connection>> {
30        self.connection.lock().map_err(|_| {
31            rusqlite::Error::SqliteFailure(
32                rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_LOCKED),
33                Some("Rust mutex protecting database connection is poisoned".to_string()),
34            )
35        })
36    }
37
38    /// Create a new database instance and initialize the schema
39    pub fn new(database_path: &str) -> Result<Self, Box<dyn std::error::Error>> {
40        let conn = Connection::open(database_path)?;
41
42        // Create the models table if it doesn't exist
43        conn.execute(
44            r"
45            CREATE TABLE IF NOT EXISTS models (
46                model_name TEXT PRIMARY KEY,
47                provider TEXT NOT NULL,
48                status TEXT NOT NULL,
49                created_at TEXT NOT NULL,
50                last_used_at TEXT NOT NULL,
51                message TEXT
52            )
53            ",
54            [],
55        )?;
56
57        // Create an index on last_used_at for efficient LRU queries
58        conn.execute(
59            "CREATE INDEX IF NOT EXISTS idx_last_used_at ON models(last_used_at)",
60            [],
61        )?;
62
63        info!("Model database initialized at: {}", database_path);
64
65        Ok(Self {
66            connection: Arc::new(Mutex::new(conn)),
67        })
68    }
69
70    /// Get the status of a model
71    pub fn get_status(&self, model_name: &str) -> SqliteResult<Option<ModelStatus>> {
72        let conn = self.acquire_connection()?;
73        let mut stmt = conn.prepare("SELECT status FROM models WHERE model_name = ?1")?;
74
75        let mut rows = stmt.query_map([model_name], |row| {
76            let status_str: String = row.get(0)?;
77            Ok(status_str)
78        })?;
79
80        if let Some(row) = rows.next() {
81            let status_str = row?;
82            let status = match status_str.as_str() {
83                "DOWNLOADING" => ModelStatus::DOWNLOADING,
84                "DOWNLOADED" => ModelStatus::DOWNLOADED,
85                "ERROR" => ModelStatus::ERROR,
86                _ => ModelStatus::ERROR,
87            };
88            Ok(Some(status))
89        } else {
90            Ok(None)
91        }
92    }
93
94    /// Get the full record for a model
95    pub fn get_model_record(&self, model_name: &str) -> SqliteResult<Option<ModelRecord>> {
96        let conn = self.acquire_connection()?;
97        let mut stmt = conn.prepare(
98            "SELECT model_name, provider, status, created_at, last_used_at, message FROM models WHERE model_name = ?1"
99        )?;
100
101        let mut rows = stmt.query_map([model_name], |row| {
102            let provider_str: String = row.get(1)?;
103            let status_str: String = row.get(2)?;
104            let created_at_str: String = row.get(3)?;
105            let last_used_at_str: String = row.get(4)?;
106            let message: Option<String> = row.get(5)?;
107
108            let provider = match provider_str.as_str() {
109                "HuggingFace" => ModelProvider::HuggingFace,
110                _ => ModelProvider::HuggingFace, // Default fallback
111            };
112
113            let status = match status_str.as_str() {
114                "DOWNLOADING" => ModelStatus::DOWNLOADING,
115                "DOWNLOADED" => ModelStatus::DOWNLOADED,
116                "ERROR" => ModelStatus::ERROR,
117                _ => ModelStatus::ERROR,
118            };
119
120            let created_at = DateTime::parse_from_rfc3339(&created_at_str)
121                .map_err(|_| {
122                    rusqlite::Error::InvalidColumnType(
123                        3,
124                        "created_at".to_string(),
125                        rusqlite::types::Type::Text,
126                    )
127                })?
128                .with_timezone(&Utc);
129
130            let last_used_at = DateTime::parse_from_rfc3339(&last_used_at_str)
131                .map_err(|_| {
132                    rusqlite::Error::InvalidColumnType(
133                        4,
134                        "last_used_at".to_string(),
135                        rusqlite::types::Type::Text,
136                    )
137                })?
138                .with_timezone(&Utc);
139
140            Ok(ModelRecord {
141                model_name: row.get(0)?,
142                provider,
143                status,
144                created_at,
145                last_used_at,
146                message,
147            })
148        })?;
149
150        if let Some(row) = rows.next() {
151            Ok(Some(row?))
152        } else {
153            Ok(None)
154        }
155    }
156
157    /// Set the status of a model, creating or updating the record
158    pub fn set_status(
159        &self,
160        model_name: &str,
161        provider: ModelProvider,
162        status: ModelStatus,
163        message: Option<String>,
164    ) -> SqliteResult<()> {
165        let conn = self.acquire_connection()?;
166        let now = Utc::now();
167
168        let provider_str = match provider {
169            ModelProvider::HuggingFace => "HuggingFace",
170        };
171
172        let status_str = match status {
173            ModelStatus::DOWNLOADING => "DOWNLOADING",
174            ModelStatus::DOWNLOADED => "DOWNLOADED",
175            ModelStatus::ERROR => "ERROR",
176        };
177
178        // Use INSERT OR REPLACE to handle both creation and updates
179        conn.execute(
180            r"
181            INSERT OR REPLACE INTO models (model_name, provider, status, created_at, last_used_at, message)
182            VALUES (?1, ?2, ?3,
183                COALESCE((SELECT created_at FROM models WHERE model_name = ?1), ?4),
184                ?4, ?5)
185            ",
186            params![
187                model_name,
188                provider_str,
189                status_str,
190                now.to_rfc3339(),
191                message
192            ],
193        )?;
194
195        Ok(())
196    }
197
198    /// Update the `last_used_at` timestamp for a model
199    pub fn touch_model(&self, model_name: &str) -> SqliteResult<()> {
200        let conn = self.acquire_connection()?;
201        let now = Utc::now();
202
203        conn.execute(
204            "UPDATE models SET last_used_at = ?1 WHERE model_name = ?2",
205            params![now.to_rfc3339(), model_name],
206        )?;
207
208        Ok(())
209    }
210
211    /// Delete a model record
212    pub fn delete_model(&self, model_name: &str) -> SqliteResult<()> {
213        let conn = self.acquire_connection()?;
214        conn.execute("DELETE FROM models WHERE model_name = ?1", [model_name])?;
215        Ok(())
216    }
217
218    /// Get models ordered by last used (oldest first) - for future LRU cleanup
219    pub fn get_models_by_last_used(&self, limit: Option<u32>) -> SqliteResult<Vec<ModelRecord>> {
220        let conn = self.acquire_connection()?;
221
222        let query = if let Some(limit) = limit {
223            format!(
224                "SELECT model_name, provider, status, created_at, last_used_at, message FROM models ORDER BY last_used_at ASC LIMIT {limit}"
225            )
226        } else {
227            "SELECT model_name, provider, status, created_at, last_used_at, message FROM models ORDER BY last_used_at ASC".to_string()
228        };
229
230        let mut stmt = conn.prepare(&query)?;
231        let rows = stmt.query_map([], |row| {
232            let provider_str: String = row.get(1)?;
233            let status_str: String = row.get(2)?;
234            let created_at_str: String = row.get(3)?;
235            let last_used_at_str: String = row.get(4)?;
236            let message: Option<String> = row.get(5)?;
237
238            let provider = match provider_str.as_str() {
239                "HuggingFace" => ModelProvider::HuggingFace,
240                _ => ModelProvider::HuggingFace,
241            };
242
243            let status = match status_str.as_str() {
244                "DOWNLOADING" => ModelStatus::DOWNLOADING,
245                "DOWNLOADED" => ModelStatus::DOWNLOADED,
246                "ERROR" => ModelStatus::ERROR,
247                _ => ModelStatus::ERROR,
248            };
249
250            let created_at = DateTime::parse_from_rfc3339(&created_at_str)
251                .map_err(|_| {
252                    rusqlite::Error::InvalidColumnType(
253                        3,
254                        "created_at".to_string(),
255                        rusqlite::types::Type::Text,
256                    )
257                })?
258                .with_timezone(&Utc);
259
260            let last_used_at = DateTime::parse_from_rfc3339(&last_used_at_str)
261                .map_err(|_| {
262                    rusqlite::Error::InvalidColumnType(
263                        4,
264                        "last_used_at".to_string(),
265                        rusqlite::types::Type::Text,
266                    )
267                })?
268                .with_timezone(&Utc);
269
270            Ok(ModelRecord {
271                model_name: row.get(0)?,
272                provider,
273                status,
274                created_at,
275                last_used_at,
276                message,
277            })
278        })?;
279
280        let mut models = Vec::new();
281        for row in rows {
282            models.push(row?);
283        }
284
285        Ok(models)
286    }
287
288    /// Get count of models with each status - for monitoring
289    pub fn get_status_counts(&self) -> SqliteResult<(u32, u32, u32)> {
290        let conn = self.acquire_connection()?;
291
292        let mut downloading = 0u32;
293        let mut downloaded = 0u32;
294        let mut error = 0u32;
295
296        let mut stmt = conn.prepare("SELECT status, COUNT(*) FROM models GROUP BY status")?;
297        let rows = stmt.query_map([], |row| {
298            let status: String = row.get(0)?;
299            let count: u32 = row.get(1)?;
300            Ok((status, count))
301        })?;
302
303        for row in rows {
304            let (status, count) = row?;
305            match status.as_str() {
306                "DOWNLOADING" => downloading = count,
307                "DOWNLOADED" => downloaded = count,
308                "ERROR" => error = count,
309                _ => {}
310            }
311        }
312
313        Ok((downloading, downloaded, error))
314    }
315
316    /// Atomically attempt to claim a model for downloading using compare-and-swap semantics
317    /// Returns the current status of the model:
318    /// - If model doesn't exist, creates it with DOWNLOADING status and returns DOWNLOADING
319    /// - If model exists, returns its current status without modification
320    ///   This prevents race conditions in distributed environments
321    pub fn try_claim_for_download(
322        &self,
323        model_name: &str,
324        provider: ModelProvider,
325    ) -> SqliteResult<ModelStatus> {
326        let conn = self.acquire_connection()?;
327        let now = Utc::now();
328
329        let provider_str = match provider {
330            ModelProvider::HuggingFace => "HuggingFace",
331        };
332
333        // Use INSERT OR IGNORE to atomically create the record only if it doesn't exist
334        // This is our compare-and-swap operation
335        let rows_affected = conn.execute(
336            r#"
337            INSERT OR IGNORE INTO models (model_name, provider, status, created_at, last_used_at, message)
338            VALUES (?1, ?2, 'DOWNLOADING', ?3, ?3, 'Starting download...')
339            "#,
340            params![model_name, provider_str, now.to_rfc3339()],
341        )?;
342
343        if rows_affected > 0 {
344            // We successfully inserted the record, so we claimed it for download
345            Ok(ModelStatus::DOWNLOADING)
346        } else {
347            // Record already exists, get its current status directly
348            let mut stmt = conn.prepare("SELECT status FROM models WHERE model_name = ?1")?;
349            let mut rows = stmt.query_map([model_name], |row| {
350                let status_str: String = row.get(0)?;
351                Ok(status_str)
352            })?;
353
354            if let Some(row) = rows.next() {
355                let status_str = row?;
356                let status = match status_str.as_str() {
357                    "DOWNLOADING" => ModelStatus::DOWNLOADING,
358                    "DOWNLOADED" => ModelStatus::DOWNLOADED,
359                    "ERROR" => ModelStatus::ERROR,
360                    _ => ModelStatus::ERROR,
361                };
362                Ok(status)
363            } else {
364                // This should never happen, but handle it gracefully
365                Err(rusqlite::Error::QueryReturnedNoRows)
366            }
367        }
368    }
369}
370
371#[cfg(test)]
372#[allow(clippy::expect_used)]
373mod tests {
374    use super::*;
375    use tempfile::TempDir;
376
377    fn create_test_database() -> (ModelDatabase, TempDir) {
378        let temp_dir = TempDir::new().expect("Failed to create temporary directory");
379        let db_path = temp_dir.path().join("test_models.db");
380        let db = ModelDatabase::new(db_path.to_str().expect("Invalid path"))
381            .expect("Failed to create test database");
382        (db, temp_dir)
383    }
384
385    #[test]
386    fn test_database_creation() {
387        let (db, _temp_dir) = create_test_database();
388        // If we get here without panicking, the database was created successfully
389
390        // Test that we can perform basic operations
391        let result = db.get_status("non-existent-model");
392        assert!(result.is_ok());
393        assert!(result.expect("Failed to get status").is_none());
394    }
395
396    #[test]
397    fn test_set_and_get_status() {
398        let (db, _temp_dir) = create_test_database();
399        let model_name = "test-model";
400        let provider = ModelProvider::HuggingFace;
401        let status = ModelStatus::DOWNLOADING;
402
403        // Set status
404        let result = db.set_status(model_name, provider, status, None);
405        assert!(result.is_ok());
406
407        // Get status
408        let retrieved_status = db.get_status(model_name).expect("Failed to get status");
409        assert!(retrieved_status.is_some());
410        assert_eq!(retrieved_status.expect("Status should be present"), status);
411    }
412
413    #[test]
414    fn test_update_status() {
415        let (db, _temp_dir) = create_test_database();
416        let model_name = "test-model";
417        let provider = ModelProvider::HuggingFace;
418
419        // Set initial status
420        db.set_status(model_name, provider, ModelStatus::DOWNLOADING, None)
421            .expect("Failed to set initial status");
422
423        // Update status
424        db.set_status(
425            model_name,
426            provider,
427            ModelStatus::DOWNLOADED,
428            Some("Success".to_string()),
429        )
430        .expect("Failed to update status");
431
432        // Verify update
433        let status = db
434            .get_status(model_name)
435            .expect("Failed to get status")
436            .expect("Status should be present");
437        assert_eq!(status, ModelStatus::DOWNLOADED);
438    }
439
440    #[test]
441    fn test_get_full_model_record() {
442        let (db, _temp_dir) = create_test_database();
443        let model_name = "test-model";
444        let provider = ModelProvider::HuggingFace;
445        let message = "Test message";
446
447        // Set status with message
448        db.set_status(
449            model_name,
450            provider,
451            ModelStatus::DOWNLOADED,
452            Some(message.to_string()),
453        )
454        .expect("Failed to set status");
455
456        // Get full record
457        let record = db
458            .get_model_record(model_name)
459            .expect("Failed to get model record");
460        assert!(record.is_some());
461
462        let record = record.expect("Record should be present");
463        assert_eq!(record.model_name, model_name);
464        assert_eq!(record.provider, provider);
465        assert_eq!(record.status, ModelStatus::DOWNLOADED);
466        assert_eq!(
467            record.message.as_ref().expect("Message should be present"),
468            message
469        );
470    }
471
472    #[test]
473    fn test_touch_model() {
474        let (db, _temp_dir) = create_test_database();
475        let model_name = "test-model";
476        let provider = ModelProvider::HuggingFace;
477
478        // Create model record
479        db.set_status(model_name, provider, ModelStatus::DOWNLOADED, None)
480            .expect("Failed to create model record");
481
482        // Get initial record
483        let initial_record = db
484            .get_model_record(model_name)
485            .expect("Failed to get initial record")
486            .expect("Record should be present");
487
488        // Sleep a bit to ensure time difference
489        std::thread::sleep(std::time::Duration::from_millis(10));
490
491        // Touch the model
492        db.touch_model(model_name).expect("Failed to touch model");
493
494        // Get updated record
495        let updated_record = db
496            .get_model_record(model_name)
497            .expect("Failed to get updated record")
498            .expect("Record should be present");
499
500        // last_used_at should be updated
501        assert!(updated_record.last_used_at > initial_record.last_used_at);
502        // created_at should remain the same
503        assert_eq!(updated_record.created_at, initial_record.created_at);
504    }
505
506    #[test]
507    fn test_delete_model() {
508        let (db, _temp_dir) = create_test_database();
509        let model_name = "test-model";
510        let provider = ModelProvider::HuggingFace;
511
512        // Create model record
513        db.set_status(model_name, provider, ModelStatus::DOWNLOADED, None)
514            .expect("Failed to create model record");
515
516        // Verify it exists
517        assert!(
518            db.get_status(model_name)
519                .expect("Failed to get status")
520                .is_some()
521        );
522
523        // Delete the model
524        db.delete_model(model_name).expect("Failed to delete model");
525
526        // Verify it's gone
527        assert!(
528            db.get_status(model_name)
529                .expect("Failed to get status")
530                .is_none()
531        );
532    }
533
534    #[test]
535    fn test_get_models_by_last_used() {
536        let (db, _temp_dir) = create_test_database();
537        let provider = ModelProvider::HuggingFace;
538
539        // Create multiple models
540        db.set_status("model1", provider, ModelStatus::DOWNLOADED, None)
541            .expect("Failed to create model1");
542        std::thread::sleep(std::time::Duration::from_millis(10));
543
544        db.set_status("model2", provider, ModelStatus::DOWNLOADED, None)
545            .expect("Failed to create model2");
546        std::thread::sleep(std::time::Duration::from_millis(10));
547
548        db.set_status("model3", provider, ModelStatus::DOWNLOADED, None)
549            .expect("Failed to create model3");
550
551        // Get all models ordered by last used
552        let models = db
553            .get_models_by_last_used(None)
554            .expect("Failed to get models");
555        assert_eq!(models.len(), 3);
556
557        // Should be ordered by last_used_at (oldest first)
558        assert_eq!(models[0].model_name, "model1");
559        assert_eq!(models[1].model_name, "model2");
560        assert_eq!(models[2].model_name, "model3");
561
562        // Test with limit
563        let limited_models = db
564            .get_models_by_last_used(Some(2))
565            .expect("Failed to get limited models");
566        assert_eq!(limited_models.len(), 2);
567        assert_eq!(limited_models[0].model_name, "model1");
568        assert_eq!(limited_models[1].model_name, "model2");
569    }
570
571    #[test]
572    fn test_get_status_counts() {
573        let (db, _temp_dir) = create_test_database();
574        let provider = ModelProvider::HuggingFace;
575
576        // Initially should be all zeros
577        let (downloading, downloaded, error) =
578            db.get_status_counts().expect("Failed to get status counts");
579        assert_eq!(downloading, 0);
580        assert_eq!(downloaded, 0);
581        assert_eq!(error, 0);
582
583        // Add models with different statuses
584        db.set_status("model1", provider, ModelStatus::DOWNLOADING, None)
585            .expect("Failed to set model1 status");
586        db.set_status("model2", provider, ModelStatus::DOWNLOADING, None)
587            .expect("Failed to set model2 status");
588        db.set_status("model3", provider, ModelStatus::DOWNLOADED, None)
589            .expect("Failed to set model3 status");
590        db.set_status("model4", provider, ModelStatus::ERROR, None)
591            .expect("Failed to set model4 status");
592
593        // Check counts
594        let (downloading, downloaded, error) =
595            db.get_status_counts().expect("Failed to get status counts");
596        assert_eq!(downloading, 2);
597        assert_eq!(downloaded, 1);
598        assert_eq!(error, 1);
599    }
600
601    #[test]
602    fn test_model_provider_string_conversion() {
603        let (db, _temp_dir) = create_test_database();
604        let model_name = "test-model";
605
606        // Test HuggingFace provider
607        db.set_status(
608            model_name,
609            ModelProvider::HuggingFace,
610            ModelStatus::DOWNLOADED,
611            None,
612        )
613        .expect("Failed to set status");
614
615        let record = db
616            .get_model_record(model_name)
617            .expect("Failed to get record")
618            .expect("Record should be present");
619        assert_eq!(record.provider, ModelProvider::HuggingFace);
620    }
621
622    #[test]
623    fn test_model_status_string_conversion() {
624        let (db, _temp_dir) = create_test_database();
625        let provider = ModelProvider::HuggingFace;
626
627        // Test all status variants
628        let statuses = [
629            ModelStatus::DOWNLOADING,
630            ModelStatus::DOWNLOADED,
631            ModelStatus::ERROR,
632        ];
633
634        for (i, status) in statuses.iter().enumerate() {
635            let model_name = format!("model{i}");
636            db.set_status(&model_name, provider, *status, None)
637                .expect("Failed to set status");
638
639            let retrieved_status = db
640                .get_status(&model_name)
641                .expect("Failed to get status")
642                .expect("Status should be present");
643            assert_eq!(retrieved_status, *status);
644        }
645    }
646
647    #[test]
648    fn test_concurrent_access() {
649        let (db, _temp_dir) = create_test_database();
650        let provider = ModelProvider::HuggingFace;
651
652        // Test that multiple operations can be performed without deadlock
653        for i in 0..10 {
654            let model_name = format!("model{i}");
655            db.set_status(&model_name, provider, ModelStatus::DOWNLOADED, None)
656                .expect("Failed to set status");
657            let _status = db.get_status(&model_name).expect("Failed to get status");
658            db.touch_model(&model_name).expect("Failed to touch model");
659        }
660
661        let models = db
662            .get_models_by_last_used(None)
663            .expect("Failed to get models");
664        assert_eq!(models.len(), 10);
665    }
666
667    #[test]
668    fn test_try_claim_for_download_new_model() {
669        let (db, _temp_dir) = create_test_database();
670        let model_name = "new-model";
671        let provider = ModelProvider::HuggingFace;
672
673        // Try to claim the model for download
674        let status = db
675            .try_claim_for_download(model_name, provider)
676            .expect("Failed to claim for download");
677        assert_eq!(status, ModelStatus::DOWNLOADING);
678
679        // Verify the model record was created
680        let record = db
681            .get_model_record(model_name)
682            .expect("Failed to get record")
683            .expect("Record should be present");
684        assert_eq!(record.model_name, model_name);
685        assert_eq!(record.provider, provider);
686        assert_eq!(record.status, ModelStatus::DOWNLOADING);
687    }
688
689    #[test]
690    fn test_try_claim_for_download_existing_model() {
691        let (db, _temp_dir) = create_test_database();
692        let model_name = "existing-model";
693        let provider = ModelProvider::HuggingFace;
694
695        // Pre-create the model record as DOWNLOADED
696        db.set_status(model_name, provider, ModelStatus::DOWNLOADED, None)
697            .expect("Failed to set initial status");
698
699        // Try to claim the model for download
700        let status = db
701            .try_claim_for_download(model_name, provider)
702            .expect("Failed to claim for download");
703        assert_eq!(status, ModelStatus::DOWNLOADED);
704
705        // Verify the model record was not modified
706        let record = db
707            .get_model_record(model_name)
708            .expect("Failed to get record")
709            .expect("Record should be present");
710        assert_eq!(record.model_name, model_name);
711        assert_eq!(record.provider, provider);
712        assert_eq!(record.status, ModelStatus::DOWNLOADED);
713    }
714
715    #[test]
716    fn test_try_claim_for_download_race_condition() {
717        let (db, _temp_dir) = create_test_database();
718        let model_name = "race-condition-model";
719        let provider = ModelProvider::HuggingFace;
720
721        // Simulate two concurrent attempts to claim the model
722        let status1 = db
723            .try_claim_for_download(model_name, provider)
724            .expect("Failed to claim for download 1");
725        let status2 = db
726            .try_claim_for_download(model_name, provider)
727            .expect("Failed to claim for download 2");
728
729        // First call should claim it (DOWNLOADING), second should see it exists (DOWNLOADING)
730        assert_eq!(status1, ModelStatus::DOWNLOADING);
731        assert_eq!(status2, ModelStatus::DOWNLOADING);
732
733        // Verify the model record reflects the DOWNLOADING status
734        let record = db
735            .get_model_record(model_name)
736            .expect("Failed to get record")
737            .expect("Record should be present");
738        assert_eq!(record.model_name, model_name);
739        assert_eq!(record.provider, provider);
740        assert_eq!(record.status, ModelStatus::DOWNLOADING);
741    }
742
743    #[test]
744    fn test_try_claim_for_download_compare_and_swap() {
745        let (db, _temp_dir) = create_test_database();
746        let model_name = "test-cas-model";
747        let provider = ModelProvider::HuggingFace;
748
749        // First claim should succeed and return DOWNLOADING
750        let status1 = db
751            .try_claim_for_download(model_name, provider)
752            .expect("Failed to claim for download 1");
753        assert_eq!(status1, ModelStatus::DOWNLOADING);
754
755        // Second claim should return DOWNLOADING (existing status)
756        let status2 = db
757            .try_claim_for_download(model_name, provider)
758            .expect("Failed to claim for download 2");
759        assert_eq!(status2, ModelStatus::DOWNLOADING);
760
761        // Update to DOWNLOADED
762        db.set_status(model_name, provider, ModelStatus::DOWNLOADED, None)
763            .expect("Failed to update status");
764
765        // Third claim should return DOWNLOADED (existing status)
766        let status3 = db
767            .try_claim_for_download(model_name, provider)
768            .expect("Failed to claim for download 3");
769        assert_eq!(status3, ModelStatus::DOWNLOADED);
770    }
771}