Skip to main content

moodbar_decode/
lib.rs

1use std::fs::File;
2use std::io::Cursor;
3use std::path::Path;
4
5use moodbar_analysis::{
6    analysis_to_raw_rgb_bytes, analyze_pcm_mono, GenerateOptions, MoodbarAnalysis,
7};
8use symphonia::core::audio::SampleBuffer;
9use symphonia::core::codecs::DecoderOptions;
10use symphonia::core::errors::Error as SymphoniaError;
11use symphonia::core::formats::FormatOptions;
12use symphonia::core::io::MediaSourceStream;
13use symphonia::core::meta::MetadataOptions;
14use symphonia::core::probe::Hint;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18pub enum MoodbarDecodeError {
19    #[error("no playable audio track found")]
20    NoAudioTrack,
21    #[error("decoded stream has no samples")]
22    EmptyAudio,
23    #[error("I/O error: {0}")]
24    Io(#[from] std::io::Error),
25    #[error("decode error: {0}")]
26    Decode(#[from] SymphoniaError),
27    #[error("invalid options: {0}")]
28    InvalidOptions(String),
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct DecodeDiagnostics {
33    pub decode_errors: usize,
34    pub zero_channel_packets: usize,
35    pub truncated_frames: usize,
36}
37
38pub fn analyze_path(
39    path: &Path,
40    options: &GenerateOptions,
41) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
42    validate_options(options)?;
43
44    let file = File::open(path)?;
45    let mss = MediaSourceStream::new(Box::new(file), Default::default());
46
47    let mut hint = Hint::new();
48    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
49        hint.with_extension(ext);
50    }
51
52    analyze_media_source(mss, hint, options)
53}
54
55pub fn analyze_bytes(
56    bytes: &[u8],
57    extension: Option<&str>,
58    options: &GenerateOptions,
59) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
60    validate_options(options)?;
61
62    let cursor = Cursor::new(bytes.to_vec());
63    let mss = MediaSourceStream::new(Box::new(cursor), Default::default());
64    let mut hint = Hint::new();
65    if let Some(ext) = extension {
66        if !ext.is_empty() {
67            hint.with_extension(ext);
68        }
69    }
70
71    analyze_media_source(mss, hint, options)
72}
73
74pub fn generate_moodbar_from_path(
75    path: &Path,
76    options: &GenerateOptions,
77) -> Result<Vec<u8>, MoodbarDecodeError> {
78    let analysis = analyze_path(path, options)?;
79    Ok(analysis_to_raw_rgb_bytes(&analysis))
80}
81
82pub fn generate_moodbar_from_bytes(
83    bytes: &[u8],
84    extension: Option<&str>,
85    options: &GenerateOptions,
86) -> Result<Vec<u8>, MoodbarDecodeError> {
87    let analysis = analyze_bytes(bytes, extension, options)?;
88    Ok(analysis_to_raw_rgb_bytes(&analysis))
89}
90
91fn analyze_media_source(
92    mss: MediaSourceStream,
93    hint: Hint,
94    options: &GenerateOptions,
95) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
96    let probed = symphonia::default::get_probe().format(
97        &hint,
98        mss,
99        &FormatOptions::default(),
100        &MetadataOptions::default(),
101    )?;
102
103    let mut format = probed.format;
104    let track = format
105        .tracks()
106        .iter()
107        .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
108        .ok_or(MoodbarDecodeError::NoAudioTrack)?;
109
110    let sample_rate = track
111        .codec_params
112        .sample_rate
113        .ok_or(MoodbarDecodeError::NoAudioTrack)?;
114    let track_id = track.id;
115
116    let mut decoder =
117        symphonia::default::get_codecs().make(&track.codec_params, &DecoderOptions::default())?;
118    let mut samples = Vec::<f32>::new();
119    let mut saw_samples = false;
120    let mut diagnostics = DecodeDiagnostics::default();
121    let mut sample_buf: Option<SampleBuffer<f32>> = None;
122
123    loop {
124        let packet = match format.next_packet() {
125            Ok(packet) => packet,
126            Err(SymphoniaError::IoError(err))
127                if err.kind() == std::io::ErrorKind::UnexpectedEof =>
128            {
129                break;
130            }
131            Err(err) => return Err(err.into()),
132        };
133
134        if packet.track_id() != track_id {
135            continue;
136        }
137
138        match decoder.decode(&packet) {
139            Ok(decoded) => {
140                let spec = *decoded.spec();
141                let channels = spec.channels.count();
142
143                if sample_buf.is_none()
144                    || sample_buf.as_ref().unwrap().capacity() < decoded.capacity()
145                {
146                    sample_buf = Some(SampleBuffer::<f32>::new(decoded.capacity() as u64, spec));
147                }
148
149                let buf = sample_buf.as_mut().unwrap();
150                buf.copy_interleaved_ref(decoded);
151                let interleaved = buf.samples();
152
153                if channels == 0 {
154                    diagnostics.zero_channel_packets += 1;
155                    continue;
156                }
157
158                let max_channels = channels.min(2);
159                for frame in interleaved.chunks(channels) {
160                    if frame.len() != channels {
161                        diagnostics.truncated_frames += 1;
162                        continue;
163                    }
164                    let sum = frame[..max_channels].iter().copied().sum::<f32>();
165                    samples.push(sum / max_channels as f32);
166                    saw_samples = true;
167                }
168            }
169            Err(SymphoniaError::DecodeError(_)) => {
170                diagnostics.decode_errors += 1;
171                continue;
172            }
173            Err(SymphoniaError::IoError(err))
174                if err.kind() == std::io::ErrorKind::UnexpectedEof =>
175            {
176                break;
177            }
178            Err(err) => return Err(err.into()),
179        }
180    }
181
182    if !saw_samples {
183        return Err(MoodbarDecodeError::EmptyAudio);
184    }
185
186    let mut analysis = analyze_pcm_mono(sample_rate, &samples, options);
187    analysis.diagnostics.decode_errors = diagnostics.decode_errors;
188    analysis.diagnostics.zero_channel_packets = diagnostics.zero_channel_packets;
189    analysis.diagnostics.truncated_frames = diagnostics.truncated_frames;
190    Ok(analysis)
191}
192
193fn validate_options(options: &GenerateOptions) -> Result<(), MoodbarDecodeError> {
194    if !options.fft_size.is_power_of_two() || options.fft_size < 64 {
195        return Err(MoodbarDecodeError::InvalidOptions(
196            "fft_size must be a power of two and >= 64".to_string(),
197        ));
198    }
199    if !(options.deterministic_floor.is_finite() && options.deterministic_floor > 0.0) {
200        return Err(MoodbarDecodeError::InvalidOptions(
201            "deterministic_floor must be finite and > 0".to_string(),
202        ));
203    }
204    if options.frames_per_color == 0 {
205        return Err(MoodbarDecodeError::InvalidOptions(
206            "frames_per_color must be >= 1".to_string(),
207        ));
208    }
209
210    let edges = if options.band_edges_hz.is_empty() {
211        vec![options.low_cut_hz, options.mid_cut_hz]
212    } else {
213        options.band_edges_hz.clone()
214    };
215
216    if edges.is_empty() {
217        return Err(MoodbarDecodeError::InvalidOptions(
218            "at least one band edge is required".to_string(),
219        ));
220    }
221    for pair in edges.windows(2) {
222        if pair[0] >= pair[1] {
223            return Err(MoodbarDecodeError::InvalidOptions(
224                "band edges must be strictly increasing".to_string(),
225            ));
226        }
227    }
228    Ok(())
229}