Skip to main content

mentra_provider/
model.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::fmt::Display;
4use time::OffsetDateTime;
5
6/// Metadata describing a model available from a provider.
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
8pub struct ModelInfo {
9    pub id: String,
10    pub provider: crate::ProviderId,
11    pub display_name: Option<String>,
12    pub description: Option<String>,
13    pub created_at: Option<OffsetDateTime>,
14}
15
16impl ModelInfo {
17    pub fn new(id: impl Into<String>, provider: impl Into<crate::ProviderId>) -> Self {
18        Self {
19            id: id.into(),
20            provider: provider.into(),
21            display_name: None,
22            description: None,
23            created_at: None,
24        }
25    }
26}
27
28/// Selection strategy used when resolving a model from a provider.
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum ModelSelector {
31    Id(String),
32    NewestAvailable,
33}
34
35/// Provider-neutral token usage metadata for a completed or in-progress response.
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
37pub struct TokenUsage {
38    pub input_tokens: Option<u64>,
39    pub output_tokens: Option<u64>,
40    pub total_tokens: Option<u64>,
41    pub cache_read_input_tokens: Option<u64>,
42    pub cache_creation_input_tokens: Option<u64>,
43    pub reasoning_tokens: Option<u64>,
44    pub thoughts_tokens: Option<u64>,
45    pub tool_input_tokens: Option<u64>,
46}
47
48impl TokenUsage {
49    pub fn is_empty(&self) -> bool {
50        self.input_tokens.is_none()
51            && self.output_tokens.is_none()
52            && self.total_tokens.is_none()
53            && self.cache_read_input_tokens.is_none()
54            && self.cache_creation_input_tokens.is_none()
55            && self.reasoning_tokens.is_none()
56            && self.thoughts_tokens.is_none()
57            && self.tool_input_tokens.is_none()
58    }
59}
60
61/// Provider-neutral chat role labels.
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(rename_all = "snake_case")]
64pub enum Role {
65    User,
66    Assistant,
67    Unknown(String),
68}
69
70impl Display for Role {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        let value = match self {
73            Self::User => "user",
74            Self::Assistant => "assistant",
75            Self::Unknown(role) => role.as_str(),
76        };
77        f.write_str(value)
78    }
79}
80
81/// Image payload supported by model providers.
82#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
83pub enum ImageSource {
84    Bytes { media_type: String, data: Vec<u8> },
85    Url { url: String },
86}
87
88impl ImageSource {
89    pub fn bytes(media_type: impl Into<String>, data: impl Into<Vec<u8>>) -> Self {
90        Self::Bytes {
91            media_type: media_type.into(),
92            data: data.into(),
93        }
94    }
95
96    pub fn url(url: impl Into<String>) -> Self {
97        Self::Url { url: url.into() }
98    }
99}
100
101/// Tool result payloads supported by provider streams and history replay.
102#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
103#[serde(untagged)]
104pub enum ToolResultContent {
105    Text(String),
106    Structured(Value),
107}
108
109impl ToolResultContent {
110    pub fn text(value: impl Into<String>) -> Self {
111        Self::Text(value.into())
112    }
113
114    pub fn len(&self) -> usize {
115        match self {
116            Self::Text(text) => text.len(),
117            Self::Structured(value) => value.to_string().len(),
118        }
119    }
120
121    pub fn is_empty(&self) -> bool {
122        self.len() == 0
123    }
124
125    pub fn clear(&mut self) {
126        *self = Self::Text(String::new());
127    }
128
129    pub fn as_str(&self) -> &str {
130        match self {
131            Self::Text(text) => text.as_str(),
132            Self::Structured(_) => panic!("ToolResultContent::as_str requires text content"),
133        }
134    }
135
136    pub fn contains(&self, pattern: &str) -> bool {
137        match self {
138            Self::Text(text) => text.contains(pattern),
139            Self::Structured(value) => value.to_string().contains(pattern),
140        }
141    }
142
143    pub fn starts_with(&self, pattern: &str) -> bool {
144        match self {
145            Self::Text(text) => text.starts_with(pattern),
146            Self::Structured(value) => value.to_string().starts_with(pattern),
147        }
148    }
149
150    pub fn push_str(&mut self, value: &str) {
151        match self {
152            Self::Text(text) => text.push_str(value),
153            Self::Structured(existing) => {
154                let mut text = existing.to_string();
155                text.push_str(value);
156                *self = Self::Text(text);
157            }
158        }
159    }
160
161    pub fn to_display_string(&self) -> String {
162        match self {
163            Self::Text(text) => text.clone(),
164            Self::Structured(value) => value.to_string(),
165        }
166    }
167}
168
169impl Default for ToolResultContent {
170    fn default() -> Self {
171        Self::Text(String::new())
172    }
173}
174
175impl From<String> for ToolResultContent {
176    fn from(value: String) -> Self {
177        Self::Text(value)
178    }
179}
180
181impl Display for ToolResultContent {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.write_str(&self.to_display_string())
184    }
185}
186
187impl PartialEq<&str> for ToolResultContent {
188    fn eq(&self, other: &&str) -> bool {
189        self.to_display_string() == *other
190    }
191}
192
193impl PartialEq<str> for ToolResultContent {
194    fn eq(&self, other: &str) -> bool {
195        self.to_display_string() == other
196    }
197}
198
199impl PartialEq<ToolResultContent> for &str {
200    fn eq(&self, other: &ToolResultContent) -> bool {
201        *self == other.to_display_string()
202    }
203}
204
205impl PartialEq<ToolResultContent> for str {
206    fn eq(&self, other: &ToolResultContent) -> bool {
207        self == other.to_display_string()
208    }
209}
210
211impl From<&str> for ToolResultContent {
212    fn from(value: &str) -> Self {
213        Self::Text(value.to_string())
214    }
215}
216
217/// Provider-neutral hosted tool search action.
218#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
219pub struct HostedToolSearchCall {
220    pub id: String,
221    #[serde(default, skip_serializing_if = "Option::is_none")]
222    pub status: Option<String>,
223    #[serde(default, skip_serializing_if = "Option::is_none")]
224    pub query: Option<String>,
225}
226
227/// Provider-neutral hosted web search actions.
228#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
229#[serde(tag = "type", rename_all = "snake_case")]
230pub enum WebSearchAction {
231    Search {
232        #[serde(default, skip_serializing_if = "Option::is_none")]
233        query: Option<String>,
234        #[serde(default, skip_serializing_if = "Option::is_none")]
235        queries: Option<Vec<String>>,
236    },
237    OpenPage {
238        #[serde(default, skip_serializing_if = "Option::is_none")]
239        url: Option<String>,
240    },
241    FindInPage {
242        #[serde(default, skip_serializing_if = "Option::is_none")]
243        url: Option<String>,
244        #[serde(default, skip_serializing_if = "Option::is_none")]
245        pattern: Option<String>,
246    },
247}
248
249/// Provider-neutral hosted web search call.
250#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
251pub struct HostedWebSearchCall {
252    pub id: String,
253    #[serde(default, skip_serializing_if = "Option::is_none")]
254    pub status: Option<String>,
255    #[serde(default, skip_serializing_if = "Option::is_none")]
256    pub action: Option<WebSearchAction>,
257}
258
259/// Provider-neutral image generation result.
260#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
261#[serde(tag = "type", rename_all = "snake_case")]
262pub enum ImageGenerationResult {
263    Image { source: ImageSource },
264    ArtifactRef { artifact_id: String },
265}
266
267/// Provider-neutral image generation call.
268#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
269pub struct ImageGenerationCall {
270    pub id: String,
271    pub status: String,
272    #[serde(default, skip_serializing_if = "Option::is_none")]
273    pub revised_prompt: Option<String>,
274    #[serde(default, skip_serializing_if = "Option::is_none")]
275    pub result: Option<ImageGenerationResult>,
276}
277
278/// A provider-neutral content block exchanged with models.
279#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
280pub enum ContentBlock {
281    Text {
282        text: String,
283    },
284    Image {
285        source: ImageSource,
286    },
287    ToolUse {
288        id: String,
289        name: String,
290        input: Value,
291    },
292    ToolResult {
293        tool_use_id: String,
294        content: ToolResultContent,
295        is_error: bool,
296    },
297    HostedToolSearch {
298        call: HostedToolSearchCall,
299    },
300    HostedWebSearch {
301        call: HostedWebSearchCall,
302    },
303    ImageGeneration {
304        call: ImageGenerationCall,
305    },
306}
307
308impl ContentBlock {
309    pub fn text(text: impl Into<String>) -> Self {
310        Self::Text { text: text.into() }
311    }
312
313    pub fn image_bytes(media_type: impl Into<String>, data: impl Into<Vec<u8>>) -> Self {
314        Self::Image {
315            source: ImageSource::bytes(media_type, data),
316        }
317    }
318
319    pub fn image_url(url: impl Into<String>) -> Self {
320        Self::Image {
321            source: ImageSource::url(url),
322        }
323    }
324}
325
326/// Provider-neutral chat message content.
327#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
328pub struct Message {
329    pub role: Role,
330    pub content: Vec<ContentBlock>,
331}
332
333impl Message {
334    pub fn user(content: ContentBlock) -> Self {
335        Self {
336            role: Role::User,
337            content: vec![content],
338        }
339    }
340
341    pub fn assistant(content: ContentBlock) -> Self {
342        Self {
343            role: Role::Assistant,
344            content: vec![content],
345        }
346    }
347
348    pub fn unknown(role: impl Into<String>, content: ContentBlock) -> Self {
349        Self {
350            role: Role::Unknown(role.into()),
351            content: vec![content],
352        }
353    }
354
355    pub fn text(&self) -> String {
356        self.content
357            .iter()
358            .filter_map(|block| match block {
359                ContentBlock::Text { text } => Some(text.as_str()),
360                _ => None,
361            })
362            .collect::<Vec<_>>()
363            .join("")
364    }
365}
366
367/// Provider-neutral tool choice hint passed to model APIs.
368#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
369pub enum ToolChoice {
370    #[default]
371    Auto,
372    Any,
373    Tool {
374        name: String,
375    },
376}