use crate::types::{VersionMeta, VersionSelector};
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Key, Nonce,
};
use anyhow::{Context, Result};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use similar::{Algorithm, DiffOp, TextDiff};
use std::{collections::HashMap, fs, path::PathBuf};
use std::{io::Read, path::Path};
#[derive(Clone)]
pub struct PromptVault {
db: sled::Db,
}
impl PromptVault {
pub fn restore_or_default(input_path: &str, password: Option<&str>) -> Result<Self> {
let input = Path::new(input_path);
if input.exists() {
println!("🔄 Found vault file at '{}', restoring...", input.display());
Self::restore(input_path, password)
} else {
println!(
"⚠️ Vault file '{}' not found — opening default vault instead.",
input.display()
);
Self::open_default().map_err(|e| anyhow::anyhow!("Failed to open default vault: {}", e))
}
}
pub fn open_or_default<P: AsRef<Path>>(path: P) -> Result<Self> {
let path_ref = path.as_ref();
match Self::open(path_ref) {
Ok(vault) => Ok(vault),
Err(e) => {
eprintln!(
"⚠️ Failed to open vault at {:?}: {}. Falling back to default vault...",
path_ref, e
);
Self::open_default().with_context(|| {
format!(
"Failed to open both specified vault {:?} and default vault",
path_ref
)
})
}
}
}
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let db = sled::open(path)?;
Ok(PromptVault { db })
}
pub fn open_default() -> Result<Self> {
let home_dir = std::env::var("HOME")?;
let path = std::path::PathBuf::from(home_dir)
.join(".promptpro")
.join("default_vault");
std::fs::create_dir_all(&path)?;
Self::open(path)
}
pub fn add(&self, key: &str, content: &str) -> Result<()> {
if self.get_latest_version_number(key)?.is_some() {
return Err(anyhow::anyhow!("Prompt with key '{}' already exists", key));
}
let version_meta = VersionMeta::new(key.to_string(), 1, content, None, None);
self.store_version(&version_meta, content, None)?;
Ok(())
}
pub fn update(&self, key: &str, content: &str, message: Option<String>) -> Result<()> {
let latest_version = self.get_latest_version_number(key)?;
let parent_version = match latest_version {
Some(v) => v,
None => return Err(anyhow::anyhow!("Prompt with key '{}' does not exist", key)),
};
let current_content = self.get_content(&key, &VersionSelector::Version(parent_version))?;
if current_content == content {
return Err(anyhow::anyhow!("No changes detected in content"));
}
let new_version = parent_version + 1;
let snapshot = true; let diff_content = None;
let mut version_meta = VersionMeta::new(
key.to_string(),
new_version,
content,
Some(parent_version),
message,
);
version_meta.snapshot = snapshot;
self.store_version(&version_meta, content, diff_content)?;
let _ = self.tag(key, "dev", new_version);
Ok(())
}
pub fn get(&self, key: &str, selector: VersionSelector) -> Result<String> {
let version_number = match selector {
VersionSelector::Latest => self
.get_latest_version_number(key)?
.ok_or_else(|| anyhow::anyhow!("No versions found for key '{}'", key))?,
VersionSelector::Version(v) => v,
VersionSelector::Tag(tag) => self
.get_version_by_tag(key, tag)?
.ok_or_else(|| anyhow::anyhow!("Tag '{}' not found for key '{}'", tag, key))?,
VersionSelector::Time(time) => {
self.get_version_by_time(key, time)?.ok_or_else(|| {
anyhow::anyhow!("No version found for key '{}' at time {}", key, time)
})?
}
};
self.get_content(key, &VersionSelector::Version(version_number))
}
pub fn history(&self, key: &str) -> Result<Vec<VersionMeta>> {
let mut versions = Vec::new();
let prefix = format!("version:{}:", key);
for result in self.db.scan_prefix(prefix.as_bytes()) {
let (_key, value) = result?;
let version_meta: VersionMeta = bincode::deserialize(&value)?;
versions.push(version_meta);
}
versions.sort_by_key(|v| v.version);
Ok(versions)
}
pub fn tag(&self, key: &str, tag: &str, version: u64) -> Result<()> {
let version_key = format!("version:{}:{}", key, version);
if self.db.get(version_key.as_bytes())?.is_none() {
return Err(anyhow::anyhow!(
"Version {} does not exist for key '{}'",
version,
key
));
}
if tag == "dev" {
let latest_version = self
.get_latest_version_number(key)?
.ok_or_else(|| anyhow::anyhow!("No versions found for key '{}'", key))?;
if version != latest_version {
return Err(anyhow::anyhow!(
"'dev' tag can only be set to the latest version (v{})",
latest_version
));
}
}
if let Ok(Some(old_version)) = self.get_version_by_tag(key, tag) {
if old_version != version {
let mut old_version_meta =
self.get_version_meta(key, old_version)?.ok_or_else(|| {
anyhow::anyhow!("Version {} not found for key '{}'", old_version, key)
})?;
old_version_meta.tags.retain(|t| t != tag);
self.update_version_meta(&old_version_meta)?;
}
}
let tag_key = format!("tag:{}:{}", key, tag);
self.db.insert(tag_key.as_bytes(), &version.to_le_bytes())?;
let mut version_meta = self
.get_version_meta(key, version)?
.ok_or_else(|| anyhow::anyhow!("Version {} not found for key '{}'", version, key))?;
if !version_meta.tags.contains(&tag.to_string()) {
version_meta.tags.push(tag.to_string());
self.update_version_meta(&version_meta)?;
}
Ok(())
}
pub fn promote(&self, key: &str, tag: &str) -> Result<()> {
let latest_version = self
.get_latest_version_number(key)?
.ok_or_else(|| anyhow::anyhow!("No versions found for key '{}'", key))?;
self.tag(key, tag, latest_version)
}
pub fn get_latest_version_number(&self, key: &str) -> Result<Option<u64>> {
let mut versions = Vec::new();
let prefix = format!("version:{}:", key);
for result in self.db.scan_prefix(prefix.as_bytes()) {
let (_key, value) = result?;
let version_meta: VersionMeta = bincode::deserialize(&value)?;
versions.push(version_meta.version);
}
if versions.is_empty() {
Ok(None)
} else {
Ok(Some(*versions.iter().max().unwrap()))
}
}
fn get_version_by_tag(&self, key: &str, tag: &str) -> Result<Option<u64>> {
let tag_key = format!("tag:{}:{}", key, tag);
if let Some(value) = self.db.get(tag_key.as_bytes())? {
let version_bytes: [u8; 8] = value
.as_ref()
.try_into()
.map_err(|_| anyhow::anyhow!("Failed to read version from tag"))?;
let version = u64::from_le_bytes(version_bytes);
Ok(Some(version))
} else {
Ok(None)
}
}
fn get_version_by_time(
&self,
key: &str,
time: chrono::DateTime<chrono::Utc>,
) -> Result<Option<u64>> {
let mut versions = Vec::new();
let prefix = format!("version:{}:", key);
for result in self.db.scan_prefix(prefix.as_bytes()) {
let (_key, value) = result?;
let version_meta: VersionMeta = bincode::deserialize(&value)?;
versions.push(version_meta);
}
versions.retain(|v| v.timestamp <= time);
versions.sort_by_key(|v| v.version);
Ok(versions.last().map(|v| v.version))
}
fn get_content(&self, key: &str, selector: &VersionSelector) -> Result<String> {
let version = match selector {
VersionSelector::Version(v) => *v,
_ => return Err(anyhow::anyhow!("Invalid selector for content retrieval")),
};
let version_meta = self
.get_version_meta(key, version)?
.ok_or_else(|| anyhow::anyhow!("Version {} not found for key '{}'", version, key))?;
if version_meta.snapshot {
let content_key = format!("content:{}:{}", key, version);
if let Some(content_bytes) = self.db.get(content_key.as_bytes())? {
Ok(String::from_utf8(content_bytes.to_vec())?)
} else {
Err(anyhow::anyhow!(
"Content not found for key '{}', version {}, make sure key were added.",
key,
version
))
}
} else {
let diff_key = format!("diff:{}:{}", key, version);
if let Some(diff_bytes) = self.db.get(diff_key.as_bytes())? {
let diff_str = String::from_utf8(diff_bytes.to_vec())?;
let parent_version = version_meta.parent.ok_or_else(|| {
anyhow::anyhow!("Diff version {} missing parent reference", version)
})?;
let parent_content =
self.get_content(key, &VersionSelector::Version(parent_version))?;
let current_content = apply_diff(&parent_content, &diff_str)?;
Ok(current_content)
} else {
Err(anyhow::anyhow!(
"Diff not found for key '{}', version {}",
key,
version
))
}
}
}
fn store_version(
&self,
version_meta: &VersionMeta,
content: &str,
_diff_content: Option<String>,
) -> Result<()> {
let version_key = format!("version:{}:{}", version_meta.key, version_meta.version);
let meta_bytes = bincode::serialize(version_meta)?;
self.db.insert(version_key.as_bytes(), meta_bytes)?;
let content_key = format!("content:{}:{}", version_meta.key, version_meta.version);
self.db.insert(content_key.as_bytes(), content.as_bytes())?;
Ok(())
}
fn get_version_meta(&self, key: &str, version: u64) -> Result<Option<VersionMeta>> {
let version_key = format!("version:{}:{}", key, version);
if let Some(value) = self.db.get(version_key.as_bytes())? {
let version_meta: VersionMeta = bincode::deserialize(&value)?;
Ok(Some(version_meta))
} else {
Ok(None)
}
}
fn update_version_meta(&self, version_meta: &VersionMeta) -> Result<()> {
let version_key = format!("version:{}:{}", version_meta.key, version_meta.version);
let meta_bytes = bincode::serialize(version_meta)?;
self.db.insert(version_key.as_bytes(), meta_bytes)?;
Ok(())
}
pub fn db(&self) -> &sled::Db {
&self.db
}
pub fn delete_prompt_key(&self, key: &str) -> Result<()> {
let versions = self.history(key)?;
for version in &versions {
let version_key = format!("version:{}:{}", key, version.version);
self.db.remove(version_key.as_bytes())?;
let content_key = format!("content:{}:{}", key, version.version);
self.db.remove(content_key.as_bytes())?;
let diff_key = format!("diff:{}:{}", key, version.version);
self.db.remove(diff_key.as_bytes())?;
}
let tag_prefix = format!("tag:{}:", key);
for result in self.db.scan_prefix(tag_prefix.as_bytes()) {
let (tag_key, _) = result?;
self.db.remove(tag_key)?;
}
Ok(())
}
pub fn dump(&self, output_path: &str, password: Option<&str>) -> Result<()> {
use std::fs::File;
use std::io::Write;
let mut data = Vec::new();
for result in self.db.iter() {
let (key, value) = result?;
data.push((key.to_vec(), value.to_vec()));
}
let serialized_data = bincode::serialize(&data)?;
let output_data = if let Some(password) = password {
let encrypted = self.encrypt_data(&serialized_data, password)?;
let mut output = b"VAULT_ENC".to_vec(); output.extend_from_slice(&encrypted);
output
} else {
let mut output = b"VAULT_RAW".to_vec(); output.extend_from_slice(&serialized_data);
output
};
let mut file = File::create(output_path)?;
file.write_all(&output_data)?;
Ok(())
}
pub fn restore(input_path: &str, password: Option<&str>) -> Result<Self> {
let input_path = Path::new(input_path);
if !input_path.exists() {
return Err(anyhow::anyhow!(
"Vault file not found: {}",
input_path.display()
));
}
let vault_name = input_path
.file_stem()
.and_then(|s| s.to_str())
.ok_or_else(|| anyhow::anyhow!("Invalid vault filename"))?;
let home = std::env::var("HOME").map_err(|_| anyhow::anyhow!("HOME env not found"))?;
let target_path = PathBuf::from(home).join(".promptpro").join(vault_name);
if target_path.exists() {
println!(
"✅ Vault '{}' already exists — skipping restore.",
vault_name
);
return Self::open(&target_path);
}
let mut data = Vec::new();
std::fs::File::open(input_path)?.read_to_end(&mut data)?;
if data.len() < 9 {
return Err(anyhow::anyhow!("Invalid vault file: too short"));
}
let header = &data[..9];
let payload = &data[9..];
let raw = if header == b"VAULT_ENC" {
if let Some(pwd) = password {
Self::decrypt_data(payload, pwd)?
} else {
return Err(anyhow::anyhow!("Vault encrypted but no password provided"));
}
} else if header == b"VAULT_RAW" {
payload.to_vec()
} else {
return Err(anyhow::anyhow!("Invalid vault file header"));
};
let entries: Vec<(Vec<u8>, Vec<u8>)> = bincode::deserialize(&raw)
.map_err(|_| anyhow::anyhow!("Failed to deserialize vault"))?;
fs::create_dir_all(&target_path)?;
let vault = Self::open(&target_path)?;
for (k, v) in entries {
vault.db.insert(k, v)?;
}
vault.db.flush()?;
println!(
"✅ Restored vault '{}' → {}",
vault_name,
target_path.display()
);
Ok(vault)
}
fn encrypt_data(&self, data: &[u8], password: &str) -> Result<Vec<u8>> {
use blake3;
let mut salt = [0u8; 32];
rand::thread_rng().fill_bytes(&mut salt);
let mut key_bytes = [0u8; 32];
let mut hasher = blake3::Hasher::new();
hasher.update(password.as_bytes());
hasher.update(&salt);
let hash = hasher.finalize();
(&mut key_bytes).copy_from_slice(&hash.as_bytes()[..32]);
let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
let cipher = Aes256Gcm::new(key);
let mut nonce_bytes = [0u8; 12];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, data)
.map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
let mut result = Vec::new();
result.extend_from_slice(&salt);
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
fn decrypt_data(data: &[u8], password: &str) -> Result<Vec<u8>> {
use blake3;
if data.len() < 44 {
return Err(anyhow::anyhow!("Encrypted data is too short"));
}
let salt = &data[0..32];
let nonce_bytes = &data[32..44];
let ciphertext = &data[44..];
let mut key_bytes = [0u8; 32];
let mut hasher = blake3::Hasher::new();
hasher.update(password.as_bytes());
hasher.update(salt);
let hash = hasher.finalize();
(&mut key_bytes).copy_from_slice(&hash.as_bytes()[..32]);
let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
let cipher = Aes256Gcm::new(key);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
.decrypt(nonce, ciphertext.as_ref())
.map_err(|e| anyhow::anyhow!("Decryption failed: {}", e))?;
Ok(plaintext)
}
}
fn apply_diff(_old_content: &str, _diff_str: &str) -> Result<String> {
Ok("".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_vault_operations() -> Result<()> {
let dir = tempdir()?;
let vault = PromptVault::open(dir.path())?;
vault.add("test_key", "initial content")?;
let content = vault.get("test_key", VersionSelector::Latest)?;
assert_eq!(content, "initial content");
vault.update(
"test_key",
"updated content",
Some("test message".to_string()),
)?;
let content = vault.get("test_key", VersionSelector::Latest)?;
assert_eq!(content, "updated content");
let history = vault.history("test_key")?;
assert_eq!(history.len(), 2);
assert_eq!(history[0].version, 1);
assert_eq!(history[1].version, 2);
assert_eq!(history[1].message, Some("test message".to_string()));
Ok(())
}
#[test]
fn test_tagging() -> Result<()> {
let dir = tempdir()?;
let vault = PromptVault::open(dir.path())?;
vault.add("test_key", "content v1")?;
vault.update("test_key", "content v2", None)?;
vault.tag("test_key", "stable", 1)?;
let content = vault.get("test_key", VersionSelector::Tag("stable"))?;
assert_eq!(content, "content v1");
vault.promote("test_key", "stable")?;
let content = vault.get("test_key", VersionSelector::Tag("stable"))?;
assert_eq!(content, "content v2");
Ok(())
}
#[test]
fn test_dev_tag_logic() -> Result<()> {
let dir = tempdir()?;
let vault = PromptVault::open(dir.path())?;
vault.add("test_key", "content v1")?;
vault.update("test_key", "content v2", None)?;
let history = vault.history("test_key")?;
assert_eq!(history.len(), 2);
let latest_version = history.last().unwrap();
assert_eq!(latest_version.version, 2);
assert!(latest_version.tags.contains(&"dev".to_string()));
let first_version = &history[0];
assert_eq!(first_version.version, 1);
assert!(!first_version.tags.contains(&"dev".to_string()));
let result = vault.tag("test_key", "dev", 1);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("'dev' tag can only be set to the latest version"));
vault.update("test_key", "content v3", None)?;
let new_history = vault.history("test_key")?;
assert_eq!(new_history.len(), 3);
let latest_version = new_history.last().unwrap();
assert_eq!(latest_version.version, 3);
assert!(latest_version.tags.contains(&"dev".to_string()));
let second_version = &new_history[1]; assert_eq!(second_version.version, 2);
assert!(!second_version.tags.contains(&"dev".to_string()));
Ok(())
}
#[test]
fn test_dump_restore_unencrypted() -> Result<()> {
use tempfile::tempdir;
let source_dir = tempdir()?;
let _target_dir = tempdir()?;
let source_vault = PromptVault::open(source_dir.path())?;
source_vault.add("test_key", "test content")?;
source_vault.update(
"test_key",
"updated content",
Some("test update".to_string()),
)?;
source_vault.tag("test_key", "stable", 1)?;
let original_content = source_vault.get("test_key", VersionSelector::Latest)?;
assert_eq!(original_content, "updated content");
let dump_file = source_dir.path().join("test_dump.vault");
source_vault.dump(dump_file.to_str().unwrap(), None)?;
let restored_vault = PromptVault::restore(dump_file.to_str().unwrap(), None)?;
let content = restored_vault.get("test_key", VersionSelector::Latest)?;
assert_eq!(content, "updated content");
let history = restored_vault.history("test_key")?;
assert_eq!(history.len(), 2);
assert_eq!(history[0].version, 1);
assert!(history[0].tags.contains(&"stable".to_string()));
Ok(())
}
#[test]
fn test_dump_restore_encrypted() -> Result<()> {
use tempfile::tempdir;
let source_dir = tempdir()?;
let _target_dir = tempdir()?;
let source_vault = PromptVault::open(source_dir.path())?;
source_vault.add("encrypted_key", "secret content")?;
source_vault.update(
"encrypted_key",
"updated secret content",
Some("update message".to_string()),
)?;
source_vault.tag("encrypted_key", "secret", 1)?;
let dump_file = source_dir.path().join("encrypted_dump.vault");
source_vault.dump(dump_file.to_str().unwrap(), Some("mypassword"))?;
let restored_vault = PromptVault::restore(dump_file.to_str().unwrap(), Some("mypassword"))?;
let content = restored_vault.get("encrypted_key", VersionSelector::Latest)?;
assert_eq!(content, "updated secret content");
let history = restored_vault.history("encrypted_key")?;
assert_eq!(history.len(), 2);
assert!(history[0].tags.contains(&"secret".to_string()));
let result = PromptVault::restore(dump_file.to_str().unwrap(), Some("wrongpassword"));
assert!(result.is_err());
Ok(())
}
}