use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::RwLock;
use crate::error::EnvelopeError;
use crate::models::{Payee, PayeeId};
use super::file_io::{read_json, write_json_atomic};
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
struct PayeeData {
payees: Vec<Payee>,
}
pub struct PayeeRepository {
path: PathBuf,
data: RwLock<HashMap<PayeeId, Payee>>,
by_name: RwLock<HashMap<String, PayeeId>>,
}
impl PayeeRepository {
pub fn new(path: PathBuf) -> Self {
Self {
path,
data: RwLock::new(HashMap::new()),
by_name: RwLock::new(HashMap::new()),
}
}
pub fn load(&self) -> Result<(), EnvelopeError> {
let file_data: PayeeData = read_json(&self.path)?;
let mut data = self
.data
.write()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire write lock: {}", e)))?;
let mut by_name = self
.by_name
.write()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire write lock: {}", e)))?;
data.clear();
by_name.clear();
for payee in file_data.payees {
let normalized = Payee::normalize_name(&payee.name);
by_name.insert(normalized, payee.id);
data.insert(payee.id, payee);
}
Ok(())
}
pub fn save(&self) -> Result<(), EnvelopeError> {
let data = self
.data
.read()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire read lock: {}", e)))?;
let mut payees: Vec<_> = data.values().cloned().collect();
payees.sort_by(|a, b| a.name.to_lowercase().cmp(&b.name.to_lowercase()));
let file_data = PayeeData { payees };
write_json_atomic(&self.path, &file_data)
}
pub fn get(&self, id: PayeeId) -> Result<Option<Payee>, EnvelopeError> {
let data = self
.data
.read()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire read lock: {}", e)))?;
Ok(data.get(&id).cloned())
}
pub fn get_all(&self) -> Result<Vec<Payee>, EnvelopeError> {
let data = self
.data
.read()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire read lock: {}", e)))?;
let mut payees: Vec<_> = data.values().cloned().collect();
payees.sort_by(|a, b| a.name.to_lowercase().cmp(&b.name.to_lowercase()));
Ok(payees)
}
pub fn get_by_name(&self, name: &str) -> Result<Option<Payee>, EnvelopeError> {
let data = self
.data
.read()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire read lock: {}", e)))?;
let by_name = self
.by_name
.read()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire read lock: {}", e)))?;
let normalized = Payee::normalize_name(name);
if let Some(&id) = by_name.get(&normalized) {
Ok(data.get(&id).cloned())
} else {
Ok(None)
}
}
pub fn search(&self, query: &str, limit: usize) -> Result<Vec<Payee>, EnvelopeError> {
let data = self
.data
.read()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire read lock: {}", e)))?;
let mut scored: Vec<_> = data
.values()
.map(|p| (p.clone(), p.similarity_score(query)))
.filter(|(_, score)| *score > 0.3)
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(scored.into_iter().take(limit).map(|(p, _)| p).collect())
}
pub fn get_or_create(&self, name: &str) -> Result<Payee, EnvelopeError> {
if let Some(payee) = self.get_by_name(name)? {
return Ok(payee);
}
let payee = Payee::new(name);
self.upsert(payee.clone())?;
Ok(payee)
}
pub fn upsert(&self, payee: Payee) -> Result<(), EnvelopeError> {
let mut data = self
.data
.write()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire write lock: {}", e)))?;
let mut by_name = self
.by_name
.write()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire write lock: {}", e)))?;
if let Some(old) = data.get(&payee.id) {
let old_normalized = Payee::normalize_name(&old.name);
by_name.remove(&old_normalized);
}
let normalized = Payee::normalize_name(&payee.name);
by_name.insert(normalized, payee.id);
data.insert(payee.id, payee);
Ok(())
}
pub fn delete(&self, id: PayeeId) -> Result<bool, EnvelopeError> {
let mut data = self
.data
.write()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire write lock: {}", e)))?;
let mut by_name = self
.by_name
.write()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire write lock: {}", e)))?;
if let Some(payee) = data.remove(&id) {
let normalized = Payee::normalize_name(&payee.name);
by_name.remove(&normalized);
Ok(true)
} else {
Ok(false)
}
}
pub fn count(&self) -> Result<usize, EnvelopeError> {
let data = self
.data
.read()
.map_err(|e| EnvelopeError::Storage(format!("Failed to acquire read lock: {}", e)))?;
Ok(data.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_repo() -> (TempDir, PayeeRepository) {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("payees.json");
let repo = PayeeRepository::new(path);
(temp_dir, repo)
}
#[test]
fn test_empty_load() {
let (_temp_dir, repo) = create_test_repo();
repo.load().unwrap();
assert_eq!(repo.count().unwrap(), 0);
}
#[test]
fn test_upsert_and_get() {
let (_temp_dir, repo) = create_test_repo();
repo.load().unwrap();
let payee = Payee::new("Test Store");
let id = payee.id;
repo.upsert(payee).unwrap();
let retrieved = repo.get(id).unwrap().unwrap();
assert_eq!(retrieved.name, "Test Store");
}
#[test]
fn test_get_by_name() {
let (_temp_dir, repo) = create_test_repo();
repo.load().unwrap();
repo.upsert(Payee::new("Grocery Store")).unwrap();
let found = repo.get_by_name("grocery store").unwrap();
assert!(found.is_some());
assert_eq!(found.unwrap().name, "Grocery Store");
let not_found = repo.get_by_name("other store").unwrap();
assert!(not_found.is_none());
}
#[test]
fn test_get_or_create() {
let (_temp_dir, repo) = create_test_repo();
repo.load().unwrap();
let p1 = repo.get_or_create("New Store").unwrap();
assert_eq!(p1.name, "New Store");
assert_eq!(repo.count().unwrap(), 1);
let p2 = repo.get_or_create("new store").unwrap();
assert_eq!(p1.id, p2.id);
assert_eq!(repo.count().unwrap(), 1);
}
#[test]
fn test_search() {
let (_temp_dir, repo) = create_test_repo();
repo.load().unwrap();
repo.upsert(Payee::new("Grocery Store")).unwrap();
repo.upsert(Payee::new("Gas Station")).unwrap();
repo.upsert(Payee::new("Restaurant")).unwrap();
let results = repo.search("groc", 10).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].name, "Grocery Store");
let results2 = repo.search("st", 10).unwrap();
assert!(results2.len() >= 2);
}
#[test]
fn test_save_and_reload() {
let (temp_dir, repo) = create_test_repo();
repo.load().unwrap();
let payee = Payee::new("Test Store");
let id = payee.id;
repo.upsert(payee).unwrap();
repo.save().unwrap();
let path = temp_dir.path().join("payees.json");
let repo2 = PayeeRepository::new(path);
repo2.load().unwrap();
let retrieved = repo2.get(id).unwrap().unwrap();
assert_eq!(retrieved.name, "Test Store");
}
#[test]
fn test_delete() {
let (_temp_dir, repo) = create_test_repo();
repo.load().unwrap();
let payee = Payee::new("Test Store");
let id = payee.id;
repo.upsert(payee).unwrap();
assert_eq!(repo.count().unwrap(), 1);
repo.delete(id).unwrap();
assert_eq!(repo.count().unwrap(), 0);
let not_found = repo.get_by_name("Test Store").unwrap();
assert!(not_found.is_none());
}
}