nika_media/tools/
context.rs1use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12
13use tokio_util::sync::CancellationToken;
14
15use super::error::tool_error;
16use super::error::MediaToolError;
17use crate::{CasStore, MediaBudget};
18
19pub struct MediaToolContext {
21 pub cas: CasStore,
23 pub budget: Arc<MediaBudget>,
25 pub compute: Arc<ComputePool>,
27 pub working_memory: Arc<WorkingMemoryBudget>,
29 pub cancel: CancellationToken,
31}
32
33impl MediaToolContext {
34 pub fn new(cas: CasStore) -> Result<Self, MediaToolError> {
38 Ok(Self {
39 cas,
40 budget: Arc::new(MediaBudget::new()),
41 compute: Arc::new(ComputePool::new()?),
42 working_memory: Arc::new(WorkingMemoryBudget::new()),
43 cancel: CancellationToken::new(),
44 })
45 }
46
47 pub async fn read_media(&self, hash: &str) -> Result<Vec<u8>, MediaToolError> {
49 self.cas.read(hash).await.map_err(|e| e.into())
50 }
51
52 pub async fn store_media(
54 &self,
55 data: &[u8],
56 task_id: &str,
57 ) -> Result<crate::store::StoreResult, MediaToolError> {
58 self.budget
60 .check_and_add(data.len() as u64, task_id)
61 .map_err(|e| -> MediaToolError { e.into() })?;
62
63 match self.cas.store(data).await {
65 Ok(result) => Ok(result),
66 Err(e) => {
67 self.budget.rollback(data.len() as u64);
69 Err(e.into())
70 }
71 }
72 }
73
74 pub fn check_cancelled(&self) -> Result<(), MediaToolError> {
76 if self.cancel.is_cancelled() {
77 Err(tool_error("media", "workflow cancelled"))
78 } else {
79 Ok(())
80 }
81 }
82}
83
84pub struct ComputePool {
91 pool: rayon::ThreadPool,
92}
93
94impl ComputePool {
95 pub fn new() -> Result<Self, MediaToolError> {
97 Ok(Self {
98 pool: rayon::ThreadPoolBuilder::new()
99 .num_threads(num_cpus().min(4))
100 .thread_name(|idx| format!("nika-media-{idx}"))
101 .panic_handler(|info| {
102 tracing::error!("media compute thread panicked: {info:?}");
105 })
106 .build()
107 .map_err(|e| {
108 tool_error(
109 "compute_pool",
110 format!("Failed to create media compute pool: {e}"),
111 )
112 })?,
113 })
114 }
115
116 pub async fn compute<F, T>(&self, f: F) -> Result<T, MediaToolError>
121 where
122 F: FnOnce() -> T + Send + 'static,
123 T: Send + 'static,
124 {
125 let (tx, rx) = tokio::sync::oneshot::channel();
126 self.pool.spawn(move || {
127 let _ = tx.send(f());
128 });
129 rx.await
130 .map_err(|_| tool_error("compute", "task panicked on rayon thread"))
131 }
132}
133
134fn num_cpus() -> usize {
136 std::thread::available_parallelism()
137 .map(|n| n.get())
138 .unwrap_or(2)
139}
140
141#[derive(Debug)]
146pub struct WorkingMemoryBudget {
147 used: AtomicUsize,
148 max_bytes: usize,
149}
150
151impl WorkingMemoryBudget {
152 pub const DEFAULT_MAX: usize = 512 * 1024 * 1024;
154
155 pub fn new() -> Self {
157 Self {
158 used: AtomicUsize::new(0),
159 max_bytes: Self::DEFAULT_MAX,
160 }
161 }
162
163 pub fn with_max(max_bytes: usize) -> Self {
165 Self {
166 used: AtomicUsize::new(0),
167 max_bytes,
168 }
169 }
170
171 pub fn acquire(&self, size: usize) -> Result<WorkingMemoryGuard<'_>, MediaToolError> {
173 let result = self
174 .used
175 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
176 let new_total = current + size;
177 if new_total > self.max_bytes {
178 None
179 } else {
180 Some(new_total)
181 }
182 });
183
184 match result {
185 Ok(_) => Ok(WorkingMemoryGuard { budget: self, size }),
186 Err(current) => Err(tool_error(
187 "memory",
188 format!(
189 "working memory exhausted ({} + {} > {} limit)",
190 current, size, self.max_bytes
191 ),
192 )),
193 }
194 }
195
196 pub fn current(&self) -> usize {
198 self.used.load(Ordering::Acquire)
199 }
200}
201
202impl Default for WorkingMemoryBudget {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208#[derive(Debug)]
210pub struct WorkingMemoryGuard<'a> {
211 budget: &'a WorkingMemoryBudget,
212 size: usize,
213}
214
215impl<'a> Drop for WorkingMemoryGuard<'a> {
216 fn drop(&mut self) {
217 let _ = self
218 .budget
219 .used
220 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
221 Some(current.saturating_sub(self.size))
222 });
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[tokio::test]
235 async fn compute_pool_executes_on_rayon_thread() {
236 let pool = ComputePool::new().unwrap();
237 let thread_name = pool
238 .compute(|| {
239 std::thread::current()
240 .name()
241 .unwrap_or("unknown")
242 .to_string()
243 })
244 .await
245 .unwrap();
246 assert!(
247 thread_name.starts_with("nika-media"),
248 "expected nika-media thread, got: {thread_name}"
249 );
250 }
251
252 #[tokio::test]
253 async fn compute_pool_returns_result() {
254 let pool = ComputePool::new().unwrap();
255 let result = pool.compute(|| 2 + 2).await.unwrap();
256 assert_eq!(result, 4);
257 }
258
259 #[tokio::test]
260 async fn compute_pool_handles_panic() {
261 let pool = ComputePool::new().unwrap();
262 let result: Result<i32, _> = pool
263 .compute(|| {
264 panic!("intentional test panic");
265 })
266 .await;
267 assert!(result.is_err());
268 assert!(result.unwrap_err().to_string().contains("panicked"));
269 }
270
271 #[test]
276 fn working_memory_acquire_release() {
277 let budget = WorkingMemoryBudget::with_max(1024);
278 assert_eq!(budget.current(), 0);
279
280 {
281 let _guard = budget.acquire(100).unwrap();
282 assert_eq!(budget.current(), 100);
283 }
284 assert_eq!(budget.current(), 0);
286 }
287
288 #[test]
289 fn working_memory_blocks_when_full() {
290 let budget = WorkingMemoryBudget::with_max(512);
291
292 let _guard = budget.acquire(512).unwrap();
293 assert_eq!(budget.current(), 512);
294
295 let result = budget.acquire(1);
297 assert!(result.is_err());
298 assert!(result
299 .unwrap_err()
300 .to_string()
301 .contains("working memory exhausted"));
302 }
303
304 #[test]
305 fn working_memory_multiple_guards() {
306 let budget = WorkingMemoryBudget::with_max(300);
307
308 let g1 = budget.acquire(100).unwrap();
309 let g2 = budget.acquire(100).unwrap();
310 assert_eq!(budget.current(), 200);
311
312 drop(g1);
313 assert_eq!(budget.current(), 100);
314
315 drop(g2);
316 assert_eq!(budget.current(), 0);
317 }
318
319 #[tokio::test]
324 async fn context_check_cancelled_ok() {
325 let dir = tempfile::tempdir().unwrap();
326 let ctx = MediaToolContext::new(CasStore::new(dir.path())).unwrap();
327 assert!(ctx.check_cancelled().is_ok());
328 }
329
330 #[tokio::test]
331 async fn context_check_cancelled_err() {
332 let dir = tempfile::tempdir().unwrap();
333 let ctx = MediaToolContext::new(CasStore::new(dir.path())).unwrap();
334 ctx.cancel.cancel();
335 assert!(ctx.check_cancelled().is_err());
336 }
337
338 #[tokio::test]
339 async fn context_store_charges_budget() {
340 let dir = tempfile::tempdir().unwrap();
341 let ctx = MediaToolContext::new(CasStore::new(dir.path())).unwrap();
342 let data = b"test media data";
343 let result = ctx.store_media(data, "test_task").await;
344 assert!(result.is_ok());
345 assert_eq!(ctx.budget.current_bytes(), data.len() as u64);
346 }
347
348 #[tokio::test]
349 async fn context_read_missing_hash() {
350 let dir = tempfile::tempdir().unwrap();
351 let ctx = MediaToolContext::new(CasStore::new(dir.path())).unwrap();
352 let result = ctx
353 .read_media("blake3:0000000000000000000000000000000000000000000000000000000000000000")
354 .await;
355 assert!(result.is_err());
356 }
357}