1use std::collections::BTreeMap;
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6use crate::ModelPreset;
7
8pub const CUDA_OXIDE_BOOK_URL: &str = "https://nvlabs.github.io/cuda-oxide/index.html";
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum ModelTask {
15 ObjectDetection,
17 PoseEstimation2d,
19 PoseLifting3d,
21 ImageClassification,
23 ImageSegmentation,
25 ImageEmbedding,
27 FaceDetection,
29 FaceEmbedding,
31 Ocr,
33 AudioClassification,
35 AudioEventDetection,
37 AudioEmbedding,
39 SpeechRecognition,
41 SpeakerDiarization,
43 SourceSeparation,
45 AudioGeneration,
47 SpeakerConditionedTts,
49 TextClassification,
51 TokenClassification,
53 ZeroShotClassification,
55 TextEmbedding,
57 Summarization,
59 Reranking,
61 QuestionAnswering,
63 MultimodalEmbedding,
65 Custom(String),
67}
68
69impl ModelTask {
70 pub fn default_label(&self) -> &'static str {
72 match self {
73 Self::ObjectDetection => "object",
74 Self::PoseEstimation2d => "pose_2d",
75 Self::PoseLifting3d => "pose_3d",
76 Self::ImageClassification => "scene",
77 Self::ImageSegmentation => "mask",
78 Self::ImageEmbedding => "image_embedding",
79 Self::FaceDetection => "face",
80 Self::FaceEmbedding => "face_embedding",
81 Self::Ocr => "ocr",
82 Self::AudioClassification => "audio_class",
83 Self::AudioEventDetection => "audio_event",
84 Self::AudioEmbedding => "audio_embedding",
85 Self::SpeechRecognition => "speech",
86 Self::SpeakerDiarization => "speaker",
87 Self::SourceSeparation => "stem",
88 Self::AudioGeneration => "audio_generation",
89 Self::SpeakerConditionedTts => "speaker_conditioned_tts",
90 Self::TextClassification => "semantic",
91 Self::TokenClassification => "token",
92 Self::ZeroShotClassification => "zero_shot",
93 Self::TextEmbedding => "embedding",
94 Self::Summarization => "summary",
95 Self::Reranking => "reranking",
96 Self::QuestionAnswering => "question_answering",
97 Self::MultimodalEmbedding => "multimodal_embedding",
98 Self::Custom(_) => "custom",
99 }
100 }
101
102 pub fn as_protocol_str(&self) -> &str {
104 match self {
105 Self::ObjectDetection => "object_detection",
106 Self::PoseEstimation2d => "pose_estimation_2d",
107 Self::PoseLifting3d => "pose_lifting_3d",
108 Self::ImageClassification => "image_classification",
109 Self::ImageSegmentation => "image_segmentation",
110 Self::ImageEmbedding => "image_embedding",
111 Self::FaceDetection => "face_detection",
112 Self::FaceEmbedding => "face_embedding",
113 Self::Ocr => "ocr",
114 Self::AudioClassification => "audio_classification",
115 Self::AudioEventDetection => "audio_event_detection",
116 Self::AudioEmbedding => "audio_embedding",
117 Self::SpeechRecognition => "speech_recognition",
118 Self::SpeakerDiarization => "speaker_diarization",
119 Self::SourceSeparation => "source_separation",
120 Self::AudioGeneration => "audio_generation",
121 Self::SpeakerConditionedTts => "speaker_conditioned_tts",
122 Self::TextClassification => "text_classification",
123 Self::TokenClassification => "token_classification",
124 Self::ZeroShotClassification => "zero_shot_classification",
125 Self::TextEmbedding => "text_embedding",
126 Self::Summarization => "summarization",
127 Self::Reranking => "reranking",
128 Self::QuestionAnswering => "question_answering",
129 Self::MultimodalEmbedding => "multimodal_embedding",
130 Self::Custom(kind) => kind.as_str(),
131 }
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
136#[serde(rename_all = "snake_case")]
137pub enum ModelFileRequest {
139 Required(String),
141 Optional(String),
143 FirstAvailable(Vec<String>),
145}
146
147impl ModelFileRequest {
148 pub fn required(path: impl Into<String>) -> Self {
150 Self::Required(path.into())
151 }
152
153 pub fn optional(path: impl Into<String>) -> Self {
155 Self::Optional(path.into())
156 }
157
158 pub fn first_available(paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
160 Self::FirstAvailable(paths.into_iter().map(Into::into).collect())
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
166#[serde(rename_all = "snake_case", tag = "kind")]
167pub enum ModelSource {
168 HuggingFace { repo_id: String, revision: String },
170 LocalPath { path: PathBuf },
172 ExternalCommand { command: PathBuf },
174 ComfyUiInventory { root: PathBuf },
176 Custom(String),
178}
179
180impl ModelSource {
181 pub fn kind(&self) -> &str {
183 match self {
184 Self::HuggingFace { .. } => "hugging_face",
185 Self::LocalPath { .. } => "local_path",
186 Self::ExternalCommand { .. } => "external_command",
187 Self::ComfyUiInventory { .. } => "comfyui_inventory",
188 Self::Custom(kind) => kind.as_str(),
189 }
190 }
191}
192
193#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
194pub struct ModelSpec {
196 pub name: String,
198 pub task: ModelTask,
200 pub source: ModelSource,
202 pub files: Vec<ModelFileRequest>,
204 #[serde(default)]
206 pub metadata: BTreeMap<String, String>,
207 #[serde(default)]
209 #[deprecated(note = "use ModelSpec::source or ModelSpec::repo_id() instead")]
210 pub repo_id: String,
211 #[serde(default)]
213 #[deprecated(note = "use ModelSpec::source or ModelSpec::revision() instead")]
214 pub revision: String,
215}
216
217pub type HuggingFaceModelSpec = ModelSpec;
219
220#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
221#[serde(rename_all = "snake_case")]
222pub enum ModelRuntimeBackend {
224 Cpu,
226 Onnx,
228 Candle,
230 CudaOxide,
232 WhisperCpp,
234 Demucs,
236 OpenCv,
238 ComfyUi,
240 External,
242 Heuristic,
244 Imported,
246 Custom(String),
248}
249
250impl ModelRuntimeBackend {
251 pub fn as_str(&self) -> &str {
253 match self {
254 Self::Cpu => "cpu",
255 Self::Onnx => "onnx",
256 Self::Candle => "candle",
257 Self::CudaOxide => "cuda_oxide",
258 Self::WhisperCpp => "whisper_cpp",
259 Self::Demucs => "demucs",
260 Self::OpenCv => "opencv",
261 Self::ComfyUi => "comfyui",
262 Self::External => "external",
263 Self::Heuristic => "heuristic",
264 Self::Imported => "imported",
265 Self::Custom(value) => value.as_str(),
266 }
267 }
268}
269
270#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
271#[serde(rename_all = "snake_case")]
272pub enum RuntimePreference {
274 Auto,
276 Native,
278 External,
280 Heuristic,
282 Imported,
284}
285
286#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
287#[serde(rename_all = "snake_case")]
288pub enum FallbackPolicy {
290 #[default]
292 Error,
293 FastFallback,
295 HeuristicFallback,
297}
298
299#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
300#[serde(rename_all = "camelCase")]
301pub struct ModelRuntimeSelection {
303 #[serde(default)]
305 pub model: Option<ModelSpec>,
306 #[serde(default)]
308 pub backend: Option<ModelRuntimeBackend>,
309 #[serde(default)]
311 pub bundle_dir: Option<PathBuf>,
312 #[serde(default)]
314 pub fallback_policy: FallbackPolicy,
315}
316
317#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
318pub struct CudaOxideRuntimeConfig {
320 pub device_index: u32,
322 pub target_sm: Option<String>,
324 pub cargo_oxide_command: String,
326 pub documentation_url: String,
328}
329
330impl Default for CudaOxideRuntimeConfig {
331 fn default() -> Self {
332 Self {
333 device_index: 0,
334 target_sm: None,
335 cargo_oxide_command: "cargo oxide".to_string(),
336 documentation_url: CUDA_OXIDE_BOOK_URL.to_string(),
337 }
338 }
339}
340
341impl CudaOxideRuntimeConfig {
342 pub fn new() -> Self {
344 Self::default()
345 }
346
347 pub fn device_index(mut self, value: u32) -> Self {
349 self.device_index = value;
350 self
351 }
352
353 pub fn target_sm(mut self, value: impl Into<String>) -> Self {
355 self.target_sm = Some(value.into());
356 self
357 }
358
359 pub fn cargo_oxide_command(mut self, value: impl Into<String>) -> Self {
361 self.cargo_oxide_command = value.into();
362 self
363 }
364
365 pub fn attributes(&self) -> BTreeMap<String, String> {
367 let mut attributes = BTreeMap::new();
368 attributes.insert(
369 "runtime.backend".to_string(),
370 ModelRuntimeBackend::CudaOxide.as_str().to_string(),
371 );
372 attributes.insert(
373 "runtime.cuda.device_index".to_string(),
374 self.device_index.to_string(),
375 );
376 attributes.insert(
377 "runtime.cuda_oxide.command".to_string(),
378 self.cargo_oxide_command.clone(),
379 );
380 attributes.insert(
381 "runtime.cuda_oxide.docs".to_string(),
382 self.documentation_url.clone(),
383 );
384 if let Some(target_sm) = &self.target_sm {
385 attributes.insert("runtime.cuda.target_sm".to_string(), target_sm.clone());
386 }
387 attributes
388 }
389}
390
391#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
392pub struct CudaOxideModelPlan {
394 pub spec: ModelSpec,
396 pub runtime: CudaOxideRuntimeConfig,
398 pub module_name: String,
400 pub kernel_names: Vec<String>,
402}
403
404impl CudaOxideModelPlan {
405 pub fn new(
407 spec: ModelSpec,
408 module_name: impl Into<String>,
409 kernel_names: impl IntoIterator<Item = impl Into<String>>,
410 ) -> Self {
411 Self {
412 spec,
413 runtime: CudaOxideRuntimeConfig::default(),
414 module_name: module_name.into(),
415 kernel_names: kernel_names.into_iter().map(Into::into).collect(),
416 }
417 }
418
419 pub fn runtime(mut self, runtime: CudaOxideRuntimeConfig) -> Self {
421 self.runtime = runtime;
422 self
423 }
424
425 pub fn attributes(&self) -> BTreeMap<String, String> {
427 let mut attributes = self.runtime.attributes();
428 attributes.insert(
429 "runtime.cuda_oxide.module".to_string(),
430 self.module_name.clone(),
431 );
432 if !self.kernel_names.is_empty() {
433 attributes.insert(
434 "runtime.cuda_oxide.kernels".to_string(),
435 self.kernel_names.join(","),
436 );
437 }
438 attributes
439 }
440}
441
442#[allow(deprecated)]
443impl ModelSpec {
444 pub fn new(repo_id: impl Into<String>, task: ModelTask) -> Self {
446 let repo_id = repo_id.into();
447 let revision = "main".to_string();
448 Self {
449 name: repo_id.clone(),
450 task,
451 source: ModelSource::HuggingFace {
452 repo_id: repo_id.clone(),
453 revision: revision.clone(),
454 },
455 files: Vec::new(),
456 metadata: BTreeMap::new(),
457 repo_id,
458 revision,
459 }
460 }
461
462 pub fn from_source(name: impl Into<String>, task: ModelTask, source: ModelSource) -> Self {
464 let name = name.into();
465 let (repo_id, revision) = match &source {
466 ModelSource::HuggingFace { repo_id, revision } => (repo_id.clone(), revision.clone()),
467 _ => (String::new(), String::new()),
468 };
469 Self {
470 name,
471 task,
472 source,
473 files: Vec::new(),
474 metadata: BTreeMap::new(),
475 repo_id,
476 revision,
477 }
478 }
479
480 pub fn repo_id_value(&self) -> Option<&str> {
482 match &self.source {
483 ModelSource::HuggingFace { repo_id, .. } => Some(repo_id.as_str()),
484 _ if !self.repo_id.is_empty() => Some(self.repo_id.as_str()),
485 _ => None,
486 }
487 }
488
489 pub fn revision_value(&self) -> Option<&str> {
491 match &self.source {
492 ModelSource::HuggingFace { revision, .. } => Some(revision.as_str()),
493 _ if !self.revision.is_empty() => Some(self.revision.as_str()),
494 _ => None,
495 }
496 }
497
498 pub fn safe_name(&self) -> String {
500 self.name
501 .chars()
502 .map(|ch| {
503 if ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_' | '.') {
504 ch
505 } else {
506 '_'
507 }
508 })
509 .collect()
510 }
511
512 pub fn from_preset(preset: ModelPreset) -> Self {
514 preset.spec()
515 }
516
517 pub fn name(mut self, value: impl Into<String>) -> Self {
519 self.name = value.into();
520 self
521 }
522
523 pub fn revision(mut self, value: impl Into<String>) -> Self {
525 let revision = value.into();
526 if let ModelSource::HuggingFace {
527 revision: source_revision,
528 ..
529 } = &mut self.source
530 {
531 *source_revision = revision.clone();
532 }
533 self.revision = revision;
534 self
535 }
536
537 pub fn file(mut self, path: impl Into<String>) -> Self {
539 self.files.push(ModelFileRequest::required(path));
540 self
541 }
542
543 pub fn optional_file(mut self, path: impl Into<String>) -> Self {
545 self.files.push(ModelFileRequest::optional(path));
546 self
547 }
548
549 pub fn first_available_file(
551 mut self,
552 paths: impl IntoIterator<Item = impl Into<String>>,
553 ) -> Self {
554 self.files.push(ModelFileRequest::first_available(paths));
555 self
556 }
557
558 pub fn cuda_oxide_plan(
560 self,
561 module_name: impl Into<String>,
562 kernel_names: impl IntoIterator<Item = impl Into<String>>,
563 ) -> CudaOxideModelPlan {
564 CudaOxideModelPlan::new(self, module_name, kernel_names)
565 }
566}