Skip to main content

reflow_taskpacks/
lib.rs

1//! Reusable graph taskpacks for media/ML workflows.
2
3use reflow_graph::types::{GraphConnection, GraphEdge, GraphExport, GraphNode, PortType};
4use serde_json::{json, Value};
5use std::collections::HashMap;
6
7pub const TPL_CV_IMAGE_TO_TENSOR: &str = "tpl_cv_image_to_tensor";
8pub const TPL_CV_RESIZE_LETTERBOX: &str = "tpl_cv_resize_letterbox";
9pub const TPL_CV_VIDEO_STREAM_TO_FRAMES: &str = "tpl_cv_video_stream_to_frames";
10pub const TPL_CV_NORMALIZE_TENSOR: &str = "tpl_cv_normalize_tensor";
11pub const TPL_CV_TENSOR_CROP_ROI: &str = "tpl_cv_tensor_crop_roi";
12pub const TPL_CV_DETECTION_TO_ROI: &str = "tpl_cv_detection_to_roi";
13pub const TPL_CV_TEMPORAL_SMOOTHER: &str = "tpl_cv_temporal_smoother";
14
15pub const TPL_ML_LOAD_MODEL: &str = "tpl_ml_load_model";
16pub const TPL_ML_RUN_INFERENCE: &str = "tpl_ml_run_inference";
17pub const TPL_ML_DECODE_DETECTIONS: &str = "tpl_ml_decode_detections";
18pub const TPL_ML_DECODE_LANDMARKS: &str = "tpl_ml_decode_landmarks";
19pub const TPL_ML_PACKET_PROBE: &str = "tpl_ml_packet_probe";
20
21/// A hand-landmark-style task graph using generic CV and ML actors.
22///
23/// This is an authoring convenience and test fixture, not privileged runtime
24/// behavior. Model details stay in actor config and can be swapped by graph
25/// editors or manifests.
26pub fn hand_landmark_graph() -> GraphExport {
27    let mut processes = HashMap::new();
28    processes.insert(
29        "palm_letterbox".to_string(),
30        node(
31            "palm_letterbox",
32            TPL_CV_RESIZE_LETTERBOX,
33            json!({"width": 224, "height": 224, "fill": 0}),
34        ),
35    );
36    processes.insert(
37        "palm_tensor".to_string(),
38        node(
39            "palm_tensor",
40            TPL_CV_IMAGE_TO_TENSOR,
41            json!({"name": "image", "dtype": "f32", "layout": "nhwc", "channels": 3}),
42        ),
43    );
44    processes.insert(
45        "palm_inference".to_string(),
46        node(
47            "palm_inference",
48            TPL_ML_RUN_INFERENCE,
49            json!({
50                "model_id": "hand-palm-detector",
51                "backend": "mock",
52                "task": "palm_detection",
53                "inputs": [{"name": "image", "dtype": "f32", "shape": {"dims": [1, 224, 224, 3]}}],
54                "outputs": [{"name": "detections", "dtype": "f32", "shape": {"dims": [1, 5]}}]
55            }),
56        ),
57    );
58    processes.insert(
59        "palm_decode".to_string(),
60        node(
61            "palm_decode",
62            TPL_ML_DECODE_DETECTIONS,
63            json!({"threshold": 0.25, "values_per_detection": 5, "fallback_detection": true}),
64        ),
65    );
66    processes.insert(
67        "roi".to_string(),
68        node(
69            "roi",
70            TPL_CV_DETECTION_TO_ROI,
71            json!({"scale": 1.8, "square": true, "fallback_center": true}),
72        ),
73    );
74    processes.insert(
75        "crop".to_string(),
76        node("crop", TPL_CV_TENSOR_CROP_ROI, json!({})),
77    );
78    processes.insert(
79        "landmark_letterbox".to_string(),
80        node(
81            "landmark_letterbox",
82            TPL_CV_RESIZE_LETTERBOX,
83            json!({"width": 224, "height": 224, "fill": 0}),
84        ),
85    );
86    processes.insert(
87        "landmark_tensor".to_string(),
88        node(
89            "landmark_tensor",
90            TPL_CV_IMAGE_TO_TENSOR,
91            json!({"name": "roi_image", "dtype": "f32", "layout": "nhwc", "channels": 3}),
92        ),
93    );
94    processes.insert(
95        "landmark_inference".to_string(),
96        node(
97            "landmark_inference",
98            TPL_ML_RUN_INFERENCE,
99            json!({
100                "model_id": "hand-landmark",
101                "backend": "mock",
102                "task": "landmark",
103                "inputs": [{"name": "roi_image", "dtype": "f32", "shape": {"dims": [1, 224, 224, 3]}}],
104                "outputs": [{"name": "landmarks", "dtype": "f32", "shape": {"dims": [1, 21, 3]}}]
105            }),
106        ),
107    );
108    processes.insert(
109        "landmark_decode".to_string(),
110        node(
111            "landmark_decode",
112            TPL_ML_DECODE_LANDMARKS,
113            json!({"values_per_landmark": 3, "max_landmarks": 21}),
114        ),
115    );
116    processes.insert(
117        "smooth".to_string(),
118        node("smooth", TPL_CV_TEMPORAL_SMOOTHER, json!({"alpha": 0.55})),
119    );
120
121    GraphExport {
122        case_sensitive: true,
123        properties: HashMap::from([
124            ("name".to_string(), json!("Hand Landmark Taskpack")),
125            ("kind".to_string(), json!("reflow.taskpack")),
126            ("task".to_string(), json!("hand_landmark")),
127            ("version".to_string(), json!(1)),
128        ]),
129        inports: HashMap::from([("frame".to_string(), edge("palm_letterbox", "frame"))]),
130        outports: HashMap::from([
131            ("landmarks".to_string(), edge("smooth", "landmarks")),
132            ("detections".to_string(), edge("palm_decode", "detections")),
133            ("roi".to_string(), edge("roi", "roi")),
134        ]),
135        groups: Vec::new(),
136        processes,
137        connections: vec![
138            conn("palm_letterbox", "frame", "palm_tensor", "frame"),
139            conn("palm_tensor", "tensor", "palm_inference", "tensor"),
140            conn("palm_inference", "tensor", "palm_decode", "tensor"),
141            conn("palm_decode", "detections", "roi", "detections"),
142            conn("palm_letterbox", "frame", "crop", "frame"),
143            conn("roi", "roi", "crop", "roi"),
144            conn("crop", "frame", "landmark_letterbox", "frame"),
145            conn("landmark_letterbox", "frame", "landmark_tensor", "frame"),
146            conn("landmark_tensor", "tensor", "landmark_inference", "tensor"),
147            conn("landmark_inference", "tensor", "landmark_decode", "tensor"),
148            conn("landmark_decode", "landmarks", "smooth", "landmarks"),
149        ],
150        graph_dependencies: Vec::new(),
151        external_connections: Vec::new(),
152        provided_interfaces: HashMap::new(),
153        required_interfaces: HashMap::new(),
154    }
155}
156
157pub fn ml_template_mapping() -> Vec<(&'static str, &'static str)> {
158    vec![
159        (TPL_CV_IMAGE_TO_TENSOR, "ImageToTensorActor"),
160        (TPL_CV_RESIZE_LETTERBOX, "ResizeLetterboxActor"),
161        (TPL_CV_VIDEO_STREAM_TO_FRAMES, "VideoStreamToFramesActor"),
162        (TPL_CV_NORMALIZE_TENSOR, "NormalizeTensorActor"),
163        (TPL_CV_TENSOR_CROP_ROI, "TensorCropRoiActor"),
164        (TPL_CV_DETECTION_TO_ROI, "DetectionToRoiActor"),
165        (TPL_CV_TEMPORAL_SMOOTHER, "TemporalSmootherActor"),
166        (TPL_ML_LOAD_MODEL, "LoadModelActor"),
167        (TPL_ML_RUN_INFERENCE, "RunInferenceActor"),
168        (TPL_ML_DECODE_DETECTIONS, "DecodeDetectionsActor"),
169        (TPL_ML_DECODE_LANDMARKS, "DecodeLandmarksActor"),
170        (TPL_ML_PACKET_PROBE, "PacketProbeActor"),
171    ]
172}
173
174fn node(id: &str, component: &str, metadata: Value) -> GraphNode {
175    GraphNode {
176        id: id.to_string(),
177        component: component.to_string(),
178        metadata: Some(json_to_hash(metadata)),
179    }
180}
181
182fn conn(from_node: &str, from_port: &str, to_node: &str, to_port: &str) -> GraphConnection {
183    GraphConnection {
184        from: edge(from_node, from_port),
185        to: edge(to_node, to_port),
186        metadata: None,
187        data: None,
188    }
189}
190
191fn edge(node_id: &str, port: &str) -> GraphEdge {
192    GraphEdge {
193        port_name: port.to_string(),
194        port_id: port.to_string(),
195        node_id: node_id.to_string(),
196        index: None,
197        expose: false,
198        data: None,
199        metadata: None,
200        port_type: PortType::Any,
201    }
202}
203
204fn json_to_hash(value: Value) -> HashMap<String, Value> {
205    value
206        .as_object()
207        .map(|object| {
208            object
209                .iter()
210                .map(|(key, value)| (key.clone(), value.clone()))
211                .collect()
212        })
213        .unwrap_or_default()
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use reflow_actor::{Actor, ActorConfig};
220    use reflow_cv_ops::{
221        DetectionToRoiActor, ImageToTensorActor, NormalizeTensorActor, ResizeLetterboxActor,
222        TemporalSmootherActor, TensorCropRoiActor,
223    };
224    use reflow_media_codec::{frame_to_message, value_from_message_or_packet};
225    use reflow_media_types::{ImageFormat, LandmarkSet, PacketMetadata, Timestamp, VideoFrame};
226    use reflow_ml_ops::{
227        DecodeDetectionsActor, DecodeLandmarksActor, LoadModelActor, PacketProbeActor,
228        RunInferenceActor,
229    };
230    use reflow_network::subgraph::SubgraphActor;
231    use std::{sync::Arc, time::Duration};
232
233    #[test]
234    fn hand_landmark_graph_has_subgraph_boundary() {
235        let graph = hand_landmark_graph();
236
237        assert!(graph.inports.contains_key("frame"));
238        assert!(graph.outports.contains_key("landmarks"));
239        assert_eq!(
240            graph.processes["palm_inference"].component,
241            TPL_ML_RUN_INFERENCE
242        );
243    }
244
245    #[test]
246    fn hand_landmark_graph_can_build_subgraph_actor() {
247        let graph = hand_landmark_graph();
248        let subgraph =
249            SubgraphActor::from_graph_export(&graph, taskpack_actor_templates()).unwrap();
250
251        assert!(subgraph.inport_map().contains_key("frame"));
252        assert!(subgraph.outport_map().contains_key("landmarks"));
253    }
254
255    #[tokio::test]
256    async fn hand_landmark_subgraph_processes_frame() {
257        let graph = hand_landmark_graph();
258        let subgraph =
259            SubgraphActor::from_graph_export(&graph, taskpack_actor_templates()).unwrap();
260        let inport_sender = subgraph.get_inports().0;
261        let outport_receiver = subgraph.get_outports().1;
262
263        let handle = tokio::spawn(subgraph.create_process(ActorConfig::default(), None));
264        let frame = sample_frame();
265        inport_sender
266            .send_async(HashMap::from([(
267                "frame".to_string(),
268                frame_to_message(&frame).unwrap(),
269            )]))
270            .await
271            .unwrap();
272
273        let output = tokio::time::timeout(Duration::from_secs(3), async {
274            loop {
275                let packet = outport_receiver
276                    .recv_async()
277                    .await
278                    .expect("subgraph outport closed before producing landmarks");
279                if packet.contains_key("landmarks") {
280                    break packet;
281                }
282            }
283        })
284        .await
285        .expect("hand landmark taskpack did not produce landmarks");
286
287        let landmarks: LandmarkSet =
288            value_from_message_or_packet(output.get("landmarks").unwrap()).unwrap();
289        assert_eq!(landmarks.landmarks.len(), 21);
290        assert_eq!(
291            landmarks.metadata.timestamp,
292            Some(Timestamp::from_millis(42))
293        );
294
295        subgraph.shutdown();
296        handle.abort();
297    }
298
299    fn taskpack_actor_templates() -> HashMap<String, Arc<dyn Actor>> {
300        HashMap::from([
301            (
302                TPL_CV_IMAGE_TO_TENSOR.to_string(),
303                Arc::new(ImageToTensorActor::new()) as Arc<dyn Actor>,
304            ),
305            (
306                TPL_CV_RESIZE_LETTERBOX.to_string(),
307                Arc::new(ResizeLetterboxActor::new()) as Arc<dyn Actor>,
308            ),
309            (
310                TPL_CV_NORMALIZE_TENSOR.to_string(),
311                Arc::new(NormalizeTensorActor::new()) as Arc<dyn Actor>,
312            ),
313            (
314                TPL_CV_TENSOR_CROP_ROI.to_string(),
315                Arc::new(TensorCropRoiActor::new()) as Arc<dyn Actor>,
316            ),
317            (
318                TPL_CV_DETECTION_TO_ROI.to_string(),
319                Arc::new(DetectionToRoiActor::new()) as Arc<dyn Actor>,
320            ),
321            (
322                TPL_CV_TEMPORAL_SMOOTHER.to_string(),
323                Arc::new(TemporalSmootherActor::new()) as Arc<dyn Actor>,
324            ),
325            (
326                TPL_ML_LOAD_MODEL.to_string(),
327                Arc::new(LoadModelActor::new()) as Arc<dyn Actor>,
328            ),
329            (
330                TPL_ML_RUN_INFERENCE.to_string(),
331                Arc::new(RunInferenceActor::new()) as Arc<dyn Actor>,
332            ),
333            (
334                TPL_ML_DECODE_DETECTIONS.to_string(),
335                Arc::new(DecodeDetectionsActor::new()) as Arc<dyn Actor>,
336            ),
337            (
338                TPL_ML_DECODE_LANDMARKS.to_string(),
339                Arc::new(DecodeLandmarksActor::new()) as Arc<dyn Actor>,
340            ),
341            (
342                TPL_ML_PACKET_PROBE.to_string(),
343                Arc::new(PacketProbeActor::new()) as Arc<dyn Actor>,
344            ),
345        ])
346    }
347
348    fn sample_frame() -> VideoFrame {
349        let width = 32;
350        let height = 24;
351        let mut data = Vec::with_capacity(width * height * 4);
352        for y in 0..height {
353            for x in 0..width {
354                data.extend_from_slice(&[
355                    (x * 255 / width) as u8,
356                    (y * 255 / height) as u8,
357                    160,
358                    255,
359                ]);
360            }
361        }
362
363        let mut metadata = PacketMetadata::with_timestamp(Timestamp::from_millis(42));
364        metadata.sequence = Some(7);
365        let mut frame = VideoFrame::new(width as u32, height as u32, ImageFormat::Rgba8, data);
366        frame.metadata = metadata;
367        frame
368    }
369}