Skip to main content

nika_media/tools/
context.rs

1//! Media tool execution context.
2//!
3//! Provides `MediaToolContext` (shared per workflow run) with:
4//! - CAS store access
5//! - Media budget enforcement
6//! - `ComputePool` (rayon, isolated from tokio)
7//! - `WorkingMemoryBudget` (transient buffer limits)
8//! - Cancellation support
9
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12
13use tokio_util::sync::CancellationToken;
14
15use super::error::tool_error;
16use super::error::MediaToolError;
17use crate::{CasStore, MediaBudget};
18
19/// Shared context for all media tool operations within a workflow run.
20pub struct MediaToolContext {
21    /// CAS store for reading/writing binary media.
22    pub cas: CasStore,
23    /// Per-run budget enforcement (500 MB default).
24    pub budget: Arc<MediaBudget>,
25    /// CPU-bound compute pool (rayon, 4 threads, isolated from tokio).
26    pub compute: Arc<ComputePool>,
27    /// Transient working memory budget (512 MB default).
28    pub working_memory: Arc<WorkingMemoryBudget>,
29    /// Workflow cancellation token.
30    pub cancel: CancellationToken,
31}
32
33impl MediaToolContext {
34    /// Create a new context with the given CAS store.
35    ///
36    /// Returns an error if the compute pool cannot be created.
37    pub fn new(cas: CasStore) -> Result<Self, MediaToolError> {
38        Ok(Self {
39            cas,
40            budget: Arc::new(MediaBudget::new()),
41            compute: Arc::new(ComputePool::new()?),
42            working_memory: Arc::new(WorkingMemoryBudget::new()),
43            cancel: CancellationToken::new(),
44        })
45    }
46
47    /// Read media data from CAS by hash.
48    pub async fn read_media(&self, hash: &str) -> Result<Vec<u8>, MediaToolError> {
49        self.cas.read(hash).await.map_err(|e| e.into())
50    }
51
52    /// Store binary data in CAS and charge the budget.
53    pub async fn store_media(
54        &self,
55        data: &[u8],
56        task_id: &str,
57    ) -> Result<crate::store::StoreResult, MediaToolError> {
58        // Charge budget first
59        self.budget
60            .check_and_add(data.len() as u64, task_id)
61            .map_err(|e| -> MediaToolError { e.into() })?;
62
63        // Store in CAS
64        match self.cas.store(data).await {
65            Ok(result) => Ok(result),
66            Err(e) => {
67                // Rollback budget on CAS failure
68                self.budget.rollback(data.len() as u64);
69                Err(e.into())
70            }
71        }
72    }
73
74    /// Check if the workflow has been cancelled.
75    pub fn check_cancelled(&self) -> Result<(), MediaToolError> {
76        if self.cancel.is_cancelled() {
77            Err(tool_error("media", "workflow cancelled"))
78        } else {
79            Ok(())
80        }
81    }
82}
83
84/// Dedicated rayon ThreadPool for CPU-bound media operations.
85///
86/// Isolates media compute (resize, encode, optimize) from the tokio runtime.
87/// Uses `min(num_cpus, 4)` threads named `nika-media-N`.
88///
89/// Communication: rayon thread → tokio via oneshot channel.
90pub struct ComputePool {
91    pool: rayon::ThreadPool,
92}
93
94impl ComputePool {
95    /// Create a new compute pool.
96    pub fn new() -> Result<Self, MediaToolError> {
97        Ok(Self {
98            pool: rayon::ThreadPoolBuilder::new()
99                .num_threads(num_cpus().min(4))
100                .thread_name(|idx| format!("nika-media-{idx}"))
101                .panic_handler(|info| {
102                    // Log the panic for debugging, then absorb — the oneshot channel
103                    // receiver gets RecvError which is mapped to a MediaToolError.
104                    tracing::error!("media compute thread panicked: {info:?}");
105                })
106                .build()
107                .map_err(|e| {
108                    tool_error(
109                        "compute_pool",
110                        format!("Failed to create media compute pool: {e}"),
111                    )
112                })?,
113        })
114    }
115
116    /// Execute a CPU-bound closure on the rayon pool.
117    ///
118    /// Bridges rayon → tokio via a oneshot channel.
119    /// If the closure panics, returns an error instead of crashing.
120    pub async fn compute<F, T>(&self, f: F) -> Result<T, MediaToolError>
121    where
122        F: FnOnce() -> T + Send + 'static,
123        T: Send + 'static,
124    {
125        let (tx, rx) = tokio::sync::oneshot::channel();
126        self.pool.spawn(move || {
127            let _ = tx.send(f());
128        });
129        rx.await
130            .map_err(|_| tool_error("compute", "task panicked on rayon thread"))
131    }
132}
133
134/// Get the number of available CPUs.
135fn num_cpus() -> usize {
136    std::thread::available_parallelism()
137        .map(|n| n.get())
138        .unwrap_or(2)
139}
140
141/// Budget for transient working memory (in-flight decode buffers, etc.).
142///
143/// Prevents OOM from multiple concurrent image decodes.
144/// Default: 512 MB. Uses atomic CAS for lock-free concurrent access.
145#[derive(Debug)]
146pub struct WorkingMemoryBudget {
147    used: AtomicUsize,
148    max_bytes: usize,
149}
150
151impl WorkingMemoryBudget {
152    /// Default transient memory budget: 512 MB.
153    pub const DEFAULT_MAX: usize = 512 * 1024 * 1024;
154
155    /// Create with default budget.
156    pub fn new() -> Self {
157        Self {
158            used: AtomicUsize::new(0),
159            max_bytes: Self::DEFAULT_MAX,
160        }
161    }
162
163    /// Create with custom budget.
164    pub fn with_max(max_bytes: usize) -> Self {
165        Self {
166            used: AtomicUsize::new(0),
167            max_bytes,
168        }
169    }
170
171    /// Try to acquire `size` bytes. Returns a guard that releases on drop.
172    pub fn acquire(&self, size: usize) -> Result<WorkingMemoryGuard<'_>, MediaToolError> {
173        let result = self
174            .used
175            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
176                let new_total = current + size;
177                if new_total > self.max_bytes {
178                    None
179                } else {
180                    Some(new_total)
181                }
182            });
183
184        match result {
185            Ok(_) => Ok(WorkingMemoryGuard { budget: self, size }),
186            Err(current) => Err(tool_error(
187                "memory",
188                format!(
189                    "working memory exhausted ({} + {} > {} limit)",
190                    current, size, self.max_bytes
191                ),
192            )),
193        }
194    }
195
196    /// Current bytes in use.
197    pub fn current(&self) -> usize {
198        self.used.load(Ordering::Acquire)
199    }
200}
201
202impl Default for WorkingMemoryBudget {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208/// RAII guard that releases working memory on drop.
209#[derive(Debug)]
210pub struct WorkingMemoryGuard<'a> {
211    budget: &'a WorkingMemoryBudget,
212    size: usize,
213}
214
215impl<'a> Drop for WorkingMemoryGuard<'a> {
216    fn drop(&mut self) {
217        let _ = self
218            .budget
219            .used
220            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
221                Some(current.saturating_sub(self.size))
222            });
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    // ═══════════════════════════════════════════
231    // COMPUTE POOL TESTS
232    // ═══════════════════════════════════════════
233
234    #[tokio::test]
235    async fn compute_pool_executes_on_rayon_thread() {
236        let pool = ComputePool::new().unwrap();
237        let thread_name = pool
238            .compute(|| {
239                std::thread::current()
240                    .name()
241                    .unwrap_or("unknown")
242                    .to_string()
243            })
244            .await
245            .unwrap();
246        assert!(
247            thread_name.starts_with("nika-media"),
248            "expected nika-media thread, got: {thread_name}"
249        );
250    }
251
252    #[tokio::test]
253    async fn compute_pool_returns_result() {
254        let pool = ComputePool::new().unwrap();
255        let result = pool.compute(|| 2 + 2).await.unwrap();
256        assert_eq!(result, 4);
257    }
258
259    #[tokio::test]
260    async fn compute_pool_handles_panic() {
261        let pool = ComputePool::new().unwrap();
262        let result: Result<i32, _> = pool
263            .compute(|| {
264                panic!("intentional test panic");
265            })
266            .await;
267        assert!(result.is_err());
268        assert!(result.unwrap_err().to_string().contains("panicked"));
269    }
270
271    // ═══════════════════════════════════════════
272    // WORKING MEMORY BUDGET TESTS
273    // ═══════════════════════════════════════════
274
275    #[test]
276    fn working_memory_acquire_release() {
277        let budget = WorkingMemoryBudget::with_max(1024);
278        assert_eq!(budget.current(), 0);
279
280        {
281            let _guard = budget.acquire(100).unwrap();
282            assert_eq!(budget.current(), 100);
283        }
284        // Guard dropped, memory released
285        assert_eq!(budget.current(), 0);
286    }
287
288    #[test]
289    fn working_memory_blocks_when_full() {
290        let budget = WorkingMemoryBudget::with_max(512);
291
292        let _guard = budget.acquire(512).unwrap();
293        assert_eq!(budget.current(), 512);
294
295        // Second acquire should fail
296        let result = budget.acquire(1);
297        assert!(result.is_err());
298        assert!(result
299            .unwrap_err()
300            .to_string()
301            .contains("working memory exhausted"));
302    }
303
304    #[test]
305    fn working_memory_multiple_guards() {
306        let budget = WorkingMemoryBudget::with_max(300);
307
308        let g1 = budget.acquire(100).unwrap();
309        let g2 = budget.acquire(100).unwrap();
310        assert_eq!(budget.current(), 200);
311
312        drop(g1);
313        assert_eq!(budget.current(), 100);
314
315        drop(g2);
316        assert_eq!(budget.current(), 0);
317    }
318
319    // ═══════════════════════════════════════════
320    // MEDIA TOOL CONTEXT TESTS
321    // ═══════════════════════════════════════════
322
323    #[tokio::test]
324    async fn context_check_cancelled_ok() {
325        let dir = tempfile::tempdir().unwrap();
326        let ctx = MediaToolContext::new(CasStore::new(dir.path())).unwrap();
327        assert!(ctx.check_cancelled().is_ok());
328    }
329
330    #[tokio::test]
331    async fn context_check_cancelled_err() {
332        let dir = tempfile::tempdir().unwrap();
333        let ctx = MediaToolContext::new(CasStore::new(dir.path())).unwrap();
334        ctx.cancel.cancel();
335        assert!(ctx.check_cancelled().is_err());
336    }
337
338    #[tokio::test]
339    async fn context_store_charges_budget() {
340        let dir = tempfile::tempdir().unwrap();
341        let ctx = MediaToolContext::new(CasStore::new(dir.path())).unwrap();
342        let data = b"test media data";
343        let result = ctx.store_media(data, "test_task").await;
344        assert!(result.is_ok());
345        assert_eq!(ctx.budget.current_bytes(), data.len() as u64);
346    }
347
348    #[tokio::test]
349    async fn context_read_missing_hash() {
350        let dir = tempfile::tempdir().unwrap();
351        let ctx = MediaToolContext::new(CasStore::new(dir.path())).unwrap();
352        let result = ctx
353            .read_media("blake3:0000000000000000000000000000000000000000000000000000000000000000")
354            .await;
355        assert!(result.is_err());
356    }
357}