omni_llm_kit/model/
types.rs

1use crate::common::{SharedString, is_default};
2use crate::model::language_provider::LanguageModelProvider;
3use crate::model::model::LanguageModel;
4use schemars::_private::serde_json;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use std::fmt;
7use std::ops::{Add, Sub};
8use std::sync::Arc;
9use uuid::Uuid;
10
11#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
12pub struct LanguageModelId(pub SharedString);
13impl From<String> for LanguageModelId {
14    fn from(value: String) -> Self {
15        Self(SharedString::from(value))
16    }
17}
18#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
19pub struct LanguageModelName(pub SharedString);
20impl From<String> for LanguageModelName {
21    fn from(value: String) -> Self {
22        Self(SharedString::from(value))
23    }
24}
25
26
27#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
28pub struct LanguageModelProviderId(pub SharedString);
29impl LanguageModelProviderId {
30    pub const fn new(id: &'static str) -> Self {
31        Self(SharedString::new_static(id))
32    }
33}
34
35#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
36pub struct LanguageModelProviderName(pub SharedString);
37
38impl LanguageModelProviderName {
39    pub const fn new(id: &'static str) -> Self {
40        Self(SharedString::new_static(id))
41    }
42}
43impl fmt::Display for LanguageModelProviderName {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        write!(f, "{}", self.0)
46    }
47}
48
49
50#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
51pub struct LanguageModelToolUseId(Arc<str>);
52impl fmt::Display for LanguageModelToolUseId {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        write!(f, "{}", self.0)
55    }
56}
57impl<T> From<T> for LanguageModelToolUseId
58where
59    T: Into<Arc<str>>,
60{
61    fn from(value: T) -> Self {
62        Self(value.into())
63    }
64}
65
66#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
67pub struct LanguageModelToolUse {
68    pub id: LanguageModelToolUseId,
69    pub name: Arc<str>,
70    pub raw_input: String,
71    pub input: serde_json::Value,
72    pub is_input_complete: bool,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
76pub struct LanguageModelToolResult {
77    pub tool_use_id: LanguageModelToolUseId,
78    pub tool_name: Arc<str>,
79    pub is_error: bool,
80    pub content: LanguageModelToolResultContent,
81    pub output: Option<serde_json::Value>,
82}
83
84#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
85pub enum LanguageModelToolResultContent {
86    Text(Arc<str>),
87    // Image(LanguageModelImage),
88}
89impl From<&str> for LanguageModelToolResultContent {
90    fn from(value: &str) -> Self {
91        Self::Text(Arc::from(value))
92    }
93}
94
95impl From<String> for LanguageModelToolResultContent {
96    fn from(value: String) -> Self {
97        Self::Text(Arc::from(value))
98    }
99}
100
101impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
102    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
103    where
104        D: serde::Deserializer<'de>,
105    {
106        use serde::de::Error;
107
108        let value = serde_json::Value::deserialize(deserializer)?;
109
110        // Models can provide these responses in several styles. Try each in order.
111
112        // 1. Try as plain string
113        if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
114            return Ok(Self::Text(Arc::from(text)));
115        }
116
117        // 2. Try as object
118        if let Some(obj) = value.as_object() {
119            // get a JSON field case-insensitively
120            fn get_field<'a>(
121                obj: &'a serde_json::Map<String, serde_json::Value>,
122                field: &str,
123            ) -> Option<&'a serde_json::Value> {
124                obj.iter()
125                    .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
126                    .map(|(_, v)| v)
127            }
128
129            // Accept wrapped text format: { "type": "text", "text": "..." }
130            if let (Some(type_value), Some(text_value)) =
131                (get_field(&obj, "type"), get_field(&obj, "text"))
132            {
133                if let Some(type_str) = type_value.as_str() {
134                    if type_str.to_lowercase() == "text" {
135                        if let Some(text) = text_value.as_str() {
136                            return Ok(Self::Text(Arc::from(text)));
137                        }
138                    }
139                }
140            }
141
142            // Check for wrapped Text variant: { "text": "..." }
143            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") {
144                if obj.len() == 1 {
145                    // Only one field, and it's "text" (case-insensitive)
146                    if let Some(text) = value.as_str() {
147                        return Ok(Self::Text(Arc::from(text)));
148                    }
149                }
150            }
151
152            // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
153            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") {
154                if obj.len() == 1 {
155                    // Only one field, and it's "image" (case-insensitive)
156                    // Try to parse the nested image object
157                    if let Some(image_obj) = value.as_object() {
158                        // if let Some(image) = LanguageModelImage::from_json(image_obj) {
159                        //     return Ok(Self::Image(image));
160                        // }
161                        todo!()
162                    }
163                }
164            }
165
166            // Try as direct Image (object with "source" and "size" fields)
167            // if let Some(image) = LanguageModelImage::from_json(&obj) {
168            //     return Ok(Self::Image(image));
169            // }
170        }
171
172        // If none of the variants match, return an error with the problematic JSON
173        Err(D::Error::custom(format!(
174            "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
175             an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
176            serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
177        )))
178    }
179}
180
181impl LanguageModelToolResultContent {
182    pub fn to_str(&self) -> Option<&str> {
183        match self {
184            Self::Text(text) => Some(&text),
185            // Self::Image(_) => None,
186        }
187    }
188
189    pub fn is_empty(&self) -> bool {
190        match self {
191            Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
192            // Self::Image(_) => false,
193        }
194    }
195}
196#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
197pub struct LanguageModelRequestTool {
198    pub name: String,
199    pub description: String,
200    pub input_schema: serde_json::Value,
201}
202
203#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
204pub enum LanguageModelCompletionEvent {
205    StatusUpdate(CompletionRequestStatus),
206    Stop(StopReason),
207    Text(String),
208    Thinking {
209        text: String,
210        signature: Option<String>,
211    },
212    RedactedThinking {
213        data: String,
214    },
215    ToolUse(LanguageModelToolUse),
216    ToolUseJsonParseError {
217        id: LanguageModelToolUseId,
218        tool_name: Arc<str>,
219        raw_input: Arc<str>,
220        json_parse_error: String,
221    },
222    StartMessage {
223        message_id: String,
224    },
225    UsageUpdate(TokenUsage),
226}
227
228#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
229#[serde(rename_all = "snake_case")]
230pub enum CompletionRequestStatus {
231    Queued {
232        position: usize,
233    },
234    Started,
235    Failed {
236        code: String,
237        message: String,
238        request_id: Uuid,
239        /// Retry duration in seconds.
240        retry_after: Option<f64>,
241    },
242    UsageUpdated {
243        amount: usize,
244        limit: UsageLimit,
245    },
246    ToolUseLimitReached,
247}
248
249#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
250pub struct TokenUsage {
251    #[serde(default, skip_serializing_if = "is_default")]
252    pub input_tokens: u64,
253    #[serde(default, skip_serializing_if = "is_default")]
254    pub output_tokens: u64,
255    #[serde(default, skip_serializing_if = "is_default")]
256    pub cache_creation_input_tokens: u64,
257    #[serde(default, skip_serializing_if = "is_default")]
258    pub cache_read_input_tokens: u64,
259}
260impl TokenUsage {
261    pub fn total_tokens(&self) -> u64 {
262        self.input_tokens
263            + self.output_tokens
264            + self.cache_read_input_tokens
265            + self.cache_creation_input_tokens
266    }
267}
268impl Add<TokenUsage> for TokenUsage {
269    type Output = Self;
270
271    fn add(self, other: Self) -> Self {
272        Self {
273            input_tokens: self.input_tokens + other.input_tokens,
274            output_tokens: self.output_tokens + other.output_tokens,
275            cache_creation_input_tokens: self.cache_creation_input_tokens
276                + other.cache_creation_input_tokens,
277            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
278        }
279    }
280}
281
282impl Sub<TokenUsage> for TokenUsage {
283    type Output = Self;
284
285    fn sub(self, other: Self) -> Self {
286        Self {
287            input_tokens: self.input_tokens - other.input_tokens,
288            output_tokens: self.output_tokens - other.output_tokens,
289            cache_creation_input_tokens: self.cache_creation_input_tokens
290                - other.cache_creation_input_tokens,
291            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
292        }
293    }
294}
295#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
296#[serde(rename_all = "snake_case")]
297pub enum UsageLimit {
298    Limited(i32),
299    Unlimited,
300}
301
302#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
303#[serde(rename_all = "snake_case")]
304pub enum StopReason {
305    EndTurn,
306    MaxTokens,
307    ToolUse,
308    Refusal,
309}
310
311#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
312#[serde(rename_all = "snake_case")]
313pub enum CompletionIntent {
314    UserPrompt,
315    ToolResults,
316    ThreadSummarization,
317    ThreadContextSummarization,
318    CreateFile,
319    EditFile,
320    InlineAssist,
321    TerminalInlineAssist,
322    GenerateGitCommitMessage,
323}
324
325#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
326#[serde(rename_all = "snake_case")]
327pub enum CompletionMode {
328    Normal,
329    Max,
330}
331
332#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
333pub enum LanguageModelToolChoice {
334    Auto,
335    Any,
336    None,
337}
338
339#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
340pub enum LanguageModelToolSchemaFormat {
341    /// A JSON schema, see https://json-schema.org
342    JsonSchema,
343    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
344    JsonSchemaSubset,
345}
346
347#[derive(Clone)]
348pub struct ConfiguredModel {
349    pub provider: Arc<dyn LanguageModelProvider + Send + Sync>,
350    pub model: Arc<dyn LanguageModel + Send + Sync>,
351}
352
353impl ConfiguredModel {
354    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
355        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
356    }
357}