Skip to main content

moodbar_decode/
lib.rs

1use std::fs::File;
2use std::io::Cursor;
3use std::path::Path;
4
5use moodbar_analysis::{
6    analysis_to_raw_rgb_bytes, analyze_pcm_mono, GenerateOptions, MoodbarAnalysis,
7};
8use symphonia::core::audio::SampleBuffer;
9use symphonia::core::codecs::DecoderOptions;
10use symphonia::core::errors::Error as SymphoniaError;
11use symphonia::core::formats::FormatOptions;
12use symphonia::core::io::MediaSourceStream;
13use symphonia::core::meta::MetadataOptions;
14use symphonia::core::probe::Hint;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18pub enum MoodbarDecodeError {
19    #[error("no playable audio track found")]
20    NoAudioTrack,
21    #[error("decoded stream has no samples")]
22    EmptyAudio,
23    #[error("I/O error: {0}")]
24    Io(#[from] std::io::Error),
25    #[error("decode error: {0}")]
26    Decode(#[from] SymphoniaError),
27    #[error("invalid options: {0}")]
28    InvalidOptions(String),
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct DecodeDiagnostics {
33    pub decode_errors: usize,
34    pub zero_channel_packets: usize,
35    pub truncated_frames: usize,
36}
37
38pub fn analyze_path(
39    path: &Path,
40    options: &GenerateOptions,
41) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
42    validate_options(options)?;
43
44    let file = File::open(path)?;
45    let mss = MediaSourceStream::new(Box::new(file), Default::default());
46
47    let mut hint = Hint::new();
48    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
49        hint.with_extension(ext);
50    }
51
52    analyze_media_source(mss, hint, options)
53}
54
55pub fn analyze_bytes(
56    bytes: &[u8],
57    extension: Option<&str>,
58    options: &GenerateOptions,
59) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
60    validate_options(options)?;
61
62    let cursor = Cursor::new(bytes.to_vec());
63    let mss = MediaSourceStream::new(Box::new(cursor), Default::default());
64    let mut hint = Hint::new();
65    if let Some(ext) = extension {
66        if !ext.is_empty() {
67            hint.with_extension(ext);
68        }
69    }
70
71    analyze_media_source(mss, hint, options)
72}
73
74pub fn generate_moodbar_from_path(
75    path: &Path,
76    options: &GenerateOptions,
77) -> Result<Vec<u8>, MoodbarDecodeError> {
78    let analysis = analyze_path(path, options)?;
79    Ok(analysis_to_raw_rgb_bytes(&analysis))
80}
81
82pub fn generate_moodbar_from_bytes(
83    bytes: &[u8],
84    extension: Option<&str>,
85    options: &GenerateOptions,
86) -> Result<Vec<u8>, MoodbarDecodeError> {
87    let analysis = analyze_bytes(bytes, extension, options)?;
88    Ok(analysis_to_raw_rgb_bytes(&analysis))
89}
90
91fn analyze_media_source(
92    mss: MediaSourceStream,
93    hint: Hint,
94    options: &GenerateOptions,
95) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
96    let probed = symphonia::default::get_probe().format(
97        &hint,
98        mss,
99        &FormatOptions::default(),
100        &MetadataOptions::default(),
101    )?;
102
103    let mut format = probed.format;
104    let track = format
105        .tracks()
106        .iter()
107        .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
108        .ok_or(MoodbarDecodeError::NoAudioTrack)?;
109
110    let sample_rate = track
111        .codec_params
112        .sample_rate
113        .ok_or(MoodbarDecodeError::NoAudioTrack)?;
114    let track_id = track.id;
115
116    let mut decoder =
117        symphonia::default::get_codecs().make(&track.codec_params, &DecoderOptions::default())?;
118    let mut samples = Vec::<f32>::new();
119    let mut saw_samples = false;
120    let mut diagnostics = DecodeDiagnostics::default();
121
122    loop {
123        let packet = match format.next_packet() {
124            Ok(packet) => packet,
125            Err(SymphoniaError::IoError(err))
126                if err.kind() == std::io::ErrorKind::UnexpectedEof =>
127            {
128                break;
129            }
130            Err(err) => return Err(err.into()),
131        };
132
133        if packet.track_id() != track_id {
134            continue;
135        }
136
137        match decoder.decode(&packet) {
138            Ok(decoded) => {
139                let spec = *decoded.spec();
140                let channels = spec.channels.count();
141                let mut sample_buf = SampleBuffer::<f32>::new(decoded.capacity() as u64, spec);
142                sample_buf.copy_interleaved_ref(decoded);
143                let interleaved = sample_buf.samples();
144
145                if channels == 0 {
146                    diagnostics.zero_channel_packets += 1;
147                    continue;
148                }
149
150                for frame in interleaved.chunks(channels) {
151                    if frame.len() != channels {
152                        diagnostics.truncated_frames += 1;
153                        continue;
154                    }
155                    let sum = frame.iter().copied().sum::<f32>();
156                    samples.push(sum / channels as f32);
157                    saw_samples = true;
158                }
159            }
160            Err(SymphoniaError::DecodeError(_)) => {
161                diagnostics.decode_errors += 1;
162                continue;
163            }
164            Err(SymphoniaError::IoError(err))
165                if err.kind() == std::io::ErrorKind::UnexpectedEof =>
166            {
167                break;
168            }
169            Err(err) => return Err(err.into()),
170        }
171    }
172
173    if !saw_samples {
174        return Err(MoodbarDecodeError::EmptyAudio);
175    }
176
177    let mut analysis = analyze_pcm_mono(sample_rate, &samples, options);
178    analysis.diagnostics.decode_errors = diagnostics.decode_errors;
179    analysis.diagnostics.zero_channel_packets = diagnostics.zero_channel_packets;
180    analysis.diagnostics.truncated_frames = diagnostics.truncated_frames;
181    Ok(analysis)
182}
183
184fn validate_options(options: &GenerateOptions) -> Result<(), MoodbarDecodeError> {
185    if !options.fft_size.is_power_of_two() || options.fft_size < 64 {
186        return Err(MoodbarDecodeError::InvalidOptions(
187            "fft_size must be a power of two and >= 64".to_string(),
188        ));
189    }
190    if !(options.deterministic_floor.is_finite() && options.deterministic_floor > 0.0) {
191        return Err(MoodbarDecodeError::InvalidOptions(
192            "deterministic_floor must be finite and > 0".to_string(),
193        ));
194    }
195    if options.frames_per_color == 0 {
196        return Err(MoodbarDecodeError::InvalidOptions(
197            "frames_per_color must be >= 1".to_string(),
198        ));
199    }
200
201    let edges = if options.band_edges_hz.is_empty() {
202        vec![options.low_cut_hz, options.mid_cut_hz]
203    } else {
204        options.band_edges_hz.clone()
205    };
206
207    if edges.is_empty() {
208        return Err(MoodbarDecodeError::InvalidOptions(
209            "at least one band edge is required".to_string(),
210        ));
211    }
212    for pair in edges.windows(2) {
213        if pair[0] >= pair[1] {
214            return Err(MoodbarDecodeError::InvalidOptions(
215                "band edges must be strictly increasing".to_string(),
216            ));
217        }
218    }
219    Ok(())
220}