use serde_json;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use thiserror::Error;
use tracing::info;
use crate::store::{ContentRef, ContentStore, Label};
use anyhow::Result;
pub mod template;
#[derive(Error, Debug)]
pub enum ReferenceError {
#[error("Failed to resolve reference: {0}")]
ResolveError(String),
}
#[derive(Default)]
pub struct ResourceCache {
entries: RwLock<HashMap<String, Arc<Vec<u8>>>>,
}
impl ResourceCache {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.entries
.read()
.expect("resource cache lock poisoned")
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub async fn resolve_reference_cached(
reference: &str,
cache: &ResourceCache,
) -> Result<(Arc<Vec<u8>>, bool), ReferenceError> {
if let Some(hit) = cache
.entries
.read()
.expect("resource cache lock poisoned")
.get(reference)
.cloned()
{
return Ok((hit, true));
}
let bytes = resolve_reference(reference).await?;
let arc = Arc::new(bytes);
cache
.entries
.write()
.expect("resource cache lock poisoned")
.insert(reference.to_string(), arc.clone());
Ok((arc, false))
}
pub async fn resolve_reference(reference: &str) -> Result<Vec<u8>, ReferenceError> {
info!("Resolving reference: {}", reference);
if reference.starts_with("store://") {
let parts: Vec<&str> = reference.split('/').collect();
if parts.len() < 3 {
return Err(ReferenceError::ResolveError(format!(
"Invalid store path: {}",
reference
)));
}
let store_id = parts[2];
let store = ContentStore::from_id(store_id);
if parts.len() >= 5 && parts[3] == "hash" {
let hash = parts[4];
let content_ref = ContentRef::from_str(hash);
info!("Resolving store path with hash: {}", hash);
store
.get(&content_ref)
.await
.map_err(|e| ReferenceError::ResolveError(e.to_string()))
} else if parts.len() >= 4 {
let label_string = parts[3];
let label = Label::new(label_string.to_string());
info!("Resolving store path with label: {}", label);
match store.get_content_by_label(&label).await {
Ok(result) => match result {
Some(content) => Ok(content),
None => Err(ReferenceError::ResolveError(format!(
"Label not found: {}",
label
))),
},
Err(e) => Err(ReferenceError::ResolveError(e.to_string())),
}
} else {
Err(ReferenceError::ResolveError(format!(
"Invalid store path format: {}",
reference
)))
}
} else if reference.starts_with("http://") || reference.starts_with("https://") {
info!("Fetching from URL: {}", reference);
let client = reqwest::Client::new();
match client.get(reference).send().await {
Ok(response) => {
if response.status().is_success() {
match response.bytes().await {
Ok(bytes) => Ok(bytes.to_vec()),
Err(e) => Err(ReferenceError::ResolveError(format!(
"Failed to read response body from {}: {}",
reference, e
))),
}
} else {
Err(ReferenceError::ResolveError(format!(
"HTTP request failed for {}: {} {}",
reference,
response.status().as_u16(),
response
.status()
.canonical_reason()
.unwrap_or("Unknown error")
)))
}
}
Err(e) => Err(ReferenceError::ResolveError(format!(
"Failed to fetch from {}: {}",
reference, e
))),
}
} else {
info!("Reading from file path: {}", reference);
tokio::fs::read(reference)
.await
.map_err(|e| ReferenceError::ResolveError(e.to_string()))
}
}
pub fn merge_initial_states(
config_state: Option<Vec<u8>>,
override_state: Option<Vec<u8>>,
) -> Result<Option<Vec<u8>>> {
match (config_state, override_state) {
(None, None) => Ok(None),
(Some(state), None) => Ok(Some(state)),
(None, Some(state)) => Ok(Some(state)),
(Some(config_state), Some(override_state)) => {
let config_json_result = serde_json::from_slice(&config_state);
let override_json_result = serde_json::from_slice(&override_state);
match (config_json_result, override_json_result) {
(Ok(mut config_json), Ok(override_json)) => {
if let (
serde_json::Value::Object(ref mut config_map),
serde_json::Value::Object(override_map),
) = (&mut config_json, &override_json)
{
for (key, value) in override_map {
config_map.insert(key.clone(), value.clone());
}
Ok(Some(serde_json::to_vec(&config_json)?))
} else {
info!("Either initial state is not a JSON object, using override state");
Ok(Some(override_state))
}
}
_ => {
info!(
"Failed to parse one of the initial states as JSON, using override state"
);
Ok(Some(override_state))
}
}
}
}
}
pub fn get_theater_home() -> String {
std::env::var("THEATER_HOME").unwrap_or_else(|_| {
format!(
"{}/{}",
std::env::var("HOME").unwrap_or_default(),
".theater"
)
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[tokio::test]
async fn resource_cache_hits_on_repeat_reference() {
let mut tmp = tempfile::NamedTempFile::new().expect("tempfile");
tmp.write_all(b"hello cache").expect("write");
let path = tmp.path().to_string_lossy().to_string();
let cache = ResourceCache::new();
assert_eq!(cache.len(), 0);
let (bytes, hit) = resolve_reference_cached(&path, &cache).await.unwrap();
assert_eq!(&**bytes, b"hello cache");
assert!(!hit, "first fetch should be a miss");
assert_eq!(cache.len(), 1);
let (bytes2, hit2) = resolve_reference_cached(&path, &cache).await.unwrap();
assert_eq!(&**bytes2, b"hello cache");
assert!(hit2, "second fetch should be a hit");
assert_eq!(cache.len(), 1);
assert!(Arc::ptr_eq(&bytes, &bytes2));
}
#[tokio::test]
async fn resource_cache_concurrent_misses_both_fetch_safely() {
let mut tmp = tempfile::NamedTempFile::new().expect("tempfile");
tmp.write_all(b"raced bytes").expect("write");
let path = tmp.path().to_string_lossy().to_string();
let cache = Arc::new(ResourceCache::new());
let path_a = path.clone();
let path_b = path.clone();
let cache_a = cache.clone();
let cache_b = cache.clone();
let (a, b) = tokio::join!(
tokio::spawn(async move { resolve_reference_cached(&path_a, &cache_a).await }),
tokio::spawn(async move { resolve_reference_cached(&path_b, &cache_b).await }),
);
let (bytes_a, _) = a.unwrap().unwrap();
let (bytes_b, _) = b.unwrap().unwrap();
assert_eq!(&**bytes_a, b"raced bytes");
assert_eq!(&**bytes_b, b"raced bytes");
assert_eq!(cache.len(), 1);
}
#[tokio::test]
async fn resource_cache_keys_by_reference_string() {
let mut tmp_a = tempfile::NamedTempFile::new().expect("a");
tmp_a.write_all(b"A").expect("write");
let mut tmp_b = tempfile::NamedTempFile::new().expect("b");
tmp_b.write_all(b"B").expect("write");
let cache = ResourceCache::new();
let (a, _) = resolve_reference_cached(&tmp_a.path().to_string_lossy(), &cache)
.await
.unwrap();
let (b, _) = resolve_reference_cached(&tmp_b.path().to_string_lossy(), &cache)
.await
.unwrap();
assert_eq!(&**a, b"A");
assert_eq!(&**b, b"B");
assert_eq!(
cache.len(),
2,
"different references must be distinct entries"
);
}
}