Skip to main content

rig_ai_sdk/
message.rs

1//! AI SDK UIMessage type definitions
2//!
3//! Implements the AI SDK UIMessage format for receiving messages from frontend clients
4//! like assistant-ui.
5//!
6//! # Overview
7//!
8//! This module provides types for deserializing AI SDK messages from frontend clients.
9//! The format supports:
10//!
11//! - Rich text content with streaming states
12//! - Multi-modal messages (text, images, files)
13//! - Tool calls and tool results (both legacy and AI SDK 5.x formats)
14//! - Reasoning/model thinking blocks
15//! - Source references (URLs, documents)
16//! - Custom data attachments
17//!
18//! Reference: [AI SDK Transport](https://ai-sdk.dev/docs/ai-sdk-ui/transport)
19//!
20//! # Examples
21//!
22//! ```ignore
23//! use rig_ai_sdk::{UIMessage, UIMessagePart};
24//!
25//! let msg = UIMessage {
26//!     id: "msg-1".to_string(),
27//!     role: "user".to_string(),
28//!     parts: vec![
29//!         UIMessagePart::Text {
30//!             text: "Hello!".to_string(),
31//!             state: None,
32//!             provider_metadata: None,
33//!         }
34//!     ],
35//!     metadata: None,
36//! };
37//! ```
38
39use serde::{Deserialize, Deserializer, Serialize};
40use serde_json::Value;
41
42// ============================================================================
43// Shared types
44// ============================================================================
45
46/// Provider metadata from AI SDK (simplified version).
47///
48/// Contains optional metadata fields that may be included by the AI provider.
49#[derive(Debug, Clone, Deserialize, Serialize, Default)]
50#[serde(default)]
51pub struct ProviderMetadata {
52    /// Model ID being used
53    pub model_id: Option<String>,
54
55    /// Request ID from the provider
56    pub request_id: Option<String>,
57
58    /// Timestamp of the request
59    pub timestamp: Option<String>,
60
61    /// Additional provider-specific metadata fields
62    #[serde(flatten)]
63    pub extra: Value,
64}
65
66/// Streaming state for message parts.
67///
68/// Indicates whether content is being streamed or is complete.
69#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
70#[serde(rename_all = "lowercase")]
71pub enum PartState {
72    /// Content is currently being streamed
73    Streaming,
74
75    /// Content is complete
76    #[default]
77    Done,
78}
79
80// ============================================================================
81// AI SDK UIMessage Part types
82// ============================================================================
83
84/// Text content part
85///
86/// Plain text content with optional streaming state and provider metadata.
87#[derive(Debug, Clone, Deserialize, Serialize)]
88#[serde(rename_all = "camelCase")]
89pub struct TextPart {
90    pub text: String,
91    pub state: Option<PartState>,
92    pub provider_metadata: Option<ProviderMetadata>,
93}
94
95/// Reasoning/thinking part
96///
97/// Model reasoning or thinking block content (e.g., for o1-style models).
98#[derive(Debug, Clone, Deserialize, Serialize)]
99#[serde(rename_all = "camelCase")]
100pub struct ReasoningPart {
101    /// Reasoning/thinking content
102    pub text: String,
103    pub state: Option<PartState>,
104    pub provider_metadata: Option<ProviderMetadata>,
105}
106
107/// File attachment part (supports any media type)
108///
109/// Represents a file attachment with URL, media type, and optional filename.
110#[derive(Debug, Clone, Deserialize, Serialize)]
111#[serde(rename_all = "camelCase")]
112pub struct FilePart {
113    pub media_type: String,
114    pub url: String,
115    pub filename: Option<String>,
116    pub provider_metadata: Option<ProviderMetadata>,
117}
118
119/// Tool state as string literal (AI SDK 5.x format)
120///
121/// Represents the current state of a tool invocation.
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
123#[serde(rename_all = "kebab-case")]
124pub enum ToolState {
125    /// Input is being streamed
126    InputStreaming,
127    /// Input is available
128    InputAvailable,
129    /// Output is available
130    OutputAvailable,
131    /// Output error
132    OutputError,
133}
134
135/// Tool invocation part (AI SDK 5.x ToolUIPart format)
136///
137/// Modern tool call format with streaming state support and provider execution info.
138/// The `type` field uses dynamic format: "tool-{toolName}"
139#[derive(Debug, Clone, Deserialize, Serialize)]
140#[serde(rename_all = "camelCase")]
141pub struct ToolPart {
142    pub tool_call_id: String,
143    #[serde(default)]
144    pub tool_name: String,
145    pub state: ToolState,
146    pub title: Option<String>,
147    pub provider_executed: Option<bool>,
148    pub call_provider_metadata: Option<ProviderMetadata>,
149    pub preliminary: Option<bool>,
150    pub input: Option<Value>,
151    pub output: Option<Value>,
152    pub raw_input: Option<Value>,
153    pub error_text: Option<String>,
154}
155
156/// URL source reference
157///
158/// References a URL as a source for the response.
159#[derive(Debug, Clone, Deserialize, Serialize)]
160#[serde(rename_all = "camelCase")]
161pub struct SourceUrlPart {
162    pub source_id: String,
163    pub url: String,
164    pub title: Option<String>,
165    pub provider_metadata: Option<ProviderMetadata>,
166}
167
168/// Document source reference
169///
170/// References a document as a source with media type and title.
171#[derive(Debug, Clone, Deserialize, Serialize)]
172#[serde(rename_all = "camelCase")]
173pub struct SourceDocumentPart {
174    pub source_id: String,
175    pub media_type: String,
176    pub title: String,
177    pub filename: Option<String>,
178    pub provider_metadata: Option<ProviderMetadata>,
179}
180
181/// Custom data part (supports dynamic `data-{name}` tags)
182///
183/// Arbitrary custom data attachment with a type name.
184#[derive(Debug, Clone, Deserialize, Serialize)]
185#[serde(rename_all = "camelCase")]
186pub struct DataPart {
187    /// Data type name (the `{name}` in `data-{name}`)
188    pub data_type: String,
189    pub id: Option<String>,
190    pub data: Value,
191}
192
193/// AI SDK UIMessage part types.
194///
195/// A message consists of one or more parts, each representing a different
196/// type of content (text, file, tool call, reasoning, etc.).
197///
198/// The format uses a `type` field to distinguish between different part types:
199/// ```json
200/// { "type": "text", "text": "Hello" }
201/// { "type": "file", "mediaType": "image/png", "url": "..." }
202/// { "type": "tool-get_weather", "toolCallId": "...", "state": "input-available" }
203/// { "type": "data-usage", "data": {...} }
204/// ```
205#[derive(Debug, Clone, Serialize)]
206pub enum UIMessagePart {
207    /// Text content
208    Text(TextPart),
209    /// Reasoning/thinking content
210    Reasoning(ReasoningPart),
211    /// File attachment
212    File(FilePart),
213    /// Tool invocation (AI SDK 5.x ToolUIPart format)
214    Tool(ToolPart),
215    /// URL source
216    SourceUrl(SourceUrlPart),
217    /// Document source
218    SourceDocument(SourceDocumentPart),
219    /// Step start marker
220    StepStart,
221    /// Dynamic data part (data-{name} pattern)
222    Data(DataPart),
223}
224
225/// Helper struct for deserialization of UIMessagePart with type field
226///
227/// Uses `tag = "type"` to handle standard types: text, reasoning, file,
228/// source-url, source-document, step-start. Tool parts are handled separately
229/// due to dynamic type field (tool-{NAME}).
230#[derive(Debug, Clone, Deserialize)]
231#[serde(tag = "type", rename_all = "kebab-case")]
232enum UIMessagePartTagged {
233    Text(TextPart),
234    Reasoning(ReasoningPart),
235    File(FilePart),
236    SourceUrl(SourceUrlPart),
237    SourceDocument(SourceDocumentPart),
238    StepStart,
239}
240
241impl<'de> Deserialize<'de> for UIMessagePart {
242    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
243    where
244        D: Deserializer<'de>,
245    {
246        let raw = serde_json::Value::deserialize(deserializer)?;
247
248        // Check if this is a data-{name} or tool-{NAME} type
249        if let Some(t) = raw.get("type").and_then(|v| v.as_str()) {
250            if t.starts_with("data-") {
251                let data_part = DataPart {
252                    data_type: t.strip_prefix("data-").unwrap_or(t).to_string(),
253                    id: raw.get("id").and_then(|v| v.as_str()).map(String::from),
254                    data: raw.get("data").cloned().unwrap_or(Value::Null),
255                };
256                return Ok(UIMessagePart::Data(data_part));
257            }
258
259            // Handle tool-{NAME} type (AI SDK 5.x ToolUIPart format)
260            if t.starts_with("tool-") {
261                let tool_name = t.strip_prefix("tool-").unwrap_or(t).to_string();
262                let mut tool_part: ToolPart =
263                    serde_json::from_value(raw.clone()).map_err(serde::de::Error::custom)?;
264                // Set tool_name from the type field
265                tool_part.tool_name = tool_name;
266                return Ok(UIMessagePart::Tool(tool_part));
267            }
268        }
269
270        // Otherwise, try tagged deserialization
271        let tagged: Result<UIMessagePartTagged, _> =
272            serde_json::from_value(raw.clone()).map_err(serde::de::Error::custom);
273
274        match tagged {
275            Ok(tagged) => Ok(match tagged {
276                UIMessagePartTagged::Text(v) => UIMessagePart::Text(v),
277                UIMessagePartTagged::Reasoning(v) => UIMessagePart::Reasoning(v),
278                UIMessagePartTagged::File(v) => UIMessagePart::File(v),
279                UIMessagePartTagged::SourceUrl(v) => UIMessagePart::SourceUrl(v),
280                UIMessagePartTagged::SourceDocument(v) => UIMessagePart::SourceDocument(v),
281                UIMessagePartTagged::StepStart => UIMessagePart::StepStart,
282            }),
283            Err(e) => Err(e),
284        }
285    }
286}
287
288/// Media type classification
289///
290/// Categorizes media types into broader categories for easier handling.
291#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292pub enum MediaType {
293    /// Image media type
294    Image,
295
296    /// Audio media type
297    Audio,
298
299    /// Video media type
300    Video,
301
302    /// Document media type
303    Document,
304
305    /// Other/unknown media type
306    Other,
307}
308
309impl UIMessagePart {
310    /// Gets the text content if this is a `Text` part.
311    ///
312    /// # Example
313    ///
314    /// ```ignore
315    /// use rig_ai_sdk::UIMessagePart;
316    ///
317    /// let part = UIMessagePart::Text {
318    ///     text: "Hello".to_string(),
319    ///     state: None,
320    ///     provider_metadata: None,
321    /// };
322    /// assert_eq!(part.as_text(), Some("Hello"));
323    /// ```
324    pub fn as_text(&self) -> Option<&str> {
325        match self {
326            UIMessagePart::Text(p) => Some(&p.text),
327            _ => None,
328        }
329    }
330
331    /// Returns `true` if this is a `Text` part.
332    pub fn is_text(&self) -> bool {
333        matches!(self, UIMessagePart::Text(_))
334    }
335
336    /// Returns `true` if this is a `Reasoning` part.
337    pub fn is_reasoning(&self) -> bool {
338        matches!(self, UIMessagePart::Reasoning(_))
339    }
340
341    /// Returns `true` if this is a tool invocation.
342    pub fn is_tool(&self) -> bool {
343        matches!(self, UIMessagePart::Tool(_))
344    }
345
346    /// Gets the tool part if this is a Tool.
347    pub fn as_tool(&self) -> Option<&ToolPart> {
348        match self {
349            UIMessagePart::Tool(p) => Some(p),
350            _ => None,
351        }
352    }
353
354    /// Returns `true` if this is a `Data` part.
355    pub fn is_data(&self) -> bool {
356        matches!(self, UIMessagePart::Data(_))
357    }
358
359    /// Gets the data content if this is a `Data` part.
360    pub fn as_data(&self) -> Option<&DataPart> {
361        match self {
362            UIMessagePart::Data(p) => Some(p),
363            _ => None,
364        }
365    }
366
367    /// Gets the file content if this is a `File` part.
368    ///
369    /// Returns a tuple of `(media_type, url, optional_filename)`.
370    pub fn as_file(&self) -> Option<(&str, &str, Option<&String>)> {
371        match self {
372            UIMessagePart::File(p) => Some((&p.media_type, &p.url, p.filename.as_ref())),
373            _ => None,
374        }
375    }
376
377    /// Gets the streaming state if applicable.
378    ///
379    /// Returns `Some(state)` for `Text` and `Reasoning` parts, `None` otherwise.
380    pub fn state(&self) -> Option<PartState> {
381        match self {
382            UIMessagePart::Text(p) => p.state,
383            UIMessagePart::Reasoning(p) => p.state,
384            _ => None,
385        }
386    }
387
388    /// Parses the media type into a [`MediaType`] category.
389    ///
390    /// Returns `Some(MediaType)` for `File` and `SourceDocument` parts, `None` otherwise.
391    pub fn media_type_kind(&self) -> Option<MediaType> {
392        match self {
393            UIMessagePart::File(p) => {
394                if p.media_type.starts_with("image/") {
395                    Some(MediaType::Image)
396                } else if p.media_type.starts_with("audio/") {
397                    Some(MediaType::Audio)
398                } else if p.media_type.starts_with("video/") {
399                    Some(MediaType::Video)
400                } else if matches!(
401                    p.media_type.as_str(),
402                    "application/pdf"
403                        | "application/msword"
404                        | "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
405                        | "text/plain"
406                        | "text/csv"
407                        | "application/json"
408                ) {
409                    Some(MediaType::Document)
410                } else {
411                    Some(MediaType::Other)
412                }
413            }
414            UIMessagePart::SourceDocument(p) => {
415                if p.media_type.starts_with("image/") {
416                    Some(MediaType::Image)
417                } else if p.media_type.starts_with("audio/") {
418                    Some(MediaType::Audio)
419                } else if p.media_type.starts_with("video/") {
420                    Some(MediaType::Video)
421                } else if matches!(
422                    p.media_type.as_str(),
423                    "application/pdf"
424                        | "application/msword"
425                        | "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
426                        | "text/plain"
427                        | "text/csv"
428                        | "application/json"
429                ) {
430                    Some(MediaType::Document)
431                } else {
432                    Some(MediaType::Other)
433                }
434            }
435            _ => None,
436        }
437    }
438}
439
440// ============================================================================
441// AI SDK UIMessage
442// ============================================================================
443
444/// AI SDK UIMessage format
445///
446/// Represents a message in the AI SDK format with role, parts, and metadata.
447///
448/// Reference: [AI SDK Transport](https://ai-sdk.dev/docs/ai-sdk-ui/transport)
449#[derive(Debug, Clone, Deserialize)]
450pub struct UIMessage {
451    /// Unique message ID
452    pub id: String,
453
454    /// Message role: "user", "assistant", or "system"
455    pub role: String,
456
457    /// Optional message metadata
458    pub metadata: Option<Value>,
459
460    /// Message parts (content items)
461    pub parts: Vec<UIMessagePart>,
462}
463
464impl UIMessage {
465    /// Gets the concatenated text content from all `Text` parts.
466    ///
467    /// # Example
468    ///
469    /// ```ignore
470    /// use rig_ai_sdk::UIMessage;
471    ///
472    /// let msg = UIMessage {
473    ///     id: "1".to_string(),
474    ///     role: "user".to_string(),
475    ///     parts: vec![
476    ///         UIMessagePart::Text { text: "Hello, ".to_string(), state: None, provider_metadata: None },
477    ///         UIMessagePart::Text { text: "world!".to_string(), state: None, provider_metadata: None },
478    ///     ],
479    ///     metadata: None,
480    /// };
481    /// assert_eq!(msg.text(), "Hello, world!");
482    /// ```
483    pub fn text(&self) -> String {
484        self.parts
485            .iter()
486            .filter_map(|p| p.as_text())
487            .collect::<Vec<_>>()
488            .join("")
489    }
490
491    /// Returns `true` if the message role is "user".
492    pub fn is_user(&self) -> bool {
493        self.role == "user"
494    }
495
496    /// Returns `true` if the message role is "assistant".
497    pub fn is_assistant(&self) -> bool {
498        self.role == "assistant"
499    }
500
501    /// Returns `true` if the message role is "system".
502    pub fn is_system(&self) -> bool {
503        self.role == "system"
504    }
505
506    /// Gets all parts matching the given predicate.
507    ///
508    /// # Example
509    ///
510    /// ```ignore
511    /// use rig_ai_sdk::UIMessage;
512    ///
513    /// let tool_parts = msg.get_parts_by_type(|p| p.is_tool_call());
514    /// ```
515    pub fn get_parts_by_type<F>(&self, predicate: F) -> Vec<&UIMessagePart>
516    where
517        F: Fn(&UIMessagePart) -> bool,
518    {
519        self.parts.iter().filter(|p| predicate(p)).collect()
520    }
521
522    /// Returns `true` if the message contains streaming content.
523    ///
524    /// Checks if any part has `PartState::Streaming`.
525    pub fn has_streaming_content(&self) -> bool {
526        self.parts
527            .iter()
528            .any(|p| p.state() == Some(PartState::Streaming))
529    }
530
531    /// Returns `true` if the message contains tool invocations.
532    pub fn has_tool_calls(&self) -> bool {
533        self.parts.iter().any(|p| p.is_tool())
534    }
535
536    /// Returns `true` if the message contains file attachments.
537    pub fn has_files(&self) -> bool {
538        self.parts.iter().any(|p| p.as_file().is_some())
539    }
540}