rig/completion/
message.rs

1use std::{convert::Infallible, str::FromStr};
2
3use crate::OneOrMany;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use super::CompletionError;
8
9// ================================================================
10// Message models
11// ================================================================
12
13/// A message represents a run of input (user) and output (assistant).
14/// Each message type (based on it's `role`) can contain a atleast one bit of content such as text,
15///  images, audio, documents, or tool related information. While each message type can contain
16///  multiple content, most often, you'll only see one content type per message
17///  (an image w/ a description, etc).
18///
19/// Each provider is responsible with converting the generic message into it's provider specific
20///  type using `From` or `TryFrom` traits. Since not every provider supports every feature, the
21///  conversion can be lossy (providing an image might be discarded for a non-image supporting
22///  provider) though the message being converted back and forth should always be the same.
23#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
24#[serde(tag = "role", rename_all = "lowercase")]
25pub enum Message {
26    /// User message containing one or more content types defined by `UserContent`.
27    User { content: OneOrMany<UserContent> },
28
29    /// Assistant message containing one or more content types defined by `AssistantContent`.
30    Assistant {
31        content: OneOrMany<AssistantContent>,
32    },
33}
34
35/// Describes the content of a message, which can be text, a tool result, an image, audio, or
36///  a document. Dependent on provider supporting the content type. Multimedia content is generally
37///  base64 (defined by it's format) encoded but additionally supports urls (for some providers).
38#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
39#[serde(tag = "type", rename_all = "lowercase")]
40pub enum UserContent {
41    Text(Text),
42    ToolResult(ToolResult),
43    Image(Image),
44    Audio(Audio),
45    Document(Document),
46}
47
48/// Describes responses from a provider which is either text or a tool call.
49#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
50#[serde(untagged)]
51pub enum AssistantContent {
52    Text(Text),
53    ToolCall(ToolCall),
54}
55
56/// Tool result content containing information about a tool call and it's resulting content.
57#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
58pub struct ToolResult {
59    pub id: String,
60    pub content: OneOrMany<ToolResultContent>,
61}
62
63/// Describes the content of a tool result, which can be text or an image.
64#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
65pub enum ToolResultContent {
66    Text(Text),
67    Image(Image),
68}
69
70/// Describes a tool call with an id and function to call, generally produced by a provider.
71#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
72pub struct ToolCall {
73    pub id: String,
74    pub function: ToolFunction,
75}
76
77/// Describes a tool function to call with a name and arguments, generally produced by a provider.
78#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
79pub struct ToolFunction {
80    pub name: String,
81    pub arguments: serde_json::Value,
82}
83
84// ================================================================
85// Base content models
86// ================================================================
87
88/// Basic text content.
89#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
90pub struct Text {
91    pub text: String,
92}
93
94/// Image content containing image data and metadata about it.
95#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
96pub struct Image {
97    pub data: String,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub format: Option<ContentFormat>,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub media_type: Option<ImageMediaType>,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub detail: Option<ImageDetail>,
104}
105
106/// Audio content containing audio data and metadata about it.
107#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
108pub struct Audio {
109    pub data: String,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub format: Option<ContentFormat>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub media_type: Option<AudioMediaType>,
114}
115
116/// Document content containing document data and metadata about it.
117#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
118pub struct Document {
119    pub data: String,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub format: Option<ContentFormat>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub media_type: Option<DocumentMediaType>,
124}
125
126/// Describes the format of the content, which can be base64 or string.
127#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
128#[serde(rename_all = "lowercase")]
129pub enum ContentFormat {
130    #[default]
131    Base64,
132    String,
133}
134
135/// Helper enum that tracks the media type of the content.
136#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
137pub enum MediaType {
138    Image(ImageMediaType),
139    Audio(AudioMediaType),
140    Document(DocumentMediaType),
141}
142
143/// Describes the image media type of the content. Not every provider supports every media type.
144/// Convertible to and from MIME type strings.
145#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
146#[serde(rename_all = "lowercase")]
147pub enum ImageMediaType {
148    JPEG,
149    PNG,
150    GIF,
151    WEBP,
152    HEIC,
153    HEIF,
154    SVG,
155}
156
157/// Describes the document media type of the content. Not every provider supports every media type.
158/// Includes also programming languages as document types for providers who support code running.
159/// Convertible to and from MIME type strings.
160#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
161#[serde(rename_all = "lowercase")]
162pub enum DocumentMediaType {
163    PDF,
164    TXT,
165    RTF,
166    HTML,
167    CSS,
168    MARKDOWN,
169    CSV,
170    XML,
171    Javascript,
172    Python,
173}
174
175/// Describes the audio media type of the content. Not every provider supports every media type.
176/// Convertible to and from MIME type strings.
177#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
178#[serde(rename_all = "lowercase")]
179pub enum AudioMediaType {
180    WAV,
181    MP3,
182    AIFF,
183    AAC,
184    OGG,
185    FLAC,
186}
187
188/// Describes the detail of the image content, which can be low, high, or auto (open-ai specific).
189#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
190#[serde(rename_all = "lowercase")]
191pub enum ImageDetail {
192    Low,
193    High,
194    #[default]
195    Auto,
196}
197
198// ================================================================
199// Impl. for message models
200// ================================================================
201
202impl Message {
203    /// This helper method is primarily used to extract the first string prompt from a `Message`.
204    /// Since `Message` might have more than just text content, we need to find the first text.
205    pub(crate) fn rag_text(&self) -> Option<String> {
206        match self {
207            Message::User { content } => {
208                for item in content.iter() {
209                    if let UserContent::Text(Text { text }) = item {
210                        return Some(text.clone());
211                    }
212                }
213                None
214            }
215            _ => None,
216        }
217    }
218
219    /// Helper constructor to make creating user messages easier.
220    pub fn user(text: impl Into<String>) -> Self {
221        Message::User {
222            content: OneOrMany::one(UserContent::text(text)),
223        }
224    }
225
226    /// Helper constructor to make creating assistant messages easier.
227    pub fn assistant(text: impl Into<String>) -> Self {
228        Message::Assistant {
229            content: OneOrMany::one(AssistantContent::text(text)),
230        }
231    }
232
233    /// Helper constructor to make creating tool result messages easier.
234    pub fn tool_result(id: impl Into<String>, content: impl Into<String>) -> Self {
235        Message::User {
236            content: OneOrMany::one(UserContent::ToolResult(ToolResult {
237                id: id.into(),
238                content: OneOrMany::one(ToolResultContent::text(content)),
239            })),
240        }
241    }
242}
243
244impl UserContent {
245    /// Helper constructor to make creating user text content easier.
246    pub fn text(text: impl Into<String>) -> Self {
247        UserContent::Text(text.into().into())
248    }
249
250    /// Helper constructor to make creating user image content easier.
251    pub fn image(
252        data: impl Into<String>,
253        format: Option<ContentFormat>,
254        media_type: Option<ImageMediaType>,
255        detail: Option<ImageDetail>,
256    ) -> Self {
257        UserContent::Image(Image {
258            data: data.into(),
259            format,
260            media_type,
261            detail,
262        })
263    }
264
265    /// Helper constructor to make creating user audio content easier.
266    pub fn audio(
267        data: impl Into<String>,
268        format: Option<ContentFormat>,
269        media_type: Option<AudioMediaType>,
270    ) -> Self {
271        UserContent::Audio(Audio {
272            data: data.into(),
273            format,
274            media_type,
275        })
276    }
277
278    /// Helper constructor to make creating user document content easier.
279    pub fn document(
280        data: impl Into<String>,
281        format: Option<ContentFormat>,
282        media_type: Option<DocumentMediaType>,
283    ) -> Self {
284        UserContent::Document(Document {
285            data: data.into(),
286            format,
287            media_type,
288        })
289    }
290
291    /// Helper constructor to make creating user tool result content easier.
292    pub fn tool_result(id: impl Into<String>, content: OneOrMany<ToolResultContent>) -> Self {
293        UserContent::ToolResult(ToolResult {
294            id: id.into(),
295            content,
296        })
297    }
298}
299
300impl AssistantContent {
301    /// Helper constructor to make creating assistant text content easier.
302    pub fn text(text: impl Into<String>) -> Self {
303        AssistantContent::Text(text.into().into())
304    }
305
306    /// Helper constructor to make creating assistant tool call content easier.
307    pub fn tool_call(
308        id: impl Into<String>,
309        name: impl Into<String>,
310        arguments: serde_json::Value,
311    ) -> Self {
312        AssistantContent::ToolCall(ToolCall {
313            id: id.into(),
314            function: ToolFunction {
315                name: name.into(),
316                arguments,
317            },
318        })
319    }
320}
321
322impl ToolResultContent {
323    /// Helper constructor to make creating tool result text content easier.
324    pub fn text(text: impl Into<String>) -> Self {
325        ToolResultContent::Text(text.into().into())
326    }
327
328    /// Helper constructor to make creating tool result image content easier.
329    pub fn image(
330        data: impl Into<String>,
331        format: Option<ContentFormat>,
332        media_type: Option<ImageMediaType>,
333        detail: Option<ImageDetail>,
334    ) -> Self {
335        ToolResultContent::Image(Image {
336            data: data.into(),
337            format,
338            media_type,
339            detail,
340        })
341    }
342}
343
344/// Trait for converting between MIME types and media types.
345pub trait MimeType {
346    fn from_mime_type(mime_type: &str) -> Option<Self>
347    where
348        Self: Sized;
349    fn to_mime_type(&self) -> &'static str;
350}
351
352impl MimeType for MediaType {
353    fn from_mime_type(mime_type: &str) -> Option<Self> {
354        ImageMediaType::from_mime_type(mime_type)
355            .map(MediaType::Image)
356            .or_else(|| {
357                DocumentMediaType::from_mime_type(mime_type)
358                    .map(MediaType::Document)
359                    .or_else(|| AudioMediaType::from_mime_type(mime_type).map(MediaType::Audio))
360            })
361    }
362
363    fn to_mime_type(&self) -> &'static str {
364        match self {
365            MediaType::Image(media_type) => media_type.to_mime_type(),
366            MediaType::Audio(media_type) => media_type.to_mime_type(),
367            MediaType::Document(media_type) => media_type.to_mime_type(),
368        }
369    }
370}
371
372impl MimeType for ImageMediaType {
373    fn from_mime_type(mime_type: &str) -> Option<Self> {
374        match mime_type {
375            "image/jpeg" => Some(ImageMediaType::JPEG),
376            "image/png" => Some(ImageMediaType::PNG),
377            "image/gif" => Some(ImageMediaType::GIF),
378            "image/webp" => Some(ImageMediaType::WEBP),
379            "image/heic" => Some(ImageMediaType::HEIC),
380            "image/heif" => Some(ImageMediaType::HEIF),
381            "image/svg+xml" => Some(ImageMediaType::SVG),
382            _ => None,
383        }
384    }
385
386    fn to_mime_type(&self) -> &'static str {
387        match self {
388            ImageMediaType::JPEG => "image/jpeg",
389            ImageMediaType::PNG => "image/png",
390            ImageMediaType::GIF => "image/gif",
391            ImageMediaType::WEBP => "image/webp",
392            ImageMediaType::HEIC => "image/heic",
393            ImageMediaType::HEIF => "image/heif",
394            ImageMediaType::SVG => "image/svg+xml",
395        }
396    }
397}
398
399impl MimeType for DocumentMediaType {
400    fn from_mime_type(mime_type: &str) -> Option<Self> {
401        match mime_type {
402            "application/pdf" => Some(DocumentMediaType::PDF),
403            "text/plain" => Some(DocumentMediaType::TXT),
404            "text/rtf" => Some(DocumentMediaType::RTF),
405            "text/html" => Some(DocumentMediaType::HTML),
406            "text/css" => Some(DocumentMediaType::CSS),
407            "text/md" | "text/markdown" => Some(DocumentMediaType::MARKDOWN),
408            "text/csv" => Some(DocumentMediaType::CSV),
409            "text/xml" => Some(DocumentMediaType::XML),
410            "application/x-javascript" | "text/x-javascript" => Some(DocumentMediaType::Javascript),
411            "application/x-python" | "text/x-python" => Some(DocumentMediaType::Python),
412            _ => None,
413        }
414    }
415
416    fn to_mime_type(&self) -> &'static str {
417        match self {
418            DocumentMediaType::PDF => "application/pdf",
419            DocumentMediaType::TXT => "text/plain",
420            DocumentMediaType::RTF => "text/rtf",
421            DocumentMediaType::HTML => "text/html",
422            DocumentMediaType::CSS => "text/css",
423            DocumentMediaType::MARKDOWN => "text/markdown",
424            DocumentMediaType::CSV => "text/csv",
425            DocumentMediaType::XML => "text/xml",
426            DocumentMediaType::Javascript => "application/x-javascript",
427            DocumentMediaType::Python => "application/x-python",
428        }
429    }
430}
431
432impl MimeType for AudioMediaType {
433    fn from_mime_type(mime_type: &str) -> Option<Self> {
434        match mime_type {
435            "audio/wav" => Some(AudioMediaType::WAV),
436            "audio/mp3" => Some(AudioMediaType::MP3),
437            "audio/aiff" => Some(AudioMediaType::AIFF),
438            "audio/aac" => Some(AudioMediaType::AAC),
439            "audio/ogg" => Some(AudioMediaType::OGG),
440            "audio/flac" => Some(AudioMediaType::FLAC),
441            _ => None,
442        }
443    }
444
445    fn to_mime_type(&self) -> &'static str {
446        match self {
447            AudioMediaType::WAV => "audio/wav",
448            AudioMediaType::MP3 => "audio/mp3",
449            AudioMediaType::AIFF => "audio/aiff",
450            AudioMediaType::AAC => "audio/aac",
451            AudioMediaType::OGG => "audio/ogg",
452            AudioMediaType::FLAC => "audio/flac",
453        }
454    }
455}
456
457impl std::str::FromStr for ImageDetail {
458    type Err = ();
459
460    fn from_str(s: &str) -> Result<Self, Self::Err> {
461        match s.to_lowercase().as_str() {
462            "low" => Ok(ImageDetail::Low),
463            "high" => Ok(ImageDetail::High),
464            "auto" => Ok(ImageDetail::Auto),
465            _ => Err(()),
466        }
467    }
468}
469
470// ================================================================
471// FromStr, From<String>, and From<&str> impls
472// ================================================================
473
474impl From<String> for Text {
475    fn from(text: String) -> Self {
476        Text { text }
477    }
478}
479
480impl From<&String> for Text {
481    fn from(text: &String) -> Self {
482        text.to_owned().into()
483    }
484}
485
486impl From<&str> for Text {
487    fn from(text: &str) -> Self {
488        text.to_owned().into()
489    }
490}
491
492impl FromStr for Text {
493    type Err = Infallible;
494
495    fn from_str(s: &str) -> Result<Self, Self::Err> {
496        Ok(s.into())
497    }
498}
499
500impl From<String> for Message {
501    fn from(text: String) -> Self {
502        Message::User {
503            content: OneOrMany::one(UserContent::Text(text.into())),
504        }
505    }
506}
507
508impl From<&str> for Message {
509    fn from(text: &str) -> Self {
510        Message::User {
511            content: OneOrMany::one(UserContent::Text(text.into())),
512        }
513    }
514}
515
516impl From<&String> for Message {
517    fn from(text: &String) -> Self {
518        Message::User {
519            content: OneOrMany::one(UserContent::Text(text.into())),
520        }
521    }
522}
523
524impl From<Text> for Message {
525    fn from(text: Text) -> Self {
526        Message::User {
527            content: OneOrMany::one(UserContent::Text(text)),
528        }
529    }
530}
531
532impl From<Image> for Message {
533    fn from(image: Image) -> Self {
534        Message::User {
535            content: OneOrMany::one(UserContent::Image(image)),
536        }
537    }
538}
539
540impl From<Audio> for Message {
541    fn from(audio: Audio) -> Self {
542        Message::User {
543            content: OneOrMany::one(UserContent::Audio(audio)),
544        }
545    }
546}
547
548impl From<Document> for Message {
549    fn from(document: Document) -> Self {
550        Message::User {
551            content: OneOrMany::one(UserContent::Document(document)),
552        }
553    }
554}
555
556impl From<String> for ToolResultContent {
557    fn from(text: String) -> Self {
558        ToolResultContent::text(text)
559    }
560}
561
562impl From<String> for AssistantContent {
563    fn from(text: String) -> Self {
564        AssistantContent::text(text)
565    }
566}
567
568impl From<String> for UserContent {
569    fn from(text: String) -> Self {
570        UserContent::text(text)
571    }
572}
573
574impl From<AssistantContent> for Message {
575    fn from(content: AssistantContent) -> Self {
576        Message::Assistant {
577            content: OneOrMany::one(content),
578        }
579    }
580}
581
582impl From<UserContent> for Message {
583    fn from(content: UserContent) -> Self {
584        Message::User {
585            content: OneOrMany::one(content),
586        }
587    }
588}
589
590impl From<OneOrMany<AssistantContent>> for Message {
591    fn from(content: OneOrMany<AssistantContent>) -> Self {
592        Message::Assistant { content }
593    }
594}
595
596impl From<OneOrMany<UserContent>> for Message {
597    fn from(content: OneOrMany<UserContent>) -> Self {
598        Message::User { content }
599    }
600}
601
602impl From<ToolCall> for Message {
603    fn from(tool_call: ToolCall) -> Self {
604        Message::Assistant {
605            content: OneOrMany::one(AssistantContent::ToolCall(tool_call)),
606        }
607    }
608}
609
610impl From<ToolResult> for Message {
611    fn from(tool_result: ToolResult) -> Self {
612        Message::User {
613            content: OneOrMany::one(UserContent::ToolResult(tool_result)),
614        }
615    }
616}
617
618impl From<ToolResultContent> for Message {
619    fn from(tool_result_content: ToolResultContent) -> Self {
620        Message::User {
621            content: OneOrMany::one(UserContent::ToolResult(ToolResult {
622                id: String::new(),
623                content: OneOrMany::one(tool_result_content),
624            })),
625        }
626    }
627}
628
629// ================================================================
630// Error types
631// ================================================================
632
633/// Error type to represent issues with converting messages to and from specific provider messages.
634#[derive(Debug, Error)]
635pub enum MessageError {
636    #[error("Message conversion error: {0}")]
637    ConversionError(String),
638}
639
640impl From<MessageError> for CompletionError {
641    fn from(error: MessageError) -> Self {
642        CompletionError::RequestError(error.into())
643    }
644}