use autonomi::client::merkle_payments::MerklePaymentReceipt;
use color_eyre::eyre::{Context, Result};
use std::fs::{DirEntry, File};
use std::io::{BufReader, BufWriter};
use std::path::PathBuf;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
const PAYMENT_EXPIRATION_SECS: u64 = 3600 * 24 * 7;
pub fn get_payments_dir() -> Result<PathBuf> {
let dir = super::data_dir::get_client_data_dir_path()?;
let payments_dir = dir.join("payments");
std::fs::create_dir_all(&payments_dir)
.wrap_err("Could not create cached payments directory")?;
Ok(payments_dir)
}
pub fn save_merkle_payment(file: &str, receipt: &MerklePaymentReceipt) -> Result<()> {
let dir = get_payments_dir()?;
let timestamp = get_timestamp_from_merkle_receipt(receipt);
let file_hash = filename_short(file);
let file_path = dir.join(format!("{timestamp}_{file_hash}"));
let file_handle = File::create(&file_path)?;
let writer = BufWriter::new(&file_handle);
serde_json::to_writer(writer, &receipt)?;
println!(
"Cached Merkle payment for {file:?} to {}",
file_path.display()
);
Ok(())
}
pub fn load_merkle_payment_for_file(file_name: &str) -> Result<Option<MerklePaymentReceipt>> {
cleanup_outdated_payments()?;
let dir = get_payments_dir()?;
let file_hash = filename_short(file_name);
let files = std::fs::read_dir(dir)?;
for file in files {
if let Some(path) = matches_filename(file.ok(), &file_hash) {
let file_handle = File::open(path)?;
let reader = BufReader::new(file_handle);
let receipt: MerklePaymentReceipt = serde_json::from_reader(reader)?;
println!("Found cached Merkle payment for {file_name}");
return Ok(Some(receipt));
}
}
Ok(None)
}
fn cleanup_outdated_payments() -> Result<()> {
let dir = get_payments_dir()?;
let files = std::fs::read_dir(dir)?;
let expired_files = files.into_iter().filter_map(|file| {
let path = file.ok()?.path();
let file_name = path.file_name()?.to_str()?;
if is_expired_file(file_name) {
Some(path)
} else {
None
}
});
for file in expired_files {
println!("Removing expired cached payment file: {}", file.display());
std::fs::remove_file(file)?;
}
Ok(())
}
fn matches_filename(file: Option<DirEntry>, file_hash: &str) -> Option<PathBuf> {
let path = file?.path();
if !path.is_file() {
return None;
}
let file_name = path.file_name()?;
let file_name = file_name.to_str()?;
if file_name.contains(file_hash) {
Some(path)
} else {
None
}
}
fn filename_short(filename: &str) -> String {
if filename.len() > 32 || filename.contains("/") || filename.contains("\\") {
sha256::digest(filename)
} else {
filename.to_string()
}
}
fn is_expired_file(filename: &str) -> bool {
let exp = PAYMENT_EXPIRATION_SECS;
let expired_if_before = SystemTime::now() - Duration::from_secs(exp);
let timestr = filename.split('_').next().unwrap_or_default();
let sec = timestr.parse::<u64>().unwrap_or_default();
let timestamp = SystemTime::UNIX_EPOCH + Duration::from_secs(sec);
timestamp < expired_if_before
}
fn now() -> String {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
timestamp.to_string()
}
fn get_timestamp_from_merkle_receipt(receipt: &MerklePaymentReceipt) -> String {
if let Some(proof) = receipt.proofs.values().next() {
return proof
.winner_pool
.midpoint_proof
.merkle_payment_timestamp
.to_string();
}
now()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_expired_filename() {
let just_expired = (SystemTime::now() - Duration::from_secs(PAYMENT_EXPIRATION_SECS))
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
.to_string();
let just_expired_1 = (SystemTime::now() - Duration::from_secs(PAYMENT_EXPIRATION_SECS + 1))
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
.to_string();
let not_expired = now();
let not_expired_1 = (SystemTime::now() + Duration::from_secs(PAYMENT_EXPIRATION_SECS - 1))
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
.to_string();
let file_hash = filename_short("test");
assert!(is_expired_file(&format!("{just_expired}_{file_hash}")));
assert!(is_expired_file(&format!("{just_expired_1}_{file_hash}")));
assert!(!is_expired_file(&format!("{not_expired}_{file_hash}")));
assert!(!is_expired_file(&format!("{not_expired_1}_{file_hash}")));
}
#[test]
fn test_cleanup_with_full_paths() -> Result<()> {
use std::fs::File;
use tempfile::TempDir;
let temp_dir = TempDir::new()?;
let temp_path = temp_dir.path();
let file_hash = filename_short("test");
let expired_timestamp = (SystemTime::now()
- Duration::from_secs(PAYMENT_EXPIRATION_SECS + 86400))
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let expired_filename = format!("{expired_timestamp}_{file_hash}");
let expired_path = temp_path.join(&expired_filename);
File::create(&expired_path)?;
let fresh_timestamp = (SystemTime::now() - Duration::from_secs(86400))
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let fresh_filename = format!("{fresh_timestamp}_{file_hash}");
let fresh_path = temp_path.join(&fresh_filename);
File::create(&fresh_path)?;
assert!(
expired_path.exists(),
"Expired file should exist before cleanup"
);
assert!(
fresh_path.exists(),
"Fresh file should exist before cleanup"
);
let files = std::fs::read_dir(temp_path)?;
let expired_files: Vec<PathBuf> = files
.into_iter()
.filter_map(|file| {
let path = file.ok()?.path();
let file_name = path.file_name()?.to_str()?;
if is_expired_file(file_name) {
Some(path)
} else {
None
}
})
.collect();
assert_eq!(
expired_files.len(),
1,
"Should find exactly one expired file"
);
assert_eq!(
expired_files.first().expect("Should have one expired file"),
&expired_path,
"Should identify the correct expired file"
);
assert!(
is_expired_file(&expired_filename),
"Expired filename should be marked as expired"
);
assert!(
!is_expired_file(&fresh_filename),
"Fresh filename should not be marked as expired"
);
let full_path_str = expired_path.to_str().unwrap();
assert!(
is_expired_file(full_path_str),
"Full path causes incorrect expiration check (this is the bug we fixed)"
);
Ok(())
}
}