birdnet_onnx/
classifier.rs

1//! Classifier builder and implementation
2
3use crate::detection::detect_model_type;
4use crate::error::{Error, Result};
5use crate::labels::load_labels_from_file;
6use crate::postprocess::top_k_predictions;
7use crate::types::{ExecutionProviderInfo, ModelConfig, ModelType, PredictionResult};
8use ndarray::Array2;
9use ort::session::Session;
10use ort::value::Value;
11use std::sync::{Arc, Mutex};
12
13// Macro to generate execution provider builder methods
14macro_rules! with_provider_method {
15    ($fn_name:ident, $provider_struct:ident, $provider_enum:ident, $doc:expr) => {
16        #[doc = $doc]
17        #[must_use]
18        pub fn $fn_name(mut self) -> Self {
19            use ort::execution_providers::$provider_struct;
20            self.execution_providers
21                .push($provider_struct::default().into());
22            // Only set requested_provider if it's still the default (CPU).
23            // This aligns with ONNX Runtime's behavior: it tries providers
24            // in the order they were added, so the first non-CPU provider
25            // is the most relevant one to track.
26            if self.requested_provider == ExecutionProviderInfo::Cpu {
27                self.requested_provider = ExecutionProviderInfo::$provider_enum;
28            }
29            self
30        }
31    };
32}
33
34/// Labels source for builder
35#[derive(Debug)]
36enum Labels {
37    Path(String),
38    InMemory(Vec<String>),
39}
40
41/// Builder for constructing a Classifier
42#[derive(Debug)]
43pub struct ClassifierBuilder {
44    model_path: Option<String>,
45    labels: Option<Labels>,
46    model_type_override: Option<ModelType>,
47    execution_providers: Vec<ort::execution_providers::ExecutionProviderDispatch>,
48    requested_provider: ExecutionProviderInfo,
49    top_k: usize,
50    min_confidence: Option<f32>,
51}
52
53impl Default for ClassifierBuilder {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl ClassifierBuilder {
60    /// Create a new classifier builder
61    #[must_use]
62    pub const fn new() -> Self {
63        Self {
64            model_path: None,
65            labels: None,
66            model_type_override: None,
67            execution_providers: Vec::new(),
68            requested_provider: ExecutionProviderInfo::Cpu,
69            top_k: 10,
70            min_confidence: None,
71        }
72    }
73
74    /// Set the path to the ONNX model file (required)
75    #[must_use]
76    pub fn model_path(mut self, path: impl Into<String>) -> Self {
77        self.model_path = Some(path.into());
78        self
79    }
80
81    /// Set the path to the labels file
82    #[must_use]
83    pub fn labels_path(mut self, path: impl Into<String>) -> Self {
84        self.labels = Some(Labels::Path(path.into()));
85        self
86    }
87
88    /// Set labels directly from a vector
89    #[must_use]
90    pub fn labels(mut self, labels: Vec<String>) -> Self {
91        self.labels = Some(Labels::InMemory(labels));
92        self
93    }
94
95    /// Override auto-detected model type (useful for `Perch` v2)
96    #[must_use]
97    pub const fn model_type(mut self, model_type: ModelType) -> Self {
98        self.model_type_override = Some(model_type);
99        self
100    }
101
102    /// Add an execution provider (GPU, CPU, etc.)
103    ///
104    /// Multiple providers can be added; they are tried in order.
105    /// If none specified, defaults to CPU.
106    #[must_use]
107    pub fn execution_provider(
108        mut self,
109        provider: impl Into<ort::execution_providers::ExecutionProviderDispatch>,
110    ) -> Self {
111        self.execution_providers.push(provider.into());
112        self
113    }
114
115    /// Set the number of top predictions to return (default: 10)
116    #[must_use]
117    pub const fn top_k(mut self, k: usize) -> Self {
118        self.top_k = k;
119        self
120    }
121
122    /// Set minimum confidence threshold for predictions
123    #[must_use]
124    pub const fn min_confidence(mut self, threshold: f32) -> Self {
125        self.min_confidence = Some(threshold);
126        self
127    }
128
129    with_provider_method!(
130        with_cuda,
131        CUDAExecutionProvider,
132        Cuda,
133        "Request CUDA execution provider (NVIDIA GPU)"
134    );
135
136    /// Request `TensorRT` execution provider (NVIDIA GPU) with optimized defaults
137    ///
138    /// This method enables performance optimizations including FP16 precision,
139    /// CUDA graphs, and caching. Expected performance: 4x faster than unoptimized
140    /// `TensorRT` and comparable to or better than CUDA provider.
141    ///
142    /// For custom `TensorRT` settings, use [`with_tensorrt_config()`](Self::with_tensorrt_config).
143    ///
144    /// # Requirements
145    /// - NVIDIA GPU (compute capability 5.3+)
146    /// - `TensorRT` library installed
147    /// - ONNX Runtime built with `TensorRT` support
148    ///
149    /// # Performance Optimizations
150    ///
151    /// The default configuration enables:
152    /// - **FP16 precision**: 2x faster inference on GPUs with tensor cores
153    /// - **CUDA graphs**: Reduced CPU launch overhead for models with many small layers
154    /// - **Engine caching**: Reduces session creation from minutes to seconds
155    /// - **Timing cache**: Accelerates future builds with similar layer configurations
156    /// - **Optimization level 3**: Balanced optimization (`TensorRT` default)
157    ///
158    /// # Example
159    ///
160    /// ```no_run
161    /// use birdnet_onnx::Classifier;
162    ///
163    /// let classifier = Classifier::builder()
164    ///     .model_path("model.onnx")
165    ///     .labels_path("labels.txt")
166    ///     .with_tensorrt()
167    ///     .build()?;
168    /// # Ok::<(), birdnet_onnx::Error>(())
169    /// ```
170    #[must_use]
171    pub fn with_tensorrt(mut self) -> Self {
172        use ort::execution_providers::TensorRTExecutionProvider;
173
174        let config = crate::tensorrt_config::TensorRTConfig::new();
175        let provider = config.apply_to(TensorRTExecutionProvider::default());
176
177        self.execution_providers.push(provider.into());
178
179        if self.requested_provider == ExecutionProviderInfo::Cpu {
180            self.requested_provider = ExecutionProviderInfo::TensorRt;
181        }
182
183        self
184    }
185
186    /// Configure `TensorRT` with custom settings
187    ///
188    /// Use this method when you need fine-grained control over `TensorRT` behavior.
189    /// For most use cases, [`with_tensorrt()`](Self::with_tensorrt) provides optimal defaults.
190    ///
191    /// # Example
192    ///
193    /// ```no_run
194    /// use birdnet_onnx::{Classifier, TensorRTConfig};
195    ///
196    /// let trt_config = TensorRTConfig::new()
197    ///     .with_fp16(false)  // Disable FP16 for accuracy-critical work
198    ///     .with_builder_optimization_level(5)  // Maximum optimization
199    ///     .with_engine_cache_path("/tmp/trt_cache");
200    ///
201    /// let classifier = Classifier::builder()
202    ///     .model_path("model.onnx")
203    ///     .labels_path("labels.txt")
204    ///     .with_tensorrt_config(trt_config)
205    ///     .build()?;
206    /// # Ok::<(), birdnet_onnx::Error>(())
207    /// ```
208    #[must_use]
209    pub fn with_tensorrt_config(mut self, config: crate::tensorrt_config::TensorRTConfig) -> Self {
210        use ort::execution_providers::TensorRTExecutionProvider;
211
212        let provider = config.apply_to(TensorRTExecutionProvider::default());
213        self.execution_providers.push(provider.into());
214
215        if self.requested_provider == ExecutionProviderInfo::Cpu {
216            self.requested_provider = ExecutionProviderInfo::TensorRt;
217        }
218
219        self
220    }
221
222    with_provider_method!(
223        with_directml,
224        DirectMLExecutionProvider,
225        DirectMl,
226        "Request `DirectML` execution provider (Windows GPU)"
227    );
228    with_provider_method!(
229        with_coreml,
230        CoreMLExecutionProvider,
231        CoreMl,
232        "Request `CoreML` execution provider (Apple Neural Engine)"
233    );
234    with_provider_method!(
235        with_rocm,
236        ROCmExecutionProvider,
237        Rocm,
238        "Request `ROCm` execution provider (AMD GPU)"
239    );
240    with_provider_method!(
241        with_openvino,
242        OpenVINOExecutionProvider,
243        OpenVino,
244        "Request `OpenVINO` execution provider (Intel accelerator)"
245    );
246    with_provider_method!(
247        with_onednn,
248        OneDNNExecutionProvider,
249        OneDnn,
250        "Request oneDNN execution provider (Intel accelerator)"
251    );
252    with_provider_method!(
253        with_qnn,
254        QNNExecutionProvider,
255        Qnn,
256        "Request QNN execution provider (Qualcomm NPU)"
257    );
258    with_provider_method!(
259        with_acl,
260        ACLExecutionProvider,
261        Acl,
262        "Request ACL execution provider (Arm Compute Library)"
263    );
264    with_provider_method!(
265        with_armnn,
266        ArmNNExecutionProvider,
267        ArmNn,
268        "Request `ArmNN` execution provider (Arm Neural Network)"
269    );
270
271    /// Build the classifier
272    ///
273    /// # Errors
274    ///
275    /// Returns an error if:
276    /// - Model path was not set
277    /// - Labels were not provided
278    /// - Model file cannot be loaded
279    /// - Model type cannot be detected
280    /// - Label count doesn't match model
281    pub fn build(self) -> Result<Classifier> {
282        // Validate required fields
283        let model_path = self.model_path.ok_or(Error::ModelPathRequired)?;
284        let labels_source = self.labels.ok_or(Error::LabelsRequired)?;
285
286        // Build session with execution providers
287        let mut session_builder = Session::builder().map_err(Error::ModelLoad)?;
288
289        for provider in self.execution_providers {
290            session_builder = session_builder
291                .with_execution_providers([provider])
292                .map_err(Error::ModelLoad)?;
293        }
294
295        let session = session_builder
296            .commit_from_file(&model_path)
297            .map_err(Error::ModelLoad)?;
298
299        // Extract input/output shapes for model detection
300        let input_shape = extract_input_shape(&session)?;
301        let output_shapes = extract_output_shapes(&session)?;
302
303        // Detect model type
304        let config = detect_model_type(&input_shape, &output_shapes, self.model_type_override)?;
305
306        // Load labels
307        let labels = match labels_source {
308            Labels::Path(path) => load_labels_from_file(&path, config.model_type)?,
309            Labels::InMemory(labels) => labels,
310        };
311
312        // Validate label count matches model
313        if labels.len() != config.num_species {
314            return Err(Error::LabelCount {
315                expected: config.num_species,
316                got: labels.len(),
317            });
318        }
319
320        Ok(Classifier {
321            inner: Arc::new(ClassifierInner {
322                session: Mutex::new(session),
323                config,
324                labels,
325                requested_provider: self.requested_provider,
326                top_k: self.top_k,
327                min_confidence: self.min_confidence,
328            }),
329        })
330    }
331}
332
333/// Extract input tensor shape from session
334fn extract_input_shape(session: &Session) -> Result<Vec<i64>> {
335    let inputs = session
336        .inputs
337        .first()
338        .ok_or_else(|| Error::ModelDetection {
339            reason: "model has no inputs".to_string(),
340        })?;
341
342    let shape = inputs
343        .input_type
344        .tensor_shape()
345        .ok_or_else(|| Error::ModelDetection {
346            reason: "input is not a tensor".to_string(),
347        })?;
348
349    Ok(shape.iter().copied().collect())
350}
351
352/// Extract output tensor shapes from session
353fn extract_output_shapes(session: &Session) -> Result<Vec<Vec<i64>>> {
354    session
355        .outputs
356        .iter()
357        .map(|output| {
358            let shape = output
359                .output_type
360                .tensor_shape()
361                .ok_or_else(|| Error::ModelDetection {
362                    reason: "output is not a tensor".to_string(),
363                })?;
364            Ok(shape.iter().copied().collect())
365        })
366        .collect()
367}
368
369/// Internal state shared via Arc for thread safety
370struct ClassifierInner {
371    session: Mutex<Session>,
372    config: ModelConfig,
373    labels: Vec<String>,
374    requested_provider: ExecutionProviderInfo,
375    top_k: usize,
376    min_confidence: Option<f32>,
377}
378
379/// Thread-safe classifier for bird species detection
380///
381/// Use `Classifier::builder()` to construct.
382#[derive(Clone)]
383pub struct Classifier {
384    inner: Arc<ClassifierInner>,
385}
386
387impl std::fmt::Debug for Classifier {
388    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389        f.debug_struct("Classifier")
390            .field("config", &self.inner.config)
391            .field("labels_count", &self.inner.labels.len())
392            .field("requested_provider", &self.inner.requested_provider)
393            .field("top_k", &self.inner.top_k)
394            .field("min_confidence", &self.inner.min_confidence)
395            .finish_non_exhaustive()
396    }
397}
398
399impl Classifier {
400    /// Create a new classifier builder
401    #[must_use]
402    pub const fn builder() -> ClassifierBuilder {
403        ClassifierBuilder::new()
404    }
405
406    /// Get the model configuration
407    #[must_use]
408    pub fn config(&self) -> &ModelConfig {
409        &self.inner.config
410    }
411
412    /// Get the species labels
413    #[must_use]
414    pub fn labels(&self) -> &[String] {
415        &self.inner.labels
416    }
417
418    /// Returns the execution provider that was requested for this classifier.
419    ///
420    /// **Note:** This returns the provider that was *requested* during build,
421    /// not necessarily the provider that is *actually active*. If the requested
422    /// provider is unavailable, ONNX Runtime will silently fall back to CPU.
423    ///
424    /// This value is only set by the typed `with_<provider>()` builder methods
425    /// (e.g., `with_cuda()`, `with_tensorrt()`). The generic `execution_provider()`
426    /// method does not affect the value returned here.
427    ///
428    /// To verify the actual provider being used, enable ONNX Runtime verbose
429    /// logging via environment variable: `ORT_LOG_LEVEL=Verbose`
430    #[must_use]
431    pub fn requested_provider(&self) -> ExecutionProviderInfo {
432        self.inner.requested_provider
433    }
434
435    /// Run inference on a single audio segment
436    ///
437    /// # Arguments
438    /// * `segment` - Audio samples (must match `config().sample_count`)
439    ///
440    /// # Returns
441    /// * `PredictionResult` with top predictions, embeddings (if available), and raw scores
442    ///
443    /// # Errors
444    ///
445    /// Returns an error if:
446    /// - Input segment size doesn't match expected sample count
447    /// - Session lock is poisoned
448    /// - ONNX inference fails
449    #[allow(clippy::significant_drop_tightening)]
450    pub fn predict(&self, segment: &[f32]) -> Result<PredictionResult> {
451        // Validate input size
452        let expected = self.inner.config.sample_count;
453        if segment.len() != expected {
454            return Err(Error::InputSize {
455                expected,
456                got: segment.len(),
457            });
458        }
459
460        // Create input tensor [1, sample_count]
461        let input_array = Array2::from_shape_vec((1, segment.len()), segment.to_vec())
462            .map_err(|e| Error::Inference(format!("failed to create input array: {e}")))?;
463
464        let input_value = Value::from_array(input_array)
465            .map_err(|e| Error::Inference(format!("failed to create input tensor: {e}")))?;
466
467        // Run inference with locked session
468        // IMPORTANT: Session lock must be held while outputs exist because ort::Value
469        // borrows from the session. Dropping the lock before processing outputs would
470        // cause a use-after-free. This is why clippy::significant_drop_tightening is
471        // suppressed on this method.
472        let mut session = self
473            .inner
474            .session
475            .lock()
476            .map_err(|e| Error::Inference(format!("session lock poisoned: {e}")))?;
477
478        let outputs = session
479            .run(ort::inputs![input_value])
480            .map_err(|e| Error::Inference(e.to_string()))?;
481
482        // Process outputs based on model type
483        self.process_outputs(&outputs)
484    }
485
486    /// Run inference on multiple audio segments (more efficient for GPU)
487    ///
488    /// # Arguments
489    /// * `segments` - Slice of audio segments (all must match `config().sample_count`)
490    ///
491    /// # Returns
492    /// * Vector of `PredictionResult`, one per input segment
493    ///
494    /// # Errors
495    ///
496    /// Returns an error if:
497    /// - Any segment size doesn't match expected sample count
498    /// - Session lock is poisoned
499    /// - ONNX inference fails
500    #[allow(clippy::significant_drop_tightening)]
501    pub fn predict_batch(&self, segments: &[&[f32]]) -> Result<Vec<PredictionResult>> {
502        if segments.is_empty() {
503            return Ok(Vec::new());
504        }
505
506        let expected = self.inner.config.sample_count;
507
508        // Validate all segments
509        for (i, seg) in segments.iter().enumerate() {
510            if seg.len() != expected {
511                return Err(Error::BatchInputSize {
512                    index: i,
513                    expected,
514                    got: seg.len(),
515                });
516            }
517        }
518
519        let batch_size = segments.len();
520
521        // Stack segments into [batch_size, sample_count]
522        let mut batch_data = Vec::with_capacity(batch_size * expected);
523        for seg in segments {
524            batch_data.extend_from_slice(seg);
525        }
526
527        let input_array = Array2::from_shape_vec((batch_size, expected), batch_data)
528            .map_err(|e| Error::Inference(format!("failed to create batch array: {e}")))?;
529
530        let input_value = Value::from_array(input_array)
531            .map_err(|e| Error::Inference(format!("failed to create input tensor: {e}")))?;
532
533        // Run inference with locked session
534        // IMPORTANT: Session lock must be held while outputs exist because ort::Value
535        // borrows from the session. Dropping the lock before processing outputs would
536        // cause a use-after-free. This is why clippy::significant_drop_tightening is
537        // suppressed on this method.
538        let mut session = self
539            .inner
540            .session
541            .lock()
542            .map_err(|e| Error::Inference(format!("session lock poisoned: {e}")))?;
543
544        let outputs = session
545            .run(ort::inputs![input_value])
546            .map_err(|e| Error::Inference(e.to_string()))?;
547
548        // Process batch outputs
549        self.process_batch_outputs(&outputs, batch_size)
550    }
551
552    /// Process single inference outputs
553    fn process_outputs(&self, outputs: &ort::session::SessionOutputs) -> Result<PredictionResult> {
554        let model_type = self.inner.config.model_type;
555
556        let (embeddings, logits) = match model_type {
557            ModelType::BirdNetV24 => {
558                // Single output: predictions
559                let logits = extract_tensor_data(outputs, 0)?;
560                (None, logits)
561            }
562            ModelType::BirdNetV30 => {
563                // Two outputs: embeddings at 0, predictions at 1
564                let embeddings = extract_tensor_data(outputs, 0)?;
565                let logits = extract_tensor_data(outputs, 1)?;
566                (Some(embeddings), logits)
567            }
568            ModelType::PerchV2 => {
569                // Four outputs: embedding at 0, spatial_embedding at 1, spectrogram at 2, predictions at 3
570                let embeddings = extract_tensor_data(outputs, 0)?;
571                let logits = extract_tensor_data(outputs, 3)?;
572                (Some(embeddings), logits)
573            }
574        };
575
576        let predictions = top_k_predictions(
577            &logits,
578            &self.inner.labels,
579            self.inner.top_k,
580            self.inner.min_confidence,
581        );
582
583        Ok(PredictionResult {
584            model_type,
585            predictions,
586            embeddings,
587            raw_scores: logits,
588        })
589    }
590
591    /// Process batch inference outputs
592    fn process_batch_outputs(
593        &self,
594        outputs: &ort::session::SessionOutputs,
595        batch_size: usize,
596    ) -> Result<Vec<PredictionResult>> {
597        let model_type = self.inner.config.model_type;
598        let num_species = self.inner.config.num_species;
599
600        match model_type {
601            ModelType::BirdNetV24 => {
602                let logits_flat = extract_tensor_data(outputs, 0)?;
603
604                (0..batch_size)
605                    .map(|i| {
606                        let start = i * num_species;
607                        let end = start + num_species;
608                        let logits = &logits_flat[start..end];
609
610                        let predictions = top_k_predictions(
611                            logits,
612                            &self.inner.labels,
613                            self.inner.top_k,
614                            self.inner.min_confidence,
615                        );
616
617                        Ok(PredictionResult {
618                            model_type,
619                            predictions,
620                            embeddings: None,
621                            raw_scores: logits.to_vec(),
622                        })
623                    })
624                    .collect()
625            }
626            ModelType::BirdNetV30 => {
627                let embedding_dim = self.inner.config.embedding_dim.ok_or_else(|| {
628                    Error::Inference(
629                        "embedding_dim missing for model that requires embeddings".into(),
630                    )
631                })?;
632                let emb_flat = extract_tensor_data(outputs, 0)?;
633                let logits_flat = extract_tensor_data(outputs, 1)?;
634
635                (0..batch_size)
636                    .map(|i| {
637                        let emb_start = i * embedding_dim;
638                        let emb_end = emb_start + embedding_dim;
639                        let embeddings = emb_flat[emb_start..emb_end].to_vec();
640
641                        let logits_start = i * num_species;
642                        let logits_end = logits_start + num_species;
643                        let logits = &logits_flat[logits_start..logits_end];
644
645                        let predictions = top_k_predictions(
646                            logits,
647                            &self.inner.labels,
648                            self.inner.top_k,
649                            self.inner.min_confidence,
650                        );
651
652                        Ok(PredictionResult {
653                            model_type,
654                            predictions,
655                            embeddings: Some(embeddings),
656                            raw_scores: logits.to_vec(),
657                        })
658                    })
659                    .collect()
660            }
661            ModelType::PerchV2 => {
662                let embedding_dim = self.inner.config.embedding_dim.ok_or_else(|| {
663                    Error::Inference(
664                        "embedding_dim missing for model that requires embeddings".into(),
665                    )
666                })?;
667                let emb_flat = extract_tensor_data(outputs, 0)?;
668                let logits_flat = extract_tensor_data(outputs, 3)?; // predictions at index 3
669
670                (0..batch_size)
671                    .map(|i| {
672                        let emb_start = i * embedding_dim;
673                        let emb_end = emb_start + embedding_dim;
674                        let embeddings = emb_flat[emb_start..emb_end].to_vec();
675
676                        let logits_start = i * num_species;
677                        let logits_end = logits_start + num_species;
678                        let logits = &logits_flat[logits_start..logits_end];
679
680                        let predictions = top_k_predictions(
681                            logits,
682                            &self.inner.labels,
683                            self.inner.top_k,
684                            self.inner.min_confidence,
685                        );
686
687                        Ok(PredictionResult {
688                            model_type,
689                            predictions,
690                            embeddings: Some(embeddings),
691                            raw_scores: logits.to_vec(),
692                        })
693                    })
694                    .collect()
695            }
696        }
697    }
698}
699
700/// Extract tensor data from session outputs by index
701fn extract_tensor_data(outputs: &ort::session::SessionOutputs, index: usize) -> Result<Vec<f32>> {
702    let output_names: Vec<_> = outputs.keys().collect();
703    let name = output_names
704        .get(index)
705        .ok_or_else(|| Error::Inference(format!("missing output tensor at index {index}")))?;
706
707    let tensor = outputs
708        .get(*name)
709        .ok_or_else(|| Error::Inference(format!("missing output tensor '{name}'")))?;
710
711    let (_, data) = tensor
712        .try_extract_tensor::<f32>()
713        .map_err(|e| Error::Inference(e.to_string()))?;
714
715    Ok(data.to_vec())
716}
717
718#[cfg(test)]
719mod tests {
720    #![allow(clippy::disallowed_methods)]
721    use super::*;
722
723    // Builder validation tests
724
725    #[test]
726    fn test_builder_missing_model_path() {
727        let result = ClassifierBuilder::new()
728            .labels(vec!["species1".to_string()])
729            .build();
730
731        assert!(matches!(result, Err(Error::ModelPathRequired)));
732    }
733
734    #[test]
735    fn test_builder_missing_labels() {
736        let result = ClassifierBuilder::new().model_path("model.onnx").build();
737
738        assert!(matches!(result, Err(Error::LabelsRequired)));
739    }
740
741    #[test]
742    fn test_builder_missing_both() {
743        let result = ClassifierBuilder::new().build();
744
745        // Should fail on model path first
746        assert!(matches!(result, Err(Error::ModelPathRequired)));
747    }
748
749    #[test]
750    fn test_builder_method_chaining() {
751        let builder = ClassifierBuilder::new()
752            .model_path("model.onnx")
753            .labels_path("labels.txt")
754            .top_k(5)
755            .min_confidence(0.5)
756            .model_type(ModelType::BirdNetV24);
757
758        assert_eq!(builder.top_k, 5);
759        assert_eq!(builder.min_confidence, Some(0.5));
760        assert_eq!(builder.model_type_override, Some(ModelType::BirdNetV24));
761    }
762
763    #[test]
764    fn test_builder_default_values() {
765        let builder = ClassifierBuilder::new();
766
767        assert_eq!(builder.top_k, 10); // Default
768        assert_eq!(builder.min_confidence, None);
769        assert_eq!(builder.model_type_override, None);
770        assert!(builder.execution_providers.is_empty());
771        assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cpu); // Default to CPU
772    }
773
774    #[test]
775    fn test_builder_top_k_zero() {
776        let builder = ClassifierBuilder::new()
777            .model_path("model.onnx")
778            .labels(vec!["species1".to_string()])
779            .top_k(0);
780
781        assert_eq!(builder.top_k, 0);
782    }
783
784    #[test]
785    fn test_builder_min_confidence_boundaries() {
786        // Note: The builder intentionally doesn't validate min_confidence bounds.
787        // Values outside [0.0, 1.0] are allowed because:
788        // - Validation happens at runtime during filtering, not at build time
789        // - This gives users flexibility to set aggressive thresholds
790        // - Values >1.0 will filter out all results (sigmoid output is always <1)
791        // - Values <0.0 will filter out nothing (sigmoid output is always >0)
792
793        let builder = ClassifierBuilder::new().min_confidence(0.0);
794        assert_eq!(builder.min_confidence, Some(0.0));
795
796        let builder = ClassifierBuilder::new().min_confidence(1.0);
797        assert_eq!(builder.min_confidence, Some(1.0));
798
799        let builder = ClassifierBuilder::new().min_confidence(1.5);
800        assert_eq!(builder.min_confidence, Some(1.5)); // Will filter all results
801
802        let builder = ClassifierBuilder::new().min_confidence(-0.5);
803        assert_eq!(builder.min_confidence, Some(-0.5)); // Will filter nothing
804    }
805
806    #[test]
807    fn test_builder_labels_path_vs_in_memory() {
808        let builder1 = ClassifierBuilder::new().labels_path("labels.txt");
809
810        assert!(matches!(builder1.labels, Some(Labels::Path(_))));
811
812        let builder2 = ClassifierBuilder::new().labels(vec!["species1".to_string()]);
813
814        assert!(matches!(builder2.labels, Some(Labels::InMemory(_))));
815    }
816
817    #[test]
818    fn test_builder_multiple_execution_providers() {
819        use ort::execution_providers::CPUExecutionProvider;
820
821        let builder = ClassifierBuilder::new()
822            .execution_provider(CPUExecutionProvider::default())
823            .execution_provider(CPUExecutionProvider::default());
824
825        assert_eq!(builder.execution_providers.len(), 2);
826    }
827
828    #[test]
829    fn test_builder_default_trait() {
830        let builder1 = ClassifierBuilder::new();
831        let builder2 = ClassifierBuilder::default();
832
833        assert_eq!(builder1.top_k, builder2.top_k);
834        assert_eq!(builder1.min_confidence, builder2.min_confidence);
835    }
836
837    // Input validation tests (these test predict/predict_batch validation logic)
838
839    #[test]
840    fn test_mock_input_size_validation() {
841        // These tests verify the input size validation logic
842        // without actually creating a full classifier
843
844        let expected_size = 144_000; // BirdNetV24 sample count
845        let wrong_size = 160_000; // BirdNetV30 sample count
846
847        // Simulate what predict() does for validation
848        let segment = vec![0.0f32; wrong_size];
849        if segment.len() != expected_size {
850            let err = Error::InputSize {
851                expected: expected_size,
852                got: segment.len(),
853            };
854            assert!(matches!(err, Error::InputSize { .. }));
855        }
856    }
857
858    #[test]
859    fn test_mock_batch_input_validation() {
860        // Test batch input validation logic
861        let expected_size = 144_000;
862        let segments = [
863            vec![0.0f32; expected_size],
864            vec![0.0f32; 160_000], // Wrong size
865            vec![0.0f32; expected_size],
866        ];
867
868        // Simulate batch validation
869        for (i, seg) in segments.iter().enumerate() {
870            if seg.len() != expected_size {
871                let err = Error::BatchInputSize {
872                    index: i,
873                    expected: expected_size,
874                    got: seg.len(),
875                };
876                assert!(matches!(err, Error::BatchInputSize { index: 1, .. }));
877                assert_eq!(i, 1);
878                break;
879            }
880        }
881    }
882
883    // Edge case tests
884
885    #[test]
886    fn test_empty_batch_handling() {
887        // Verify that empty batch returns empty result
888        let segments: Vec<&[f32]> = vec![];
889        assert!(segments.is_empty());
890        // The actual predict_batch method returns Ok(Vec::new()) for empty input
891    }
892
893    #[test]
894    fn test_labels_enum_debug() {
895        let labels_path = Labels::Path("test.txt".to_string());
896        let debug_str = format!("{labels_path:?}");
897        assert!(debug_str.contains("Path"));
898
899        let labels_mem = Labels::InMemory(vec!["test".to_string()]);
900        let debug_str = format!("{labels_mem:?}");
901        assert!(debug_str.contains("InMemory"));
902    }
903
904    // Execution provider tests
905
906    #[test]
907    fn test_requested_provider_defaults_to_cpu() {
908        let builder = ClassifierBuilder::new();
909        assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cpu);
910    }
911
912    #[test]
913    fn test_builder_debug_includes_requested_provider() {
914        let builder = ClassifierBuilder::new()
915            .model_path("test.onnx")
916            .labels(vec!["species1".to_string()]);
917
918        let debug_str = format!("{builder:?}");
919        assert!(debug_str.contains("requested_provider"));
920        assert!(debug_str.contains("Cpu"));
921    }
922
923    // Typed builder method tests
924
925    #[test]
926    fn test_with_cuda_sets_requested_provider() {
927        let builder = ClassifierBuilder::new().with_cuda();
928        assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cuda);
929        assert_eq!(builder.execution_providers.len(), 1);
930    }
931
932    #[test]
933    fn test_with_tensorrt_sets_requested_provider() {
934        let builder = ClassifierBuilder::new().with_tensorrt();
935        assert_eq!(builder.requested_provider, ExecutionProviderInfo::TensorRt);
936        assert_eq!(builder.execution_providers.len(), 1);
937    }
938
939    #[test]
940    fn test_with_tensorrt_config_sets_requested_provider() {
941        use crate::TensorRTConfig;
942
943        let config = TensorRTConfig::new();
944        let builder = ClassifierBuilder::new().with_tensorrt_config(config);
945        assert_eq!(builder.requested_provider, ExecutionProviderInfo::TensorRt);
946        assert_eq!(builder.execution_providers.len(), 1);
947    }
948
949    #[test]
950    fn test_with_tensorrt_config_custom_settings() {
951        use crate::TensorRTConfig;
952
953        let config = TensorRTConfig::new()
954            .with_fp16(false)
955            .with_builder_optimization_level(5)
956            .with_device_id(1);
957
958        let builder = ClassifierBuilder::new().with_tensorrt_config(config);
959        assert_eq!(builder.requested_provider, ExecutionProviderInfo::TensorRt);
960        assert_eq!(builder.execution_providers.len(), 1);
961    }
962
963    #[test]
964    fn test_with_tensorrt_config_disable_optimizations() {
965        use crate::TensorRTConfig;
966
967        let config = TensorRTConfig::new()
968            .with_fp16(false)
969            .with_cuda_graph(false)
970            .with_engine_cache(false)
971            .with_timing_cache(false);
972
973        let builder = ClassifierBuilder::new().with_tensorrt_config(config);
974        assert_eq!(builder.requested_provider, ExecutionProviderInfo::TensorRt);
975        assert_eq!(builder.execution_providers.len(), 1);
976    }
977
978    #[test]
979    fn test_with_directml_sets_requested_provider() {
980        let builder = ClassifierBuilder::new().with_directml();
981        assert_eq!(builder.requested_provider, ExecutionProviderInfo::DirectMl);
982        assert_eq!(builder.execution_providers.len(), 1);
983    }
984
985    #[test]
986    fn test_with_coreml_sets_requested_provider() {
987        let builder = ClassifierBuilder::new().with_coreml();
988        assert_eq!(builder.requested_provider, ExecutionProviderInfo::CoreMl);
989        assert_eq!(builder.execution_providers.len(), 1);
990    }
991
992    #[test]
993    fn test_with_rocm_sets_requested_provider() {
994        let builder = ClassifierBuilder::new().with_rocm();
995        assert_eq!(builder.requested_provider, ExecutionProviderInfo::Rocm);
996        assert_eq!(builder.execution_providers.len(), 1);
997    }
998
999    #[test]
1000    fn test_with_openvino_sets_requested_provider() {
1001        let builder = ClassifierBuilder::new().with_openvino();
1002        assert_eq!(builder.requested_provider, ExecutionProviderInfo::OpenVino);
1003        assert_eq!(builder.execution_providers.len(), 1);
1004    }
1005
1006    #[test]
1007    fn test_with_onednn_sets_requested_provider() {
1008        let builder = ClassifierBuilder::new().with_onednn();
1009        assert_eq!(builder.requested_provider, ExecutionProviderInfo::OneDnn);
1010        assert_eq!(builder.execution_providers.len(), 1);
1011    }
1012
1013    #[test]
1014    fn test_with_qnn_sets_requested_provider() {
1015        let builder = ClassifierBuilder::new().with_qnn();
1016        assert_eq!(builder.requested_provider, ExecutionProviderInfo::Qnn);
1017        assert_eq!(builder.execution_providers.len(), 1);
1018    }
1019
1020    #[test]
1021    fn test_with_acl_sets_requested_provider() {
1022        let builder = ClassifierBuilder::new().with_acl();
1023        assert_eq!(builder.requested_provider, ExecutionProviderInfo::Acl);
1024        assert_eq!(builder.execution_providers.len(), 1);
1025    }
1026
1027    #[test]
1028    fn test_with_armnn_sets_requested_provider() {
1029        let builder = ClassifierBuilder::new().with_armnn();
1030        assert_eq!(builder.requested_provider, ExecutionProviderInfo::ArmNn);
1031        assert_eq!(builder.execution_providers.len(), 1);
1032    }
1033
1034    #[test]
1035    fn test_chaining_multiple_providers_first_wins() {
1036        let builder = ClassifierBuilder::new().with_cuda().with_tensorrt();
1037        // First non-CPU provider wins (aligns with ort's provider priority)
1038        assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cuda);
1039        // Both providers added to the vec
1040        assert_eq!(builder.execution_providers.len(), 2);
1041    }
1042
1043    #[test]
1044    fn test_chaining_three_providers_first_wins() {
1045        let builder = ClassifierBuilder::new()
1046            .with_cuda()
1047            .with_tensorrt()
1048            .with_directml();
1049        // First non-CPU provider wins (aligns with ort's provider priority)
1050        assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cuda);
1051        // All three providers added
1052        assert_eq!(builder.execution_providers.len(), 3);
1053    }
1054
1055    #[test]
1056    fn test_provider_methods_can_chain_with_other_builders() {
1057        let builder = ClassifierBuilder::new()
1058            .model_path("model.onnx")
1059            .labels_path("labels.txt")
1060            .with_cuda()
1061            .top_k(5)
1062            .min_confidence(0.8);
1063
1064        assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cuda);
1065        assert_eq!(builder.top_k, 5);
1066        assert_eq!(builder.min_confidence, Some(0.8));
1067        assert_eq!(builder.execution_providers.len(), 1);
1068    }
1069
1070    #[test]
1071    fn test_provider_methods_return_self_for_chaining() {
1072        // Verify each method returns Self and can be chained
1073        let builder = ClassifierBuilder::new()
1074            .with_cuda()
1075            .with_tensorrt()
1076            .with_directml()
1077            .with_coreml()
1078            .with_rocm()
1079            .with_openvino()
1080            .with_onednn()
1081            .with_qnn()
1082            .with_acl()
1083            .with_armnn();
1084
1085        // First non-CPU provider wins (aligns with ort's provider priority)
1086        assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cuda);
1087        // All 10 providers added
1088        assert_eq!(builder.execution_providers.len(), 10);
1089    }
1090}