1use reflow_graph::types::{GraphConnection, GraphEdge, GraphExport, GraphNode, PortType};
4use serde_json::{json, Value};
5use std::collections::HashMap;
6
7pub const TPL_CV_IMAGE_TO_TENSOR: &str = "tpl_cv_image_to_tensor";
8pub const TPL_CV_RESIZE_LETTERBOX: &str = "tpl_cv_resize_letterbox";
9pub const TPL_CV_VIDEO_STREAM_TO_FRAMES: &str = "tpl_cv_video_stream_to_frames";
10pub const TPL_CV_NORMALIZE_TENSOR: &str = "tpl_cv_normalize_tensor";
11pub const TPL_CV_TENSOR_CROP_ROI: &str = "tpl_cv_tensor_crop_roi";
12pub const TPL_CV_DETECTION_TO_ROI: &str = "tpl_cv_detection_to_roi";
13pub const TPL_CV_TEMPORAL_SMOOTHER: &str = "tpl_cv_temporal_smoother";
14
15pub const TPL_ML_LOAD_MODEL: &str = "tpl_ml_load_model";
16pub const TPL_ML_RUN_INFERENCE: &str = "tpl_ml_run_inference";
17pub const TPL_ML_DECODE_DETECTIONS: &str = "tpl_ml_decode_detections";
18pub const TPL_ML_DECODE_LANDMARKS: &str = "tpl_ml_decode_landmarks";
19pub const TPL_ML_PACKET_PROBE: &str = "tpl_ml_packet_probe";
20
21pub fn hand_landmark_graph() -> GraphExport {
27 let mut processes = HashMap::new();
28 processes.insert(
29 "palm_letterbox".to_string(),
30 node(
31 "palm_letterbox",
32 TPL_CV_RESIZE_LETTERBOX,
33 json!({"width": 224, "height": 224, "fill": 0}),
34 ),
35 );
36 processes.insert(
37 "palm_tensor".to_string(),
38 node(
39 "palm_tensor",
40 TPL_CV_IMAGE_TO_TENSOR,
41 json!({"name": "image", "dtype": "f32", "layout": "nhwc", "channels": 3}),
42 ),
43 );
44 processes.insert(
45 "palm_inference".to_string(),
46 node(
47 "palm_inference",
48 TPL_ML_RUN_INFERENCE,
49 json!({
50 "model_id": "hand-palm-detector",
51 "backend": "mock",
52 "task": "palm_detection",
53 "inputs": [{"name": "image", "dtype": "f32", "shape": {"dims": [1, 224, 224, 3]}}],
54 "outputs": [{"name": "detections", "dtype": "f32", "shape": {"dims": [1, 5]}}]
55 }),
56 ),
57 );
58 processes.insert(
59 "palm_decode".to_string(),
60 node(
61 "palm_decode",
62 TPL_ML_DECODE_DETECTIONS,
63 json!({"threshold": 0.25, "values_per_detection": 5, "fallback_detection": true}),
64 ),
65 );
66 processes.insert(
67 "roi".to_string(),
68 node(
69 "roi",
70 TPL_CV_DETECTION_TO_ROI,
71 json!({"scale": 1.8, "square": true, "fallback_center": true}),
72 ),
73 );
74 processes.insert(
75 "crop".to_string(),
76 node("crop", TPL_CV_TENSOR_CROP_ROI, json!({})),
77 );
78 processes.insert(
79 "landmark_letterbox".to_string(),
80 node(
81 "landmark_letterbox",
82 TPL_CV_RESIZE_LETTERBOX,
83 json!({"width": 224, "height": 224, "fill": 0}),
84 ),
85 );
86 processes.insert(
87 "landmark_tensor".to_string(),
88 node(
89 "landmark_tensor",
90 TPL_CV_IMAGE_TO_TENSOR,
91 json!({"name": "roi_image", "dtype": "f32", "layout": "nhwc", "channels": 3}),
92 ),
93 );
94 processes.insert(
95 "landmark_inference".to_string(),
96 node(
97 "landmark_inference",
98 TPL_ML_RUN_INFERENCE,
99 json!({
100 "model_id": "hand-landmark",
101 "backend": "mock",
102 "task": "landmark",
103 "inputs": [{"name": "roi_image", "dtype": "f32", "shape": {"dims": [1, 224, 224, 3]}}],
104 "outputs": [{"name": "landmarks", "dtype": "f32", "shape": {"dims": [1, 21, 3]}}]
105 }),
106 ),
107 );
108 processes.insert(
109 "landmark_decode".to_string(),
110 node(
111 "landmark_decode",
112 TPL_ML_DECODE_LANDMARKS,
113 json!({"values_per_landmark": 3, "max_landmarks": 21}),
114 ),
115 );
116 processes.insert(
117 "smooth".to_string(),
118 node("smooth", TPL_CV_TEMPORAL_SMOOTHER, json!({"alpha": 0.55})),
119 );
120
121 GraphExport {
122 case_sensitive: true,
123 properties: HashMap::from([
124 ("name".to_string(), json!("Hand Landmark Taskpack")),
125 ("kind".to_string(), json!("reflow.taskpack")),
126 ("task".to_string(), json!("hand_landmark")),
127 ("version".to_string(), json!(1)),
128 ]),
129 inports: HashMap::from([("frame".to_string(), edge("palm_letterbox", "frame"))]),
130 outports: HashMap::from([
131 ("landmarks".to_string(), edge("smooth", "landmarks")),
132 ("detections".to_string(), edge("palm_decode", "detections")),
133 ("roi".to_string(), edge("roi", "roi")),
134 ]),
135 groups: Vec::new(),
136 processes,
137 connections: vec![
138 conn("palm_letterbox", "frame", "palm_tensor", "frame"),
139 conn("palm_tensor", "tensor", "palm_inference", "tensor"),
140 conn("palm_inference", "tensor", "palm_decode", "tensor"),
141 conn("palm_decode", "detections", "roi", "detections"),
142 conn("palm_letterbox", "frame", "crop", "frame"),
143 conn("roi", "roi", "crop", "roi"),
144 conn("crop", "frame", "landmark_letterbox", "frame"),
145 conn("landmark_letterbox", "frame", "landmark_tensor", "frame"),
146 conn("landmark_tensor", "tensor", "landmark_inference", "tensor"),
147 conn("landmark_inference", "tensor", "landmark_decode", "tensor"),
148 conn("landmark_decode", "landmarks", "smooth", "landmarks"),
149 ],
150 graph_dependencies: Vec::new(),
151 external_connections: Vec::new(),
152 provided_interfaces: HashMap::new(),
153 required_interfaces: HashMap::new(),
154 }
155}
156
157pub fn ml_template_mapping() -> Vec<(&'static str, &'static str)> {
158 vec![
159 (TPL_CV_IMAGE_TO_TENSOR, "ImageToTensorActor"),
160 (TPL_CV_RESIZE_LETTERBOX, "ResizeLetterboxActor"),
161 (TPL_CV_VIDEO_STREAM_TO_FRAMES, "VideoStreamToFramesActor"),
162 (TPL_CV_NORMALIZE_TENSOR, "NormalizeTensorActor"),
163 (TPL_CV_TENSOR_CROP_ROI, "TensorCropRoiActor"),
164 (TPL_CV_DETECTION_TO_ROI, "DetectionToRoiActor"),
165 (TPL_CV_TEMPORAL_SMOOTHER, "TemporalSmootherActor"),
166 (TPL_ML_LOAD_MODEL, "LoadModelActor"),
167 (TPL_ML_RUN_INFERENCE, "RunInferenceActor"),
168 (TPL_ML_DECODE_DETECTIONS, "DecodeDetectionsActor"),
169 (TPL_ML_DECODE_LANDMARKS, "DecodeLandmarksActor"),
170 (TPL_ML_PACKET_PROBE, "PacketProbeActor"),
171 ]
172}
173
174fn node(id: &str, component: &str, metadata: Value) -> GraphNode {
175 GraphNode {
176 id: id.to_string(),
177 component: component.to_string(),
178 metadata: Some(json_to_hash(metadata)),
179 }
180}
181
182fn conn(from_node: &str, from_port: &str, to_node: &str, to_port: &str) -> GraphConnection {
183 GraphConnection {
184 from: edge(from_node, from_port),
185 to: edge(to_node, to_port),
186 metadata: None,
187 data: None,
188 }
189}
190
191fn edge(node_id: &str, port: &str) -> GraphEdge {
192 GraphEdge {
193 port_name: port.to_string(),
194 port_id: port.to_string(),
195 node_id: node_id.to_string(),
196 index: None,
197 expose: false,
198 data: None,
199 metadata: None,
200 port_type: PortType::Any,
201 }
202}
203
204fn json_to_hash(value: Value) -> HashMap<String, Value> {
205 value
206 .as_object()
207 .map(|object| {
208 object
209 .iter()
210 .map(|(key, value)| (key.clone(), value.clone()))
211 .collect()
212 })
213 .unwrap_or_default()
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use reflow_actor::{Actor, ActorConfig};
220 use reflow_cv_ops::{
221 DetectionToRoiActor, ImageToTensorActor, NormalizeTensorActor, ResizeLetterboxActor,
222 TemporalSmootherActor, TensorCropRoiActor,
223 };
224 use reflow_media_codec::{frame_to_message, value_from_message_or_packet};
225 use reflow_media_types::{ImageFormat, LandmarkSet, PacketMetadata, Timestamp, VideoFrame};
226 use reflow_ml_ops::{
227 DecodeDetectionsActor, DecodeLandmarksActor, LoadModelActor, PacketProbeActor,
228 RunInferenceActor,
229 };
230 use reflow_network::subgraph::SubgraphActor;
231 use std::{sync::Arc, time::Duration};
232
233 #[test]
234 fn hand_landmark_graph_has_subgraph_boundary() {
235 let graph = hand_landmark_graph();
236
237 assert!(graph.inports.contains_key("frame"));
238 assert!(graph.outports.contains_key("landmarks"));
239 assert_eq!(
240 graph.processes["palm_inference"].component,
241 TPL_ML_RUN_INFERENCE
242 );
243 }
244
245 #[test]
246 fn hand_landmark_graph_can_build_subgraph_actor() {
247 let graph = hand_landmark_graph();
248 let subgraph =
249 SubgraphActor::from_graph_export(&graph, taskpack_actor_templates()).unwrap();
250
251 assert!(subgraph.inport_map().contains_key("frame"));
252 assert!(subgraph.outport_map().contains_key("landmarks"));
253 }
254
255 #[tokio::test]
256 async fn hand_landmark_subgraph_processes_frame() {
257 let graph = hand_landmark_graph();
258 let subgraph =
259 SubgraphActor::from_graph_export(&graph, taskpack_actor_templates()).unwrap();
260 let inport_sender = subgraph.get_inports().0;
261 let outport_receiver = subgraph.get_outports().1;
262
263 let handle = tokio::spawn(subgraph.create_process(ActorConfig::default(), None));
264 let frame = sample_frame();
265 inport_sender
266 .send_async(HashMap::from([(
267 "frame".to_string(),
268 frame_to_message(&frame).unwrap(),
269 )]))
270 .await
271 .unwrap();
272
273 let output = tokio::time::timeout(Duration::from_secs(3), async {
274 loop {
275 let packet = outport_receiver
276 .recv_async()
277 .await
278 .expect("subgraph outport closed before producing landmarks");
279 if packet.contains_key("landmarks") {
280 break packet;
281 }
282 }
283 })
284 .await
285 .expect("hand landmark taskpack did not produce landmarks");
286
287 let landmarks: LandmarkSet =
288 value_from_message_or_packet(output.get("landmarks").unwrap()).unwrap();
289 assert_eq!(landmarks.landmarks.len(), 21);
290 assert_eq!(
291 landmarks.metadata.timestamp,
292 Some(Timestamp::from_millis(42))
293 );
294
295 subgraph.shutdown();
296 handle.abort();
297 }
298
299 fn taskpack_actor_templates() -> HashMap<String, Arc<dyn Actor>> {
300 HashMap::from([
301 (
302 TPL_CV_IMAGE_TO_TENSOR.to_string(),
303 Arc::new(ImageToTensorActor::new()) as Arc<dyn Actor>,
304 ),
305 (
306 TPL_CV_RESIZE_LETTERBOX.to_string(),
307 Arc::new(ResizeLetterboxActor::new()) as Arc<dyn Actor>,
308 ),
309 (
310 TPL_CV_NORMALIZE_TENSOR.to_string(),
311 Arc::new(NormalizeTensorActor::new()) as Arc<dyn Actor>,
312 ),
313 (
314 TPL_CV_TENSOR_CROP_ROI.to_string(),
315 Arc::new(TensorCropRoiActor::new()) as Arc<dyn Actor>,
316 ),
317 (
318 TPL_CV_DETECTION_TO_ROI.to_string(),
319 Arc::new(DetectionToRoiActor::new()) as Arc<dyn Actor>,
320 ),
321 (
322 TPL_CV_TEMPORAL_SMOOTHER.to_string(),
323 Arc::new(TemporalSmootherActor::new()) as Arc<dyn Actor>,
324 ),
325 (
326 TPL_ML_LOAD_MODEL.to_string(),
327 Arc::new(LoadModelActor::new()) as Arc<dyn Actor>,
328 ),
329 (
330 TPL_ML_RUN_INFERENCE.to_string(),
331 Arc::new(RunInferenceActor::new()) as Arc<dyn Actor>,
332 ),
333 (
334 TPL_ML_DECODE_DETECTIONS.to_string(),
335 Arc::new(DecodeDetectionsActor::new()) as Arc<dyn Actor>,
336 ),
337 (
338 TPL_ML_DECODE_LANDMARKS.to_string(),
339 Arc::new(DecodeLandmarksActor::new()) as Arc<dyn Actor>,
340 ),
341 (
342 TPL_ML_PACKET_PROBE.to_string(),
343 Arc::new(PacketProbeActor::new()) as Arc<dyn Actor>,
344 ),
345 ])
346 }
347
348 fn sample_frame() -> VideoFrame {
349 let width = 32;
350 let height = 24;
351 let mut data = Vec::with_capacity(width * height * 4);
352 for y in 0..height {
353 for x in 0..width {
354 data.extend_from_slice(&[
355 (x * 255 / width) as u8,
356 (y * 255 / height) as u8,
357 160,
358 255,
359 ]);
360 }
361 }
362
363 let mut metadata = PacketMetadata::with_timestamp(Timestamp::from_millis(42));
364 metadata.sequence = Some(7);
365 let mut frame = VideoFrame::new(width as u32, height as u32, ImageFormat::Rgba8, data);
366 frame.metadata = metadata;
367 frame
368 }
369}