chromaprint_rust/
lib.rs

1extern crate chromaprint_sys_next;
2extern crate thiserror;
3
4use std::time::Duration;
5
6use chromaprint_sys_next::*;
7
8pub mod simhash;
9
10/// Error type.
11#[derive(thiserror::Error, Debug)]
12pub enum Error {
13    #[error("operation failed")]
14    OperationFailed,
15    #[error("invalid fingerprint string")]
16    InvalidFingerprintString(#[from] std::str::Utf8Error),
17    #[error("invalid argument: `{0}`")]
18    InvalidArgument(String),
19}
20
21/// Result type.
22type Result<T> = std::result::Result<T, Error>;
23
24/// Chromaprint algorithm to use.
25///
26/// The default is [Algorithm::Test2].
27#[derive(Clone, Copy, Debug)]
28#[repr(u8)]
29pub enum Algorithm {
30    Test1 = 0,
31    Test2,
32    Test3,
33    /// Removes leading silence.
34    Test4,
35    Test5,
36}
37
38impl Default for Algorithm {
39    fn default() -> Self {
40        Self::Test2
41    }
42}
43
44/// Holds a single fingerprint returned by Chromaprint.
45///
46/// This can be one of:
47///
48/// 1. Base64: A fingerprint as a base-64 string. This is what you would use to
49///    upload a fingerprint to acousticid.
50/// 2. Raw: A fingerprint as a raw byte array. This is the normal representation.
51/// 3. Hash: A 32-bit hash of the raw fingerprint. Basically a compressed version of the raw fingerprint.
52#[derive(Debug)]
53pub struct Fingerprint<F: FingerprintRef> {
54    inner: F,
55}
56
57pub trait FingerprintRef {}
58impl FingerprintRef for Base64 {}
59impl FingerprintRef for Raw {}
60impl FingerprintRef for Hash {}
61
62#[derive(Debug)]
63pub struct Base64 {
64    data: *const libc::c_char,
65}
66
67#[derive(Debug)]
68pub struct Raw {
69    data: *const u32,
70    size: usize,
71}
72
73#[derive(Debug)]
74pub struct Hash(u32);
75
76impl From<Hash> for u32 {
77    fn from(hash: Hash) -> Self {
78        hash.0
79    }
80}
81
82impl Drop for Base64 {
83    fn drop(&mut self) {
84        unsafe { chromaprint_dealloc(self.data as *mut std::ffi::c_void) };
85    }
86}
87
88impl Drop for Raw {
89    fn drop(&mut self) {
90        unsafe { chromaprint_dealloc(self.data as *mut std::ffi::c_void) };
91    }
92}
93
94impl Fingerprint<Base64> {
95    pub fn get(&self) -> Result<&str> {
96        let s = unsafe { std::ffi::CStr::from_ptr(self.inner.data) }.to_str();
97        s.map_err(|e| Error::InvalidFingerprintString(e))
98    }
99}
100
101impl Fingerprint<Raw> {
102    pub fn get(&self) -> &[u32] {
103        let s = unsafe { std::slice::from_raw_parts(self.inner.data, self.inner.size) };
104        s
105    }
106}
107
108impl Fingerprint<Hash> {
109    pub fn get(&self) -> u32 {
110        self.inner.0
111    }
112}
113
114impl TryFrom<Fingerprint<Raw>> for Fingerprint<Hash> {
115    type Error = Error;
116    fn try_from(raw: Fingerprint<Raw>) -> Result<Self> {
117        let mut hash: u32 = 0;
118        let data = raw.get();
119        let rc =
120            unsafe { chromaprint_hash_fingerprint(data.as_ptr(), data.len() as i32, &mut hash) };
121        if rc != 1 {
122            return Err(Error::OperationFailed);
123        }
124        return Ok(Fingerprint { inner: Hash(hash) });
125    }
126}
127
128pub struct Context {
129    ctx: *mut ChromaprintContext,
130    algorithm: Algorithm,
131}
132
133impl Context {
134    /// Creates a new Chromaprint context with the given algorithm. To use the default algorithm,
135    /// call [`default`](Self::default).
136    pub fn new(algorithm: Algorithm) -> Self {
137        let ctx = unsafe { chromaprint_new(algorithm as i32) };
138        Self { ctx, algorithm }
139    }
140
141    pub fn algorithm(&self) -> Algorithm {
142        self.algorithm
143    }
144
145    /// Returns the sample rate used internally by Chromaprint. If you want to avoid having Chromaprint internally
146    /// resample audio, make sure to use this sample rate.
147    pub fn sample_rate(&self) -> u32 {
148        unsafe { chromaprint_get_sample_rate(self.ctx) as u32 }
149    }
150
151    /// Starts a fingerprinting session. Audio samples will be buffered by Chromaprint until [`finish`](Self::finish)
152    /// is called.
153    pub fn start(&mut self, sample_rate: u32, num_channels: u16) -> Result<()> {
154        let rc = unsafe { chromaprint_start(self.ctx, sample_rate as i32, num_channels as i32) };
155        if rc != 1 {
156            return Err(Error::OperationFailed);
157        }
158        Ok(())
159    }
160
161    /// Feeds a set of audio samples to the fingerprinter.
162    pub fn feed(&mut self, data: &[i16]) -> Result<()> {
163        let rc = unsafe { chromaprint_feed(self.ctx, data.as_ptr(), data.len() as i32) };
164        if rc != 1 {
165            return Err(Error::OperationFailed);
166        }
167        Ok(())
168    }
169
170    /// Signals to the fingerprinter that the audio clip is complete. You must call this method before
171    /// extracting a fingerprint.
172    ///
173    /// Important note: before calling [`finish`](Self::finish), you should provide at least 3 seconds worth of audio samples.
174    /// The reason is that the size of the raw fingerprint is directly related to the amount of audio data fed
175    /// to the fingerprinter.
176    ///
177    /// In general, the raw fingerprint size is `~= (duration_in_secs * 11025 - 4096) / 1365 - 15 - 4 + 1`
178    ///
179    /// See detailed discussion [here](https://github.com/acoustid/chromaprint/issues/45).
180    pub fn finish(&mut self) -> Result<()> {
181        let rc = unsafe { chromaprint_finish(self.ctx) };
182        if rc != 1 {
183            return Err(Error::OperationFailed);
184        }
185        Ok(())
186    }
187
188    /// Returns the raw fingerprint.
189    pub fn get_fingerprint_raw(&self) -> Result<Fingerprint<Raw>> {
190        let mut data_ptr = std::ptr::null_mut();
191        let mut size: i32 = 0;
192        let rc = unsafe { chromaprint_get_raw_fingerprint(self.ctx, &mut data_ptr, &mut size) };
193        if rc != 1 {
194            return Err(Error::OperationFailed);
195        }
196        Ok(Fingerprint {
197            inner: Raw {
198                data: data_ptr as *const _,
199                size: size as usize,
200            },
201        })
202    }
203
204    /// Returns a hash of the raw fingerprint.
205    ///
206    /// Under the hood, Chromaprint computes a 32-bit [SimHash](https://en.wikipedia.org/wiki/SimHash) of the raw fingerprint.
207    pub fn get_fingerprint_hash(&self) -> Result<Fingerprint<Hash>> {
208        let mut hash: u32 = 0;
209        let rc = unsafe { chromaprint_get_fingerprint_hash(self.ctx, &mut hash) };
210        if rc != 1 {
211            return Err(Error::OperationFailed);
212        }
213        Ok(Fingerprint { inner: Hash(hash) })
214    }
215
216    /// Returns a compressed version of the raw fingerprint in Base64 format. This is the format used by
217    /// the [AcousticID](https://acoustid.org/) service.
218    pub fn get_fingerprint_base64(&self) -> Result<Fingerprint<Base64>> {
219        let mut out_ptr = std::ptr::null_mut();
220        let rc = unsafe { chromaprint_get_fingerprint(self.ctx, &mut out_ptr) };
221        if rc != 1 {
222            return Err(Error::OperationFailed);
223        }
224        Ok(Fingerprint {
225            inner: Base64 {
226                data: out_ptr as *const _,
227            },
228        })
229    }
230
231    /// Returns the current delay for the fingerprint.
232    ///
233    /// This value represents the duration of samples that had to be buffered before
234    /// Chromaprint could start generating the fingerprint.
235    pub fn get_delay(&self) -> Result<Duration> {
236        let delay_ms = unsafe { chromaprint_get_delay_ms(self.ctx) };
237        if delay_ms < 0 {
238            return Err(Error::OperationFailed);
239        }
240        let delay = Duration::from_millis(delay_ms as u64);
241        Ok(delay)
242    }
243
244    /// Returns the duration of a single item in the raw fingerprint.
245    ///
246    /// For example, if you compute a raw fingerprint and it contains 1000 32-bit values,
247    /// the duration returned by this method will tell you how much time (in audio samples)
248    /// is represented by each 32-bit value in the fingerprint.
249    pub fn get_item_duration(&self) -> Result<Duration> {
250        let item_duration_ms = unsafe { chromaprint_get_item_duration_ms(self.ctx) };
251        if item_duration_ms < 0 {
252            return Err(Error::OperationFailed);
253        }
254        let item_duration = Duration::from_millis(item_duration_ms as u64);
255        Ok(item_duration)
256    }
257
258    /// Clear the current fingerprint.
259    pub fn clear_fingerprint(&mut self) -> Result<()> {
260        let rc = unsafe { chromaprint_clear_fingerprint(self.ctx) };
261        if rc != 1 {
262            return Err(Error::OperationFailed);
263        }
264        Ok(())
265    }
266}
267
268impl Default for Context {
269    fn default() -> Self {
270        Self::new(Algorithm::default())
271    }
272}
273
274impl Drop for Context {
275    fn drop(&mut self) {
276        unsafe { chromaprint_free(self.ctx) }
277    }
278}
279
280/// DelayedFingerprinter allows you to generate Chromaprint fingerprints at a higher resolution
281/// than allowed by default.
282///
283/// By design, Chromaprint requires at least 3 seconds of audio to generate a fingerprint. To get
284/// more precise fingerprints, we can use multiple Contexts separated by a fixed delay. For example,
285/// to obtain a fingerprint for each second of audio, run 3 contexts separated by one second.
286///
287/// DelayedFingerprinter abstracts this logic out into a simple API. Once created, you just need to
288/// call `feed()` and check if any hashes were returned. Each hash will be accompnaied with a timestamp
289/// that indicates when the fingerprint was generated.
290pub struct DelayedFingerprinter {
291    ctx: Vec<Context>,
292    next_fingerprint: Vec<Duration>,
293    interval: Duration,
294    sample_rate: u32,
295    num_channels: u16,
296    started: bool,
297    clock: Duration,
298    clock_delta: Duration,
299}
300
301impl DelayedFingerprinter {
302    pub fn new(
303        n: usize,
304        interval: Duration,
305        delay: Duration,
306        sample_rate: Option<u32>,
307        num_channels: u16,
308        algorithm: Option<Algorithm>,
309    ) -> Self {
310        let mut ctx = Vec::with_capacity(n);
311        for _ in 0..n {
312            if let Some(algorithm) = algorithm {
313                ctx.push(Context::new(algorithm));
314            } else {
315                ctx.push(Context::default());
316            }
317        }
318
319        // Use the default Chromaprint sample rate if not specified.
320        let sample_rate = sample_rate.unwrap_or_else(|| ctx[0].sample_rate());
321
322        // Determine when the first fingerprint is needed for each delay.
323        let mut next_fingerprint = Vec::with_capacity(n);
324        for i in 0..n {
325            next_fingerprint.push(interval + delay.mul_f32(i as f32));
326        }
327
328        Self {
329            ctx,
330            next_fingerprint,
331            interval,
332            sample_rate,
333            num_channels,
334            started: false,
335            clock: Duration::ZERO,
336            clock_delta: Duration::from_micros(1),
337        }
338    }
339
340    pub fn interval(&self) -> Duration {
341        self.interval
342    }
343
344    pub fn sample_rate(&self) -> u32 {
345        self.sample_rate
346    }
347
348    pub fn clock(&self) -> Duration {
349        self.clock
350    }
351
352    pub fn feed(&mut self, samples: &[i16]) -> Result<Vec<(Fingerprint<Raw>, Duration)>> {
353        // We can get multiple hashes in a single call (e.g., large number of samples).
354        let mut hashes = Vec::new();
355
356        if !self.started {
357            for ctx in self.ctx.iter_mut() {
358                ctx.start(self.sample_rate, self.num_channels)?;
359            }
360            self.started = true;
361        }
362
363        for (i, ctx) in self.ctx.iter_mut().enumerate() {
364            // If the clock is within `clock_delta` of a fingerprint, assume it's time to take one. This
365            // is done to handle floating point precision issues during comparison.
366            if self.clock + self.clock_delta >= self.next_fingerprint[i] {
367                ctx.finish()?;
368                hashes.push((ctx.get_fingerprint_raw()?, self.clock));
369                ctx.start(self.sample_rate, self.num_channels)?;
370                self.next_fingerprint[i] = self.clock + self.interval;
371            }
372        }
373
374        for ctx in &mut self.ctx {
375            ctx.feed(samples)?;
376        }
377
378        // Increment the clock based on number of samples and the configured sample rate.
379        self.clock += Duration::from_secs_f64(
380            samples.len() as f64 / self.sample_rate as f64 / self.num_channels as f64,
381        );
382
383        Ok(hashes)
384    }
385}
386
387#[cfg(test)]
388mod test {
389    use std::{
390        io::Read,
391        path::{Path, PathBuf},
392        str::FromStr,
393    };
394
395    use super::*;
396
397    // Load raw audio as `i16` samples.
398    fn load_audio(path: impl AsRef<Path>) -> Vec<i16> {
399        let mut data = Vec::new();
400        let mut buf = [0u8; 2];
401        let mut f = std::fs::File::open(path).unwrap();
402        while f.read_exact(&mut buf).is_ok() {
403            data.push(i16::from_le_bytes(buf));
404        }
405        data
406    }
407
408    #[test]
409    fn test_load_audio() {
410        let audio_path = PathBuf::from_str(env!("CARGO_MANIFEST_DIR"))
411            .unwrap()
412            .join("resources")
413            .join("test_mono_44100.raw");
414        let data = load_audio(&audio_path);
415        assert_eq!(data.len(), 2 * 44100); // 2 seconds @ 44.1 kHz
416        assert_eq!(data[1000], 0);
417        assert_eq!(data[2000], 107);
418        assert_eq!(data[3000], 128);
419    }
420
421    #[test]
422    #[ignore = "failing"]
423    fn test_mono() {
424        let audio_path = PathBuf::from_str(env!("CARGO_MANIFEST_DIR"))
425            .unwrap()
426            .join("resources")
427            .join("test_mono_44100.raw");
428        let data = load_audio(&audio_path);
429
430        let mut ctx = Context::default();
431        ctx.start(44100, 1).unwrap();
432        ctx.feed(&data).unwrap();
433        ctx.finish().unwrap();
434        dbg!(ctx.get_fingerprint_hash().unwrap());
435        dbg!(ctx.get_fingerprint_base64().unwrap());
436        dbg!(ctx.get_fingerprint_raw().unwrap());
437    }
438
439    #[test]
440    fn test_stereo() {
441        let audio_path = PathBuf::from_str(env!("CARGO_MANIFEST_DIR"))
442            .unwrap()
443            .join("resources")
444            .join("test_stereo_44100.raw");
445        let data = load_audio(&audio_path);
446
447        let mut ctx = Context::default();
448        ctx.start(44100, 1).unwrap();
449        ctx.feed(&data).unwrap();
450        ctx.finish().unwrap();
451
452        assert_eq!(ctx.get_fingerprint_hash().unwrap().get(), 3732003127);
453        assert_eq!(
454            ctx.get_fingerprint_raw().unwrap().get(),
455            &[
456                3740390231, 3739276119, 3730871573, 3743460629, 3743525173, 3744594229, 3727948087,
457                1584920886, 1593302326, 1593295926, 1584907318,
458            ]
459        );
460        assert_eq!(
461            ctx.get_fingerprint_base64().unwrap().get().unwrap(),
462            "AQAAC0kkZUqYREkUnFAXHk8uuMZl6EfO4zu-4ABKFGESWIIMEQE"
463        );
464    }
465
466    #[test]
467    fn test_sample_rate() {
468        let ctx = Context::default();
469        assert_eq!(ctx.sample_rate(), 11025);
470    }
471
472    #[test]
473    fn test_delayed_fingerprinter() {
474        let mut s = DelayedFingerprinter::new(
475            2,
476            Duration::from_secs(3),
477            Duration::from_millis(100),
478            Some(44100),
479            1,
480            None,
481        );
482        let audio_path = PathBuf::from_str(env!("CARGO_MANIFEST_DIR"))
483            .unwrap()
484            .join("resources")
485            .join("test_stereo_44100.raw");
486        let data = load_audio(&audio_path);
487
488        // Feed 100ms chunks and ensure that exactly two hashes are returned.
489        let hashes = data
490            .chunks(4410)
491            .map(|samples| s.feed(samples).unwrap())
492            .flatten()
493            .map(|(f, ts)| (TryInto::<Fingerprint<Hash>>::try_into(f).unwrap().get(), ts))
494            .collect::<Vec<(u32, Duration)>>();
495
496        assert_eq!(
497            &hashes,
498            &[
499                (3739276119, Duration::from_secs(3)),
500                (3730870549, Duration::from_millis(3100))
501            ]
502        );
503    }
504}