ai_sdk_provider/language_model/prompt.rs
1use super::content::*;
2use crate::Error;
3use serde::{Deserialize, Deserializer, Serialize};
4use serde_json::Value;
5use std::ops::{Deref, DerefMut};
6
7/// A container for conversation messages to send to a language model.
8///
9/// The `Prompt` struct wraps a vector of `Message` objects and provides flexible
10/// conversion traits to support multiple input formats. This enables you to construct
11/// prompts from simple strings, structured message arrays, or individual messages.
12///
13/// # Construction Methods
14///
15/// `Prompt` can be created in several ways:
16/// - From a simple string: `Prompt::from("Hello")`
17/// - From a slice: `Prompt::from("Hello")`
18/// - From a Vec of messages: `Prompt::from(vec![message1, message2])`
19/// - From a single message: `Prompt::from(message)`
20/// - From JSON via TryFrom: `json_value.try_into()`
21///
22/// # Flexibility
23///
24/// The struct implements `Deref` and `DerefMut` to provide Vec-like access to the
25/// underlying messages. You can use standard Vec methods like `len()`, `push()`, and `iter()`.
26///
27/// # Usage
28///
29/// ```ignore
30/// // Simple text prompt
31/// let prompt: Prompt = "What is 2+2?".into();
32///
33/// // Structured messages
34/// let messages = vec![
35/// Message::System { content: "You are helpful".into() },
36/// Message::User { content: vec![UserContentPart::Text { text: "Hello".into() }] },
37/// ];
38/// let prompt: Prompt = messages.into();
39///
40/// // JSON conversion
41/// let json = json!({"role": "user", "content": "Hello"});
42/// let prompt: Prompt = json.try_into()?;
43/// ```
44#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
45pub struct Prompt(Vec<Message>);
46
47impl Prompt {
48 /// Creates a new empty prompt with no messages.
49 ///
50 /// # Returns
51 ///
52 /// A prompt containing an empty message vector that can be populated with `push()`.
53 pub fn new() -> Self {
54 Self(Vec::new())
55 }
56
57 /// Creates a prompt from a single message.
58 ///
59 /// # Arguments
60 ///
61 /// * `message` - The message to initialize the prompt with
62 ///
63 /// # Returns
64 ///
65 /// A prompt containing the provided message as its sole element.
66 ///
67 /// # Example
68 ///
69 /// ```ignore
70 /// let msg = Message::User { content: vec![UserContentPart::Text { text: "Hello".into() }] };
71 /// let prompt = Prompt::from_message(msg);
72 /// ```
73 pub fn from_message(message: Message) -> Self {
74 Self(vec![message])
75 }
76}
77
78/// Enables Vec-like access to the underlying message vector.
79///
80/// Through this implementation, you can use standard Vec methods directly on Prompt,
81/// such as `len()`, `push()`, `iter()`, and indexing operations.
82impl Deref for Prompt {
83 type Target = Vec<Message>;
84
85 fn deref(&self) -> &Self::Target {
86 &self.0
87 }
88}
89
90/// Enables mutable Vec-like access to the underlying message vector.
91impl DerefMut for Prompt {
92 fn deref_mut(&mut self) -> &mut Self::Target {
93 &mut self.0
94 }
95}
96
97/// Creates a prompt from a String, automatically treating it as a user message.
98///
99/// The string is converted into a User message with a single text content part.
100/// This provides a convenient way to create simple text prompts.
101impl From<String> for Prompt {
102 fn from(s: String) -> Self {
103 Prompt(vec![Message::User {
104 content: vec![UserContentPart::Text { text: s }],
105 }])
106 }
107}
108
109/// Creates a prompt from a string slice, automatically treating it as a user message.
110impl From<&str> for Prompt {
111 fn from(s: &str) -> Self {
112 Prompt::from(s.to_string())
113 }
114}
115
116/// Creates a prompt from a vector of messages.
117///
118/// Useful when you have multiple messages with different roles that you want to
119/// form a complete conversation history.
120impl From<Vec<Message>> for Prompt {
121 fn from(v: Vec<Message>) -> Self {
122 Prompt(v)
123 }
124}
125
126/// Creates a prompt from a single message.
127///
128/// Wraps the message in a vector to form a single-message prompt.
129impl From<Message> for Prompt {
130 fn from(m: Message) -> Self {
131 Prompt(vec![m])
132 }
133}
134
135/// Creates a prompt from a JSON value with flexible input formats.
136///
137/// This trait supports converting from various JSON structures:
138/// - A JSON string: converted to a User message with text content
139/// - A JSON object: deserialized as a single Message
140/// - A JSON array: deserialized as an array of Message objects
141///
142/// # Errors
143///
144/// Returns an error if the JSON structure doesn't match one of the supported formats
145/// or if deserialization fails.
146impl TryFrom<Value> for Prompt {
147 type Error = Error;
148
149 fn try_from(value: Value) -> std::result::Result<Self, Self::Error> {
150 match value {
151 Value::String(s) => Ok(Prompt::from(s)),
152
153 Value::Array(arr) => {
154 let messages: Vec<Message> = serde_json::from_value(Value::Array(arr))
155 .map_err(|e| format!("Invalid prompt array: {}", e))?;
156 Ok(Prompt(messages))
157 }
158
159 Value::Object(obj) => {
160 let message: Message = serde_json::from_value(Value::Object(obj))
161 .map_err(|e| format!("Invalid prompt object: {}", e))?;
162 Ok(Prompt(vec![message]))
163 }
164
165 _ => Err("JSON must be a string, object, or array of messages".into()),
166 }
167 }
168}
169
170/// Custom deserializer that handles both string and array representations of user content.
171///
172/// This deserializer provides flexibility when deserializing user message content.
173/// It can handle either a simple string (converted to a text part) or an array of
174/// structured content parts. This allows both `"Hello"` and `[{"type": "text", "text": "Hello"}]`
175/// to deserialize successfully.
176fn deserialize_user_content<'de, D>(
177 deserializer: D,
178) -> std::result::Result<Vec<UserContentPart>, D::Error>
179where
180 D: Deserializer<'de>,
181{
182 let value = Value::deserialize(deserializer)?;
183 match value {
184 Value::String(s) => Ok(vec![UserContentPart::Text { text: s }]),
185 Value::Array(arr) => {
186 serde_json::from_value(Value::Array(arr)).map_err(serde::de::Error::custom)
187 }
188 _ => Err(serde::de::Error::custom(
189 "User content must be a string or an array of content parts",
190 )),
191 }
192}
193
194/// A message in a conversation with a language model.
195///
196/// Messages represent different roles in the conversation: system instructions, user inputs,
197/// assistant responses, and tool results. Each message contains content appropriate to its role.
198///
199/// # Variants
200///
201/// * `System` - Provides instructions and context for how the model should behave
202/// * `User` - Input from the user/human, typically containing text and/or files
203/// * `Assistant` - Output from the language model, typically containing text and/or tool calls
204/// * `Tool` - Results from executing tools that the model requested
205///
206/// # Serialization
207///
208/// Messages are serialized with a `role` field that identifies their type. User message
209/// content is deserialized flexibly to accept either a string or an array of content parts.
210///
211/// # Usage in Prompts
212///
213/// Messages form the conversation history passed to the model. A typical conversation
214/// might follow this pattern:
215/// 1. System message (optional) - sets context
216/// 2. User message - the human's question
217/// 3. Assistant message - the model's response
218/// 4. Repeat steps 2-3 as needed for multi-turn conversations
219#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
220#[serde(tag = "role", rename_all = "lowercase")]
221pub enum Message {
222 /// System message providing instructions and context for model behavior.
223 ///
224 /// System messages are typically placed at the beginning of a conversation
225 /// to establish tone, guidelines, or special instructions for the model.
226 System {
227 /// Plain text instructions for the model
228 content: String,
229 },
230 /// Message from the human user.
231 ///
232 /// User messages contain the actual input - questions, requests, or statements
233 /// from the human. Content can be text, files, or a mix of both.
234 User {
235 /// Content parts, typically text and/or file references.
236 /// Deserialized flexibly to accept string or array formats.
237 #[serde(deserialize_with = "deserialize_user_content")]
238 content: Vec<UserContentPart>,
239 },
240 /// Message from the assistant (language model).
241 ///
242 /// Assistant messages contain the model's responses, which can include text,
243 /// reasoning, tool calls, or references to files.
244 Assistant {
245 /// Content parts generated by the model
246 content: Vec<AssistantContentPart>,
247 },
248 /// Message containing tool execution results.
249 ///
250 /// After the model requests a tool call, tool results are provided in a Tool
251 /// message so the model can see what happened and make follow-up decisions.
252 Tool {
253 /// Results from tool executions
254 content: Vec<ToolResultPart>,
255 },
256}
257
258/// Represents file content as either raw binary data or a URL reference.
259///
260/// This enum provides flexibility in how file content is transmitted. Binary data
261/// is useful for small files or when you have the data in memory, while URLs are
262/// more efficient for large files or remote resources.
263///
264/// # Variants
265///
266/// * `Binary` - Raw file bytes embedded in the message
267/// * `Url` - URL pointing to the file location (HTTP/HTTPS or other schemes)
268#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
269#[serde(untagged)]
270pub enum FileData {
271 /// Raw file bytes embedded in the message
272 Binary(Vec<u8>),
273 /// URL pointing to the file location
274 Url(String),
275}
276
277/// A content element within a user message.
278///
279/// User messages can contain multiple content parts, allowing for complex inputs
280/// that mix text with various types of media like images or audio files.
281///
282/// # Variants
283///
284/// * `Text` - Plain text content
285/// * `File` - Binary or URL-based file content with MIME type
286#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
287#[serde(tag = "type", rename_all = "lowercase")]
288pub enum UserContentPart {
289 /// Plain text content from the user
290 Text {
291 /// The actual text content
292 text: String,
293 },
294 /// File or media content from the user (images, audio, video, documents, etc.)
295 File {
296 /// The file data, either as binary bytes or as a URL
297 data: FileData,
298 /// MIME type identifying the file format (e.g., "image/jpeg", "audio/mp3")
299 media_type: String,
300 },
301}
302
303/// A content element within an assistant (model) message.
304///
305/// Assistant messages contain the model's response, which can include various types
306/// of content: text answers, reasoning processes, generated files, and tool invocations.
307///
308/// # Variants
309///
310/// * `Text` - Plain text response from the model
311/// * `Reasoning` - Internal reasoning from specialized models (e.g., o1)
312/// * `File` - Generated files or media
313/// * `ToolCall` - A request to invoke a tool or function
314/// * `ToolResult` - Results from a tool execution
315///
316/// # Processing Assistant Messages
317///
318/// When receiving an assistant message, you should examine each content part to determine
319/// what action to take. For example, if you encounter a ToolCall, you should execute the
320/// tool and include the result in a subsequent message.
321#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
322#[serde(tag = "type", rename_all = "kebab-case")]
323pub enum AssistantContentPart {
324 /// Text response from the model
325 Text(TextPart),
326 /// Reasoning process from the model (for reasoning models like o1)
327 Reasoning(ReasoningPart),
328 /// File or media generated by the model
329 File(FilePart),
330 /// Tool invocation request from the model
331 ToolCall(ToolCallPart),
332 /// Tool execution result provided by the model
333 ToolResult(ToolResultPart),
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use serde_json::json;
340
341 #[test]
342 fn test_message_system() {
343 let msg = Message::System {
344 content: "You are helpful".into(),
345 };
346 let json = serde_json::to_value(&msg).unwrap();
347 assert_eq!(json["role"], "system");
348 assert_eq!(json["content"], "You are helpful");
349 }
350
351 #[test]
352 fn test_message_user() {
353 let msg = Message::User {
354 content: vec![UserContentPart::Text {
355 text: "Hello".into(),
356 }],
357 };
358 let json = serde_json::to_value(&msg).unwrap();
359 assert_eq!(json["role"], "user");
360 assert_eq!(json["content"][0]["type"], "text");
361 }
362
363 #[test]
364 fn test_user_content_file_binary() {
365 let part = UserContentPart::File {
366 data: FileData::Binary(vec![1, 2, 3]),
367 media_type: "image/png".into(),
368 };
369 let json = serde_json::to_value(&part).unwrap();
370 assert_eq!(json["type"], "file");
371 assert_eq!(json["media_type"], "image/png");
372 }
373
374 #[test]
375 fn test_user_content_file_url() {
376 let part = UserContentPart::File {
377 data: FileData::Url("https://example.com/image.jpg".into()),
378 media_type: "image/jpeg".into(),
379 };
380 let json = serde_json::to_value(&part).unwrap();
381 assert_eq!(json["type"], "file");
382 assert_eq!(json["media_type"], "image/jpeg");
383 }
384
385 #[test]
386 fn test_prompt_struct() {
387 let mut prompt = Prompt::new();
388 assert!(prompt.is_empty());
389
390 prompt.push(Message::System {
391 content: "test".into(),
392 });
393 assert_eq!(prompt.len(), 1);
394 }
395
396 #[test]
397 fn test_prompt_from_string() {
398 let prompt: Prompt = "Hello".into();
399 assert_eq!(prompt.len(), 1);
400 match &prompt[0] {
401 Message::User { content } => match &content[0] {
402 UserContentPart::Text { text } => assert_eq!(text, "Hello"),
403 _ => panic!("Expected text content"),
404 },
405 _ => panic!("Expected user message"),
406 }
407 }
408
409 #[test]
410 fn test_prompt_from_vec() {
411 let v = vec![Message::System {
412 content: "s".into(),
413 }];
414 let prompt: Prompt = v.into();
415 assert_eq!(prompt.len(), 1);
416 }
417
418 #[test]
419 fn test_try_from_json_string() {
420 let json = json!("Hello world");
421 let prompt: Prompt = json.try_into().unwrap();
422 assert_eq!(prompt.len(), 1);
423 match &prompt[0] {
424 Message::User { content } => match &content[0] {
425 UserContentPart::Text { text } => assert_eq!(text, "Hello world"),
426 _ => panic!("Expected text content"),
427 },
428 _ => panic!("Expected user message"),
429 }
430 }
431
432 #[test]
433 fn test_try_from_json_object() {
434 let json = json!({
435 "role": "system",
436 "content": "You are helpful"
437 });
438 let prompt: Prompt = json.try_into().unwrap();
439 assert_eq!(prompt.len(), 1);
440 match &prompt[0] {
441 Message::System { content } => assert_eq!(content, "You are helpful"),
442 _ => panic!("Expected system message"),
443 }
444 }
445
446 #[test]
447 fn test_try_from_json_array() {
448 let json = json!([
449 {
450 "role": "system",
451 "content": "System"
452 },
453 {
454 "role": "user",
455 "content": [
456 {
457 "type": "text",
458 "text": "User"
459 }
460 ]
461 }
462 ]);
463 let prompt: Prompt = json.try_into().unwrap();
464 assert_eq!(prompt.len(), 2);
465 }
466
467 #[test]
468 fn test_try_from_invalid_json() {
469 let json = json!(123); // Invalid type
470 let result: Result<Prompt, _> = json.try_into();
471 assert!(result.is_err());
472 }
473
474 #[test]
475 fn test_message_user_string_content() {
476 let json = json!({
477 "role": "user",
478 "content": "Hello world"
479 });
480 let msg: Message = serde_json::from_value(json).unwrap();
481 match msg {
482 Message::User { content } => {
483 assert_eq!(content.len(), 1);
484 match &content[0] {
485 UserContentPart::Text { text } => assert_eq!(text, "Hello world"),
486 _ => panic!("Expected text content"),
487 }
488 }
489 _ => panic!("Expected user message"),
490 }
491 }
492
493 #[test]
494 fn test_message_user_array_content() {
495 let json = json!({
496 "role": "user",
497 "content": [
498 { "type": "text", "text": "Hello" },
499 { "type": "text", "text": "World" }
500 ]
501 });
502 let msg: Message = serde_json::from_value(json).unwrap();
503 match msg {
504 Message::User { content } => {
505 assert_eq!(content.len(), 2);
506 }
507 _ => panic!("Expected user message"),
508 }
509 }
510}