1use std::fs::File;
2use std::io::Cursor;
3use std::path::Path;
4
5use moodbar_analysis::{
6 analysis_to_raw_rgb_bytes, analyze_pcm_mono, GenerateOptions, MoodbarAnalysis,
7};
8use symphonia::core::audio::SampleBuffer;
9use symphonia::core::codecs::DecoderOptions;
10use symphonia::core::errors::Error as SymphoniaError;
11use symphonia::core::formats::FormatOptions;
12use symphonia::core::io::MediaSourceStream;
13use symphonia::core::meta::MetadataOptions;
14use symphonia::core::probe::Hint;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18pub enum MoodbarDecodeError {
19 #[error("no playable audio track found")]
20 NoAudioTrack,
21 #[error("decoded stream has no samples")]
22 EmptyAudio,
23 #[error("I/O error: {0}")]
24 Io(#[from] std::io::Error),
25 #[error("decode error: {0}")]
26 Decode(#[from] SymphoniaError),
27 #[error("invalid options: {0}")]
28 InvalidOptions(String),
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct DecodeDiagnostics {
33 pub decode_errors: usize,
34 pub zero_channel_packets: usize,
35 pub truncated_frames: usize,
36}
37
38pub fn analyze_path(
39 path: &Path,
40 options: &GenerateOptions,
41) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
42 validate_options(options)?;
43
44 let file = File::open(path)?;
45 let mss = MediaSourceStream::new(Box::new(file), Default::default());
46
47 let mut hint = Hint::new();
48 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
49 hint.with_extension(ext);
50 }
51
52 analyze_media_source(mss, hint, options)
53}
54
55pub fn analyze_bytes(
56 bytes: &[u8],
57 extension: Option<&str>,
58 options: &GenerateOptions,
59) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
60 validate_options(options)?;
61
62 let cursor = Cursor::new(bytes.to_vec());
63 let mss = MediaSourceStream::new(Box::new(cursor), Default::default());
64 let mut hint = Hint::new();
65 if let Some(ext) = extension {
66 if !ext.is_empty() {
67 hint.with_extension(ext);
68 }
69 }
70
71 analyze_media_source(mss, hint, options)
72}
73
74pub fn generate_moodbar_from_path(
75 path: &Path,
76 options: &GenerateOptions,
77) -> Result<Vec<u8>, MoodbarDecodeError> {
78 let analysis = analyze_path(path, options)?;
79 Ok(analysis_to_raw_rgb_bytes(&analysis))
80}
81
82pub fn generate_moodbar_from_bytes(
83 bytes: &[u8],
84 extension: Option<&str>,
85 options: &GenerateOptions,
86) -> Result<Vec<u8>, MoodbarDecodeError> {
87 let analysis = analyze_bytes(bytes, extension, options)?;
88 Ok(analysis_to_raw_rgb_bytes(&analysis))
89}
90
91fn analyze_media_source(
92 mss: MediaSourceStream,
93 hint: Hint,
94 options: &GenerateOptions,
95) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
96 let probed = symphonia::default::get_probe().format(
97 &hint,
98 mss,
99 &FormatOptions::default(),
100 &MetadataOptions::default(),
101 )?;
102
103 let mut format = probed.format;
104 let track = format
105 .tracks()
106 .iter()
107 .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
108 .ok_or(MoodbarDecodeError::NoAudioTrack)?;
109
110 let sample_rate = track
111 .codec_params
112 .sample_rate
113 .ok_or(MoodbarDecodeError::NoAudioTrack)?;
114 let track_id = track.id;
115
116 let mut decoder =
117 symphonia::default::get_codecs().make(&track.codec_params, &DecoderOptions::default())?;
118 let mut samples = Vec::<f32>::new();
119 let mut saw_samples = false;
120 let mut diagnostics = DecodeDiagnostics::default();
121 let mut sample_buf: Option<SampleBuffer<f32>> = None;
122
123 loop {
124 let packet = match format.next_packet() {
125 Ok(packet) => packet,
126 Err(SymphoniaError::IoError(err))
127 if err.kind() == std::io::ErrorKind::UnexpectedEof =>
128 {
129 break;
130 }
131 Err(err) => return Err(err.into()),
132 };
133
134 if packet.track_id() != track_id {
135 continue;
136 }
137
138 match decoder.decode(&packet) {
139 Ok(decoded) => {
140 let spec = *decoded.spec();
141 let channels = spec.channels.count();
142
143 if sample_buf.is_none()
144 || sample_buf.as_ref().unwrap().capacity() < decoded.capacity()
145 {
146 sample_buf = Some(SampleBuffer::<f32>::new(decoded.capacity() as u64, spec));
147 }
148
149 let buf = sample_buf.as_mut().unwrap();
150 buf.copy_interleaved_ref(decoded);
151 let interleaved = buf.samples();
152
153 if channels == 0 {
154 diagnostics.zero_channel_packets += 1;
155 continue;
156 }
157
158 let max_channels = channels.min(2);
159 for frame in interleaved.chunks(channels) {
160 if frame.len() != channels {
161 diagnostics.truncated_frames += 1;
162 continue;
163 }
164 let sum = frame[..max_channels].iter().copied().sum::<f32>();
165 samples.push(sum / max_channels as f32);
166 saw_samples = true;
167 }
168 }
169 Err(SymphoniaError::DecodeError(_)) => {
170 diagnostics.decode_errors += 1;
171 continue;
172 }
173 Err(SymphoniaError::IoError(err))
174 if err.kind() == std::io::ErrorKind::UnexpectedEof =>
175 {
176 break;
177 }
178 Err(err) => return Err(err.into()),
179 }
180 }
181
182 if !saw_samples {
183 return Err(MoodbarDecodeError::EmptyAudio);
184 }
185
186 let mut analysis = analyze_pcm_mono(sample_rate, &samples, options);
187 analysis.diagnostics.decode_errors = diagnostics.decode_errors;
188 analysis.diagnostics.zero_channel_packets = diagnostics.zero_channel_packets;
189 analysis.diagnostics.truncated_frames = diagnostics.truncated_frames;
190 Ok(analysis)
191}
192
193fn validate_options(options: &GenerateOptions) -> Result<(), MoodbarDecodeError> {
194 if !options.fft_size.is_power_of_two() || options.fft_size < 64 {
195 return Err(MoodbarDecodeError::InvalidOptions(
196 "fft_size must be a power of two and >= 64".to_string(),
197 ));
198 }
199 if !(options.deterministic_floor.is_finite() && options.deterministic_floor > 0.0) {
200 return Err(MoodbarDecodeError::InvalidOptions(
201 "deterministic_floor must be finite and > 0".to_string(),
202 ));
203 }
204 if options.frames_per_color == 0 {
205 return Err(MoodbarDecodeError::InvalidOptions(
206 "frames_per_color must be >= 1".to_string(),
207 ));
208 }
209
210 let edges = if options.band_edges_hz.is_empty() {
211 vec![options.low_cut_hz, options.mid_cut_hz]
212 } else {
213 options.band_edges_hz.clone()
214 };
215
216 if edges.is_empty() {
217 return Err(MoodbarDecodeError::InvalidOptions(
218 "at least one band edge is required".to_string(),
219 ));
220 }
221 for pair in edges.windows(2) {
222 if pair[0] >= pair[1] {
223 return Err(MoodbarDecodeError::InvalidOptions(
224 "band edges must be strictly increasing".to_string(),
225 ));
226 }
227 }
228 Ok(())
229}