use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use super::error::tool_error;
use crate::error::NikaError;
use crate::media::{CasStore, MediaBudget};
pub struct MediaToolContext {
pub cas: CasStore,
pub budget: Arc<MediaBudget>,
pub compute: Arc<ComputePool>,
pub working_memory: Arc<WorkingMemoryBudget>,
pub cancel: CancellationToken,
}
impl MediaToolContext {
pub fn new(cas: CasStore) -> Self {
Self {
cas,
budget: Arc::new(MediaBudget::new()),
compute: Arc::new(ComputePool::new()),
working_memory: Arc::new(WorkingMemoryBudget::new()),
cancel: CancellationToken::new(),
}
}
pub async fn read_media(&self, hash: &str) -> Result<Vec<u8>, NikaError> {
self.cas.read(hash).await.map_err(|e| e.into())
}
pub async fn store_media(
&self,
data: &[u8],
task_id: &str,
) -> Result<crate::media::store::StoreResult, NikaError> {
self.budget
.check_and_add(data.len() as u64, task_id)
.map_err(|e| -> NikaError { e.into() })?;
match self.cas.store(data).await {
Ok(result) => Ok(result),
Err(e) => {
self.budget.rollback(data.len() as u64);
Err(e.into())
}
}
}
pub fn check_cancelled(&self) -> Result<(), NikaError> {
if self.cancel.is_cancelled() {
Err(tool_error("media", "workflow cancelled"))
} else {
Ok(())
}
}
}
pub struct ComputePool {
pool: rayon::ThreadPool,
}
impl ComputePool {
pub fn new() -> Self {
Self {
pool: rayon::ThreadPoolBuilder::new()
.num_threads(num_cpus().min(4))
.thread_name(|idx| format!("nika-media-{idx}"))
.panic_handler(|info| {
tracing::error!("media compute thread panicked: {info:?}");
})
.build()
.expect("failed to create media compute pool"),
}
}
pub async fn compute<F, T>(&self, f: F) -> Result<T, NikaError>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = tokio::sync::oneshot::channel();
self.pool.spawn(move || {
let _ = tx.send(f());
});
rx.await
.map_err(|_| tool_error("compute", "task panicked on rayon thread"))
}
}
impl Default for ComputePool {
fn default() -> Self {
Self::new()
}
}
fn num_cpus() -> usize {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(2)
}
#[derive(Debug)]
pub struct WorkingMemoryBudget {
used: AtomicUsize,
max_bytes: usize,
}
impl WorkingMemoryBudget {
pub const DEFAULT_MAX: usize = 512 * 1024 * 1024;
pub fn new() -> Self {
Self {
used: AtomicUsize::new(0),
max_bytes: Self::DEFAULT_MAX,
}
}
pub fn with_max(max_bytes: usize) -> Self {
Self {
used: AtomicUsize::new(0),
max_bytes,
}
}
pub fn acquire(&self, size: usize) -> Result<WorkingMemoryGuard<'_>, NikaError> {
let result = self
.used
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
let new_total = current + size;
if new_total > self.max_bytes {
None
} else {
Some(new_total)
}
});
match result {
Ok(_) => Ok(WorkingMemoryGuard { budget: self, size }),
Err(current) => Err(tool_error(
"memory",
format!(
"working memory exhausted ({} + {} > {} limit)",
current, size, self.max_bytes
),
)),
}
}
pub fn current(&self) -> usize {
self.used.load(Ordering::Acquire)
}
}
impl Default for WorkingMemoryBudget {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct WorkingMemoryGuard<'a> {
budget: &'a WorkingMemoryBudget,
size: usize,
}
impl<'a> Drop for WorkingMemoryGuard<'a> {
fn drop(&mut self) {
let _ = self
.budget
.used
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
Some(current.saturating_sub(self.size))
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn compute_pool_executes_on_rayon_thread() {
let pool = ComputePool::new();
let thread_name = pool
.compute(|| {
std::thread::current()
.name()
.unwrap_or("unknown")
.to_string()
})
.await
.unwrap();
assert!(
thread_name.starts_with("nika-media"),
"expected nika-media thread, got: {thread_name}"
);
}
#[tokio::test]
async fn compute_pool_returns_result() {
let pool = ComputePool::new();
let result = pool.compute(|| 2 + 2).await.unwrap();
assert_eq!(result, 4);
}
#[tokio::test]
async fn compute_pool_handles_panic() {
let pool = ComputePool::new();
let result: Result<i32, _> = pool
.compute(|| {
panic!("intentional test panic");
})
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("panicked"));
}
#[test]
fn working_memory_acquire_release() {
let budget = WorkingMemoryBudget::with_max(1024);
assert_eq!(budget.current(), 0);
{
let _guard = budget.acquire(100).unwrap();
assert_eq!(budget.current(), 100);
}
assert_eq!(budget.current(), 0);
}
#[test]
fn working_memory_blocks_when_full() {
let budget = WorkingMemoryBudget::with_max(512);
let _guard = budget.acquire(512).unwrap();
assert_eq!(budget.current(), 512);
let result = budget.acquire(1);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("working memory exhausted"));
}
#[test]
fn working_memory_multiple_guards() {
let budget = WorkingMemoryBudget::with_max(300);
let g1 = budget.acquire(100).unwrap();
let g2 = budget.acquire(100).unwrap();
assert_eq!(budget.current(), 200);
drop(g1);
assert_eq!(budget.current(), 100);
drop(g2);
assert_eq!(budget.current(), 0);
}
#[tokio::test]
async fn context_check_cancelled_ok() {
let dir = tempfile::tempdir().unwrap();
let ctx = MediaToolContext::new(CasStore::new(dir.path()));
assert!(ctx.check_cancelled().is_ok());
}
#[tokio::test]
async fn context_check_cancelled_err() {
let dir = tempfile::tempdir().unwrap();
let ctx = MediaToolContext::new(CasStore::new(dir.path()));
ctx.cancel.cancel();
assert!(ctx.check_cancelled().is_err());
}
#[tokio::test]
async fn context_store_charges_budget() {
let dir = tempfile::tempdir().unwrap();
let ctx = MediaToolContext::new(CasStore::new(dir.path()));
let data = b"test media data";
let result = ctx.store_media(data, "test_task").await;
assert!(result.is_ok());
assert_eq!(ctx.budget.current_bytes(), data.len() as u64);
}
#[tokio::test]
async fn context_read_missing_hash() {
let dir = tempfile::tempdir().unwrap();
let ctx = MediaToolContext::new(CasStore::new(dir.path()));
let result = ctx
.read_media("blake3:0000000000000000000000000000000000000000000000000000000000000000")
.await;
assert!(result.is_err());
}
}