dynamo_llm/
model_type.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use bitflags::bitflags;
5use serde::{Deserialize, Serialize};
6use std::fmt;
7use strum::Display;
8
9bitflags! {
10    /// Represents the set of model capabilities (endpoints) a model can support.
11    ///
12    /// This type is implemented using `bitflags` instead of a plain `enum`
13    /// so that multiple capabilities can be combined in a single value:
14    ///
15    /// - `ModelType::Chat`
16    /// - `ModelType::Completions`
17    /// - `ModelType::Embedding`
18    /// - `ModelType::TensorBased`
19    ///
20    /// For example, a model that supports both chat and completions can be
21    /// expressed as:
22    ///
23    /// ```rust
24    /// use dynamo_llm::model_type::ModelType;
25    /// let mt = ModelType::Chat | ModelType::Completions;
26    /// assert!(mt.supports_chat());
27    /// assert!(mt.supports_completions());
28    /// ```
29    ///
30    /// Using bitflags avoids deep branching on a single enum variant,
31    /// simplifies checks like `supports_chat()`, and enables efficient,
32    /// type-safe combinations of multiple endpoint types within a single byte.
33    #[derive(Copy, Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
34    pub struct ModelType: u8 {
35        const Chat = 1 << 0;
36        const Completions = 1 << 1;
37        const Embedding = 1 << 2;
38        const TensorBased = 1 << 3;
39    }
40}
41
42impl ModelType {
43    pub fn as_str(&self) -> String {
44        self.as_vec().join(",")
45    }
46
47    pub fn supports_chat(&self) -> bool {
48        self.contains(ModelType::Chat)
49    }
50    pub fn supports_completions(&self) -> bool {
51        self.contains(ModelType::Completions)
52    }
53    pub fn supports_embedding(&self) -> bool {
54        self.contains(ModelType::Embedding)
55    }
56    pub fn supports_tensor(&self) -> bool {
57        self.contains(ModelType::TensorBased)
58    }
59
60    pub fn as_vec(&self) -> Vec<&'static str> {
61        let mut result = Vec::new();
62        if self.supports_chat() {
63            result.push("chat");
64        }
65        if self.supports_completions() {
66            result.push("completions");
67        }
68        if self.supports_embedding() {
69            result.push("embedding");
70        }
71        if self.supports_tensor() {
72            result.push("tensor");
73        }
74        result
75    }
76
77    /// Returns all endpoint types supported by this model type.
78    /// This properly handles combinations like Chat | Completions.
79    pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> {
80        let mut endpoint_types = Vec::new();
81        if self.contains(Self::Chat) {
82            endpoint_types.push(crate::endpoint_type::EndpointType::Chat);
83        }
84        if self.contains(Self::Completions) {
85            endpoint_types.push(crate::endpoint_type::EndpointType::Completion);
86        }
87        if self.contains(Self::Embedding) {
88            endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
89        }
90        // [gluo NOTE] ModelType::Tensor doesn't map to any endpoint type,
91        // current use of endpoint type is LLM specific and so does the HTTP
92        // server that uses it.
93        endpoint_types
94    }
95}
96
97impl fmt::Display for ModelType {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        write!(f, "{}", self.as_str())
100    }
101}
102
103#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq)]
104pub enum ModelInput {
105    /// Raw text input
106    Text,
107    /// Pre-processed input
108    Tokens,
109    /// Tensor input
110    Tensor,
111}
112
113impl ModelInput {
114    pub fn as_str(&self) -> &str {
115        match self {
116            Self::Text => "text",
117            Self::Tokens => "tokens",
118            Self::Tensor => "tensor",
119        }
120    }
121}