Skip to main content

oximedia_graph/
graph_evaluator.rs

1//! Graph evaluation: topological execution, cycle detection facade, dynamic editing,
2//! serialization helpers, conditional routing, gain node, and multi-input node.
3//!
4//! This module wires together the various graph-processing primitives (topological sort,
5//! cycle detection, frame pool) into concrete high-level APIs.
6
7#![allow(dead_code)]
8
9#[allow(unused_imports)]
10use std::collections::{HashMap, HashSet, VecDeque};
11
12use crate::cycle_detect::CycleGraph;
13use crate::error::{GraphError, GraphResult};
14use crate::frame::FilterFrame;
15use crate::node::{Node, NodeId, NodeState, NodeType};
16use crate::port::{AudioPortFormat, InputPort, OutputPort, PortFormat, PortId, PortType};
17
18#[allow(unused_imports)]
19use oximedia_audio::{AudioBuffer, AudioFrame, ChannelLayout};
20#[allow(unused_imports)]
21use oximedia_codec::VideoFrame;
22#[allow(unused_imports)]
23use oximedia_core::{PixelFormat, SampleFormat, Timestamp};
24
25// ─── Cycle detection facade ──────────────────────────────────────────────────
26
27/// Error returned when a cycle is found during graph validation.
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct CycleError {
30    /// Nodes that form or participate in a cycle.
31    pub cycle_nodes: Vec<NodeId>,
32}
33
34impl std::fmt::Display for CycleError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(
37            f,
38            "Cycle detected in graph involving {} nodes",
39            self.cycle_nodes.len()
40        )
41    }
42}
43
44impl std::error::Error for CycleError {}
45
46// ─── Simple graph descriptor used by evaluator / validator ───────────────────
47
48/// A lightweight directed-graph descriptor used by the evaluator layer.
49///
50/// This is a plain data structure — it does not own `Node` trait objects.
51/// Use `FilterGraph` from `crate::graph` for full node execution.
52#[derive(Debug, Clone, Default)]
53pub struct SimpleGraph {
54    /// Adjacency list: source NodeId → list of successor NodeIds.
55    pub edges: HashMap<NodeId, Vec<NodeId>>,
56    /// Node labels for serialization and diagnostics.
57    pub labels: HashMap<NodeId, String>,
58}
59
60impl SimpleGraph {
61    /// Create a new empty graph descriptor.
62    #[must_use]
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    /// Add a node with a label.
68    pub fn add_node(&mut self, id: NodeId, label: impl Into<String>) {
69        self.labels.insert(id, label.into());
70        self.edges.entry(id).or_default();
71    }
72
73    /// Add a directed edge (implicitly creates both nodes if absent).
74    pub fn add_edge(&mut self, from: NodeId, to: NodeId) {
75        self.edges.entry(from).or_default().push(to);
76        self.edges.entry(to).or_default();
77        self.labels
78            .entry(from)
79            .or_insert_with(|| format!("{}", from.0));
80        self.labels.entry(to).or_insert_with(|| format!("{}", to.0));
81    }
82
83    /// Return the number of nodes.
84    #[must_use]
85    pub fn node_count(&self) -> usize {
86        self.edges.len()
87    }
88
89    /// Return the number of directed edges.
90    #[must_use]
91    pub fn edge_count(&self) -> usize {
92        self.edges.values().map(|v| v.len()).sum()
93    }
94
95    /// Return all node IDs sorted for determinism.
96    #[must_use]
97    pub fn node_ids(&self) -> Vec<NodeId> {
98        let mut ids: Vec<NodeId> = self.edges.keys().copied().collect();
99        ids.sort_by_key(|id| id.0);
100        ids
101    }
102}
103
104// ─── GraphEvaluator ──────────────────────────────────────────────────────────
105
106/// Evaluates a filter graph by executing nodes in topological order.
107///
108/// # Example
109///
110/// ```ignore
111/// use oximedia_graph::graph_evaluator::GraphEvaluator;
112/// let order = GraphEvaluator::topological_sort(&graph).unwrap();
113/// ```
114pub struct GraphEvaluator;
115
116impl GraphEvaluator {
117    /// Perform a topological sort of the graph using Kahn's algorithm.
118    ///
119    /// Returns the nodes in a valid processing order (every edge `a → b`
120    /// guarantees `a` appears before `b` in the result).
121    ///
122    /// # Errors
123    ///
124    /// Returns [`CycleError`] if the graph contains a cycle.
125    pub fn topological_sort(graph: &SimpleGraph) -> Result<Vec<NodeId>, CycleError> {
126        // Build in-degree map
127        let mut in_degree: HashMap<NodeId, usize> = graph.edges.keys().map(|&id| (id, 0)).collect();
128
129        for successors in graph.edges.values() {
130            for &succ in successors {
131                *in_degree.entry(succ).or_insert(0) += 1;
132            }
133        }
134
135        // Enqueue nodes with in-degree 0 (sorted for determinism)
136        let mut queue: VecDeque<NodeId> = {
137            let mut zeros: Vec<NodeId> = in_degree
138                .iter()
139                .filter(|(_, &d)| d == 0)
140                .map(|(&id, _)| id)
141                .collect();
142            zeros.sort_by_key(|id| id.0);
143            zeros.into_iter().collect()
144        };
145
146        let mut order = Vec::with_capacity(graph.node_count());
147
148        while let Some(node) = queue.pop_front() {
149            order.push(node);
150            let mut successors: Vec<NodeId> = graph.edges.get(&node).cloned().unwrap_or_default();
151            successors.sort_by_key(|id| id.0);
152            for succ in successors {
153                if let Some(d) = in_degree.get_mut(&succ) {
154                    *d -= 1;
155                    if *d == 0 {
156                        queue.push_back(succ);
157                    }
158                }
159            }
160        }
161
162        if order.len() != graph.node_count() {
163            let cycle_nodes: Vec<NodeId> = in_degree
164                .iter()
165                .filter(|(_, &d)| d > 0)
166                .map(|(&id, _)| id)
167                .collect();
168            return Err(CycleError { cycle_nodes });
169        }
170
171        Ok(order)
172    }
173
174    /// Process a frame through all nodes in topological order, applying
175    /// each node's transform in sequence.
176    ///
177    /// `processors` maps `NodeId` to a boxed transform function
178    /// `fn(Option<FilterFrame>) -> Option<FilterFrame>`.
179    pub fn evaluate(
180        order: &[NodeId],
181        processors: &mut HashMap<
182            NodeId,
183            Box<dyn FnMut(Option<FilterFrame>) -> Option<FilterFrame>>,
184        >,
185        input: FilterFrame,
186    ) -> Option<FilterFrame> {
187        let mut current = Some(input);
188        for &id in order {
189            if let Some(proc) = processors.get_mut(&id) {
190                current = proc(current);
191            }
192        }
193        current
194    }
195}
196
197// ─── GraphValidator ───────────────────────────────────────────────────────────
198
199/// Validates a filter graph for structural correctness.
200pub struct GraphValidator;
201
202impl GraphValidator {
203    /// Return `true` if the graph has at least one cycle.
204    ///
205    /// Uses DFS with gray/black node colouring (identical to
206    /// `CycleGraph::has_cycle` but operating on the `SimpleGraph` type).
207    #[must_use]
208    pub fn has_cycle(graph: &SimpleGraph) -> bool {
209        let mut cg = CycleGraph::new();
210        for (&from, successors) in &graph.edges {
211            for &to in successors {
212                cg.add_edge(
213                    crate::cycle_detect::CycleNodeId(from.0 as usize),
214                    crate::cycle_detect::CycleNodeId(to.0 as usize),
215                );
216            }
217            // Ensure isolated nodes are registered
218            if successors.is_empty() {
219                cg.add_node(crate::cycle_detect::CycleNodeId(from.0 as usize));
220            }
221        }
222        cg.has_cycle()
223    }
224
225    /// Return `true` if the graph is a valid DAG (no cycles, at least one node).
226    #[must_use]
227    pub fn is_valid_dag(graph: &SimpleGraph) -> bool {
228        !graph.edges.is_empty() && !Self::has_cycle(graph)
229    }
230
231    /// Collect all nodes with in-degree 0 (potential source nodes).
232    #[must_use]
233    pub fn source_nodes(graph: &SimpleGraph) -> Vec<NodeId> {
234        let mut in_degree: HashMap<NodeId, usize> = graph.edges.keys().map(|&k| (k, 0)).collect();
235        for succs in graph.edges.values() {
236            for &s in succs {
237                *in_degree.entry(s).or_insert(0) += 1;
238            }
239        }
240        let mut sources: Vec<NodeId> = in_degree
241            .into_iter()
242            .filter(|(_, d)| *d == 0)
243            .map(|(id, _)| id)
244            .collect();
245        sources.sort_by_key(|id| id.0);
246        sources
247    }
248
249    /// Collect all nodes with out-degree 0 (potential sink nodes).
250    #[must_use]
251    pub fn sink_nodes(graph: &SimpleGraph) -> Vec<NodeId> {
252        let mut sinks: Vec<NodeId> = graph
253            .edges
254            .iter()
255            .filter(|(_, succs)| succs.is_empty())
256            .map(|(&id, _)| id)
257            .collect();
258        sinks.sort_by_key(|id| id.0);
259        sinks
260    }
261}
262
263// ─── EdgeId ───────────────────────────────────────────────────────────────────
264
265/// Identifies a directed edge by its source and destination node IDs.
266#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
267pub struct EdgeId {
268    /// Source node.
269    pub from: NodeId,
270    /// Destination node.
271    pub to: NodeId,
272}
273
274impl EdgeId {
275    /// Create a new edge ID.
276    #[must_use]
277    pub fn new(from: NodeId, to: NodeId) -> Self {
278        Self { from, to }
279    }
280}
281
282// ─── GraphEditor ─────────────────────────────────────────────────────────────
283
284/// Supports dynamic modification of a graph without full rebuild.
285pub struct GraphEditor;
286
287impl GraphEditor {
288    /// Insert `new_node_id` on an existing edge `edge`, splitting it into two.
289    ///
290    /// Before: `edge.from → edge.to`
291    /// After:  `edge.from → new_node_id → edge.to`
292    ///
293    /// The edge label of the new segments inherits the label of `new_node_id`.
294    ///
295    /// # Errors
296    ///
297    /// Returns `Err` if the edge does not exist in the graph.
298    pub fn insert_node_between(
299        graph: &mut SimpleGraph,
300        new_node_id: NodeId,
301        new_node_label: impl Into<String>,
302        edge: EdgeId,
303    ) -> GraphResult<()> {
304        // Verify edge exists
305        let successors = graph
306            .edges
307            .get(&edge.from)
308            .ok_or(GraphError::NodeNotFound(edge.from))?;
309        if !successors.contains(&edge.to) {
310            return Err(GraphError::ConfigurationError(format!(
311                "Edge {:?} → {:?} does not exist",
312                edge.from, edge.to
313            )));
314        }
315
316        // Remove old edge from.to
317        if let Some(succs) = graph.edges.get_mut(&edge.from) {
318            succs.retain(|&id| id != edge.to);
319        }
320
321        // Insert new node
322        let label = new_node_label.into();
323        graph.labels.insert(new_node_id, label);
324        graph.edges.entry(new_node_id).or_default();
325
326        // Add edges: from → new → to
327        graph
328            .edges
329            .get_mut(&edge.from)
330            .unwrap_or(&mut vec![])
331            .push(new_node_id);
332        // Use entry to avoid double borrow
333        graph.edges.entry(edge.from).or_default().push(new_node_id);
334        // Add new_node → to
335        graph.edges.entry(new_node_id).or_default().push(edge.to);
336
337        Ok(())
338    }
339}
340
341// ─── GraphSerializer ─────────────────────────────────────────────────────────
342
343/// Serializes a filter graph to various text formats.
344pub struct GraphSerializer;
345
346impl GraphSerializer {
347    /// Produce a valid Graphviz DOT representation of the graph.
348    ///
349    /// ```text
350    /// digraph G {
351    ///     0 [label="source"];
352    ///     1 [label="filter"];
353    ///     2 [label="sink"];
354    ///     0 -> 1;
355    ///     1 -> 2;
356    /// }
357    /// ```
358    #[must_use]
359    pub fn to_dot(graph: &SimpleGraph) -> String {
360        let mut out = String::from("digraph G {\n");
361
362        let mut node_ids = graph.node_ids();
363        node_ids.sort_by_key(|id| id.0);
364
365        for id in &node_ids {
366            let label = graph
367                .labels
368                .get(id)
369                .cloned()
370                .unwrap_or_else(|| format!("{}", id.0));
371            // Escape quotes in labels
372            let escaped = label.replace('"', "\\\"");
373            out.push_str(&format!("    {} [label=\"{}\"];\n", id.0, escaped));
374        }
375
376        for id in &node_ids {
377            if let Some(succs) = graph.edges.get(id) {
378                let mut sorted = succs.clone();
379                sorted.sort_by_key(|s| s.0);
380                for succ in sorted {
381                    out.push_str(&format!("    {} -> {};\n", id.0, succ.0));
382                }
383            }
384        }
385
386        out.push_str("}\n");
387        out
388    }
389
390    /// Produce a simple adjacency-list text representation.
391    #[must_use]
392    pub fn to_adjacency_list(graph: &SimpleGraph) -> String {
393        let mut out = String::new();
394        let mut ids = graph.node_ids();
395        ids.sort_by_key(|id| id.0);
396        for id in ids {
397            let label = graph
398                .labels
399                .get(&id)
400                .cloned()
401                .unwrap_or_else(|| format!("{}", id.0));
402            let succs = graph.edges.get(&id).cloned().unwrap_or_default();
403            let succ_str: Vec<String> = succs.iter().map(|s| format!("{}", s.0)).collect();
404            out.push_str(&format!("{} ({}): {}\n", id.0, label, succ_str.join(", ")));
405        }
406        out
407    }
408}
409
410// ─── GraphStatsSnapshot ──────────────────────────────────────────────────────
411
412/// A snapshot of graph structural statistics.
413#[derive(Debug, Clone)]
414pub struct GraphStatsSnapshot {
415    /// Number of nodes.
416    pub node_count: usize,
417    /// Number of directed edges.
418    pub edge_count: usize,
419    /// Maximum depth (longest path from any source).
420    pub max_depth: usize,
421}
422
423/// Measures structural statistics of a graph.
424pub struct GraphStats;
425
426impl GraphStats {
427    /// Measure node count, edge count, and max depth.
428    ///
429    /// Max depth is computed via BFS from all source nodes (in-degree 0).
430    #[must_use]
431    pub fn measure(graph: &SimpleGraph) -> GraphStatsSnapshot {
432        let node_count = graph.node_count();
433        let edge_count = graph.edge_count();
434        let max_depth = Self::compute_max_depth(graph);
435        GraphStatsSnapshot {
436            node_count,
437            edge_count,
438            max_depth,
439        }
440    }
441
442    fn compute_max_depth(graph: &SimpleGraph) -> usize {
443        if graph.edges.is_empty() {
444            return 0;
445        }
446
447        // Build in-degree map
448        let mut in_degree: HashMap<NodeId, usize> = graph.edges.keys().map(|&k| (k, 0)).collect();
449        for succs in graph.edges.values() {
450            for &s in succs {
451                *in_degree.entry(s).or_insert(0) += 1;
452            }
453        }
454
455        // BFS depth propagation
456        let mut depth: HashMap<NodeId, usize> = HashMap::new();
457        let mut queue: VecDeque<NodeId> = in_degree
458            .iter()
459            .filter(|(_, &d)| d == 0)
460            .map(|(&id, _)| id)
461            .collect();
462        for &id in &queue {
463            depth.insert(id, 0);
464        }
465
466        let mut max_d = 0;
467        // Use topological BFS
468        let mut local_in_degree = in_degree.clone();
469        while let Some(node) = queue.pop_front() {
470            let cur = *depth.get(&node).unwrap_or(&0);
471            if let Some(succs) = graph.edges.get(&node) {
472                for &succ in succs {
473                    let new_d = cur + 1;
474                    let entry = depth.entry(succ).or_insert(0);
475                    if new_d > *entry {
476                        *entry = new_d;
477                        if new_d > max_d {
478                            max_d = new_d;
479                        }
480                    }
481                    if let Some(d) = local_in_degree.get_mut(&succ) {
482                        *d = d.saturating_sub(1);
483                        if *d == 0 {
484                            queue.push_back(succ);
485                        }
486                    }
487                }
488            }
489        }
490
491        max_d
492    }
493}
494
495// ─── MultiInputNode ───────────────────────────────────────────────────────────
496
497/// A node with a configurable number of inputs and outputs.
498///
499/// Used for fan-in topologies where multiple upstream outputs are combined.
500pub struct MultiInputNode {
501    id: NodeId,
502    name: String,
503    state: NodeState,
504    inputs: Vec<InputPort>,
505    outputs: Vec<OutputPort>,
506}
507
508impl MultiInputNode {
509    /// Create a new multi-input node.
510    ///
511    /// All ports are typed as `Any` (video or audio).
512    #[must_use]
513    pub fn new(inputs: u32, outputs: u32, name: impl Into<String>) -> Self {
514        let name = name.into();
515        let input_ports = (0..inputs)
516            .map(|i| InputPort::new(PortId(i), &format!("input_{i}"), PortType::Video))
517            .collect();
518        let output_ports = (0..outputs)
519            .map(|i| OutputPort::new(PortId(i), &format!("output_{i}"), PortType::Video))
520            .collect();
521        Self {
522            id: NodeId(0),
523            name,
524            state: NodeState::Idle,
525            inputs: input_ports,
526            outputs: output_ports,
527        }
528    }
529
530    /// Set the node ID.
531    #[must_use]
532    pub fn with_id(mut self, id: NodeId) -> Self {
533        self.id = id;
534        self
535    }
536
537    /// Return the number of input ports.
538    #[must_use]
539    pub fn input_count(&self) -> usize {
540        self.inputs.len()
541    }
542
543    /// Return the number of output ports.
544    #[must_use]
545    pub fn output_count(&self) -> usize {
546        self.outputs.len()
547    }
548}
549
550impl Node for MultiInputNode {
551    fn id(&self) -> NodeId {
552        self.id
553    }
554    fn name(&self) -> &str {
555        &self.name
556    }
557    fn node_type(&self) -> NodeType {
558        NodeType::Filter
559    }
560    fn state(&self) -> NodeState {
561        self.state
562    }
563
564    fn set_state(&mut self, state: NodeState) -> GraphResult<()> {
565        self.state = state;
566        Ok(())
567    }
568
569    fn inputs(&self) -> &[InputPort] {
570        &self.inputs
571    }
572    fn outputs(&self) -> &[OutputPort] {
573        &self.outputs
574    }
575
576    fn process(&mut self, input: Option<FilterFrame>) -> GraphResult<Option<FilterFrame>> {
577        // Pass-through: forward the first available input unchanged
578        Ok(input)
579    }
580
581    fn reset(&mut self) -> GraphResult<()> {
582        self.state = NodeState::Idle;
583        Ok(())
584    }
585}
586
587/// Connect multiple source nodes to a single target node's inputs.
588///
589/// Returns an error if `sources.len()` exceeds the target's input port count.
590pub fn connect_fan_in(
591    graph: &mut SimpleGraph,
592    sources: &[NodeId],
593    target: NodeId,
594) -> GraphResult<()> {
595    for &src in sources {
596        if !graph.edges.contains_key(&src) {
597            return Err(GraphError::NodeNotFound(src));
598        }
599        graph.edges.entry(src).or_default().push(target);
600    }
601    graph.edges.entry(target).or_default();
602    Ok(())
603}
604
605// ─── ConditionalRouter ────────────────────────────────────────────────────────
606
607/// A filter node that routes frames to one of two output ports based on a
608/// predicate function applied to each incoming frame.
609///
610/// - Output port 0: condition is `true`
611/// - Output port 1: condition is `false`
612pub struct ConditionalRouter {
613    id: NodeId,
614    name: String,
615    state: NodeState,
616    condition: Box<dyn Fn(&FilterFrame) -> bool + Send + Sync>,
617    /// Most recently routed output (0 = true branch, 1 = false branch).
618    last_route: Option<u32>,
619    inputs: Vec<InputPort>,
620    outputs: Vec<OutputPort>,
621}
622
623impl ConditionalRouter {
624    /// Create a new conditional router with the given predicate.
625    ///
626    /// The predicate receives a shared reference to the frame and returns
627    /// `true` to route to output 0 or `false` to route to output 1.
628    pub fn new(condition: Box<dyn Fn(&FilterFrame) -> bool + Send + Sync>) -> Self {
629        Self {
630            id: NodeId(0),
631            name: "conditional_router".to_string(),
632            state: NodeState::Idle,
633            condition,
634            last_route: None,
635            inputs: vec![InputPort::new(PortId(0), "input", PortType::Video)],
636            outputs: vec![
637                OutputPort::new(PortId(0), "true_output", PortType::Video),
638                OutputPort::new(PortId(1), "false_output", PortType::Video),
639            ],
640        }
641    }
642
643    /// Set the node ID.
644    #[must_use]
645    pub fn with_id(mut self, id: NodeId) -> Self {
646        self.id = id;
647        self
648    }
649
650    /// Return which output port the last frame was routed to.
651    #[must_use]
652    pub fn last_route(&self) -> Option<u32> {
653        self.last_route
654    }
655
656    /// Evaluate the condition and return the output port index (0 or 1).
657    #[must_use]
658    pub fn route(&self, frame: &FilterFrame) -> u32 {
659        if (self.condition)(frame) {
660            0
661        } else {
662            1
663        }
664    }
665}
666
667impl Node for ConditionalRouter {
668    fn id(&self) -> NodeId {
669        self.id
670    }
671    fn name(&self) -> &str {
672        &self.name
673    }
674    fn node_type(&self) -> NodeType {
675        NodeType::Filter
676    }
677    fn state(&self) -> NodeState {
678        self.state
679    }
680
681    fn set_state(&mut self, state: NodeState) -> GraphResult<()> {
682        self.state = state;
683        Ok(())
684    }
685
686    fn inputs(&self) -> &[InputPort] {
687        &self.inputs
688    }
689    fn outputs(&self) -> &[OutputPort] {
690        &self.outputs
691    }
692
693    fn process(&mut self, input: Option<FilterFrame>) -> GraphResult<Option<FilterFrame>> {
694        match input {
695            Some(frame) => {
696                self.last_route = Some(self.route(&frame));
697                Ok(Some(frame))
698            }
699            None => Ok(None),
700        }
701    }
702
703    fn reset(&mut self) -> GraphResult<()> {
704        self.last_route = None;
705        self.state = NodeState::Idle;
706        Ok(())
707    }
708}
709
710// ─── GainNode ─────────────────────────────────────────────────────────────────
711
712/// An audio gain node that scales sample amplitude by a dB value.
713///
714/// `gain_db` of 0 dB = unity gain; +6 dB ≈ 2×; −6 dB ≈ 0.5×.
715pub struct GainNode {
716    id: NodeId,
717    name: String,
718    state: NodeState,
719    /// Gain in dB.
720    gain_db: f32,
721    /// Precomputed linear multiplier: `10^(gain_db / 20)`.
722    linear_gain: f32,
723    inputs: Vec<InputPort>,
724    outputs: Vec<OutputPort>,
725}
726
727impl GainNode {
728    /// Create a new gain node with the specified dB gain.
729    #[must_use]
730    pub fn new(gain_db: f32) -> Self {
731        let linear_gain = 10.0_f32.powf(gain_db / 20.0);
732        let audio_fmt = PortFormat::Audio(AudioPortFormat::any());
733        Self {
734            id: NodeId(0),
735            name: "gain".to_string(),
736            state: NodeState::Idle,
737            gain_db,
738            linear_gain,
739            inputs: vec![
740                InputPort::new(PortId(0), "input", PortType::Audio).with_format(audio_fmt.clone())
741            ],
742            outputs: vec![
743                OutputPort::new(PortId(0), "output", PortType::Audio).with_format(audio_fmt)
744            ],
745        }
746    }
747
748    /// Set the node ID.
749    #[must_use]
750    pub fn with_id(mut self, id: NodeId) -> Self {
751        self.id = id;
752        self
753    }
754
755    /// Return the gain in dB.
756    #[must_use]
757    pub fn gain_db(&self) -> f32 {
758        self.gain_db
759    }
760
761    /// Return the linear gain multiplier.
762    #[must_use]
763    pub fn linear_gain(&self) -> f32 {
764        self.linear_gain
765    }
766
767    /// Apply gain to an audio frame, returning a new frame.
768    #[must_use]
769    pub fn process_frame(&self, frame: &AudioFrame) -> AudioFrame {
770        let channels = frame.channels.count();
771        let sample_count = frame.sample_count();
772
773        if sample_count == 0 || channels == 0 {
774            return frame.clone();
775        }
776
777        match &frame.samples {
778            AudioBuffer::Interleaved(data) => {
779                let bytes_per_sample = frame.format.bytes_per_sample();
780                if bytes_per_sample == 0 {
781                    return frame.clone();
782                }
783
784                let total_samples = data.len() / bytes_per_sample;
785                let mut out = Vec::with_capacity(data.len());
786
787                for i in 0..total_samples {
788                    let offset = i * bytes_per_sample;
789                    if offset + bytes_per_sample > data.len() {
790                        break;
791                    }
792                    let scaled = Self::scale_sample(
793                        &data[offset..offset + bytes_per_sample],
794                        frame.format,
795                        self.linear_gain,
796                    );
797                    out.extend_from_slice(&scaled);
798                }
799
800                let mut output_frame =
801                    AudioFrame::new(frame.format, frame.sample_rate, frame.channels.clone());
802                output_frame.samples = AudioBuffer::Interleaved(bytes::Bytes::from(out));
803                output_frame
804            }
805            AudioBuffer::Planar(planes) => {
806                let bytes_per_sample = frame.format.bytes_per_sample();
807                if bytes_per_sample == 0 {
808                    return frame.clone();
809                }
810
811                let scaled_planes: Vec<bytes::Bytes> = planes
812                    .iter()
813                    .map(|plane| {
814                        let total = plane.len() / bytes_per_sample;
815                        let mut out = Vec::with_capacity(plane.len());
816                        for i in 0..total {
817                            let offset = i * bytes_per_sample;
818                            if offset + bytes_per_sample > plane.len() {
819                                break;
820                            }
821                            let scaled = Self::scale_sample(
822                                &plane[offset..offset + bytes_per_sample],
823                                frame.format,
824                                self.linear_gain,
825                            );
826                            out.extend_from_slice(&scaled);
827                        }
828                        bytes::Bytes::from(out)
829                    })
830                    .collect();
831
832                let mut output_frame =
833                    AudioFrame::new(frame.format, frame.sample_rate, frame.channels.clone());
834                output_frame.samples = AudioBuffer::Planar(scaled_planes);
835                output_frame
836            }
837        }
838    }
839
840    /// Scale a single sample byte slice by `linear_gain`.
841    fn scale_sample(bytes: &[u8], format: SampleFormat, gain: f32) -> Vec<u8> {
842        match format {
843            SampleFormat::F32 => {
844                if bytes.len() < 4 {
845                    return bytes.to_vec();
846                }
847                let v = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
848                let scaled = (v * gain).clamp(-1.0, 1.0);
849                scaled.to_le_bytes().to_vec()
850            }
851            SampleFormat::S16 => {
852                if bytes.len() < 2 {
853                    return bytes.to_vec();
854                }
855                let v = i16::from_le_bytes([bytes[0], bytes[1]]);
856                let scaled = ((v as f32) * gain).clamp(i16::MIN as f32, i16::MAX as f32) as i16;
857                scaled.to_le_bytes().to_vec()
858            }
859            SampleFormat::S32 => {
860                if bytes.len() < 4 {
861                    return bytes.to_vec();
862                }
863                let v = i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
864                let scaled =
865                    ((v as f64) * gain as f64).clamp(i32::MIN as f64, i32::MAX as f64) as i32;
866                scaled.to_le_bytes().to_vec()
867            }
868            SampleFormat::U8 => {
869                if bytes.is_empty() {
870                    return bytes.to_vec();
871                }
872                // Map [0,255] → [-1,1], scale, map back
873                let v = (bytes[0] as f32 - 128.0) / 128.0;
874                let scaled = (v * gain).clamp(-1.0, 1.0);
875                let out = ((scaled * 128.0) + 128.0).clamp(0.0, 255.0) as u8;
876                vec![out]
877            }
878            _ => bytes.to_vec(),
879        }
880    }
881}
882
883impl Node for GainNode {
884    fn id(&self) -> NodeId {
885        self.id
886    }
887    fn name(&self) -> &str {
888        &self.name
889    }
890    fn node_type(&self) -> NodeType {
891        NodeType::Filter
892    }
893    fn state(&self) -> NodeState {
894        self.state
895    }
896
897    fn set_state(&mut self, state: NodeState) -> GraphResult<()> {
898        self.state = state;
899        Ok(())
900    }
901
902    fn inputs(&self) -> &[InputPort] {
903        &self.inputs
904    }
905    fn outputs(&self) -> &[OutputPort] {
906        &self.outputs
907    }
908
909    fn process(&mut self, input: Option<FilterFrame>) -> GraphResult<Option<FilterFrame>> {
910        match input {
911            Some(FilterFrame::Audio(audio_frame)) => {
912                let output = self.process_frame(&audio_frame);
913                Ok(Some(FilterFrame::Audio(output)))
914            }
915            Some(other) => Ok(Some(other)), // pass video frames through unchanged
916            None => Ok(None),
917        }
918    }
919
920    fn reset(&mut self) -> GraphResult<()> {
921        self.state = NodeState::Idle;
922        Ok(())
923    }
924}
925
926// ─── CropNode ────────────────────────────────────────────────────────────────
927
928/// A simple crop-region descriptor for use with the evaluator layer.
929///
930/// Note: the full `CropFilter` in `filters::video::crop` provides a complete
931/// `Node` implementation. This struct is a thin wrapper for the evaluator API.
932#[derive(Debug, Clone, Copy)]
933pub struct CropRegion {
934    /// Left pixel offset.
935    pub x: u32,
936    /// Top pixel offset.
937    pub y: u32,
938    /// Output width.
939    pub w: u32,
940    /// Output height.
941    pub h: u32,
942}
943
944impl CropRegion {
945    /// Create a new crop region.
946    #[must_use]
947    pub fn new(x: u32, y: u32, w: u32, h: u32) -> Self {
948        Self { x, y, w, h }
949    }
950
951    /// Crop a raw RGBA/RGB pixel buffer (packed, row-major).
952    ///
953    /// `src_w` is the original image width in pixels.
954    /// `channels` is bytes per pixel (e.g. 4 for RGBA).
955    #[must_use]
956    pub fn crop_buffer(&self, data: &[u8], src_w: u32, channels: u32) -> Vec<u8> {
957        let stride = (src_w * channels) as usize;
958        let row_bytes = (self.w * channels) as usize;
959        let mut out = Vec::with_capacity(row_bytes * self.h as usize);
960
961        for row in self.y..self.y + self.h {
962            let row_start = row as usize * stride + (self.x * channels) as usize;
963            let row_end = row_start + row_bytes;
964            if row_end <= data.len() {
965                out.extend_from_slice(&data[row_start..row_end]);
966            }
967        }
968
969        out
970    }
971}
972
973// ─── Tests ────────────────────────────────────────────────────────────────────
974
975#[cfg(test)]
976mod tests {
977    use super::*;
978
979    fn make_chain() -> SimpleGraph {
980        let mut g = SimpleGraph::new();
981        g.add_node(NodeId(0), "A");
982        g.add_node(NodeId(1), "B");
983        g.add_node(NodeId(2), "C");
984        g.add_edge(NodeId(0), NodeId(1));
985        g.add_edge(NodeId(1), NodeId(2));
986        g
987    }
988
989    // ── GraphEvaluator ──
990
991    #[test]
992    fn test_topological_sort_linear_chain() {
993        let g = make_chain();
994        let order = GraphEvaluator::topological_sort(&g).expect("sort should succeed");
995        assert_eq!(order, vec![NodeId(0), NodeId(1), NodeId(2)]);
996    }
997
998    #[test]
999    fn test_topological_sort_cycle_returns_error() {
1000        let mut g = SimpleGraph::new();
1001        g.add_edge(NodeId(0), NodeId(0)); // self-loop
1002        let result = GraphEvaluator::topological_sort(&g);
1003        assert!(result.is_err());
1004    }
1005
1006    #[test]
1007    fn test_topological_sort_diamond() {
1008        let mut g = SimpleGraph::new();
1009        g.add_edge(NodeId(0), NodeId(1));
1010        g.add_edge(NodeId(0), NodeId(2));
1011        g.add_edge(NodeId(1), NodeId(3));
1012        g.add_edge(NodeId(2), NodeId(3));
1013        let order = GraphEvaluator::topological_sort(&g).expect("sort should succeed");
1014        assert_eq!(order[0], NodeId(0));
1015        assert_eq!(*order.last().expect("last"), NodeId(3));
1016    }
1017
1018    // ── GraphValidator ──
1019
1020    #[test]
1021    fn test_has_cycle_self_loop() {
1022        let mut g = SimpleGraph::new();
1023        g.add_edge(NodeId(5), NodeId(5));
1024        assert!(GraphValidator::has_cycle(&g));
1025    }
1026
1027    #[test]
1028    fn test_has_cycle_dag() {
1029        let g = make_chain();
1030        assert!(!GraphValidator::has_cycle(&g));
1031    }
1032
1033    #[test]
1034    fn test_has_cycle_3_node_cycle() {
1035        let mut g = SimpleGraph::new();
1036        g.add_edge(NodeId(0), NodeId(1));
1037        g.add_edge(NodeId(1), NodeId(2));
1038        g.add_edge(NodeId(2), NodeId(0));
1039        assert!(GraphValidator::has_cycle(&g));
1040    }
1041
1042    // ── GraphSerializer ──
1043
1044    #[test]
1045    fn test_to_dot_contains_digraph() {
1046        let g = make_chain();
1047        let dot = GraphSerializer::to_dot(&g);
1048        assert!(dot.contains("digraph"), "DOT output must contain 'digraph'");
1049    }
1050
1051    #[test]
1052    fn test_to_dot_contains_nodes() {
1053        let g = make_chain();
1054        let dot = GraphSerializer::to_dot(&g);
1055        assert!(dot.contains("label=\"A\""));
1056        assert!(dot.contains("label=\"B\""));
1057        assert!(dot.contains("label=\"C\""));
1058    }
1059
1060    #[test]
1061    fn test_to_dot_contains_edges() {
1062        let g = make_chain();
1063        let dot = GraphSerializer::to_dot(&g);
1064        assert!(dot.contains("->"));
1065    }
1066
1067    // ── GraphStats ──
1068
1069    #[test]
1070    fn test_stats_measure_chain() {
1071        let g = make_chain();
1072        let snap = GraphStats::measure(&g);
1073        assert_eq!(snap.node_count, 3);
1074        assert_eq!(snap.edge_count, 2);
1075        assert_eq!(snap.max_depth, 2);
1076    }
1077
1078    #[test]
1079    fn test_stats_measure_empty() {
1080        let g = SimpleGraph::new();
1081        let snap = GraphStats::measure(&g);
1082        assert_eq!(snap.node_count, 0);
1083        assert_eq!(snap.edge_count, 0);
1084        assert_eq!(snap.max_depth, 0);
1085    }
1086
1087    // ── MultiInputNode ──
1088
1089    #[test]
1090    fn test_multi_input_node_creation() {
1091        let n = MultiInputNode::new(3, 1, "fan_in");
1092        assert_eq!(n.input_count(), 3);
1093        assert_eq!(n.output_count(), 1);
1094        assert_eq!(n.name(), "fan_in");
1095    }
1096
1097    #[test]
1098    fn test_connect_fan_in() {
1099        let mut g = SimpleGraph::new();
1100        g.add_node(NodeId(0), "src_a");
1101        g.add_node(NodeId(1), "src_b");
1102        g.add_node(NodeId(2), "mixer");
1103        let result = connect_fan_in(&mut g, &[NodeId(0), NodeId(1)], NodeId(2));
1104        assert!(result.is_ok());
1105        assert!(g.edges[&NodeId(0)].contains(&NodeId(2)));
1106        assert!(g.edges[&NodeId(1)].contains(&NodeId(2)));
1107    }
1108
1109    // ── ConditionalRouter ──
1110
1111    #[test]
1112    fn test_conditional_router_true_branch() {
1113        let router = ConditionalRouter::new(Box::new(|frame| frame.is_video()));
1114        let video = VideoFrame::new(PixelFormat::Yuv420p, 100, 100);
1115        let frame = FilterFrame::Video(video);
1116        assert_eq!(router.route(&frame), 0); // true → port 0
1117    }
1118
1119    #[test]
1120    fn test_conditional_router_false_branch() {
1121        let router = ConditionalRouter::new(Box::new(|frame| frame.is_video()));
1122        let audio = AudioFrame::new(SampleFormat::F32, 48000, ChannelLayout::Stereo);
1123        let frame = FilterFrame::Audio(audio);
1124        assert_eq!(router.route(&frame), 1); // false → port 1
1125    }
1126
1127    // ── GainNode ──
1128
1129    #[test]
1130    fn test_gain_node_unity_gain() {
1131        let g = GainNode::new(0.0);
1132        assert!((g.linear_gain() - 1.0).abs() < 1e-5);
1133    }
1134
1135    #[test]
1136    fn test_gain_node_6db() {
1137        let g = GainNode::new(6.0206);
1138        // 10^(6.0206/20) ≈ 2.0
1139        assert!((g.linear_gain() - 2.0).abs() < 0.01);
1140    }
1141
1142    #[test]
1143    fn test_gain_node_process_f32_sample() {
1144        use bytes::BytesMut;
1145
1146        let gain = GainNode::new(6.0206); // ≈ 2×
1147        let mut frame = AudioFrame::new(SampleFormat::F32, 48000, ChannelLayout::Mono);
1148        let mut buf = BytesMut::new();
1149        buf.extend_from_slice(&0.25f32.to_le_bytes());
1150        frame.samples = AudioBuffer::Interleaved(buf.freeze());
1151
1152        let out = gain.process_frame(&frame);
1153        if let AudioBuffer::Interleaved(data) = &out.samples {
1154            let v = f32::from_le_bytes([data[0], data[1], data[2], data[3]]);
1155            // 0.25 × 2 = 0.5
1156            assert!((v - 0.5).abs() < 0.05, "expected ~0.5, got {v}");
1157        }
1158    }
1159
1160    // ── CropRegion ──
1161
1162    #[test]
1163    fn test_crop_region_extracts_subregion() {
1164        // 4×4 RGBA image, each pixel = (row*10, col*10, 0, 255)
1165        let mut data = vec![0u8; 4 * 4 * 4];
1166        for row in 0..4u32 {
1167            for col in 0..4u32 {
1168                let idx = ((row * 4 + col) * 4) as usize;
1169                data[idx] = (row * 10) as u8;
1170                data[idx + 1] = (col * 10) as u8;
1171                data[idx + 2] = 0;
1172                data[idx + 3] = 255;
1173            }
1174        }
1175
1176        let crop = CropRegion::new(1, 1, 2, 2);
1177        let result = crop.crop_buffer(&data, 4, 4);
1178        assert_eq!(result.len(), 2 * 2 * 4);
1179        // First pixel of cropped region: row=1, col=1 → (10, 10, 0, 255)
1180        assert_eq!(result[0], 10);
1181        assert_eq!(result[1], 10);
1182    }
1183
1184    // ── EdgeId / GraphEditor ──
1185
1186    #[test]
1187    fn test_graph_editor_insert_node_between() {
1188        let mut g = SimpleGraph::new();
1189        g.add_node(NodeId(0), "A");
1190        g.add_node(NodeId(2), "C");
1191        g.add_edge(NodeId(0), NodeId(2));
1192
1193        let edge = EdgeId::new(NodeId(0), NodeId(2));
1194        // Insert B between A and C
1195        let result = GraphEditor::insert_node_between(&mut g, NodeId(1), "B", edge);
1196        assert!(
1197            result.is_ok(),
1198            "insert_node_between should succeed: {result:?}"
1199        );
1200
1201        // A should now have B as successor (not C directly)
1202        assert!(
1203            g.edges[&NodeId(0)].contains(&NodeId(1)) || !g.edges[&NodeId(0)].contains(&NodeId(2))
1204        );
1205        // B should have C as successor
1206        assert!(g.edges[&NodeId(1)].contains(&NodeId(2)));
1207    }
1208
1209    #[test]
1210    fn test_graph_editor_nonexistent_edge_returns_error() {
1211        let mut g = SimpleGraph::new();
1212        g.add_node(NodeId(0), "A");
1213        g.add_node(NodeId(1), "B");
1214        // No edge 0→1 yet
1215        let edge = EdgeId::new(NodeId(0), NodeId(1));
1216        let result = GraphEditor::insert_node_between(&mut g, NodeId(2), "C", edge);
1217        assert!(result.is_err());
1218    }
1219}