1use std::fs::File;
2use std::io::Cursor;
3use std::path::Path;
4
5use moodbar_analysis::{
6 analysis_to_raw_rgb_bytes, analyze_pcm_mono, GenerateOptions, MoodbarAnalysis,
7};
8use symphonia::core::audio::SampleBuffer;
9use symphonia::core::codecs::DecoderOptions;
10use symphonia::core::errors::Error as SymphoniaError;
11use symphonia::core::formats::FormatOptions;
12use symphonia::core::io::MediaSourceStream;
13use symphonia::core::meta::MetadataOptions;
14use symphonia::core::probe::Hint;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18pub enum MoodbarDecodeError {
19 #[error("no playable audio track found")]
20 NoAudioTrack,
21 #[error("decoded stream has no samples")]
22 EmptyAudio,
23 #[error("I/O error: {0}")]
24 Io(#[from] std::io::Error),
25 #[error("decode error: {0}")]
26 Decode(#[from] SymphoniaError),
27 #[error("invalid options: {0}")]
28 InvalidOptions(String),
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct DecodeDiagnostics {
33 pub decode_errors: usize,
34 pub zero_channel_packets: usize,
35 pub truncated_frames: usize,
36}
37
38pub fn analyze_path(
39 path: &Path,
40 options: &GenerateOptions,
41) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
42 validate_options(options)?;
43
44 let file = File::open(path)?;
45 let mss = MediaSourceStream::new(Box::new(file), Default::default());
46
47 let mut hint = Hint::new();
48 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
49 hint.with_extension(ext);
50 }
51
52 analyze_media_source(mss, hint, options)
53}
54
55pub fn analyze_bytes(
56 bytes: &[u8],
57 extension: Option<&str>,
58 options: &GenerateOptions,
59) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
60 validate_options(options)?;
61
62 let cursor = Cursor::new(bytes.to_vec());
63 let mss = MediaSourceStream::new(Box::new(cursor), Default::default());
64 let mut hint = Hint::new();
65 if let Some(ext) = extension {
66 if !ext.is_empty() {
67 hint.with_extension(ext);
68 }
69 }
70
71 analyze_media_source(mss, hint, options)
72}
73
74pub fn generate_moodbar_from_path(
75 path: &Path,
76 options: &GenerateOptions,
77) -> Result<Vec<u8>, MoodbarDecodeError> {
78 let analysis = analyze_path(path, options)?;
79 Ok(analysis_to_raw_rgb_bytes(&analysis))
80}
81
82pub fn generate_moodbar_from_bytes(
83 bytes: &[u8],
84 extension: Option<&str>,
85 options: &GenerateOptions,
86) -> Result<Vec<u8>, MoodbarDecodeError> {
87 let analysis = analyze_bytes(bytes, extension, options)?;
88 Ok(analysis_to_raw_rgb_bytes(&analysis))
89}
90
91fn analyze_media_source(
92 mss: MediaSourceStream,
93 hint: Hint,
94 options: &GenerateOptions,
95) -> Result<MoodbarAnalysis, MoodbarDecodeError> {
96 let probed = symphonia::default::get_probe().format(
97 &hint,
98 mss,
99 &FormatOptions::default(),
100 &MetadataOptions::default(),
101 )?;
102
103 let mut format = probed.format;
104 let track = format
105 .tracks()
106 .iter()
107 .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
108 .ok_or(MoodbarDecodeError::NoAudioTrack)?;
109
110 let sample_rate = track
111 .codec_params
112 .sample_rate
113 .ok_or(MoodbarDecodeError::NoAudioTrack)?;
114 let track_id = track.id;
115
116 let mut decoder =
117 symphonia::default::get_codecs().make(&track.codec_params, &DecoderOptions::default())?;
118 let mut samples = Vec::<f32>::new();
119 let mut saw_samples = false;
120 let mut diagnostics = DecodeDiagnostics::default();
121
122 loop {
123 let packet = match format.next_packet() {
124 Ok(packet) => packet,
125 Err(SymphoniaError::IoError(err))
126 if err.kind() == std::io::ErrorKind::UnexpectedEof =>
127 {
128 break;
129 }
130 Err(err) => return Err(err.into()),
131 };
132
133 if packet.track_id() != track_id {
134 continue;
135 }
136
137 match decoder.decode(&packet) {
138 Ok(decoded) => {
139 let spec = *decoded.spec();
140 let channels = spec.channels.count();
141 let mut sample_buf = SampleBuffer::<f32>::new(decoded.capacity() as u64, spec);
142 sample_buf.copy_interleaved_ref(decoded);
143 let interleaved = sample_buf.samples();
144
145 if channels == 0 {
146 diagnostics.zero_channel_packets += 1;
147 continue;
148 }
149
150 for frame in interleaved.chunks(channels) {
151 if frame.len() != channels {
152 diagnostics.truncated_frames += 1;
153 continue;
154 }
155 let sum = frame.iter().copied().sum::<f32>();
156 samples.push(sum / channels as f32);
157 saw_samples = true;
158 }
159 }
160 Err(SymphoniaError::DecodeError(_)) => {
161 diagnostics.decode_errors += 1;
162 continue;
163 }
164 Err(SymphoniaError::IoError(err))
165 if err.kind() == std::io::ErrorKind::UnexpectedEof =>
166 {
167 break;
168 }
169 Err(err) => return Err(err.into()),
170 }
171 }
172
173 if !saw_samples {
174 return Err(MoodbarDecodeError::EmptyAudio);
175 }
176
177 let mut analysis = analyze_pcm_mono(sample_rate, &samples, options);
178 analysis.diagnostics.decode_errors = diagnostics.decode_errors;
179 analysis.diagnostics.zero_channel_packets = diagnostics.zero_channel_packets;
180 analysis.diagnostics.truncated_frames = diagnostics.truncated_frames;
181 Ok(analysis)
182}
183
184fn validate_options(options: &GenerateOptions) -> Result<(), MoodbarDecodeError> {
185 if !options.fft_size.is_power_of_two() || options.fft_size < 64 {
186 return Err(MoodbarDecodeError::InvalidOptions(
187 "fft_size must be a power of two and >= 64".to_string(),
188 ));
189 }
190 if !(options.deterministic_floor.is_finite() && options.deterministic_floor > 0.0) {
191 return Err(MoodbarDecodeError::InvalidOptions(
192 "deterministic_floor must be finite and > 0".to_string(),
193 ));
194 }
195 if options.frames_per_color == 0 {
196 return Err(MoodbarDecodeError::InvalidOptions(
197 "frames_per_color must be >= 1".to_string(),
198 ));
199 }
200
201 let edges = if options.band_edges_hz.is_empty() {
202 vec![options.low_cut_hz, options.mid_cut_hz]
203 } else {
204 options.band_edges_hz.clone()
205 };
206
207 if edges.is_empty() {
208 return Err(MoodbarDecodeError::InvalidOptions(
209 "at least one band edge is required".to_string(),
210 ));
211 }
212 for pair in edges.windows(2) {
213 if pair[0] >= pair[1] {
214 return Err(MoodbarDecodeError::InvalidOptions(
215 "band edges must be strictly increasing".to_string(),
216 ));
217 }
218 }
219 Ok(())
220}