use std::collections::BTreeMap;
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use crate::ModelPreset;
pub const CUDA_OXIDE_BOOK_URL: &str = "https://nvlabs.github.io/cuda-oxide/index.html";
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelTask {
ObjectDetection,
PoseEstimation2d,
PoseLifting3d,
ImageClassification,
ImageSegmentation,
ImageEmbedding,
FaceDetection,
FaceEmbedding,
Ocr,
AudioClassification,
AudioEventDetection,
AudioEmbedding,
SpeechRecognition,
SpeakerDiarization,
SourceSeparation,
AudioGeneration,
SpeakerConditionedTts,
TextClassification,
TokenClassification,
ZeroShotClassification,
TextEmbedding,
Summarization,
Reranking,
QuestionAnswering,
MultimodalEmbedding,
Custom(String),
}
impl ModelTask {
pub fn default_label(&self) -> &'static str {
match self {
Self::ObjectDetection => "object",
Self::PoseEstimation2d => "pose_2d",
Self::PoseLifting3d => "pose_3d",
Self::ImageClassification => "scene",
Self::ImageSegmentation => "mask",
Self::ImageEmbedding => "image_embedding",
Self::FaceDetection => "face",
Self::FaceEmbedding => "face_embedding",
Self::Ocr => "ocr",
Self::AudioClassification => "audio_class",
Self::AudioEventDetection => "audio_event",
Self::AudioEmbedding => "audio_embedding",
Self::SpeechRecognition => "speech",
Self::SpeakerDiarization => "speaker",
Self::SourceSeparation => "stem",
Self::AudioGeneration => "audio_generation",
Self::SpeakerConditionedTts => "speaker_conditioned_tts",
Self::TextClassification => "semantic",
Self::TokenClassification => "token",
Self::ZeroShotClassification => "zero_shot",
Self::TextEmbedding => "embedding",
Self::Summarization => "summary",
Self::Reranking => "reranking",
Self::QuestionAnswering => "question_answering",
Self::MultimodalEmbedding => "multimodal_embedding",
Self::Custom(_) => "custom",
}
}
pub fn as_protocol_str(&self) -> &str {
match self {
Self::ObjectDetection => "object_detection",
Self::PoseEstimation2d => "pose_estimation_2d",
Self::PoseLifting3d => "pose_lifting_3d",
Self::ImageClassification => "image_classification",
Self::ImageSegmentation => "image_segmentation",
Self::ImageEmbedding => "image_embedding",
Self::FaceDetection => "face_detection",
Self::FaceEmbedding => "face_embedding",
Self::Ocr => "ocr",
Self::AudioClassification => "audio_classification",
Self::AudioEventDetection => "audio_event_detection",
Self::AudioEmbedding => "audio_embedding",
Self::SpeechRecognition => "speech_recognition",
Self::SpeakerDiarization => "speaker_diarization",
Self::SourceSeparation => "source_separation",
Self::AudioGeneration => "audio_generation",
Self::SpeakerConditionedTts => "speaker_conditioned_tts",
Self::TextClassification => "text_classification",
Self::TokenClassification => "token_classification",
Self::ZeroShotClassification => "zero_shot_classification",
Self::TextEmbedding => "text_embedding",
Self::Summarization => "summarization",
Self::Reranking => "reranking",
Self::QuestionAnswering => "question_answering",
Self::MultimodalEmbedding => "multimodal_embedding",
Self::Custom(kind) => kind.as_str(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelFileRequest {
Required(String),
Optional(String),
FirstAvailable(Vec<String>),
}
impl ModelFileRequest {
pub fn required(path: impl Into<String>) -> Self {
Self::Required(path.into())
}
pub fn optional(path: impl Into<String>) -> Self {
Self::Optional(path.into())
}
pub fn first_available(paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self::FirstAvailable(paths.into_iter().map(Into::into).collect())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "kind")]
pub enum ModelSource {
HuggingFace { repo_id: String, revision: String },
LocalPath { path: PathBuf },
ExternalCommand { command: PathBuf },
ComfyUiInventory { root: PathBuf },
Custom(String),
}
impl ModelSource {
pub fn kind(&self) -> &str {
match self {
Self::HuggingFace { .. } => "hugging_face",
Self::LocalPath { .. } => "local_path",
Self::ExternalCommand { .. } => "external_command",
Self::ComfyUiInventory { .. } => "comfyui_inventory",
Self::Custom(kind) => kind.as_str(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ModelSpec {
pub name: String,
pub task: ModelTask,
pub source: ModelSource,
pub files: Vec<ModelFileRequest>,
#[serde(default)]
pub metadata: BTreeMap<String, String>,
#[serde(default)]
#[deprecated(note = "use ModelSpec::source or ModelSpec::repo_id() instead")]
pub repo_id: String,
#[serde(default)]
#[deprecated(note = "use ModelSpec::source or ModelSpec::revision() instead")]
pub revision: String,
}
pub type HuggingFaceModelSpec = ModelSpec;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelRuntimeBackend {
Cpu,
Onnx,
Candle,
CudaOxide,
WhisperCpp,
Demucs,
OpenCv,
ComfyUi,
External,
Heuristic,
Imported,
Custom(String),
}
impl ModelRuntimeBackend {
pub fn as_str(&self) -> &str {
match self {
Self::Cpu => "cpu",
Self::Onnx => "onnx",
Self::Candle => "candle",
Self::CudaOxide => "cuda_oxide",
Self::WhisperCpp => "whisper_cpp",
Self::Demucs => "demucs",
Self::OpenCv => "opencv",
Self::ComfyUi => "comfyui",
Self::External => "external",
Self::Heuristic => "heuristic",
Self::Imported => "imported",
Self::Custom(value) => value.as_str(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RuntimePreference {
Auto,
Native,
External,
Heuristic,
Imported,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum FallbackPolicy {
#[default]
Error,
FastFallback,
HeuristicFallback,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ModelRuntimeSelection {
#[serde(default)]
pub model: Option<ModelSpec>,
#[serde(default)]
pub backend: Option<ModelRuntimeBackend>,
#[serde(default)]
pub bundle_dir: Option<PathBuf>,
#[serde(default)]
pub fallback_policy: FallbackPolicy,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CudaOxideRuntimeConfig {
pub device_index: u32,
pub target_sm: Option<String>,
pub cargo_oxide_command: String,
pub documentation_url: String,
}
impl Default for CudaOxideRuntimeConfig {
fn default() -> Self {
Self {
device_index: 0,
target_sm: None,
cargo_oxide_command: "cargo oxide".to_string(),
documentation_url: CUDA_OXIDE_BOOK_URL.to_string(),
}
}
}
impl CudaOxideRuntimeConfig {
pub fn new() -> Self {
Self::default()
}
pub fn device_index(mut self, value: u32) -> Self {
self.device_index = value;
self
}
pub fn target_sm(mut self, value: impl Into<String>) -> Self {
self.target_sm = Some(value.into());
self
}
pub fn cargo_oxide_command(mut self, value: impl Into<String>) -> Self {
self.cargo_oxide_command = value.into();
self
}
pub fn attributes(&self) -> BTreeMap<String, String> {
let mut attributes = BTreeMap::new();
attributes.insert(
"runtime.backend".to_string(),
ModelRuntimeBackend::CudaOxide.as_str().to_string(),
);
attributes.insert(
"runtime.cuda.device_index".to_string(),
self.device_index.to_string(),
);
attributes.insert(
"runtime.cuda_oxide.command".to_string(),
self.cargo_oxide_command.clone(),
);
attributes.insert(
"runtime.cuda_oxide.docs".to_string(),
self.documentation_url.clone(),
);
if let Some(target_sm) = &self.target_sm {
attributes.insert("runtime.cuda.target_sm".to_string(), target_sm.clone());
}
attributes
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CudaOxideModelPlan {
pub spec: ModelSpec,
pub runtime: CudaOxideRuntimeConfig,
pub module_name: String,
pub kernel_names: Vec<String>,
}
impl CudaOxideModelPlan {
pub fn new(
spec: ModelSpec,
module_name: impl Into<String>,
kernel_names: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
Self {
spec,
runtime: CudaOxideRuntimeConfig::default(),
module_name: module_name.into(),
kernel_names: kernel_names.into_iter().map(Into::into).collect(),
}
}
pub fn runtime(mut self, runtime: CudaOxideRuntimeConfig) -> Self {
self.runtime = runtime;
self
}
pub fn attributes(&self) -> BTreeMap<String, String> {
let mut attributes = self.runtime.attributes();
attributes.insert(
"runtime.cuda_oxide.module".to_string(),
self.module_name.clone(),
);
if !self.kernel_names.is_empty() {
attributes.insert(
"runtime.cuda_oxide.kernels".to_string(),
self.kernel_names.join(","),
);
}
attributes
}
}
#[allow(deprecated)]
impl ModelSpec {
pub fn new(repo_id: impl Into<String>, task: ModelTask) -> Self {
let repo_id = repo_id.into();
let revision = "main".to_string();
Self {
name: repo_id.clone(),
task,
source: ModelSource::HuggingFace {
repo_id: repo_id.clone(),
revision: revision.clone(),
},
files: Vec::new(),
metadata: BTreeMap::new(),
repo_id,
revision,
}
}
pub fn from_source(name: impl Into<String>, task: ModelTask, source: ModelSource) -> Self {
let name = name.into();
let (repo_id, revision) = match &source {
ModelSource::HuggingFace { repo_id, revision } => (repo_id.clone(), revision.clone()),
_ => (String::new(), String::new()),
};
Self {
name,
task,
source,
files: Vec::new(),
metadata: BTreeMap::new(),
repo_id,
revision,
}
}
pub fn repo_id_value(&self) -> Option<&str> {
match &self.source {
ModelSource::HuggingFace { repo_id, .. } => Some(repo_id.as_str()),
_ if !self.repo_id.is_empty() => Some(self.repo_id.as_str()),
_ => None,
}
}
pub fn revision_value(&self) -> Option<&str> {
match &self.source {
ModelSource::HuggingFace { revision, .. } => Some(revision.as_str()),
_ if !self.revision.is_empty() => Some(self.revision.as_str()),
_ => None,
}
}
pub fn safe_name(&self) -> String {
self.name
.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_' | '.') {
ch
} else {
'_'
}
})
.collect()
}
pub fn from_preset(preset: ModelPreset) -> Self {
preset.spec()
}
pub fn name(mut self, value: impl Into<String>) -> Self {
self.name = value.into();
self
}
pub fn revision(mut self, value: impl Into<String>) -> Self {
let revision = value.into();
if let ModelSource::HuggingFace {
revision: source_revision,
..
} = &mut self.source
{
*source_revision = revision.clone();
}
self.revision = revision;
self
}
pub fn file(mut self, path: impl Into<String>) -> Self {
self.files.push(ModelFileRequest::required(path));
self
}
pub fn optional_file(mut self, path: impl Into<String>) -> Self {
self.files.push(ModelFileRequest::optional(path));
self
}
pub fn first_available_file(
mut self,
paths: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.files.push(ModelFileRequest::first_available(paths));
self
}
pub fn cuda_oxide_plan(
self,
module_name: impl Into<String>,
kernel_names: impl IntoIterator<Item = impl Into<String>>,
) -> CudaOxideModelPlan {
CudaOxideModelPlan::new(self, module_name, kernel_names)
}
}