use crate::clip::Clip;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const DEFAULT_MIN_CONFIDENCE: f32 = 0.5;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct KeywordSuggestion {
pub keyword: String,
pub confidence: f32,
pub reason: String,
}
impl KeywordSuggestion {
#[must_use]
pub fn new(keyword: impl Into<String>, confidence: f32, reason: impl Into<String>) -> Self {
Self {
keyword: keyword.into(),
confidence: confidence.clamp(0.0, 1.0),
reason: reason.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiTagResult {
pub clip_id: crate::clip::ClipId,
pub suggestions: Vec<KeywordSuggestion>,
}
impl AiTagResult {
#[must_use]
pub fn filtered(&self, min_confidence: f32) -> Vec<&KeywordSuggestion> {
self.suggestions
.iter()
.filter(|s| s.confidence >= min_confidence)
.collect()
}
#[must_use]
pub fn top_n(&self, n: usize) -> Vec<&KeywordSuggestion> {
let mut sorted: Vec<&KeywordSuggestion> = self.suggestions.iter().collect();
sorted.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted.into_iter().take(n).collect()
}
}
#[derive(Debug, Clone)]
pub struct AiTaggerConfig {
pub min_confidence: f32,
pub max_suggestions: usize,
pub keyword_weights: HashMap<String, f32>,
}
impl Default for AiTaggerConfig {
fn default() -> Self {
Self {
min_confidence: DEFAULT_MIN_CONFIDENCE,
max_suggestions: 10,
keyword_weights: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct ClipAiTagger {
config: AiTaggerConfig,
}
impl Default for ClipAiTagger {
fn default() -> Self {
Self::new(AiTaggerConfig::default())
}
}
impl ClipAiTagger {
#[must_use]
pub fn new(config: AiTaggerConfig) -> Self {
Self { config }
}
#[must_use]
pub fn tag_clip(&self, clip: &Clip) -> AiTagResult {
let mut raw: Vec<KeywordSuggestion> = Vec::new();
self.analyse_file_name(clip, &mut raw);
self.analyse_duration(clip, &mut raw);
self.analyse_camera_metadata(clip, &mut raw);
self.analyse_existing_keywords(clip, &mut raw);
for s in &mut raw {
if let Some(&w) = self.config.keyword_weights.get(&s.keyword) {
s.confidence = (s.confidence * w).clamp(0.0, 1.0);
}
}
let mut seen: HashMap<String, f32> = HashMap::new();
for s in raw {
let entry = seen.entry(s.keyword.clone()).or_insert(0.0_f32);
if s.confidence > *entry {
*entry = s.confidence;
}
}
let mut suggestions: Vec<KeywordSuggestion> = seen
.into_iter()
.filter(|(_, conf)| *conf >= self.config.min_confidence)
.map(|(kw, conf)| KeywordSuggestion {
keyword: kw.clone(),
confidence: conf,
reason: "combined-signal".to_string(),
})
.collect();
suggestions.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
suggestions.truncate(self.config.max_suggestions);
AiTagResult {
clip_id: clip.id,
suggestions,
}
}
#[must_use]
pub fn tag_clips(&self, clips: &[Clip]) -> Vec<AiTagResult> {
clips.iter().map(|c| self.tag_clip(c)).collect()
}
fn analyse_file_name(&self, clip: &Clip, out: &mut Vec<KeywordSuggestion>) {
let stem = clip
.file_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("")
.to_lowercase();
let token_map: &[(&str, &str, f32)] = &[
("int", "interior", 0.75),
("interior", "interior", 0.85),
("ext", "exterior", 0.75),
("exterior", "exterior", 0.85),
("interview", "interview", 0.90),
("broll", "b-roll", 0.85),
("b_roll", "b-roll", 0.85),
("vox", "vox-pop", 0.80),
("aerial", "aerial", 0.85),
("drone", "aerial", 0.85),
("timelapse", "time-lapse", 0.90),
("slowmo", "slow-motion", 0.85),
("slo_mo", "slow-motion", 0.85),
("night", "night", 0.80),
("day", "day", 0.60),
("wide", "wide-shot", 0.75),
("closeup", "close-up", 0.80),
("close_up", "close-up", 0.80),
("cutaway", "cutaway", 0.85),
];
for (token, keyword, conf) in token_map {
if stem.contains(token) {
out.push(KeywordSuggestion::new(
*keyword,
*conf,
format!("file-name token '{token}'"),
));
}
}
}
fn analyse_duration(&self, clip: &Clip, out: &mut Vec<KeywordSuggestion>) {
if let Some(dur) = clip.effective_duration() {
let fps = clip.frame_rate.map_or(24.0, |fr| fr.to_f64());
let seconds = dur as f64 / fps;
if seconds < 5.0 {
out.push(KeywordSuggestion::new("clip-short", 0.70, "duration < 5 s"));
} else if seconds > 120.0 {
out.push(KeywordSuggestion::new(
"clip-long",
0.65,
"duration > 2 min",
));
}
}
}
fn analyse_camera_metadata(&self, clip: &Clip, out: &mut Vec<KeywordSuggestion>) {
if let Some(cam) = &clip.camera {
if let Some(iso) = cam.iso {
if iso >= 3200 {
out.push(KeywordSuggestion::new(
"low-light",
0.80,
format!("ISO {iso}"),
));
}
}
}
if let Some(fr) = clip.frame_rate {
let fps = fr.to_f64();
if fps > 60.0 {
out.push(KeywordSuggestion::new(
"slow-motion",
0.85,
format!("{fps:.0} fps"),
));
}
}
}
fn analyse_existing_keywords(&self, clip: &Clip, out: &mut Vec<KeywordSuggestion>) {
let co_occur: &[(&str, &str, f32)] = &[
("interview", "talking-head", 0.70),
("interview", "dialogue", 0.65),
("b-roll", "cutaway", 0.65),
("aerial", "establishing-shot", 0.70),
("slow-motion", "action", 0.60),
("exterior", "establishing-shot", 0.55),
("low-light", "cinematic", 0.60),
];
for (trigger, suggestion, conf) in co_occur {
if clip.keywords.iter().any(|k| k == trigger) {
out.push(KeywordSuggestion::new(
*suggestion,
*conf,
format!("co-occurrence with '{trigger}'"),
));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::camera_metadata::CameraMetadata;
use std::path::PathBuf;
fn make_clip(name: &str) -> Clip {
Clip::new(PathBuf::from(format!("/media/{name}.mov")))
}
#[test]
fn test_tag_clip_empty_no_suggestions_below_threshold() {
let tagger = ClipAiTagger::default();
let clip = make_clip("generic_clip");
let result = tagger.tag_clip(&clip);
assert!(result.clip_id == clip.id);
}
#[test]
fn test_tag_clip_file_name_interview() {
let tagger = ClipAiTagger::default();
let clip = make_clip("interview_001");
let result = tagger.tag_clip(&clip);
let has_interview = result.suggestions.iter().any(|s| s.keyword == "interview");
assert!(has_interview, "Expected 'interview' suggestion");
}
#[test]
fn test_tag_clip_file_name_broll() {
let tagger = ClipAiTagger::default();
let clip = make_clip("broll_city");
let result = tagger.tag_clip(&clip);
let has_broll = result.suggestions.iter().any(|s| s.keyword == "b-roll");
assert!(has_broll, "Expected 'b-roll' suggestion");
}
#[test]
fn test_tag_clip_duration_short() {
let tagger = ClipAiTagger::default();
let mut clip = make_clip("stinger");
clip.set_duration(60); let result = tagger.tag_clip(&clip);
let has_short = result.suggestions.iter().any(|s| s.keyword == "clip-short");
assert!(has_short, "Expected 'clip-short' suggestion");
}
#[test]
fn test_tag_clip_camera_low_light() {
let tagger = ClipAiTagger::default();
let mut clip = make_clip("night_scene");
let mut cam = CameraMetadata::default();
cam.iso = Some(6400);
clip.set_camera_metadata(cam);
let result = tagger.tag_clip(&clip);
let has_low_light = result.suggestions.iter().any(|s| s.keyword == "low-light");
assert!(has_low_light, "Expected 'low-light' suggestion from ISO");
}
#[test]
fn test_tag_clip_camera_slow_motion() {
use oximedia_core::types::Rational;
let tagger = ClipAiTagger::default();
let mut clip = make_clip("sports");
clip.set_frame_rate(Rational::new(120, 1));
let result = tagger.tag_clip(&clip);
let has_slo = result
.suggestions
.iter()
.any(|s| s.keyword == "slow-motion");
assert!(has_slo, "Expected 'slow-motion' from high fps");
}
#[test]
fn test_tag_clip_co_occurrence() {
let tagger = ClipAiTagger::default();
let mut clip = make_clip("clip");
clip.add_keyword("interview");
let result = tagger.tag_clip(&clip);
let has_th = result
.suggestions
.iter()
.any(|s| s.keyword == "talking-head");
assert!(has_th, "Expected 'talking-head' co-occurrence suggestion");
}
#[test]
fn test_top_n_returns_sorted_descending() {
let tagger = ClipAiTagger::default();
let clip = make_clip("broll_interview_aerial");
let result = tagger.tag_clip(&clip);
let top = result.top_n(3);
if top.len() >= 2 {
assert!(top[0].confidence >= top[1].confidence);
}
}
#[test]
fn test_filtered_min_confidence() {
let tagger = ClipAiTagger::default();
let clip = make_clip("interview_clip");
let result = tagger.tag_clip(&clip);
let high = result.filtered(0.9);
for s in high {
assert!(s.confidence >= 0.9);
}
}
#[test]
fn test_tag_clips_batch() {
let tagger = ClipAiTagger::default();
let clips = vec![make_clip("interview_01"), make_clip("broll_outdoor")];
let results = tagger.tag_clips(&clips);
assert_eq!(results.len(), 2);
}
#[test]
fn test_max_suggestions_respected() {
let config = AiTaggerConfig {
max_suggestions: 2,
min_confidence: 0.0,
..AiTaggerConfig::default()
};
let tagger = ClipAiTagger::new(config);
let clip = make_clip("broll_interview_aerial_night_drone");
let result = tagger.tag_clip(&clip);
assert!(result.suggestions.len() <= 2);
}
}