Skip to main content

yscv_cli/
pipeline.rs

1use std::fs;
2use std::time::{Duration, Instant};
3
4use serde_json::{Value, json};
5use yscv_detect::{
6    BoundingBox, Detection, FrameFaceDetectScratch, FramePeopleDetectScratch,
7    Rgb8FaceDetectScratch, detect_faces_from_frame_with_scratch,
8    detect_faces_from_rgb8_with_scratch, detect_people_from_frame_with_scratch,
9    detect_people_from_rgb8_with_scratch,
10};
11use yscv_eval::{
12    PipelineDurations, parse_pipeline_benchmark_thresholds, summarize_pipeline_durations,
13    validate_pipeline_benchmark_thresholds,
14};
15use yscv_recognize::Recognizer;
16use yscv_track::{TrackedDetection, Tracker, TrackerConfig};
17use yscv_video::FrameStream;
18
19use crate::benchmark::{BenchmarkCollector, format_benchmark_report, format_benchmark_violations};
20use crate::config::{CliConfig, CliError, DetectTarget, resolve_detect_config};
21use crate::error::AppError;
22use crate::event_log::JsonlEventWriter;
23use crate::source::{build_source, open_camera_source};
24use crate::util::{duration_to_ms, ensure_parent_dir};
25
26#[derive(Debug, Clone, Copy)]
27struct FrameMeta {
28    index: u64,
29    timestamp_us: u64,
30    width: usize,
31    height: usize,
32}
33
34struct FrameProcessContext<'a> {
35    cli: &'a CliConfig,
36    mode: &'a str,
37    tracker: &'a mut Tracker,
38    recognizer: &'a mut Recognizer,
39    event_log: &'a mut Option<JsonlEventWriter>,
40    benchmark: &'a mut Option<BenchmarkCollector>,
41    tracked_detections: Vec<TrackedDetection>,
42    tracked_event_records: Vec<Value>,
43}
44
45struct FrameProcessInput<'a> {
46    frame: FrameMeta,
47    detections: &'a [Detection],
48    detect_duration: Duration,
49    frame_start: Instant,
50}
51
52pub fn run_pipeline(cli: &CliConfig) -> Result<(), AppError> {
53    let mut recognizer = if let Some(path) = cli.identities_path.as_deref() {
54        Recognizer::load_json_file(path)?
55    } else {
56        Recognizer::new(cli.recognition_threshold)?
57    };
58    recognizer.set_threshold(cli.recognition_threshold)?;
59
60    let mut event_log = if let Some(path) = cli.event_log_path.as_deref() {
61        Some(JsonlEventWriter::create(path)?)
62    } else {
63        None
64    };
65    let mut tracker = Tracker::new(TrackerConfig {
66        match_iou_threshold: cli.track_iou_threshold,
67        max_missed_frames: cli.track_max_missed_frames,
68        max_tracks: cli.track_max_tracks,
69    })?;
70    let mut benchmark = cli.benchmark.then(BenchmarkCollector::default);
71    if cli.camera {
72        println!(
73            "yscv-cli demo: starting camera frame stream (raw rgb8 {} path)",
74            cli.detect_target.as_str()
75        );
76        run_camera_rgb8_pipeline(
77            cli,
78            &mut recognizer,
79            &mut tracker,
80            &mut event_log,
81            &mut benchmark,
82        )?;
83    } else {
84        let mode = "deterministic";
85        println!("yscv-cli demo: starting {mode} frame stream");
86        run_standard_pipeline(
87            cli,
88            mode,
89            &mut recognizer,
90            &mut tracker,
91            &mut event_log,
92            &mut benchmark,
93        )?;
94    }
95
96    finalize_benchmark(cli, benchmark)?;
97    flush_event_log(event_log)?;
98    println!("yscv-cli demo: stream completed");
99    Ok(())
100}
101
102fn run_standard_pipeline(
103    cli: &CliConfig,
104    mode: &str,
105    recognizer: &mut Recognizer,
106    tracker: &mut Tracker,
107    event_log: &mut Option<JsonlEventWriter>,
108    benchmark: &mut Option<BenchmarkCollector>,
109) -> Result<(), AppError> {
110    let source = build_source(cli, recognizer)?;
111    let mut stream = FrameStream::new(source);
112    if let Some(max_frames) = cli.max_frames {
113        stream = stream.with_max_frames(max_frames);
114    }
115    let mut process_context = FrameProcessContext {
116        cli,
117        mode,
118        tracker,
119        recognizer,
120        event_log,
121        benchmark,
122        tracked_detections: Vec::new(),
123        tracked_event_records: Vec::new(),
124    };
125    let mut people_scratch = FramePeopleDetectScratch::default();
126    let mut face_scratch = FrameFaceDetectScratch::default();
127
128    while let Some(frame) = stream.try_next()? {
129        let frame_start = Instant::now();
130        let frame_height = frame.image().shape()[0];
131        let frame_width = frame.image().shape()[1];
132        let detect_config = resolve_detect_config(cli, frame_width, frame_height);
133        let detect_start = Instant::now();
134        let detections = match cli.detect_target {
135            DetectTarget::People => detect_people_from_frame_with_scratch(
136                &frame,
137                detect_config.score_threshold,
138                detect_config.min_area,
139                detect_config.iou_threshold,
140                detect_config.max_detections,
141                &mut people_scratch,
142            )?,
143            DetectTarget::Faces => detect_faces_from_frame_with_scratch(
144                &frame,
145                detect_config.score_threshold,
146                detect_config.min_area,
147                detect_config.iou_threshold,
148                detect_config.max_detections,
149                &mut face_scratch,
150            )?,
151        };
152        let detect_duration = detect_start.elapsed();
153
154        process_frame(
155            &mut process_context,
156            FrameProcessInput {
157                frame: FrameMeta {
158                    index: frame.index(),
159                    timestamp_us: frame.timestamp_us(),
160                    width: frame_width,
161                    height: frame_height,
162                },
163                detections: &detections,
164                detect_duration,
165                frame_start,
166            },
167        )?;
168    }
169
170    Ok(())
171}
172
173fn run_camera_rgb8_pipeline(
174    cli: &CliConfig,
175    recognizer: &mut Recognizer,
176    tracker: &mut Tracker,
177    event_log: &mut Option<JsonlEventWriter>,
178    benchmark: &mut Option<BenchmarkCollector>,
179) -> Result<(), AppError> {
180    let mut source = open_camera_source(cli)?;
181    let max_frames = cli.max_frames.unwrap_or(usize::MAX);
182    let mut frames_seen = 0usize;
183    let mut people_scratch = yscv_detect::Rgb8PeopleDetectScratch::default();
184    let mut face_scratch = Rgb8FaceDetectScratch::default();
185    let mut process_context = FrameProcessContext {
186        cli,
187        mode: "camera",
188        tracker,
189        recognizer,
190        event_log,
191        benchmark,
192        tracked_detections: Vec::new(),
193        tracked_event_records: Vec::new(),
194    };
195
196    while frames_seen < max_frames {
197        let Some(frame) = source.next_rgb8_frame()? else {
198            break;
199        };
200        frames_seen += 1;
201
202        let frame_start = Instant::now();
203        let frame_width = frame.width();
204        let frame_height = frame.height();
205        let detect_config = resolve_detect_config(cli, frame_width, frame_height);
206        let detect_start = Instant::now();
207        let detections = match cli.detect_target {
208            DetectTarget::People => detect_people_from_rgb8_with_scratch(
209                (frame_width, frame_height),
210                frame.data(),
211                detect_config.score_threshold,
212                detect_config.min_area,
213                detect_config.iou_threshold,
214                detect_config.max_detections,
215                &mut people_scratch,
216            )?,
217            DetectTarget::Faces => detect_faces_from_rgb8_with_scratch(
218                (frame_width, frame_height),
219                frame.data(),
220                detect_config.score_threshold,
221                detect_config.min_area,
222                detect_config.iou_threshold,
223                detect_config.max_detections,
224                &mut face_scratch,
225            )?,
226        };
227        let detect_duration = detect_start.elapsed();
228
229        process_frame(
230            &mut process_context,
231            FrameProcessInput {
232                frame: FrameMeta {
233                    index: frame.index(),
234                    timestamp_us: frame.timestamp_us(),
235                    width: frame_width,
236                    height: frame_height,
237                },
238                detections: &detections,
239                detect_duration,
240                frame_start,
241            },
242        )?;
243    }
244
245    Ok(())
246}
247
248fn process_frame(
249    context: &mut FrameProcessContext<'_>,
250    input: FrameProcessInput<'_>,
251) -> Result<(), AppError> {
252    let FrameProcessInput {
253        frame,
254        detections,
255        detect_duration,
256        frame_start,
257    } = input;
258    let track_start = Instant::now();
259    context
260        .tracker
261        .update_into(detections, &mut context.tracked_detections);
262    let track_duration = track_start.elapsed();
263    let tracked = context.tracked_detections.as_slice();
264
265    let tracked_targets = context
266        .tracker
267        .count_by_class(context.cli.detect_target.class_id());
268    println!(
269        "frame={} ts_us={} detections={} tracked_{}={}",
270        frame.index,
271        frame.timestamp_us,
272        detections.len(),
273        context.cli.detect_target.count_label(),
274        tracked_targets,
275    );
276
277    let mut recognize_duration = Duration::ZERO;
278    let collect_event_records = context.event_log.is_some();
279    if collect_event_records {
280        context.tracked_event_records.clear();
281        if context.tracked_event_records.capacity() < tracked.len() {
282            context
283                .tracked_event_records
284                .reserve(tracked.len() - context.tracked_event_records.capacity());
285        }
286    }
287
288    for (idx, item) in tracked.iter().enumerate() {
289        let recognize_start = Instant::now();
290        let embedding =
291            bbox_embedding_components(item.detection.bbox, frame.width as f32, frame.height as f32);
292        let recognition = context.recognizer.recognize_slice(&embedding)?;
293        recognize_duration += recognize_start.elapsed();
294        let identity_label = recognition.identity.as_deref().unwrap_or("unknown");
295        println!(
296            "  det#{idx} track_id={} score={:.3} identity={} sim={:.3} bbox=({:.1},{:.1},{:.1},{:.1})",
297            item.track_id,
298            item.detection.score,
299            identity_label,
300            recognition.score,
301            item.detection.bbox.x1,
302            item.detection.bbox.y1,
303            item.detection.bbox.x2,
304            item.detection.bbox.y2,
305        );
306        if collect_event_records {
307            context.tracked_event_records.push(json!({
308                "det_index": idx,
309                "track_id": item.track_id,
310                "class_id": item.detection.class_id,
311                "score": item.detection.score,
312                "identity": recognition.identity.clone(),
313                "similarity": recognition.score,
314                "bbox": {
315                    "x1": item.detection.bbox.x1,
316                    "y1": item.detection.bbox.y1,
317                    "x2": item.detection.bbox.x2,
318                    "y2": item.detection.bbox.y2,
319                },
320            }));
321        }
322    }
323
324    let end_to_end_duration = frame_start.elapsed();
325    if let Some(writer) = context.event_log.as_mut() {
326        writer.write_record(&json!({
327            "frame_index": frame.index,
328            "timestamp_us": frame.timestamp_us,
329            "mode": context.mode,
330            "detect_target": context.cli.detect_target.as_str(),
331            "detection_count": detections.len(),
332            "tracked_target_count": tracked_targets,
333            "timings_ms": {
334                "detect": duration_to_ms(detect_duration),
335                "track": duration_to_ms(track_duration),
336                "recognize": duration_to_ms(recognize_duration),
337                "end_to_end": duration_to_ms(end_to_end_duration),
338            },
339            "tracked": &context.tracked_event_records,
340        }))?;
341    }
342    if let Some(collector) = context.benchmark.as_mut() {
343        collector.detect.push(detect_duration);
344        collector.track.push(track_duration);
345        collector.recognize.push(recognize_duration);
346        collector.end_to_end.push(end_to_end_duration);
347    }
348    Ok(())
349}
350
351fn bbox_embedding_components(bbox: BoundingBox, frame_width: f32, frame_height: f32) -> [f32; 3] {
352    let cx = ((bbox.x1 + bbox.x2) * 0.5) / frame_width;
353    let cy = ((bbox.y1 + bbox.y2) * 0.5) / frame_height;
354    let area = bbox.area() / (frame_width * frame_height);
355    [cx, cy, area]
356}
357
358fn finalize_benchmark(
359    cli: &CliConfig,
360    benchmark: Option<BenchmarkCollector>,
361) -> Result<(), AppError> {
362    if let Some(collector) = benchmark {
363        let report = summarize_pipeline_durations(PipelineDurations {
364            detect: &collector.detect,
365            track: &collector.track,
366            recognize: &collector.recognize,
367            end_to_end: &collector.end_to_end,
368        })?;
369        let text_report = format_benchmark_report(&report);
370        println!("{text_report}");
371
372        if let Some(path) = cli.benchmark_baseline_path.as_deref() {
373            let baseline_text = fs::read_to_string(path)?;
374            let thresholds = parse_pipeline_benchmark_thresholds(&baseline_text)?;
375            let violations = validate_pipeline_benchmark_thresholds(&report, &thresholds);
376            if violations.is_empty() {
377                println!("benchmark baseline check passed ({})", path.display());
378            } else {
379                println!(
380                    "benchmark baseline check failed ({}):\n{}",
381                    path.display(),
382                    format_benchmark_violations(&violations)
383                );
384                return Err(CliError::Message("benchmark regression detected".to_string()).into());
385            }
386        }
387        if let Some(path) = cli.benchmark_report_path.as_deref() {
388            ensure_parent_dir(path)?;
389            fs::write(path, &text_report)?;
390            println!("benchmark report saved to {}", path.display());
391        }
392    }
393    Ok(())
394}
395
396fn flush_event_log(event_log: Option<JsonlEventWriter>) -> Result<(), AppError> {
397    if let Some(mut writer) = event_log {
398        writer.flush()?;
399        println!("event log saved to {}", writer.path().display());
400    }
401    Ok(())
402}