1use crate::common::{SharedString, is_default};
2use crate::model::language_provider::LanguageModelProvider;
3use crate::model::model::LanguageModel;
4use schemars::_private::serde_json;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use std::fmt;
7use std::ops::{Add, Sub};
8use std::sync::Arc;
9use uuid::Uuid;
10
11#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
12pub struct LanguageModelId(pub SharedString);
13impl From<String> for LanguageModelId {
14 fn from(value: String) -> Self {
15 Self(SharedString::from(value))
16 }
17}
18#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
19pub struct LanguageModelName(pub SharedString);
20impl From<String> for LanguageModelName {
21 fn from(value: String) -> Self {
22 Self(SharedString::from(value))
23 }
24}
25
26
27#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
28pub struct LanguageModelProviderId(pub SharedString);
29impl LanguageModelProviderId {
30 pub const fn new(id: &'static str) -> Self {
31 Self(SharedString::new_static(id))
32 }
33}
34
35#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
36pub struct LanguageModelProviderName(pub SharedString);
37
38impl LanguageModelProviderName {
39 pub const fn new(id: &'static str) -> Self {
40 Self(SharedString::new_static(id))
41 }
42}
43impl fmt::Display for LanguageModelProviderName {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 write!(f, "{}", self.0)
46 }
47}
48
49
50#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
51pub struct LanguageModelToolUseId(Arc<str>);
52impl fmt::Display for LanguageModelToolUseId {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 write!(f, "{}", self.0)
55 }
56}
57impl<T> From<T> for LanguageModelToolUseId
58where
59 T: Into<Arc<str>>,
60{
61 fn from(value: T) -> Self {
62 Self(value.into())
63 }
64}
65
66#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
67pub struct LanguageModelToolUse {
68 pub id: LanguageModelToolUseId,
69 pub name: Arc<str>,
70 pub raw_input: String,
71 pub input: serde_json::Value,
72 pub is_input_complete: bool,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
76pub struct LanguageModelToolResult {
77 pub tool_use_id: LanguageModelToolUseId,
78 pub tool_name: Arc<str>,
79 pub is_error: bool,
80 pub content: LanguageModelToolResultContent,
81 pub output: Option<serde_json::Value>,
82}
83
84#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
85pub enum LanguageModelToolResultContent {
86 Text(Arc<str>),
87 }
89impl From<&str> for LanguageModelToolResultContent {
90 fn from(value: &str) -> Self {
91 Self::Text(Arc::from(value))
92 }
93}
94
95impl From<String> for LanguageModelToolResultContent {
96 fn from(value: String) -> Self {
97 Self::Text(Arc::from(value))
98 }
99}
100
101impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
102 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
103 where
104 D: serde::Deserializer<'de>,
105 {
106 use serde::de::Error;
107
108 let value = serde_json::Value::deserialize(deserializer)?;
109
110 if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
114 return Ok(Self::Text(Arc::from(text)));
115 }
116
117 if let Some(obj) = value.as_object() {
119 fn get_field<'a>(
121 obj: &'a serde_json::Map<String, serde_json::Value>,
122 field: &str,
123 ) -> Option<&'a serde_json::Value> {
124 obj.iter()
125 .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
126 .map(|(_, v)| v)
127 }
128
129 if let (Some(type_value), Some(text_value)) =
131 (get_field(&obj, "type"), get_field(&obj, "text"))
132 {
133 if let Some(type_str) = type_value.as_str() {
134 if type_str.to_lowercase() == "text" {
135 if let Some(text) = text_value.as_str() {
136 return Ok(Self::Text(Arc::from(text)));
137 }
138 }
139 }
140 }
141
142 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") {
144 if obj.len() == 1 {
145 if let Some(text) = value.as_str() {
147 return Ok(Self::Text(Arc::from(text)));
148 }
149 }
150 }
151
152 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") {
154 if obj.len() == 1 {
155 if let Some(image_obj) = value.as_object() {
158 todo!()
162 }
163 }
164 }
165
166 }
171
172 Err(D::Error::custom(format!(
174 "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
175 an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
176 serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
177 )))
178 }
179}
180
181impl LanguageModelToolResultContent {
182 pub fn to_str(&self) -> Option<&str> {
183 match self {
184 Self::Text(text) => Some(&text),
185 }
187 }
188
189 pub fn is_empty(&self) -> bool {
190 match self {
191 Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
192 }
194 }
195}
196#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
197pub struct LanguageModelRequestTool {
198 pub name: String,
199 pub description: String,
200 pub input_schema: serde_json::Value,
201}
202
203#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
204pub enum LanguageModelCompletionEvent {
205 StatusUpdate(CompletionRequestStatus),
206 Stop(StopReason),
207 Text(String),
208 Thinking {
209 text: String,
210 signature: Option<String>,
211 },
212 RedactedThinking {
213 data: String,
214 },
215 ToolUse(LanguageModelToolUse),
216 ToolUseJsonParseError {
217 id: LanguageModelToolUseId,
218 tool_name: Arc<str>,
219 raw_input: Arc<str>,
220 json_parse_error: String,
221 },
222 StartMessage {
223 message_id: String,
224 },
225 UsageUpdate(TokenUsage),
226}
227
228#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
229#[serde(rename_all = "snake_case")]
230pub enum CompletionRequestStatus {
231 Queued {
232 position: usize,
233 },
234 Started,
235 Failed {
236 code: String,
237 message: String,
238 request_id: Uuid,
239 retry_after: Option<f64>,
241 },
242 UsageUpdated {
243 amount: usize,
244 limit: UsageLimit,
245 },
246 ToolUseLimitReached,
247}
248
249#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
250pub struct TokenUsage {
251 #[serde(default, skip_serializing_if = "is_default")]
252 pub input_tokens: u64,
253 #[serde(default, skip_serializing_if = "is_default")]
254 pub output_tokens: u64,
255 #[serde(default, skip_serializing_if = "is_default")]
256 pub cache_creation_input_tokens: u64,
257 #[serde(default, skip_serializing_if = "is_default")]
258 pub cache_read_input_tokens: u64,
259}
260impl TokenUsage {
261 pub fn total_tokens(&self) -> u64 {
262 self.input_tokens
263 + self.output_tokens
264 + self.cache_read_input_tokens
265 + self.cache_creation_input_tokens
266 }
267}
268impl Add<TokenUsage> for TokenUsage {
269 type Output = Self;
270
271 fn add(self, other: Self) -> Self {
272 Self {
273 input_tokens: self.input_tokens + other.input_tokens,
274 output_tokens: self.output_tokens + other.output_tokens,
275 cache_creation_input_tokens: self.cache_creation_input_tokens
276 + other.cache_creation_input_tokens,
277 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
278 }
279 }
280}
281
282impl Sub<TokenUsage> for TokenUsage {
283 type Output = Self;
284
285 fn sub(self, other: Self) -> Self {
286 Self {
287 input_tokens: self.input_tokens - other.input_tokens,
288 output_tokens: self.output_tokens - other.output_tokens,
289 cache_creation_input_tokens: self.cache_creation_input_tokens
290 - other.cache_creation_input_tokens,
291 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
292 }
293 }
294}
295#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
296#[serde(rename_all = "snake_case")]
297pub enum UsageLimit {
298 Limited(i32),
299 Unlimited,
300}
301
302#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
303#[serde(rename_all = "snake_case")]
304pub enum StopReason {
305 EndTurn,
306 MaxTokens,
307 ToolUse,
308 Refusal,
309}
310
311#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
312#[serde(rename_all = "snake_case")]
313pub enum CompletionIntent {
314 UserPrompt,
315 ToolResults,
316 ThreadSummarization,
317 ThreadContextSummarization,
318 CreateFile,
319 EditFile,
320 InlineAssist,
321 TerminalInlineAssist,
322 GenerateGitCommitMessage,
323}
324
325#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
326#[serde(rename_all = "snake_case")]
327pub enum CompletionMode {
328 Normal,
329 Max,
330}
331
332#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
333pub enum LanguageModelToolChoice {
334 Auto,
335 Any,
336 None,
337}
338
339#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
340pub enum LanguageModelToolSchemaFormat {
341 JsonSchema,
343 JsonSchemaSubset,
345}
346
347#[derive(Clone)]
348pub struct ConfiguredModel {
349 pub provider: Arc<dyn LanguageModelProvider + Send + Sync>,
350 pub model: Arc<dyn LanguageModel + Send + Sync>,
351}
352
353impl ConfiguredModel {
354 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
355 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
356 }
357}