ai_sdk_provider/language_model/
prompt.rs

1use super::content::*;
2use crate::Error;
3use serde::{Deserialize, Deserializer, Serialize};
4use serde_json::Value;
5use std::ops::{Deref, DerefMut};
6
7/// A container for conversation messages to send to a language model.
8///
9/// The `Prompt` struct wraps a vector of `Message` objects and provides flexible
10/// conversion traits to support multiple input formats. This enables you to construct
11/// prompts from simple strings, structured message arrays, or individual messages.
12///
13/// # Construction Methods
14///
15/// `Prompt` can be created in several ways:
16/// - From a simple string: `Prompt::from("Hello")`
17/// - From a slice: `Prompt::from("Hello")`
18/// - From a Vec of messages: `Prompt::from(vec![message1, message2])`
19/// - From a single message: `Prompt::from(message)`
20/// - From JSON via TryFrom: `json_value.try_into()`
21///
22/// # Flexibility
23///
24/// The struct implements `Deref` and `DerefMut` to provide Vec-like access to the
25/// underlying messages. You can use standard Vec methods like `len()`, `push()`, and `iter()`.
26///
27/// # Usage
28///
29/// ```ignore
30/// // Simple text prompt
31/// let prompt: Prompt = "What is 2+2?".into();
32///
33/// // Structured messages
34/// let messages = vec![
35///     Message::System { content: "You are helpful".into() },
36///     Message::User { content: vec![UserContentPart::Text { text: "Hello".into() }] },
37/// ];
38/// let prompt: Prompt = messages.into();
39///
40/// // JSON conversion
41/// let json = json!({"role": "user", "content": "Hello"});
42/// let prompt: Prompt = json.try_into()?;
43/// ```
44#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
45pub struct Prompt(Vec<Message>);
46
47impl Prompt {
48    /// Creates a new empty prompt with no messages.
49    ///
50    /// # Returns
51    ///
52    /// A prompt containing an empty message vector that can be populated with `push()`.
53    pub fn new() -> Self {
54        Self(Vec::new())
55    }
56
57    /// Creates a prompt from a single message.
58    ///
59    /// # Arguments
60    ///
61    /// * `message` - The message to initialize the prompt with
62    ///
63    /// # Returns
64    ///
65    /// A prompt containing the provided message as its sole element.
66    ///
67    /// # Example
68    ///
69    /// ```ignore
70    /// let msg = Message::User { content: vec![UserContentPart::Text { text: "Hello".into() }] };
71    /// let prompt = Prompt::from_message(msg);
72    /// ```
73    pub fn from_message(message: Message) -> Self {
74        Self(vec![message])
75    }
76}
77
78/// Enables Vec-like access to the underlying message vector.
79///
80/// Through this implementation, you can use standard Vec methods directly on Prompt,
81/// such as `len()`, `push()`, `iter()`, and indexing operations.
82impl Deref for Prompt {
83    type Target = Vec<Message>;
84
85    fn deref(&self) -> &Self::Target {
86        &self.0
87    }
88}
89
90/// Enables mutable Vec-like access to the underlying message vector.
91impl DerefMut for Prompt {
92    fn deref_mut(&mut self) -> &mut Self::Target {
93        &mut self.0
94    }
95}
96
97/// Creates a prompt from a String, automatically treating it as a user message.
98///
99/// The string is converted into a User message with a single text content part.
100/// This provides a convenient way to create simple text prompts.
101impl From<String> for Prompt {
102    fn from(s: String) -> Self {
103        Prompt(vec![Message::User {
104            content: vec![UserContentPart::Text { text: s }],
105        }])
106    }
107}
108
109/// Creates a prompt from a string slice, automatically treating it as a user message.
110impl From<&str> for Prompt {
111    fn from(s: &str) -> Self {
112        Prompt::from(s.to_string())
113    }
114}
115
116/// Creates a prompt from a vector of messages.
117///
118/// Useful when you have multiple messages with different roles that you want to
119/// form a complete conversation history.
120impl From<Vec<Message>> for Prompt {
121    fn from(v: Vec<Message>) -> Self {
122        Prompt(v)
123    }
124}
125
126/// Creates a prompt from a single message.
127///
128/// Wraps the message in a vector to form a single-message prompt.
129impl From<Message> for Prompt {
130    fn from(m: Message) -> Self {
131        Prompt(vec![m])
132    }
133}
134
135/// Creates a prompt from a JSON value with flexible input formats.
136///
137/// This trait supports converting from various JSON structures:
138/// - A JSON string: converted to a User message with text content
139/// - A JSON object: deserialized as a single Message
140/// - A JSON array: deserialized as an array of Message objects
141///
142/// # Errors
143///
144/// Returns an error if the JSON structure doesn't match one of the supported formats
145/// or if deserialization fails.
146impl TryFrom<Value> for Prompt {
147    type Error = Error;
148
149    fn try_from(value: Value) -> std::result::Result<Self, Self::Error> {
150        match value {
151            Value::String(s) => Ok(Prompt::from(s)),
152
153            Value::Array(arr) => {
154                let messages: Vec<Message> = serde_json::from_value(Value::Array(arr))
155                    .map_err(|e| format!("Invalid prompt array: {}", e))?;
156                Ok(Prompt(messages))
157            }
158
159            Value::Object(obj) => {
160                let message: Message = serde_json::from_value(Value::Object(obj))
161                    .map_err(|e| format!("Invalid prompt object: {}", e))?;
162                Ok(Prompt(vec![message]))
163            }
164
165            _ => Err("JSON must be a string, object, or array of messages".into()),
166        }
167    }
168}
169
170/// Custom deserializer that handles both string and array representations of user content.
171///
172/// This deserializer provides flexibility when deserializing user message content.
173/// It can handle either a simple string (converted to a text part) or an array of
174/// structured content parts. This allows both `"Hello"` and `[{"type": "text", "text": "Hello"}]`
175/// to deserialize successfully.
176fn deserialize_user_content<'de, D>(
177    deserializer: D,
178) -> std::result::Result<Vec<UserContentPart>, D::Error>
179where
180    D: Deserializer<'de>,
181{
182    let value = Value::deserialize(deserializer)?;
183    match value {
184        Value::String(s) => Ok(vec![UserContentPart::Text { text: s }]),
185        Value::Array(arr) => {
186            serde_json::from_value(Value::Array(arr)).map_err(serde::de::Error::custom)
187        }
188        _ => Err(serde::de::Error::custom(
189            "User content must be a string or an array of content parts",
190        )),
191    }
192}
193
194/// A message in a conversation with a language model.
195///
196/// Messages represent different roles in the conversation: system instructions, user inputs,
197/// assistant responses, and tool results. Each message contains content appropriate to its role.
198///
199/// # Variants
200///
201/// * `System` - Provides instructions and context for how the model should behave
202/// * `User` - Input from the user/human, typically containing text and/or files
203/// * `Assistant` - Output from the language model, typically containing text and/or tool calls
204/// * `Tool` - Results from executing tools that the model requested
205///
206/// # Serialization
207///
208/// Messages are serialized with a `role` field that identifies their type. User message
209/// content is deserialized flexibly to accept either a string or an array of content parts.
210///
211/// # Usage in Prompts
212///
213/// Messages form the conversation history passed to the model. A typical conversation
214/// might follow this pattern:
215/// 1. System message (optional) - sets context
216/// 2. User message - the human's question
217/// 3. Assistant message - the model's response
218/// 4. Repeat steps 2-3 as needed for multi-turn conversations
219#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
220#[serde(tag = "role", rename_all = "lowercase")]
221pub enum Message {
222    /// System message providing instructions and context for model behavior.
223    ///
224    /// System messages are typically placed at the beginning of a conversation
225    /// to establish tone, guidelines, or special instructions for the model.
226    System {
227        /// Plain text instructions for the model
228        content: String,
229    },
230    /// Message from the human user.
231    ///
232    /// User messages contain the actual input - questions, requests, or statements
233    /// from the human. Content can be text, files, or a mix of both.
234    User {
235        /// Content parts, typically text and/or file references.
236        /// Deserialized flexibly to accept string or array formats.
237        #[serde(deserialize_with = "deserialize_user_content")]
238        content: Vec<UserContentPart>,
239    },
240    /// Message from the assistant (language model).
241    ///
242    /// Assistant messages contain the model's responses, which can include text,
243    /// reasoning, tool calls, or references to files.
244    Assistant {
245        /// Content parts generated by the model
246        content: Vec<AssistantContentPart>,
247    },
248    /// Message containing tool execution results.
249    ///
250    /// After the model requests a tool call, tool results are provided in a Tool
251    /// message so the model can see what happened and make follow-up decisions.
252    Tool {
253        /// Results from tool executions
254        content: Vec<ToolResultPart>,
255    },
256}
257
258/// Represents file content as either raw binary data or a URL reference.
259///
260/// This enum provides flexibility in how file content is transmitted. Binary data
261/// is useful for small files or when you have the data in memory, while URLs are
262/// more efficient for large files or remote resources.
263///
264/// # Variants
265///
266/// * `Binary` - Raw file bytes embedded in the message
267/// * `Url` - URL pointing to the file location (HTTP/HTTPS or other schemes)
268#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
269#[serde(untagged)]
270pub enum FileData {
271    /// Raw file bytes embedded in the message
272    Binary(Vec<u8>),
273    /// URL pointing to the file location
274    Url(String),
275}
276
277/// A content element within a user message.
278///
279/// User messages can contain multiple content parts, allowing for complex inputs
280/// that mix text with various types of media like images or audio files.
281///
282/// # Variants
283///
284/// * `Text` - Plain text content
285/// * `File` - Binary or URL-based file content with MIME type
286#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
287#[serde(tag = "type", rename_all = "lowercase")]
288pub enum UserContentPart {
289    /// Plain text content from the user
290    Text {
291        /// The actual text content
292        text: String,
293    },
294    /// File or media content from the user (images, audio, video, documents, etc.)
295    File {
296        /// The file data, either as binary bytes or as a URL
297        data: FileData,
298        /// MIME type identifying the file format (e.g., "image/jpeg", "audio/mp3")
299        media_type: String,
300    },
301}
302
303/// A content element within an assistant (model) message.
304///
305/// Assistant messages contain the model's response, which can include various types
306/// of content: text answers, reasoning processes, generated files, and tool invocations.
307///
308/// # Variants
309///
310/// * `Text` - Plain text response from the model
311/// * `Reasoning` - Internal reasoning from specialized models (e.g., o1)
312/// * `File` - Generated files or media
313/// * `ToolCall` - A request to invoke a tool or function
314/// * `ToolResult` - Results from a tool execution
315///
316/// # Processing Assistant Messages
317///
318/// When receiving an assistant message, you should examine each content part to determine
319/// what action to take. For example, if you encounter a ToolCall, you should execute the
320/// tool and include the result in a subsequent message.
321#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
322#[serde(tag = "type", rename_all = "kebab-case")]
323pub enum AssistantContentPart {
324    /// Text response from the model
325    Text(TextPart),
326    /// Reasoning process from the model (for reasoning models like o1)
327    Reasoning(ReasoningPart),
328    /// File or media generated by the model
329    File(FilePart),
330    /// Tool invocation request from the model
331    ToolCall(ToolCallPart),
332    /// Tool execution result provided by the model
333    ToolResult(ToolResultPart),
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use serde_json::json;
340
341    #[test]
342    fn test_message_system() {
343        let msg = Message::System {
344            content: "You are helpful".into(),
345        };
346        let json = serde_json::to_value(&msg).unwrap();
347        assert_eq!(json["role"], "system");
348        assert_eq!(json["content"], "You are helpful");
349    }
350
351    #[test]
352    fn test_message_user() {
353        let msg = Message::User {
354            content: vec![UserContentPart::Text {
355                text: "Hello".into(),
356            }],
357        };
358        let json = serde_json::to_value(&msg).unwrap();
359        assert_eq!(json["role"], "user");
360        assert_eq!(json["content"][0]["type"], "text");
361    }
362
363    #[test]
364    fn test_user_content_file_binary() {
365        let part = UserContentPart::File {
366            data: FileData::Binary(vec![1, 2, 3]),
367            media_type: "image/png".into(),
368        };
369        let json = serde_json::to_value(&part).unwrap();
370        assert_eq!(json["type"], "file");
371        assert_eq!(json["media_type"], "image/png");
372    }
373
374    #[test]
375    fn test_user_content_file_url() {
376        let part = UserContentPart::File {
377            data: FileData::Url("https://example.com/image.jpg".into()),
378            media_type: "image/jpeg".into(),
379        };
380        let json = serde_json::to_value(&part).unwrap();
381        assert_eq!(json["type"], "file");
382        assert_eq!(json["media_type"], "image/jpeg");
383    }
384
385    #[test]
386    fn test_prompt_struct() {
387        let mut prompt = Prompt::new();
388        assert!(prompt.is_empty());
389
390        prompt.push(Message::System {
391            content: "test".into(),
392        });
393        assert_eq!(prompt.len(), 1);
394    }
395
396    #[test]
397    fn test_prompt_from_string() {
398        let prompt: Prompt = "Hello".into();
399        assert_eq!(prompt.len(), 1);
400        match &prompt[0] {
401            Message::User { content } => match &content[0] {
402                UserContentPart::Text { text } => assert_eq!(text, "Hello"),
403                _ => panic!("Expected text content"),
404            },
405            _ => panic!("Expected user message"),
406        }
407    }
408
409    #[test]
410    fn test_prompt_from_vec() {
411        let v = vec![Message::System {
412            content: "s".into(),
413        }];
414        let prompt: Prompt = v.into();
415        assert_eq!(prompt.len(), 1);
416    }
417
418    #[test]
419    fn test_try_from_json_string() {
420        let json = json!("Hello world");
421        let prompt: Prompt = json.try_into().unwrap();
422        assert_eq!(prompt.len(), 1);
423        match &prompt[0] {
424            Message::User { content } => match &content[0] {
425                UserContentPart::Text { text } => assert_eq!(text, "Hello world"),
426                _ => panic!("Expected text content"),
427            },
428            _ => panic!("Expected user message"),
429        }
430    }
431
432    #[test]
433    fn test_try_from_json_object() {
434        let json = json!({
435            "role": "system",
436            "content": "You are helpful"
437        });
438        let prompt: Prompt = json.try_into().unwrap();
439        assert_eq!(prompt.len(), 1);
440        match &prompt[0] {
441            Message::System { content } => assert_eq!(content, "You are helpful"),
442            _ => panic!("Expected system message"),
443        }
444    }
445
446    #[test]
447    fn test_try_from_json_array() {
448        let json = json!([
449            {
450                "role": "system",
451                "content": "System"
452            },
453            {
454                "role": "user",
455                "content": [
456                    {
457                        "type": "text",
458                        "text": "User"
459                    }
460                ]
461            }
462        ]);
463        let prompt: Prompt = json.try_into().unwrap();
464        assert_eq!(prompt.len(), 2);
465    }
466
467    #[test]
468    fn test_try_from_invalid_json() {
469        let json = json!(123); // Invalid type
470        let result: Result<Prompt, _> = json.try_into();
471        assert!(result.is_err());
472    }
473
474    #[test]
475    fn test_message_user_string_content() {
476        let json = json!({
477            "role": "user",
478            "content": "Hello world"
479        });
480        let msg: Message = serde_json::from_value(json).unwrap();
481        match msg {
482            Message::User { content } => {
483                assert_eq!(content.len(), 1);
484                match &content[0] {
485                    UserContentPart::Text { text } => assert_eq!(text, "Hello world"),
486                    _ => panic!("Expected text content"),
487                }
488            }
489            _ => panic!("Expected user message"),
490        }
491    }
492
493    #[test]
494    fn test_message_user_array_content() {
495        let json = json!({
496            "role": "user",
497            "content": [
498                { "type": "text", "text": "Hello" },
499                { "type": "text", "text": "World" }
500            ]
501        });
502        let msg: Message = serde_json::from_value(json).unwrap();
503        match msg {
504            Message::User { content } => {
505                assert_eq!(content.len(), 2);
506            }
507            _ => panic!("Expected user message"),
508        }
509    }
510}