use crate::crypto::CryptoUtils;
use crate::custom_error::KSMRError;
use crate::dto::{EncryptedPayload, KsmHttpResponse, TransmissionKey};
use log::{debug, warn};
use reqwest::blocking::Client;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use std::env;
use std::fs::{File, OpenOptions};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::str::FromStr;
const DEFAULT_CACHE_FILE: &str = "ksm_cache.bin";
pub fn get_cache_file_path() -> PathBuf {
let cache_dir = env::var("KSM_CACHE_DIR").unwrap_or_else(|_| ".".to_string());
Path::new(&cache_dir).join(DEFAULT_CACHE_FILE)
}
pub fn save_cache(data: &[u8]) -> Result<(), KSMRError> {
let cache_path = get_cache_file_path();
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&cache_path)
.map_err(|e| KSMRError::CacheSaveError(format!("Failed to open cache file: {}", e)))?;
file.write_all(data)
.map_err(|e| KSMRError::CacheSaveError(format!("Failed to write cache: {}", e)))?;
file.sync_all()
.map_err(|e| KSMRError::CacheSaveError(format!("Failed to sync cache: {}", e)))?;
debug!("Cache saved to {:?}", cache_path);
Ok(())
}
pub fn get_cached_data() -> Option<Vec<u8>> {
let cache_path = get_cache_file_path();
if !cache_path.exists() {
return None;
}
let mut file = File::open(&cache_path).ok()?;
let mut data = Vec::new();
file.read_to_end(&mut data).ok()?;
debug!("Cache loaded from {:?}", cache_path);
Some(data)
}
pub fn clear_cache() -> Result<(), KSMRError> {
let cache_path = get_cache_file_path();
if cache_path.exists() {
std::fs::remove_file(&cache_path)
.map_err(|e| KSMRError::CacheRetrieveError(format!("Failed to delete cache: {}", e)))?;
}
Ok(())
}
pub fn cache_exists() -> bool {
get_cache_file_path().exists()
}
pub fn caching_post_function(
url: String,
transmission_key: TransmissionKey,
encrypted_payload: EncryptedPayload,
) -> Result<KsmHttpResponse, KSMRError> {
let proxy_url = std::env::var("KSM_PROXY_URL").ok();
match make_http_request(url, transmission_key.clone(), encrypted_payload, proxy_url) {
Ok(response) if response.status_code == 200 => {
let mut cache_data = transmission_key.key.clone();
cache_data.extend_from_slice(&response.data);
if let Err(e) = save_cache(&cache_data) {
warn!("Failed to save cache: {}", e);
}
Ok(response)
}
Ok(response) => {
Ok(response)
}
Err(network_error) => {
warn!(
"Network request failed: {}, attempting to use cached data",
network_error
);
if let Some(cached_data) = get_cached_data() {
if cached_data.len() > 32 {
let cached_transmission_key = cached_data[0..32].to_vec();
let cached_response_data = cached_data[32..].to_vec();
debug!("Using cached data ({} bytes)", cached_response_data.len());
let decrypted_data =
CryptoUtils::decrypt_aes(&cached_response_data, &cached_transmission_key)
.map_err(|e| {
warn!("Failed to decrypt cached data: {}", e);
KSMRError::CryptoError(format!("Cache decryption failed: {}", e))
})?;
let re_encrypted_data =
CryptoUtils::encrypt_aes_gcm(&decrypted_data, &transmission_key.key, None)?;
debug!(
"Successfully decrypted cached data and re-encrypted with current transmission key"
);
return Ok(KsmHttpResponse {
status_code: 200,
data: re_encrypted_data,
http_response: Some("Cached response (re-encrypted)".to_string()),
});
}
}
Err(network_error)
}
}
}
fn make_http_request(
url: String,
transmission_key: TransmissionKey,
encrypted_payload: EncryptedPayload,
proxy_url: Option<String>,
) -> Result<KsmHttpResponse, KSMRError> {
let mut client_builder = Client::builder();
if let Some(ref proxy) = proxy_url {
if let Ok(p) = reqwest::Proxy::all(proxy) {
client_builder = client_builder.proxy(p);
}
}
let client = client_builder
.build()
.map_err(|e| KSMRError::HTTPError(format!("Failed to build client: {}", e)))?;
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_str("Content-Type").unwrap(),
HeaderValue::from_str("application/octet-stream").unwrap(),
);
headers.insert(
HeaderName::from_str("PublicKeyId").unwrap(),
HeaderValue::from_str(&transmission_key.public_key_id).unwrap(),
);
headers.insert(
HeaderName::from_str("TransmissionKey").unwrap(),
HeaderValue::from_str(&crate::utils::bytes_to_base64(
&transmission_key.encrypted_key,
))
.unwrap(),
);
headers.insert(
HeaderName::from_str("Authorization").unwrap(),
HeaderValue::from_str(&format!(
"Signature {}",
crate::utils::bytes_to_base64(&encrypted_payload.signature.to_bytes())
))
.unwrap(),
);
let response = client
.post(&url)
.headers(headers)
.body(encrypted_payload.encrypted_payload.clone())
.send()
.map_err(|e| KSMRError::HTTPError(format!("HTTP request failed: {}", e)))?;
let status_code = response.status().as_u16();
let response_body = response
.bytes()
.map_err(|e| KSMRError::HTTPError(format!("Failed to read response: {}", e)))?
.to_vec();
Ok(KsmHttpResponse {
status_code,
data: response_body,
http_response: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_file_path() {
let path = get_cache_file_path();
assert!(path.to_str().unwrap().contains("ksm_cache.bin"));
}
#[test]
fn test_cache_operations() {
let _ = clear_cache();
assert!(!cache_exists());
assert!(get_cached_data().is_none());
let test_data = b"test cache data";
save_cache(test_data).ok();
assert!(cache_exists());
let loaded = get_cached_data();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap(), test_data);
clear_cache().ok();
assert!(!cache_exists());
}
}