fai_protocol/database/
mod.rs

1//! Database layer for FAI Protocol
2//!
3//! Handles SQLite database operations for commits, staging, and file tracking.
4
5use anyhow::Result;
6use chrono::{DateTime, Utc};
7use rusqlite::{params, Connection};
8use std::path::Path;
9
10/// Represents a commit in the FAI repository
11#[derive(Debug, Clone)]
12pub struct Commit {
13    /// Commit hash
14    pub hash: String,
15    /// Commit message
16    pub message: String,
17    /// Commit timestamp
18    pub timestamp: DateTime<Utc>,
19    /// Parent commit hash (None for initial commit)
20    pub parent_hash: Option<String>,
21}
22
23/// Database manager for FAI Protocol
24pub struct DatabaseManager {
25    /// SQLite database connection
26    conn: Connection,
27}
28
29impl DatabaseManager {
30    /// Create a new database manager with the specified database path
31    ///
32    /// # Arguments
33    /// * `db_path` - Path to the SQLite database file
34    ///
35    /// # Returns
36    /// A new DatabaseManager instance
37    pub fn new(db_path: &Path) -> Result<Self> {
38        let conn = Connection::open(db_path)?;
39        let db = Self { conn };
40        db.init_schema()?;
41        Ok(db)
42    }
43
44    /// Get access to the database connection (for network module usage)
45    pub fn connection(&self) -> &Connection {
46        &self.conn
47    }
48
49    /// Initialize the database schema
50    ///
51    /// Creates the necessary tables if they don't exist
52    fn init_schema(&self) -> Result<()> {
53        // Enable foreign key support
54        self.conn.execute("PRAGMA foreign_keys = ON", [])?;
55
56        // Create commits table
57        self.conn.execute(
58            "CREATE TABLE IF NOT EXISTS commits (
59                hash TEXT PRIMARY KEY,
60                message TEXT NOT NULL,
61                timestamp INTEGER NOT NULL,
62                parent_hash TEXT,
63                FOREIGN KEY (parent_hash) REFERENCES commits(hash)
64            )",
65            [],
66        )?;
67
68        // Create commit_files table
69        self.conn.execute(
70            "CREATE TABLE IF NOT EXISTS commit_files (
71                commit_hash TEXT NOT NULL,
72                file_path TEXT NOT NULL,
73                file_hash TEXT NOT NULL,
74                file_size INTEGER NOT NULL,
75                PRIMARY KEY (commit_hash, file_path),
76                FOREIGN KEY (commit_hash) REFERENCES commits(hash) ON DELETE CASCADE
77            )",
78            [],
79        )?;
80
81        // Create staging table
82        self.conn.execute(
83            "CREATE TABLE IF NOT EXISTS staging (
84                file_path TEXT PRIMARY KEY,
85                file_hash TEXT NOT NULL,
86                file_size INTEGER NOT NULL
87            )",
88            [],
89        )?;
90
91        Ok(())
92    }
93
94    /// Add a file to the staging area
95    ///
96    /// # Arguments
97    /// * `path` - File path relative to repository root
98    /// * `hash` - Content hash of the file
99    /// * `size` - File size in bytes
100    pub fn add_to_staging(&self, path: &str, hash: &str, size: u64) -> Result<()> {
101        self.conn.execute(
102            "INSERT OR REPLACE INTO staging (file_path, file_hash, file_size) VALUES (?1, ?2, ?3)",
103            params![path, hash, size],
104        )?;
105        Ok(())
106    }
107
108    /// Get all staged files
109    ///
110    /// # Returns
111    /// Vector of tuples containing (file_path, file_hash, file_size)
112    pub fn get_staged_files(&self) -> Result<Vec<(String, String, u64)>> {
113        let mut stmt = self
114            .conn
115            .prepare("SELECT file_path, file_hash, file_size FROM staging ORDER BY file_path")?;
116
117        let rows = stmt.query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?;
118
119        let mut files = Vec::new();
120        for row in rows {
121            files.push(row?);
122        }
123
124        Ok(files)
125    }
126
127    /// Clear all files from the staging area
128    pub fn clear_staging(&self) -> Result<()> {
129        self.conn.execute("DELETE FROM staging", [])?;
130        Ok(())
131    }
132
133    /// Create a new commit
134    ///
135    /// # Arguments
136    /// * `hash` - Commit hash
137    /// * `message` - Commit message
138    /// * `parent` - Optional parent commit hash
139    /// * `files` - List of files included in this commit
140    pub fn create_commit(
141        &self,
142        hash: &str,
143        message: &str,
144        parent: Option<&str>,
145        files: &[(String, String, u64)],
146    ) -> Result<()> {
147        // Validate inputs
148        if hash.is_empty() {
149            return Err(anyhow::anyhow!("Commit hash cannot be empty"));
150        }
151        if message.trim().is_empty() {
152            return Err(anyhow::anyhow!("Commit message cannot be empty"));
153        }
154
155        // Insert commit with current timestamp in milliseconds for uniqueness
156        let timestamp = Utc::now().timestamp_millis();
157
158        // Check if commit already exists
159        let existing_count: i64 = self
160            .conn
161            .query_row(
162                "SELECT COUNT(*) FROM commits WHERE hash = ?1",
163                [hash],
164                |row| row.get(0),
165            )
166            .unwrap_or(0);
167
168        if existing_count > 0 {
169            println!("DEBUG: Commit {} already exists, skipping insertion", hash);
170            return Ok(());
171        }
172
173        println!(
174            "DEBUG: Creating commit: hash={}, message={}, timestamp={}",
175            hash, message, timestamp
176        );
177
178        // Insert commit
179        match self.conn.execute(
180            "INSERT INTO commits (hash, message, timestamp, parent_hash) VALUES (?1, ?2, ?3, ?4)",
181            params![hash, message, timestamp, parent],
182        ) {
183            Ok(rows) => {
184                println!(
185                    "DEBUG: Successfully inserted commit, rows affected: {}",
186                    rows
187                );
188            }
189            Err(e) => {
190                println!("DEBUG: Failed to insert commit: {}", e);
191                return Err(anyhow::anyhow!("Failed to insert commit: {}", e));
192            }
193        }
194
195        // Insert commit files
196        for (file_path, file_hash, file_size) in files {
197            println!(
198                "DEBUG: Inserting commit file: path={}, hash={}, size={}",
199                file_path, file_hash, file_size
200            );
201            match self.conn.execute(
202                "INSERT INTO commit_files (commit_hash, file_path, file_hash, file_size) VALUES (?1, ?2, ?3, ?4)",
203                params![hash, file_path, file_hash, file_size],
204            ) {
205                Ok(rows) => {
206                    println!("DEBUG: Successfully inserted commit file, rows affected: {}", rows);
207                }
208                Err(e) => {
209                    println!("DEBUG: Failed to insert commit file: {}", e);
210                    // Continue with other files even if one fails
211                }
212            }
213        }
214
215        Ok(())
216    }
217
218    /// Get commit information by hash
219    ///
220    /// # Arguments
221    /// * `hash` - Commit hash
222    ///
223    /// # Returns
224    /// The commit if found, None otherwise
225    pub fn get_commit(&self, hash: &str) -> Result<Option<Commit>> {
226        let mut stmt = self
227            .conn
228            .prepare("SELECT hash, message, timestamp, parent_hash FROM commits WHERE hash = ?1")?;
229
230        let mut rows = stmt.query([hash])?;
231        if let Some(row) = rows.next()? {
232            Ok(Some(Commit {
233                hash: row.get(0)?,
234                message: row.get(1)?,
235                timestamp: DateTime::from_timestamp_millis(row.get(2)?).unwrap_or_default(),
236                parent_hash: row.get(3)?,
237            }))
238        } else {
239            Ok(None)
240        }
241    }
242
243    /// Get all files associated with a commit
244    ///
245    /// # Arguments
246    /// * `hash` - Commit hash
247    ///
248    /// # Returns
249    /// Vector of tuples containing (file_path, file_hash, file_size)
250    pub fn get_commit_files(&self, hash: &str) -> Result<Vec<(String, String, u64)>> {
251        let mut stmt = self.conn.prepare(
252            "SELECT file_path, file_hash, file_size FROM commit_files WHERE commit_hash = ?1 ORDER BY file_path"
253        )?;
254
255        let rows = stmt.query_map([hash], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?;
256
257        let mut files = Vec::new();
258        for row in rows {
259            files.push(row?);
260        }
261
262        Ok(files)
263    }
264
265    /// Get the latest commit (HEAD)
266    ///
267    /// # Returns
268    /// The latest commit hash if any commits exist
269    pub fn get_head(&self) -> Result<Option<String>> {
270        let mut stmt = self
271            .conn
272            .prepare("SELECT hash FROM commits ORDER BY timestamp DESC, hash DESC LIMIT 1")?;
273
274        let mut rows = stmt.query([])?;
275        if let Some(row) = rows.next()? {
276            Ok(Some(row.get(0)?))
277        } else {
278            Ok(None)
279        }
280    }
281
282    /// Get commit history
283    ///
284    /// # Arguments
285    /// * `limit` - Maximum number of commits to return (None for all)
286    ///
287    /// # Returns
288    /// Vector of commits ordered by timestamp (newest first)
289    pub fn get_commit_history(&self, limit: Option<i32>) -> Result<Vec<Commit>> {
290        let query = if let Some(limit) = limit {
291            format!("SELECT hash, message, timestamp, parent_hash FROM commits ORDER BY timestamp DESC LIMIT {}", limit)
292        } else {
293            "SELECT hash, message, timestamp, parent_hash FROM commits ORDER BY timestamp DESC"
294                .to_string()
295        };
296
297        let mut stmt = self.conn.prepare(&query)?;
298
299        let rows = stmt.query_map([], |row| {
300            Ok(Commit {
301                hash: row.get(0)?,
302                message: row.get(1)?,
303                timestamp: DateTime::from_timestamp_millis(row.get(2)?).unwrap_or_default(),
304                parent_hash: row.get(3)?,
305            })
306        })?;
307
308        let mut commits = Vec::new();
309        for row in rows {
310            commits.push(row?);
311        }
312
313        Ok(commits)
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use tempfile::TempDir;
321
322    fn create_temp_database() -> (DatabaseManager, TempDir) {
323        let temp_dir = TempDir::new().unwrap();
324        let db_path = temp_dir.path().join("test.db");
325        let db = DatabaseManager::new(&db_path).unwrap();
326        (db, temp_dir)
327    }
328
329    #[test]
330    fn test_staging_operations() {
331        let (db, _temp_dir) = create_temp_database();
332
333        // Test adding to staging
334        db.add_to_staging("test.txt", "hash123", 100).unwrap();
335        db.add_to_staging("model.onnx", "hash456", 2048).unwrap();
336
337        // Test getting staged files
338        let staged = db.get_staged_files().unwrap();
339        assert_eq!(staged.len(), 2);
340        // Check that both files are present (order guaranteed by ORDER BY file_path)
341        assert!(staged.contains(&("model.onnx".to_string(), "hash456".to_string(), 2048)));
342        assert!(staged.contains(&("test.txt".to_string(), "hash123".to_string(), 100)));
343
344        // Test clearing staging
345        db.clear_staging().unwrap();
346        let staged = db.get_staged_files().unwrap();
347        assert_eq!(staged.len(), 0);
348    }
349
350    #[test]
351    fn test_commit_operations() {
352        let (db, _temp_dir) = create_temp_database();
353
354        // Create initial commit
355        let files = vec![
356            ("file1.txt".to_string(), "hash1".to_string(), 100),
357            ("file2.txt".to_string(), "hash2".to_string(), 200),
358        ];
359        db.create_commit("commit1", "Initial commit", None, &files)
360            .unwrap();
361
362        // Test getting commit
363        let commit = db.get_commit("commit1").unwrap();
364        assert!(commit.is_some());
365        let commit = commit.unwrap();
366        assert_eq!(commit.hash, "commit1");
367        assert_eq!(commit.message, "Initial commit");
368        assert_eq!(commit.parent_hash, None);
369
370        // Test getting commit files
371        let commit_files = db.get_commit_files("commit1").unwrap();
372        assert_eq!(commit_files.len(), 2);
373        // Check that both files are present (order guaranteed by ORDER BY file_path)
374        assert!(commit_files.contains(&("file1.txt".to_string(), "hash1".to_string(), 100)));
375        assert!(commit_files.contains(&("file2.txt".to_string(), "hash2".to_string(), 200)));
376
377        // Test HEAD
378        let head = db.get_head().unwrap();
379        assert_eq!(head, Some("commit1".to_string()));
380
381        // Create second commit with parent (add small delay to ensure different timestamp)
382        std::thread::sleep(std::time::Duration::from_millis(10));
383        let files2 = vec![
384            ("file1.txt".to_string(), "hash1_updated".to_string(), 150),
385            ("file3.txt".to_string(), "hash3".to_string(), 300),
386        ];
387        db.create_commit("commit2", "Second commit", Some("commit1"), &files2)
388            .unwrap();
389
390        // Test HEAD updated
391        let head = db.get_head().unwrap();
392        println!("Debug: HEAD = {:?}", head);
393        println!("Debug: Expected = {:?}", Some("commit2".to_string()));
394        assert_eq!(head, Some("commit2".to_string()));
395
396        // Test commit history
397        let history = db.get_commit_history(None).unwrap();
398        assert_eq!(history.len(), 2);
399        // Check that both commits are present (order guaranteed by ORDER BY timestamp DESC)
400        assert!(history.iter().any(|c| c.hash == "commit1"));
401        assert!(history.iter().any(|c| c.hash == "commit2"));
402        // Most recent commit should be first
403        assert_eq!(history[0].hash, "commit2");
404    }
405}