1use crate::detection::detect_model_type;
4use crate::error::{Error, Result};
5use crate::labels::load_labels_from_file;
6use crate::postprocess::top_k_predictions;
7use crate::types::{ExecutionProviderInfo, ModelConfig, ModelType, PredictionResult};
8use ndarray::Array2;
9use ort::session::Session;
10use ort::value::Value;
11use std::sync::{Arc, Mutex};
12
13macro_rules! with_provider_method {
15 ($fn_name:ident, $provider_struct:ident, $provider_enum:ident, $doc:expr) => {
16 #[doc = $doc]
17 #[must_use]
18 pub fn $fn_name(mut self) -> Self {
19 use ort::execution_providers::$provider_struct;
20 self.execution_providers
21 .push($provider_struct::default().into());
22 if self.requested_provider == ExecutionProviderInfo::Cpu {
27 self.requested_provider = ExecutionProviderInfo::$provider_enum;
28 }
29 self
30 }
31 };
32}
33
34#[derive(Debug)]
36enum Labels {
37 Path(String),
38 InMemory(Vec<String>),
39}
40
41#[derive(Debug)]
43pub struct ClassifierBuilder {
44 model_path: Option<String>,
45 labels: Option<Labels>,
46 model_type_override: Option<ModelType>,
47 execution_providers: Vec<ort::execution_providers::ExecutionProviderDispatch>,
48 requested_provider: ExecutionProviderInfo,
49 top_k: usize,
50 min_confidence: Option<f32>,
51}
52
53impl Default for ClassifierBuilder {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl ClassifierBuilder {
60 #[must_use]
62 pub const fn new() -> Self {
63 Self {
64 model_path: None,
65 labels: None,
66 model_type_override: None,
67 execution_providers: Vec::new(),
68 requested_provider: ExecutionProviderInfo::Cpu,
69 top_k: 10,
70 min_confidence: None,
71 }
72 }
73
74 #[must_use]
76 pub fn model_path(mut self, path: impl Into<String>) -> Self {
77 self.model_path = Some(path.into());
78 self
79 }
80
81 #[must_use]
83 pub fn labels_path(mut self, path: impl Into<String>) -> Self {
84 self.labels = Some(Labels::Path(path.into()));
85 self
86 }
87
88 #[must_use]
90 pub fn labels(mut self, labels: Vec<String>) -> Self {
91 self.labels = Some(Labels::InMemory(labels));
92 self
93 }
94
95 #[must_use]
97 pub const fn model_type(mut self, model_type: ModelType) -> Self {
98 self.model_type_override = Some(model_type);
99 self
100 }
101
102 #[must_use]
107 pub fn execution_provider(
108 mut self,
109 provider: impl Into<ort::execution_providers::ExecutionProviderDispatch>,
110 ) -> Self {
111 self.execution_providers.push(provider.into());
112 self
113 }
114
115 #[must_use]
117 pub const fn top_k(mut self, k: usize) -> Self {
118 self.top_k = k;
119 self
120 }
121
122 #[must_use]
124 pub const fn min_confidence(mut self, threshold: f32) -> Self {
125 self.min_confidence = Some(threshold);
126 self
127 }
128
129 with_provider_method!(
130 with_cuda,
131 CUDAExecutionProvider,
132 Cuda,
133 "Request CUDA execution provider (NVIDIA GPU)"
134 );
135
136 #[must_use]
171 pub fn with_tensorrt(mut self) -> Self {
172 use ort::execution_providers::TensorRTExecutionProvider;
173
174 let config = crate::tensorrt_config::TensorRTConfig::new();
175 let provider = config.apply_to(TensorRTExecutionProvider::default());
176
177 self.execution_providers.push(provider.into());
178
179 if self.requested_provider == ExecutionProviderInfo::Cpu {
180 self.requested_provider = ExecutionProviderInfo::TensorRt;
181 }
182
183 self
184 }
185
186 #[must_use]
209 pub fn with_tensorrt_config(mut self, config: crate::tensorrt_config::TensorRTConfig) -> Self {
210 use ort::execution_providers::TensorRTExecutionProvider;
211
212 let provider = config.apply_to(TensorRTExecutionProvider::default());
213 self.execution_providers.push(provider.into());
214
215 if self.requested_provider == ExecutionProviderInfo::Cpu {
216 self.requested_provider = ExecutionProviderInfo::TensorRt;
217 }
218
219 self
220 }
221
222 with_provider_method!(
223 with_directml,
224 DirectMLExecutionProvider,
225 DirectMl,
226 "Request `DirectML` execution provider (Windows GPU)"
227 );
228 with_provider_method!(
229 with_coreml,
230 CoreMLExecutionProvider,
231 CoreMl,
232 "Request `CoreML` execution provider (Apple Neural Engine)"
233 );
234 with_provider_method!(
235 with_rocm,
236 ROCmExecutionProvider,
237 Rocm,
238 "Request `ROCm` execution provider (AMD GPU)"
239 );
240 with_provider_method!(
241 with_openvino,
242 OpenVINOExecutionProvider,
243 OpenVino,
244 "Request `OpenVINO` execution provider (Intel accelerator)"
245 );
246 with_provider_method!(
247 with_onednn,
248 OneDNNExecutionProvider,
249 OneDnn,
250 "Request oneDNN execution provider (Intel accelerator)"
251 );
252 with_provider_method!(
253 with_qnn,
254 QNNExecutionProvider,
255 Qnn,
256 "Request QNN execution provider (Qualcomm NPU)"
257 );
258 with_provider_method!(
259 with_acl,
260 ACLExecutionProvider,
261 Acl,
262 "Request ACL execution provider (Arm Compute Library)"
263 );
264 with_provider_method!(
265 with_armnn,
266 ArmNNExecutionProvider,
267 ArmNn,
268 "Request `ArmNN` execution provider (Arm Neural Network)"
269 );
270
271 pub fn build(self) -> Result<Classifier> {
282 let model_path = self.model_path.ok_or(Error::ModelPathRequired)?;
284 let labels_source = self.labels.ok_or(Error::LabelsRequired)?;
285
286 let mut session_builder = Session::builder().map_err(Error::ModelLoad)?;
288
289 for provider in self.execution_providers {
290 session_builder = session_builder
291 .with_execution_providers([provider])
292 .map_err(Error::ModelLoad)?;
293 }
294
295 let session = session_builder
296 .commit_from_file(&model_path)
297 .map_err(Error::ModelLoad)?;
298
299 let input_shape = extract_input_shape(&session)?;
301 let output_shapes = extract_output_shapes(&session)?;
302
303 let config = detect_model_type(&input_shape, &output_shapes, self.model_type_override)?;
305
306 let labels = match labels_source {
308 Labels::Path(path) => load_labels_from_file(&path, config.model_type)?,
309 Labels::InMemory(labels) => labels,
310 };
311
312 if labels.len() != config.num_species {
314 return Err(Error::LabelCount {
315 expected: config.num_species,
316 got: labels.len(),
317 });
318 }
319
320 Ok(Classifier {
321 inner: Arc::new(ClassifierInner {
322 session: Mutex::new(session),
323 config,
324 labels,
325 requested_provider: self.requested_provider,
326 top_k: self.top_k,
327 min_confidence: self.min_confidence,
328 }),
329 })
330 }
331}
332
333fn extract_input_shape(session: &Session) -> Result<Vec<i64>> {
335 let inputs = session
336 .inputs
337 .first()
338 .ok_or_else(|| Error::ModelDetection {
339 reason: "model has no inputs".to_string(),
340 })?;
341
342 let shape = inputs
343 .input_type
344 .tensor_shape()
345 .ok_or_else(|| Error::ModelDetection {
346 reason: "input is not a tensor".to_string(),
347 })?;
348
349 Ok(shape.iter().copied().collect())
350}
351
352fn extract_output_shapes(session: &Session) -> Result<Vec<Vec<i64>>> {
354 session
355 .outputs
356 .iter()
357 .map(|output| {
358 let shape = output
359 .output_type
360 .tensor_shape()
361 .ok_or_else(|| Error::ModelDetection {
362 reason: "output is not a tensor".to_string(),
363 })?;
364 Ok(shape.iter().copied().collect())
365 })
366 .collect()
367}
368
369struct ClassifierInner {
371 session: Mutex<Session>,
372 config: ModelConfig,
373 labels: Vec<String>,
374 requested_provider: ExecutionProviderInfo,
375 top_k: usize,
376 min_confidence: Option<f32>,
377}
378
379#[derive(Clone)]
383pub struct Classifier {
384 inner: Arc<ClassifierInner>,
385}
386
387impl std::fmt::Debug for Classifier {
388 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389 f.debug_struct("Classifier")
390 .field("config", &self.inner.config)
391 .field("labels_count", &self.inner.labels.len())
392 .field("requested_provider", &self.inner.requested_provider)
393 .field("top_k", &self.inner.top_k)
394 .field("min_confidence", &self.inner.min_confidence)
395 .finish_non_exhaustive()
396 }
397}
398
399impl Classifier {
400 #[must_use]
402 pub const fn builder() -> ClassifierBuilder {
403 ClassifierBuilder::new()
404 }
405
406 #[must_use]
408 pub fn config(&self) -> &ModelConfig {
409 &self.inner.config
410 }
411
412 #[must_use]
414 pub fn labels(&self) -> &[String] {
415 &self.inner.labels
416 }
417
418 #[must_use]
431 pub fn requested_provider(&self) -> ExecutionProviderInfo {
432 self.inner.requested_provider
433 }
434
435 #[allow(clippy::significant_drop_tightening)]
450 pub fn predict(&self, segment: &[f32]) -> Result<PredictionResult> {
451 let expected = self.inner.config.sample_count;
453 if segment.len() != expected {
454 return Err(Error::InputSize {
455 expected,
456 got: segment.len(),
457 });
458 }
459
460 let input_array = Array2::from_shape_vec((1, segment.len()), segment.to_vec())
462 .map_err(|e| Error::Inference(format!("failed to create input array: {e}")))?;
463
464 let input_value = Value::from_array(input_array)
465 .map_err(|e| Error::Inference(format!("failed to create input tensor: {e}")))?;
466
467 let mut session = self
473 .inner
474 .session
475 .lock()
476 .map_err(|e| Error::Inference(format!("session lock poisoned: {e}")))?;
477
478 let outputs = session
479 .run(ort::inputs![input_value])
480 .map_err(|e| Error::Inference(e.to_string()))?;
481
482 self.process_outputs(&outputs)
484 }
485
486 #[allow(clippy::significant_drop_tightening)]
501 pub fn predict_batch(&self, segments: &[&[f32]]) -> Result<Vec<PredictionResult>> {
502 if segments.is_empty() {
503 return Ok(Vec::new());
504 }
505
506 let expected = self.inner.config.sample_count;
507
508 for (i, seg) in segments.iter().enumerate() {
510 if seg.len() != expected {
511 return Err(Error::BatchInputSize {
512 index: i,
513 expected,
514 got: seg.len(),
515 });
516 }
517 }
518
519 let batch_size = segments.len();
520
521 let mut batch_data = Vec::with_capacity(batch_size * expected);
523 for seg in segments {
524 batch_data.extend_from_slice(seg);
525 }
526
527 let input_array = Array2::from_shape_vec((batch_size, expected), batch_data)
528 .map_err(|e| Error::Inference(format!("failed to create batch array: {e}")))?;
529
530 let input_value = Value::from_array(input_array)
531 .map_err(|e| Error::Inference(format!("failed to create input tensor: {e}")))?;
532
533 let mut session = self
539 .inner
540 .session
541 .lock()
542 .map_err(|e| Error::Inference(format!("session lock poisoned: {e}")))?;
543
544 let outputs = session
545 .run(ort::inputs![input_value])
546 .map_err(|e| Error::Inference(e.to_string()))?;
547
548 self.process_batch_outputs(&outputs, batch_size)
550 }
551
552 fn process_outputs(&self, outputs: &ort::session::SessionOutputs) -> Result<PredictionResult> {
554 let model_type = self.inner.config.model_type;
555
556 let (embeddings, logits) = match model_type {
557 ModelType::BirdNetV24 => {
558 let logits = extract_tensor_data(outputs, 0)?;
560 (None, logits)
561 }
562 ModelType::BirdNetV30 => {
563 let embeddings = extract_tensor_data(outputs, 0)?;
565 let logits = extract_tensor_data(outputs, 1)?;
566 (Some(embeddings), logits)
567 }
568 ModelType::PerchV2 => {
569 let embeddings = extract_tensor_data(outputs, 0)?;
571 let logits = extract_tensor_data(outputs, 3)?;
572 (Some(embeddings), logits)
573 }
574 };
575
576 let predictions = top_k_predictions(
577 &logits,
578 &self.inner.labels,
579 self.inner.top_k,
580 self.inner.min_confidence,
581 );
582
583 Ok(PredictionResult {
584 model_type,
585 predictions,
586 embeddings,
587 raw_scores: logits,
588 })
589 }
590
591 fn process_batch_outputs(
593 &self,
594 outputs: &ort::session::SessionOutputs,
595 batch_size: usize,
596 ) -> Result<Vec<PredictionResult>> {
597 let model_type = self.inner.config.model_type;
598 let num_species = self.inner.config.num_species;
599
600 match model_type {
601 ModelType::BirdNetV24 => {
602 let logits_flat = extract_tensor_data(outputs, 0)?;
603
604 (0..batch_size)
605 .map(|i| {
606 let start = i * num_species;
607 let end = start + num_species;
608 let logits = &logits_flat[start..end];
609
610 let predictions = top_k_predictions(
611 logits,
612 &self.inner.labels,
613 self.inner.top_k,
614 self.inner.min_confidence,
615 );
616
617 Ok(PredictionResult {
618 model_type,
619 predictions,
620 embeddings: None,
621 raw_scores: logits.to_vec(),
622 })
623 })
624 .collect()
625 }
626 ModelType::BirdNetV30 => {
627 let embedding_dim = self.inner.config.embedding_dim.ok_or_else(|| {
628 Error::Inference(
629 "embedding_dim missing for model that requires embeddings".into(),
630 )
631 })?;
632 let emb_flat = extract_tensor_data(outputs, 0)?;
633 let logits_flat = extract_tensor_data(outputs, 1)?;
634
635 (0..batch_size)
636 .map(|i| {
637 let emb_start = i * embedding_dim;
638 let emb_end = emb_start + embedding_dim;
639 let embeddings = emb_flat[emb_start..emb_end].to_vec();
640
641 let logits_start = i * num_species;
642 let logits_end = logits_start + num_species;
643 let logits = &logits_flat[logits_start..logits_end];
644
645 let predictions = top_k_predictions(
646 logits,
647 &self.inner.labels,
648 self.inner.top_k,
649 self.inner.min_confidence,
650 );
651
652 Ok(PredictionResult {
653 model_type,
654 predictions,
655 embeddings: Some(embeddings),
656 raw_scores: logits.to_vec(),
657 })
658 })
659 .collect()
660 }
661 ModelType::PerchV2 => {
662 let embedding_dim = self.inner.config.embedding_dim.ok_or_else(|| {
663 Error::Inference(
664 "embedding_dim missing for model that requires embeddings".into(),
665 )
666 })?;
667 let emb_flat = extract_tensor_data(outputs, 0)?;
668 let logits_flat = extract_tensor_data(outputs, 3)?; (0..batch_size)
671 .map(|i| {
672 let emb_start = i * embedding_dim;
673 let emb_end = emb_start + embedding_dim;
674 let embeddings = emb_flat[emb_start..emb_end].to_vec();
675
676 let logits_start = i * num_species;
677 let logits_end = logits_start + num_species;
678 let logits = &logits_flat[logits_start..logits_end];
679
680 let predictions = top_k_predictions(
681 logits,
682 &self.inner.labels,
683 self.inner.top_k,
684 self.inner.min_confidence,
685 );
686
687 Ok(PredictionResult {
688 model_type,
689 predictions,
690 embeddings: Some(embeddings),
691 raw_scores: logits.to_vec(),
692 })
693 })
694 .collect()
695 }
696 }
697 }
698}
699
700fn extract_tensor_data(outputs: &ort::session::SessionOutputs, index: usize) -> Result<Vec<f32>> {
702 let output_names: Vec<_> = outputs.keys().collect();
703 let name = output_names
704 .get(index)
705 .ok_or_else(|| Error::Inference(format!("missing output tensor at index {index}")))?;
706
707 let tensor = outputs
708 .get(*name)
709 .ok_or_else(|| Error::Inference(format!("missing output tensor '{name}'")))?;
710
711 let (_, data) = tensor
712 .try_extract_tensor::<f32>()
713 .map_err(|e| Error::Inference(e.to_string()))?;
714
715 Ok(data.to_vec())
716}
717
718#[cfg(test)]
719mod tests {
720 #![allow(clippy::disallowed_methods)]
721 use super::*;
722
723 #[test]
726 fn test_builder_missing_model_path() {
727 let result = ClassifierBuilder::new()
728 .labels(vec!["species1".to_string()])
729 .build();
730
731 assert!(matches!(result, Err(Error::ModelPathRequired)));
732 }
733
734 #[test]
735 fn test_builder_missing_labels() {
736 let result = ClassifierBuilder::new().model_path("model.onnx").build();
737
738 assert!(matches!(result, Err(Error::LabelsRequired)));
739 }
740
741 #[test]
742 fn test_builder_missing_both() {
743 let result = ClassifierBuilder::new().build();
744
745 assert!(matches!(result, Err(Error::ModelPathRequired)));
747 }
748
749 #[test]
750 fn test_builder_method_chaining() {
751 let builder = ClassifierBuilder::new()
752 .model_path("model.onnx")
753 .labels_path("labels.txt")
754 .top_k(5)
755 .min_confidence(0.5)
756 .model_type(ModelType::BirdNetV24);
757
758 assert_eq!(builder.top_k, 5);
759 assert_eq!(builder.min_confidence, Some(0.5));
760 assert_eq!(builder.model_type_override, Some(ModelType::BirdNetV24));
761 }
762
763 #[test]
764 fn test_builder_default_values() {
765 let builder = ClassifierBuilder::new();
766
767 assert_eq!(builder.top_k, 10); assert_eq!(builder.min_confidence, None);
769 assert_eq!(builder.model_type_override, None);
770 assert!(builder.execution_providers.is_empty());
771 assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cpu); }
773
774 #[test]
775 fn test_builder_top_k_zero() {
776 let builder = ClassifierBuilder::new()
777 .model_path("model.onnx")
778 .labels(vec!["species1".to_string()])
779 .top_k(0);
780
781 assert_eq!(builder.top_k, 0);
782 }
783
784 #[test]
785 fn test_builder_min_confidence_boundaries() {
786 let builder = ClassifierBuilder::new().min_confidence(0.0);
794 assert_eq!(builder.min_confidence, Some(0.0));
795
796 let builder = ClassifierBuilder::new().min_confidence(1.0);
797 assert_eq!(builder.min_confidence, Some(1.0));
798
799 let builder = ClassifierBuilder::new().min_confidence(1.5);
800 assert_eq!(builder.min_confidence, Some(1.5)); let builder = ClassifierBuilder::new().min_confidence(-0.5);
803 assert_eq!(builder.min_confidence, Some(-0.5)); }
805
806 #[test]
807 fn test_builder_labels_path_vs_in_memory() {
808 let builder1 = ClassifierBuilder::new().labels_path("labels.txt");
809
810 assert!(matches!(builder1.labels, Some(Labels::Path(_))));
811
812 let builder2 = ClassifierBuilder::new().labels(vec!["species1".to_string()]);
813
814 assert!(matches!(builder2.labels, Some(Labels::InMemory(_))));
815 }
816
817 #[test]
818 fn test_builder_multiple_execution_providers() {
819 use ort::execution_providers::CPUExecutionProvider;
820
821 let builder = ClassifierBuilder::new()
822 .execution_provider(CPUExecutionProvider::default())
823 .execution_provider(CPUExecutionProvider::default());
824
825 assert_eq!(builder.execution_providers.len(), 2);
826 }
827
828 #[test]
829 fn test_builder_default_trait() {
830 let builder1 = ClassifierBuilder::new();
831 let builder2 = ClassifierBuilder::default();
832
833 assert_eq!(builder1.top_k, builder2.top_k);
834 assert_eq!(builder1.min_confidence, builder2.min_confidence);
835 }
836
837 #[test]
840 fn test_mock_input_size_validation() {
841 let expected_size = 144_000; let wrong_size = 160_000; let segment = vec![0.0f32; wrong_size];
849 if segment.len() != expected_size {
850 let err = Error::InputSize {
851 expected: expected_size,
852 got: segment.len(),
853 };
854 assert!(matches!(err, Error::InputSize { .. }));
855 }
856 }
857
858 #[test]
859 fn test_mock_batch_input_validation() {
860 let expected_size = 144_000;
862 let segments = [
863 vec![0.0f32; expected_size],
864 vec![0.0f32; 160_000], vec![0.0f32; expected_size],
866 ];
867
868 for (i, seg) in segments.iter().enumerate() {
870 if seg.len() != expected_size {
871 let err = Error::BatchInputSize {
872 index: i,
873 expected: expected_size,
874 got: seg.len(),
875 };
876 assert!(matches!(err, Error::BatchInputSize { index: 1, .. }));
877 assert_eq!(i, 1);
878 break;
879 }
880 }
881 }
882
883 #[test]
886 fn test_empty_batch_handling() {
887 let segments: Vec<&[f32]> = vec![];
889 assert!(segments.is_empty());
890 }
892
893 #[test]
894 fn test_labels_enum_debug() {
895 let labels_path = Labels::Path("test.txt".to_string());
896 let debug_str = format!("{labels_path:?}");
897 assert!(debug_str.contains("Path"));
898
899 let labels_mem = Labels::InMemory(vec!["test".to_string()]);
900 let debug_str = format!("{labels_mem:?}");
901 assert!(debug_str.contains("InMemory"));
902 }
903
904 #[test]
907 fn test_requested_provider_defaults_to_cpu() {
908 let builder = ClassifierBuilder::new();
909 assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cpu);
910 }
911
912 #[test]
913 fn test_builder_debug_includes_requested_provider() {
914 let builder = ClassifierBuilder::new()
915 .model_path("test.onnx")
916 .labels(vec!["species1".to_string()]);
917
918 let debug_str = format!("{builder:?}");
919 assert!(debug_str.contains("requested_provider"));
920 assert!(debug_str.contains("Cpu"));
921 }
922
923 #[test]
926 fn test_with_cuda_sets_requested_provider() {
927 let builder = ClassifierBuilder::new().with_cuda();
928 assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cuda);
929 assert_eq!(builder.execution_providers.len(), 1);
930 }
931
932 #[test]
933 fn test_with_tensorrt_sets_requested_provider() {
934 let builder = ClassifierBuilder::new().with_tensorrt();
935 assert_eq!(builder.requested_provider, ExecutionProviderInfo::TensorRt);
936 assert_eq!(builder.execution_providers.len(), 1);
937 }
938
939 #[test]
940 fn test_with_tensorrt_config_sets_requested_provider() {
941 use crate::TensorRTConfig;
942
943 let config = TensorRTConfig::new();
944 let builder = ClassifierBuilder::new().with_tensorrt_config(config);
945 assert_eq!(builder.requested_provider, ExecutionProviderInfo::TensorRt);
946 assert_eq!(builder.execution_providers.len(), 1);
947 }
948
949 #[test]
950 fn test_with_tensorrt_config_custom_settings() {
951 use crate::TensorRTConfig;
952
953 let config = TensorRTConfig::new()
954 .with_fp16(false)
955 .with_builder_optimization_level(5)
956 .with_device_id(1);
957
958 let builder = ClassifierBuilder::new().with_tensorrt_config(config);
959 assert_eq!(builder.requested_provider, ExecutionProviderInfo::TensorRt);
960 assert_eq!(builder.execution_providers.len(), 1);
961 }
962
963 #[test]
964 fn test_with_tensorrt_config_disable_optimizations() {
965 use crate::TensorRTConfig;
966
967 let config = TensorRTConfig::new()
968 .with_fp16(false)
969 .with_cuda_graph(false)
970 .with_engine_cache(false)
971 .with_timing_cache(false);
972
973 let builder = ClassifierBuilder::new().with_tensorrt_config(config);
974 assert_eq!(builder.requested_provider, ExecutionProviderInfo::TensorRt);
975 assert_eq!(builder.execution_providers.len(), 1);
976 }
977
978 #[test]
979 fn test_with_directml_sets_requested_provider() {
980 let builder = ClassifierBuilder::new().with_directml();
981 assert_eq!(builder.requested_provider, ExecutionProviderInfo::DirectMl);
982 assert_eq!(builder.execution_providers.len(), 1);
983 }
984
985 #[test]
986 fn test_with_coreml_sets_requested_provider() {
987 let builder = ClassifierBuilder::new().with_coreml();
988 assert_eq!(builder.requested_provider, ExecutionProviderInfo::CoreMl);
989 assert_eq!(builder.execution_providers.len(), 1);
990 }
991
992 #[test]
993 fn test_with_rocm_sets_requested_provider() {
994 let builder = ClassifierBuilder::new().with_rocm();
995 assert_eq!(builder.requested_provider, ExecutionProviderInfo::Rocm);
996 assert_eq!(builder.execution_providers.len(), 1);
997 }
998
999 #[test]
1000 fn test_with_openvino_sets_requested_provider() {
1001 let builder = ClassifierBuilder::new().with_openvino();
1002 assert_eq!(builder.requested_provider, ExecutionProviderInfo::OpenVino);
1003 assert_eq!(builder.execution_providers.len(), 1);
1004 }
1005
1006 #[test]
1007 fn test_with_onednn_sets_requested_provider() {
1008 let builder = ClassifierBuilder::new().with_onednn();
1009 assert_eq!(builder.requested_provider, ExecutionProviderInfo::OneDnn);
1010 assert_eq!(builder.execution_providers.len(), 1);
1011 }
1012
1013 #[test]
1014 fn test_with_qnn_sets_requested_provider() {
1015 let builder = ClassifierBuilder::new().with_qnn();
1016 assert_eq!(builder.requested_provider, ExecutionProviderInfo::Qnn);
1017 assert_eq!(builder.execution_providers.len(), 1);
1018 }
1019
1020 #[test]
1021 fn test_with_acl_sets_requested_provider() {
1022 let builder = ClassifierBuilder::new().with_acl();
1023 assert_eq!(builder.requested_provider, ExecutionProviderInfo::Acl);
1024 assert_eq!(builder.execution_providers.len(), 1);
1025 }
1026
1027 #[test]
1028 fn test_with_armnn_sets_requested_provider() {
1029 let builder = ClassifierBuilder::new().with_armnn();
1030 assert_eq!(builder.requested_provider, ExecutionProviderInfo::ArmNn);
1031 assert_eq!(builder.execution_providers.len(), 1);
1032 }
1033
1034 #[test]
1035 fn test_chaining_multiple_providers_first_wins() {
1036 let builder = ClassifierBuilder::new().with_cuda().with_tensorrt();
1037 assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cuda);
1039 assert_eq!(builder.execution_providers.len(), 2);
1041 }
1042
1043 #[test]
1044 fn test_chaining_three_providers_first_wins() {
1045 let builder = ClassifierBuilder::new()
1046 .with_cuda()
1047 .with_tensorrt()
1048 .with_directml();
1049 assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cuda);
1051 assert_eq!(builder.execution_providers.len(), 3);
1053 }
1054
1055 #[test]
1056 fn test_provider_methods_can_chain_with_other_builders() {
1057 let builder = ClassifierBuilder::new()
1058 .model_path("model.onnx")
1059 .labels_path("labels.txt")
1060 .with_cuda()
1061 .top_k(5)
1062 .min_confidence(0.8);
1063
1064 assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cuda);
1065 assert_eq!(builder.top_k, 5);
1066 assert_eq!(builder.min_confidence, Some(0.8));
1067 assert_eq!(builder.execution_providers.len(), 1);
1068 }
1069
1070 #[test]
1071 fn test_provider_methods_return_self_for_chaining() {
1072 let builder = ClassifierBuilder::new()
1074 .with_cuda()
1075 .with_tensorrt()
1076 .with_directml()
1077 .with_coreml()
1078 .with_rocm()
1079 .with_openvino()
1080 .with_onednn()
1081 .with_qnn()
1082 .with_acl()
1083 .with_armnn();
1084
1085 assert_eq!(builder.requested_provider, ExecutionProviderInfo::Cuda);
1087 assert_eq!(builder.execution_providers.len(), 10);
1089 }
1090}