Skip to main content

model_runtime/
spec.rs

1use std::collections::BTreeMap;
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6use crate::ModelPreset;
7
8/// Documentation URL for the cuda-oxide compiler/runtime stack.
9pub const CUDA_OXIDE_BOOK_URL: &str = "https://nvlabs.github.io/cuda-oxide/index.html";
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13/// Variants describing model task.
14pub enum ModelTask {
15    /// The object detection variant.
16    ObjectDetection,
17    /// The pose estimation2d variant.
18    PoseEstimation2d,
19    /// The pose lifting3d variant.
20    PoseLifting3d,
21    /// The image classification variant.
22    ImageClassification,
23    /// The image segmentation variant.
24    ImageSegmentation,
25    /// The image embedding variant.
26    ImageEmbedding,
27    /// The face detection variant.
28    FaceDetection,
29    /// The face embedding variant.
30    FaceEmbedding,
31    /// The OCR variant.
32    Ocr,
33    /// The audio classification variant.
34    AudioClassification,
35    /// The audio event detection variant.
36    AudioEventDetection,
37    /// The audio embedding variant.
38    AudioEmbedding,
39    /// The speech recognition variant.
40    SpeechRecognition,
41    /// The speaker diarization variant.
42    SpeakerDiarization,
43    /// The source separation variant.
44    SourceSeparation,
45    /// The audio generation variant.
46    AudioGeneration,
47    /// The speaker-conditioned text-to-speech variant.
48    SpeakerConditionedTts,
49    /// The text classification variant.
50    TextClassification,
51    /// The token classification variant.
52    TokenClassification,
53    /// The zero shot classification variant.
54    ZeroShotClassification,
55    /// The text embedding variant.
56    TextEmbedding,
57    /// The summarization variant.
58    Summarization,
59    /// The reranking variant.
60    Reranking,
61    /// The question answering variant.
62    QuestionAnswering,
63    /// The multimodal embedding variant.
64    MultimodalEmbedding,
65    /// The custom variant.
66    Custom(String),
67}
68
69impl ModelTask {
70    /// Returns default label.
71    pub fn default_label(&self) -> &'static str {
72        match self {
73            Self::ObjectDetection => "object",
74            Self::PoseEstimation2d => "pose_2d",
75            Self::PoseLifting3d => "pose_3d",
76            Self::ImageClassification => "scene",
77            Self::ImageSegmentation => "mask",
78            Self::ImageEmbedding => "image_embedding",
79            Self::FaceDetection => "face",
80            Self::FaceEmbedding => "face_embedding",
81            Self::Ocr => "ocr",
82            Self::AudioClassification => "audio_class",
83            Self::AudioEventDetection => "audio_event",
84            Self::AudioEmbedding => "audio_embedding",
85            Self::SpeechRecognition => "speech",
86            Self::SpeakerDiarization => "speaker",
87            Self::SourceSeparation => "stem",
88            Self::AudioGeneration => "audio_generation",
89            Self::SpeakerConditionedTts => "speaker_conditioned_tts",
90            Self::TextClassification => "semantic",
91            Self::TokenClassification => "token",
92            Self::ZeroShotClassification => "zero_shot",
93            Self::TextEmbedding => "embedding",
94            Self::Summarization => "summary",
95            Self::Reranking => "reranking",
96            Self::QuestionAnswering => "question_answering",
97            Self::MultimodalEmbedding => "multimodal_embedding",
98            Self::Custom(_) => "custom",
99        }
100    }
101
102    /// Returns a stable protocol string for this task.
103    pub fn as_protocol_str(&self) -> &str {
104        match self {
105            Self::ObjectDetection => "object_detection",
106            Self::PoseEstimation2d => "pose_estimation_2d",
107            Self::PoseLifting3d => "pose_lifting_3d",
108            Self::ImageClassification => "image_classification",
109            Self::ImageSegmentation => "image_segmentation",
110            Self::ImageEmbedding => "image_embedding",
111            Self::FaceDetection => "face_detection",
112            Self::FaceEmbedding => "face_embedding",
113            Self::Ocr => "ocr",
114            Self::AudioClassification => "audio_classification",
115            Self::AudioEventDetection => "audio_event_detection",
116            Self::AudioEmbedding => "audio_embedding",
117            Self::SpeechRecognition => "speech_recognition",
118            Self::SpeakerDiarization => "speaker_diarization",
119            Self::SourceSeparation => "source_separation",
120            Self::AudioGeneration => "audio_generation",
121            Self::SpeakerConditionedTts => "speaker_conditioned_tts",
122            Self::TextClassification => "text_classification",
123            Self::TokenClassification => "token_classification",
124            Self::ZeroShotClassification => "zero_shot_classification",
125            Self::TextEmbedding => "text_embedding",
126            Self::Summarization => "summarization",
127            Self::Reranking => "reranking",
128            Self::QuestionAnswering => "question_answering",
129            Self::MultimodalEmbedding => "multimodal_embedding",
130            Self::Custom(kind) => kind.as_str(),
131        }
132    }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
136#[serde(rename_all = "snake_case")]
137/// Variants describing model file request.
138pub enum ModelFileRequest {
139    /// The required variant.
140    Required(String),
141    /// The optional variant.
142    Optional(String),
143    /// The first available variant.
144    FirstAvailable(Vec<String>),
145}
146
147impl ModelFileRequest {
148    /// Returns required.
149    pub fn required(path: impl Into<String>) -> Self {
150        Self::Required(path.into())
151    }
152
153    /// Returns optional.
154    pub fn optional(path: impl Into<String>) -> Self {
155        Self::Optional(path.into())
156    }
157
158    /// Returns first available.
159    pub fn first_available(paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
160        Self::FirstAvailable(paths.into_iter().map(Into::into).collect())
161    }
162}
163
164/// Model source.
165#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
166#[serde(rename_all = "snake_case", tag = "kind")]
167pub enum ModelSource {
168    /// Hugging Face model repository.
169    HuggingFace { repo_id: String, revision: String },
170    /// Local model path.
171    LocalPath { path: PathBuf },
172    /// External command model runner.
173    ExternalCommand { command: PathBuf },
174    /// ComfyUI inventory root.
175    ComfyUiInventory { root: PathBuf },
176    /// Caller-defined source.
177    Custom(String),
178}
179
180impl ModelSource {
181    /// Returns a stable source kind.
182    pub fn kind(&self) -> &str {
183        match self {
184            Self::HuggingFace { .. } => "hugging_face",
185            Self::LocalPath { .. } => "local_path",
186            Self::ExternalCommand { .. } => "external_command",
187            Self::ComfyUiInventory { .. } => "comfyui_inventory",
188            Self::Custom(kind) => kind.as_str(),
189        }
190    }
191}
192
193#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
194/// Generic model spec.
195pub struct ModelSpec {
196    /// Human-readable name for this value.
197    pub name: String,
198    /// The task value.
199    pub task: ModelTask,
200    /// The model source value.
201    pub source: ModelSource,
202    /// The files value.
203    pub files: Vec<ModelFileRequest>,
204    /// Caller-defined model metadata.
205    #[serde(default)]
206    pub metadata: BTreeMap<String, String>,
207    /// Deprecated compatibility field for Hugging Face model identifiers.
208    #[serde(default)]
209    #[deprecated(note = "use ModelSpec::source or ModelSpec::repo_id() instead")]
210    pub repo_id: String,
211    /// Deprecated compatibility field for Hugging Face revisions.
212    #[serde(default)]
213    #[deprecated(note = "use ModelSpec::source or ModelSpec::revision() instead")]
214    pub revision: String,
215}
216
217/// Hugging Face-oriented compatibility alias for callers migrating to `ModelSpec`.
218pub type HuggingFaceModelSpec = ModelSpec;
219
220#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
221#[serde(rename_all = "snake_case")]
222/// Variants describing model runtime backend.
223pub enum ModelRuntimeBackend {
224    /// The CPU variant.
225    Cpu,
226    /// The ONNX variant.
227    Onnx,
228    /// The candle variant.
229    Candle,
230    /// The cuda oxide variant.
231    CudaOxide,
232    /// The whisper.cpp variant.
233    WhisperCpp,
234    /// The Demucs variant.
235    Demucs,
236    /// The OpenCV variant.
237    OpenCv,
238    /// The ComfyUI variant.
239    ComfyUi,
240    /// The external variant.
241    External,
242    /// The heuristic variant.
243    Heuristic,
244    /// Imported caller predictions.
245    Imported,
246    /// The custom variant.
247    Custom(String),
248}
249
250impl ModelRuntimeBackend {
251    /// Borrows this value as a str.
252    pub fn as_str(&self) -> &str {
253        match self {
254            Self::Cpu => "cpu",
255            Self::Onnx => "onnx",
256            Self::Candle => "candle",
257            Self::CudaOxide => "cuda_oxide",
258            Self::WhisperCpp => "whisper_cpp",
259            Self::Demucs => "demucs",
260            Self::OpenCv => "opencv",
261            Self::ComfyUi => "comfyui",
262            Self::External => "external",
263            Self::Heuristic => "heuristic",
264            Self::Imported => "imported",
265            Self::Custom(value) => value.as_str(),
266        }
267    }
268}
269
270#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
271#[serde(rename_all = "snake_case")]
272/// General runtime preference for domain APIs.
273pub enum RuntimePreference {
274    /// Choose the best available runtime.
275    Auto,
276    /// Prefer native execution.
277    Native,
278    /// Prefer external execution.
279    External,
280    /// Force heuristic execution.
281    Heuristic,
282    /// Use imported predictions.
283    Imported,
284}
285
286#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
287#[serde(rename_all = "snake_case")]
288/// General fallback behavior for model-backed operations.
289pub enum FallbackPolicy {
290    /// Return a typed error.
291    #[default]
292    Error,
293    /// Use a fast deterministic fallback.
294    FastFallback,
295    /// Use a heuristic fallback.
296    HeuristicFallback,
297}
298
299#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
300#[serde(rename_all = "camelCase")]
301/// Advanced generic runtime selection.
302pub struct ModelRuntimeSelection {
303    /// Optional concrete model spec.
304    #[serde(default)]
305    pub model: Option<ModelSpec>,
306    /// Optional preferred backend.
307    #[serde(default)]
308    pub backend: Option<ModelRuntimeBackend>,
309    /// Optional bundle directory.
310    #[serde(default)]
311    pub bundle_dir: Option<PathBuf>,
312    /// Fallback policy.
313    #[serde(default)]
314    pub fallback_policy: FallbackPolicy,
315}
316
317#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
318/// Data type for cuda-oxide runtime config.
319pub struct CudaOxideRuntimeConfig {
320    /// The CUDA device index value.
321    pub device_index: u32,
322    /// Optional target SM architecture, such as `sm_80`.
323    pub target_sm: Option<String>,
324    /// Cargo subcommand used to build and run cuda-oxide kernels.
325    pub cargo_oxide_command: String,
326    /// The cuda-oxide documentation URL used for operator setup.
327    pub documentation_url: String,
328}
329
330impl Default for CudaOxideRuntimeConfig {
331    fn default() -> Self {
332        Self {
333            device_index: 0,
334            target_sm: None,
335            cargo_oxide_command: "cargo oxide".to_string(),
336            documentation_url: CUDA_OXIDE_BOOK_URL.to_string(),
337        }
338    }
339}
340
341impl CudaOxideRuntimeConfig {
342    /// Creates a new value.
343    pub fn new() -> Self {
344        Self::default()
345    }
346
347    /// Returns device index.
348    pub fn device_index(mut self, value: u32) -> Self {
349        self.device_index = value;
350        self
351    }
352
353    /// Returns target SM.
354    pub fn target_sm(mut self, value: impl Into<String>) -> Self {
355        self.target_sm = Some(value.into());
356        self
357    }
358
359    /// Returns cargo oxide command.
360    pub fn cargo_oxide_command(mut self, value: impl Into<String>) -> Self {
361        self.cargo_oxide_command = value.into();
362        self
363    }
364
365    /// Returns runtime attributes for prediction metadata and traces.
366    pub fn attributes(&self) -> BTreeMap<String, String> {
367        let mut attributes = BTreeMap::new();
368        attributes.insert(
369            "runtime.backend".to_string(),
370            ModelRuntimeBackend::CudaOxide.as_str().to_string(),
371        );
372        attributes.insert(
373            "runtime.cuda.device_index".to_string(),
374            self.device_index.to_string(),
375        );
376        attributes.insert(
377            "runtime.cuda_oxide.command".to_string(),
378            self.cargo_oxide_command.clone(),
379        );
380        attributes.insert(
381            "runtime.cuda_oxide.docs".to_string(),
382            self.documentation_url.clone(),
383        );
384        if let Some(target_sm) = &self.target_sm {
385            attributes.insert("runtime.cuda.target_sm".to_string(), target_sm.clone());
386        }
387        attributes
388    }
389}
390
391#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
392/// Data type for cuda-oxide model plan.
393pub struct CudaOxideModelPlan {
394    /// The model spec value.
395    pub spec: ModelSpec,
396    /// The runtime value.
397    pub runtime: CudaOxideRuntimeConfig,
398    /// The cuda-oxide module name value.
399    pub module_name: String,
400    /// The kernel names value.
401    pub kernel_names: Vec<String>,
402}
403
404impl CudaOxideModelPlan {
405    /// Creates a new value.
406    pub fn new(
407        spec: ModelSpec,
408        module_name: impl Into<String>,
409        kernel_names: impl IntoIterator<Item = impl Into<String>>,
410    ) -> Self {
411        Self {
412            spec,
413            runtime: CudaOxideRuntimeConfig::default(),
414            module_name: module_name.into(),
415            kernel_names: kernel_names.into_iter().map(Into::into).collect(),
416        }
417    }
418
419    /// Returns runtime.
420    pub fn runtime(mut self, runtime: CudaOxideRuntimeConfig) -> Self {
421        self.runtime = runtime;
422        self
423    }
424
425    /// Returns attributes for prediction metadata and traces.
426    pub fn attributes(&self) -> BTreeMap<String, String> {
427        let mut attributes = self.runtime.attributes();
428        attributes.insert(
429            "runtime.cuda_oxide.module".to_string(),
430            self.module_name.clone(),
431        );
432        if !self.kernel_names.is_empty() {
433            attributes.insert(
434                "runtime.cuda_oxide.kernels".to_string(),
435                self.kernel_names.join(","),
436            );
437        }
438        attributes
439    }
440}
441
442#[allow(deprecated)]
443impl ModelSpec {
444    /// Creates a Hugging Face-backed model spec.
445    pub fn new(repo_id: impl Into<String>, task: ModelTask) -> Self {
446        let repo_id = repo_id.into();
447        let revision = "main".to_string();
448        Self {
449            name: repo_id.clone(),
450            task,
451            source: ModelSource::HuggingFace {
452                repo_id: repo_id.clone(),
453                revision: revision.clone(),
454            },
455            files: Vec::new(),
456            metadata: BTreeMap::new(),
457            repo_id,
458            revision,
459        }
460    }
461
462    /// Creates a model spec from explicit source.
463    pub fn from_source(name: impl Into<String>, task: ModelTask, source: ModelSource) -> Self {
464        let name = name.into();
465        let (repo_id, revision) = match &source {
466            ModelSource::HuggingFace { repo_id, revision } => (repo_id.clone(), revision.clone()),
467            _ => (String::new(), String::new()),
468        };
469        Self {
470            name,
471            task,
472            source,
473            files: Vec::new(),
474            metadata: BTreeMap::new(),
475            repo_id,
476            revision,
477        }
478    }
479
480    /// Returns the Hugging Face repo id when this spec has one.
481    pub fn repo_id_value(&self) -> Option<&str> {
482        match &self.source {
483            ModelSource::HuggingFace { repo_id, .. } => Some(repo_id.as_str()),
484            _ if !self.repo_id.is_empty() => Some(self.repo_id.as_str()),
485            _ => None,
486        }
487    }
488
489    /// Returns the revision when this spec has one.
490    pub fn revision_value(&self) -> Option<&str> {
491        match &self.source {
492            ModelSource::HuggingFace { revision, .. } => Some(revision.as_str()),
493            _ if !self.revision.is_empty() => Some(self.revision.as_str()),
494            _ => None,
495        }
496    }
497
498    /// Returns a filesystem-safe model name segment.
499    pub fn safe_name(&self) -> String {
500        self.name
501            .chars()
502            .map(|ch| {
503                if ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_' | '.') {
504                    ch
505                } else {
506                    '_'
507                }
508            })
509            .collect()
510    }
511
512    /// Builds this value from preset.
513    pub fn from_preset(preset: ModelPreset) -> Self {
514        preset.spec()
515    }
516
517    /// Returns name.
518    pub fn name(mut self, value: impl Into<String>) -> Self {
519        self.name = value.into();
520        self
521    }
522
523    /// Returns revision.
524    pub fn revision(mut self, value: impl Into<String>) -> Self {
525        let revision = value.into();
526        if let ModelSource::HuggingFace {
527            revision: source_revision,
528            ..
529        } = &mut self.source
530        {
531            *source_revision = revision.clone();
532        }
533        self.revision = revision;
534        self
535    }
536
537    /// Returns file.
538    pub fn file(mut self, path: impl Into<String>) -> Self {
539        self.files.push(ModelFileRequest::required(path));
540        self
541    }
542
543    /// Returns optional file.
544    pub fn optional_file(mut self, path: impl Into<String>) -> Self {
545        self.files.push(ModelFileRequest::optional(path));
546        self
547    }
548
549    /// Returns first available file.
550    pub fn first_available_file(
551        mut self,
552        paths: impl IntoIterator<Item = impl Into<String>>,
553    ) -> Self {
554        self.files.push(ModelFileRequest::first_available(paths));
555        self
556    }
557
558    /// Builds a cuda-oxide runtime plan for this model spec.
559    pub fn cuda_oxide_plan(
560        self,
561        module_name: impl Into<String>,
562        kernel_names: impl IntoIterator<Item = impl Into<String>>,
563    ) -> CudaOxideModelPlan {
564        CudaOxideModelPlan::new(self, module_name, kernel_names)
565    }
566}