use anyhow::Result;
use chrono::{DateTime, Utc};
use rusqlite::{params, Connection};
use std::path::Path;
#[derive(Debug, Clone)]
pub struct Commit {
pub hash: String,
pub message: String,
pub timestamp: DateTime<Utc>,
pub parent_hash: Option<String>,
}
pub struct DatabaseManager {
conn: Connection,
}
impl DatabaseManager {
pub fn new(db_path: &Path) -> Result<Self> {
let conn = Connection::open(db_path)?;
let db = Self { conn };
db.init_schema()?;
Ok(db)
}
pub fn connection(&self) -> &Connection {
&self.conn
}
fn init_schema(&self) -> Result<()> {
self.conn.execute("PRAGMA foreign_keys = ON", [])?;
self.conn.execute(
"CREATE TABLE IF NOT EXISTS commits (
hash TEXT PRIMARY KEY,
message TEXT NOT NULL,
timestamp INTEGER NOT NULL,
parent_hash TEXT,
FOREIGN KEY (parent_hash) REFERENCES commits(hash)
)",
[],
)?;
self.conn.execute(
"CREATE TABLE IF NOT EXISTS commit_files (
commit_hash TEXT NOT NULL,
file_path TEXT NOT NULL,
file_hash TEXT NOT NULL,
file_size INTEGER NOT NULL,
PRIMARY KEY (commit_hash, file_path),
FOREIGN KEY (commit_hash) REFERENCES commits(hash) ON DELETE CASCADE
)",
[],
)?;
self.conn.execute(
"CREATE TABLE IF NOT EXISTS staging (
file_path TEXT PRIMARY KEY,
file_hash TEXT NOT NULL,
file_size INTEGER NOT NULL
)",
[],
)?;
Ok(())
}
pub fn add_to_staging(&self, path: &str, hash: &str, size: u64) -> Result<()> {
self.conn.execute(
"INSERT OR REPLACE INTO staging (file_path, file_hash, file_size) VALUES (?1, ?2, ?3)",
params![path, hash, size],
)?;
Ok(())
}
pub fn get_staged_files(&self) -> Result<Vec<(String, String, u64)>> {
let mut stmt = self
.conn
.prepare("SELECT file_path, file_hash, file_size FROM staging ORDER BY file_path")?;
let rows = stmt.query_map([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?;
let mut files = Vec::new();
for row in rows {
files.push(row?);
}
Ok(files)
}
pub fn clear_staging(&self) -> Result<()> {
self.conn.execute("DELETE FROM staging", [])?;
Ok(())
}
pub fn create_commit(
&self,
hash: &str,
message: &str,
parent: Option<&str>,
files: &[(String, String, u64)],
) -> Result<()> {
if hash.is_empty() {
return Err(anyhow::anyhow!("Commit hash cannot be empty"));
}
if message.trim().is_empty() {
return Err(anyhow::anyhow!("Commit message cannot be empty"));
}
let timestamp = Utc::now().timestamp_millis();
let existing_count: i64 = self
.conn
.query_row(
"SELECT COUNT(*) FROM commits WHERE hash = ?1",
[hash],
|row| row.get(0),
)
.unwrap_or(0);
if existing_count > 0 {
println!("DEBUG: Commit {} already exists, skipping insertion", hash);
return Ok(());
}
println!(
"DEBUG: Creating commit: hash={}, message={}, timestamp={}",
hash, message, timestamp
);
match self.conn.execute(
"INSERT INTO commits (hash, message, timestamp, parent_hash) VALUES (?1, ?2, ?3, ?4)",
params![hash, message, timestamp, parent],
) {
Ok(rows) => {
println!(
"DEBUG: Successfully inserted commit, rows affected: {}",
rows
);
}
Err(e) => {
println!("DEBUG: Failed to insert commit: {}", e);
return Err(anyhow::anyhow!("Failed to insert commit: {}", e));
}
}
for (file_path, file_hash, file_size) in files {
println!(
"DEBUG: Inserting commit file: path={}, hash={}, size={}",
file_path, file_hash, file_size
);
match self.conn.execute(
"INSERT INTO commit_files (commit_hash, file_path, file_hash, file_size) VALUES (?1, ?2, ?3, ?4)",
params![hash, file_path, file_hash, file_size],
) {
Ok(rows) => {
println!("DEBUG: Successfully inserted commit file, rows affected: {}", rows);
}
Err(e) => {
println!("DEBUG: Failed to insert commit file: {}", e);
}
}
}
Ok(())
}
pub fn get_commit(&self, hash: &str) -> Result<Option<Commit>> {
let mut stmt = self
.conn
.prepare("SELECT hash, message, timestamp, parent_hash FROM commits WHERE hash = ?1")?;
let mut rows = stmt.query([hash])?;
if let Some(row) = rows.next()? {
Ok(Some(Commit {
hash: row.get(0)?,
message: row.get(1)?,
timestamp: DateTime::from_timestamp_millis(row.get(2)?).unwrap_or_default(),
parent_hash: row.get(3)?,
}))
} else {
Ok(None)
}
}
pub fn get_commit_files(&self, hash: &str) -> Result<Vec<(String, String, u64)>> {
let mut stmt = self.conn.prepare(
"SELECT file_path, file_hash, file_size FROM commit_files WHERE commit_hash = ?1 ORDER BY file_path"
)?;
let rows = stmt.query_map([hash], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?;
let mut files = Vec::new();
for row in rows {
files.push(row?);
}
Ok(files)
}
pub fn get_head(&self) -> Result<Option<String>> {
let mut stmt = self
.conn
.prepare("SELECT hash FROM commits ORDER BY timestamp DESC, hash DESC LIMIT 1")?;
let mut rows = stmt.query([])?;
if let Some(row) = rows.next()? {
Ok(Some(row.get(0)?))
} else {
Ok(None)
}
}
pub fn get_commit_history(&self, limit: Option<i32>) -> Result<Vec<Commit>> {
let query = if let Some(limit) = limit {
format!("SELECT hash, message, timestamp, parent_hash FROM commits ORDER BY timestamp DESC LIMIT {}", limit)
} else {
"SELECT hash, message, timestamp, parent_hash FROM commits ORDER BY timestamp DESC"
.to_string()
};
let mut stmt = self.conn.prepare(&query)?;
let rows = stmt.query_map([], |row| {
Ok(Commit {
hash: row.get(0)?,
message: row.get(1)?,
timestamp: DateTime::from_timestamp_millis(row.get(2)?).unwrap_or_default(),
parent_hash: row.get(3)?,
})
})?;
let mut commits = Vec::new();
for row in rows {
commits.push(row?);
}
Ok(commits)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_temp_database() -> (DatabaseManager, TempDir) {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let db = DatabaseManager::new(&db_path).unwrap();
(db, temp_dir)
}
#[test]
fn test_staging_operations() {
let (db, _temp_dir) = create_temp_database();
db.add_to_staging("test.txt", "hash123", 100).unwrap();
db.add_to_staging("model.onnx", "hash456", 2048).unwrap();
let staged = db.get_staged_files().unwrap();
assert_eq!(staged.len(), 2);
assert!(staged.contains(&("model.onnx".to_string(), "hash456".to_string(), 2048)));
assert!(staged.contains(&("test.txt".to_string(), "hash123".to_string(), 100)));
db.clear_staging().unwrap();
let staged = db.get_staged_files().unwrap();
assert_eq!(staged.len(), 0);
}
#[test]
fn test_commit_operations() {
let (db, _temp_dir) = create_temp_database();
let files = vec![
("file1.txt".to_string(), "hash1".to_string(), 100),
("file2.txt".to_string(), "hash2".to_string(), 200),
];
db.create_commit("commit1", "Initial commit", None, &files)
.unwrap();
let commit = db.get_commit("commit1").unwrap();
assert!(commit.is_some());
let commit = commit.unwrap();
assert_eq!(commit.hash, "commit1");
assert_eq!(commit.message, "Initial commit");
assert_eq!(commit.parent_hash, None);
let commit_files = db.get_commit_files("commit1").unwrap();
assert_eq!(commit_files.len(), 2);
assert!(commit_files.contains(&("file1.txt".to_string(), "hash1".to_string(), 100)));
assert!(commit_files.contains(&("file2.txt".to_string(), "hash2".to_string(), 200)));
let head = db.get_head().unwrap();
assert_eq!(head, Some("commit1".to_string()));
std::thread::sleep(std::time::Duration::from_millis(10));
let files2 = vec![
("file1.txt".to_string(), "hash1_updated".to_string(), 150),
("file3.txt".to_string(), "hash3".to_string(), 300),
];
db.create_commit("commit2", "Second commit", Some("commit1"), &files2)
.unwrap();
let head = db.get_head().unwrap();
println!("Debug: HEAD = {:?}", head);
println!("Debug: Expected = {:?}", Some("commit2".to_string()));
assert_eq!(head, Some("commit2".to_string()));
let history = db.get_commit_history(None).unwrap();
assert_eq!(history.len(), 2);
assert!(history.iter().any(|c| c.hash == "commit1"));
assert!(history.iter().any(|c| c.hash == "commit2"));
assert_eq!(history[0].hash, "commit2");
}
}