rig/completion/
message.rs

1use std::{convert::Infallible, str::FromStr};
2
3use crate::OneOrMany;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use super::CompletionError;
8
9// ================================================================
10// Message models
11// ================================================================
12
13/// A message represents a run of input (user) and output (assistant).
14/// Each message type (based on it's `role`) can contain a atleast one bit of content such as text,
15///  images, audio, documents, or tool related information. While each message type can contain
16///  multiple content, most often, you'll only see one content type per message
17///  (an image w/ a description, etc).
18///
19/// Each provider is responsible with converting the generic message into it's provider specific
20///  type using `From` or `TryFrom` traits. Since not every provider supports every feature, the
21///  conversion can be lossy (providing an image might be discarded for a non-image supporting
22///  provider) though the message being converted back and forth should always be the same.
23#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
24#[serde(tag = "role", rename_all = "lowercase")]
25pub enum Message {
26    /// User message containing one or more content types defined by `UserContent`.
27    User { content: OneOrMany<UserContent> },
28
29    /// Assistant message containing one or more content types defined by `AssistantContent`.
30    Assistant {
31        id: Option<String>,
32        content: OneOrMany<AssistantContent>,
33    },
34}
35
36/// Describes the content of a message, which can be text, a tool result, an image, audio, or
37///  a document. Dependent on provider supporting the content type. Multimedia content is generally
38///  base64 (defined by it's format) encoded but additionally supports urls (for some providers).
39#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
40#[serde(tag = "type", rename_all = "lowercase")]
41pub enum UserContent {
42    Text(Text),
43    ToolResult(ToolResult),
44    Image(Image),
45    Audio(Audio),
46    Document(Document),
47}
48
49/// Describes responses from a provider which is either text or a tool call.
50#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
51#[serde(untagged)]
52pub enum AssistantContent {
53    Text(Text),
54    ToolCall(ToolCall),
55    Reasoning(Reasoning),
56}
57
58#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
59pub struct Reasoning {
60    pub reasoning: String,
61}
62
63/// Tool result content containing information about a tool call and it's resulting content.
64#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
65pub struct ToolResult {
66    pub id: String,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub call_id: Option<String>,
69    pub content: OneOrMany<ToolResultContent>,
70}
71
72/// Describes the content of a tool result, which can be text or an image.
73#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
74#[serde(tag = "type", rename_all = "lowercase")]
75pub enum ToolResultContent {
76    Text(Text),
77    Image(Image),
78}
79
80/// Describes a tool call with an id and function to call, generally produced by a provider.
81#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
82pub struct ToolCall {
83    pub id: String,
84    pub call_id: Option<String>,
85    pub function: ToolFunction,
86}
87
88/// Describes a tool function to call with a name and arguments, generally produced by a provider.
89#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
90pub struct ToolFunction {
91    pub name: String,
92    pub arguments: serde_json::Value,
93}
94
95// ================================================================
96// Base content models
97// ================================================================
98
99/// Basic text content.
100#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
101pub struct Text {
102    pub text: String,
103}
104
105/// Image content containing image data and metadata about it.
106#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
107pub struct Image {
108    pub data: String,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub format: Option<ContentFormat>,
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub media_type: Option<ImageMediaType>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub detail: Option<ImageDetail>,
115}
116
117/// Audio content containing audio data and metadata about it.
118#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
119pub struct Audio {
120    pub data: String,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub format: Option<ContentFormat>,
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub media_type: Option<AudioMediaType>,
125}
126
127/// Document content containing document data and metadata about it.
128#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
129pub struct Document {
130    pub data: String,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub format: Option<ContentFormat>,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub media_type: Option<DocumentMediaType>,
135}
136
137/// Describes the format of the content, which can be base64 or string.
138#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
139#[serde(rename_all = "lowercase")]
140pub enum ContentFormat {
141    #[default]
142    Base64,
143    String,
144}
145
146/// Helper enum that tracks the media type of the content.
147#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
148pub enum MediaType {
149    Image(ImageMediaType),
150    Audio(AudioMediaType),
151    Document(DocumentMediaType),
152}
153
154/// Describes the image media type of the content. Not every provider supports every media type.
155/// Convertible to and from MIME type strings.
156#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
157#[serde(rename_all = "lowercase")]
158pub enum ImageMediaType {
159    JPEG,
160    PNG,
161    GIF,
162    WEBP,
163    HEIC,
164    HEIF,
165    SVG,
166}
167
168/// Describes the document media type of the content. Not every provider supports every media type.
169/// Includes also programming languages as document types for providers who support code running.
170/// Convertible to and from MIME type strings.
171#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
172#[serde(rename_all = "lowercase")]
173pub enum DocumentMediaType {
174    PDF,
175    TXT,
176    RTF,
177    HTML,
178    CSS,
179    MARKDOWN,
180    CSV,
181    XML,
182    Javascript,
183    Python,
184}
185
186/// Describes the audio media type of the content. Not every provider supports every media type.
187/// Convertible to and from MIME type strings.
188#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
189#[serde(rename_all = "lowercase")]
190pub enum AudioMediaType {
191    WAV,
192    MP3,
193    AIFF,
194    AAC,
195    OGG,
196    FLAC,
197}
198
199/// Describes the detail of the image content, which can be low, high, or auto (open-ai specific).
200#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
201#[serde(rename_all = "lowercase")]
202pub enum ImageDetail {
203    Low,
204    High,
205    #[default]
206    Auto,
207}
208
209// ================================================================
210// Impl. for message models
211// ================================================================
212
213impl Message {
214    /// This helper method is primarily used to extract the first string prompt from a `Message`.
215    /// Since `Message` might have more than just text content, we need to find the first text.
216    pub(crate) fn rag_text(&self) -> Option<String> {
217        match self {
218            Message::User { content } => {
219                for item in content.iter() {
220                    if let UserContent::Text(Text { text }) = item {
221                        return Some(text.clone());
222                    }
223                }
224                None
225            }
226            _ => None,
227        }
228    }
229
230    /// Helper constructor to make creating user messages easier.
231    pub fn user(text: impl Into<String>) -> Self {
232        Message::User {
233            content: OneOrMany::one(UserContent::text(text)),
234        }
235    }
236
237    /// Helper constructor to make creating assistant messages easier.
238    pub fn assistant(text: impl Into<String>) -> Self {
239        Message::Assistant {
240            id: None,
241            content: OneOrMany::one(AssistantContent::text(text)),
242        }
243    }
244
245    /// Helper constructor to make creating assistant messages easier.
246    pub fn assistant_with_id(id: String, text: impl Into<String>) -> Self {
247        Message::Assistant {
248            id: Some(id),
249            content: OneOrMany::one(AssistantContent::text(text)),
250        }
251    }
252
253    /// Helper constructor to make creating tool result messages easier.
254    pub fn tool_result(id: impl Into<String>, content: impl Into<String>) -> Self {
255        Message::User {
256            content: OneOrMany::one(UserContent::ToolResult(ToolResult {
257                id: id.into(),
258                call_id: None,
259                content: OneOrMany::one(ToolResultContent::text(content)),
260            })),
261        }
262    }
263
264    pub fn tool_result_with_call_id(
265        id: impl Into<String>,
266        call_id: Option<String>,
267        content: impl Into<String>,
268    ) -> Self {
269        Message::User {
270            content: OneOrMany::one(UserContent::ToolResult(ToolResult {
271                id: id.into(),
272                call_id,
273                content: OneOrMany::one(ToolResultContent::text(content)),
274            })),
275        }
276    }
277}
278
279impl UserContent {
280    /// Helper constructor to make creating user text content easier.
281    pub fn text(text: impl Into<String>) -> Self {
282        UserContent::Text(text.into().into())
283    }
284
285    /// Helper constructor to make creating user image content easier.
286    pub fn image(
287        data: impl Into<String>,
288        format: Option<ContentFormat>,
289        media_type: Option<ImageMediaType>,
290        detail: Option<ImageDetail>,
291    ) -> Self {
292        UserContent::Image(Image {
293            data: data.into(),
294            format,
295            media_type,
296            detail,
297        })
298    }
299
300    /// Helper constructor to make creating user audio content easier.
301    pub fn audio(
302        data: impl Into<String>,
303        format: Option<ContentFormat>,
304        media_type: Option<AudioMediaType>,
305    ) -> Self {
306        UserContent::Audio(Audio {
307            data: data.into(),
308            format,
309            media_type,
310        })
311    }
312
313    /// Helper constructor to make creating user document content easier.
314    pub fn document(
315        data: impl Into<String>,
316        format: Option<ContentFormat>,
317        media_type: Option<DocumentMediaType>,
318    ) -> Self {
319        UserContent::Document(Document {
320            data: data.into(),
321            format,
322            media_type,
323        })
324    }
325
326    /// Helper constructor to make creating user tool result content easier.
327    pub fn tool_result(id: impl Into<String>, content: OneOrMany<ToolResultContent>) -> Self {
328        UserContent::ToolResult(ToolResult {
329            id: id.into(),
330            call_id: None,
331            content,
332        })
333    }
334
335    /// Helper constructor to make creating user tool result content easier.
336    pub fn tool_result_with_call_id(
337        id: impl Into<String>,
338        call_id: String,
339        content: OneOrMany<ToolResultContent>,
340    ) -> Self {
341        UserContent::ToolResult(ToolResult {
342            id: id.into(),
343            call_id: Some(call_id),
344            content,
345        })
346    }
347}
348
349impl AssistantContent {
350    /// Helper constructor to make creating assistant text content easier.
351    pub fn text(text: impl Into<String>) -> Self {
352        AssistantContent::Text(text.into().into())
353    }
354
355    /// Helper constructor to make creating assistant tool call content easier.
356    pub fn tool_call(
357        id: impl Into<String>,
358        name: impl Into<String>,
359        arguments: serde_json::Value,
360    ) -> Self {
361        AssistantContent::ToolCall(ToolCall {
362            id: id.into(),
363            call_id: None,
364            function: ToolFunction {
365                name: name.into(),
366                arguments,
367            },
368        })
369    }
370
371    pub fn tool_call_with_call_id(
372        id: impl Into<String>,
373        call_id: String,
374        name: impl Into<String>,
375        arguments: serde_json::Value,
376    ) -> Self {
377        AssistantContent::ToolCall(ToolCall {
378            id: id.into(),
379            call_id: Some(call_id),
380            function: ToolFunction {
381                name: name.into(),
382                arguments,
383            },
384        })
385    }
386}
387
388impl ToolResultContent {
389    /// Helper constructor to make creating tool result text content easier.
390    pub fn text(text: impl Into<String>) -> Self {
391        ToolResultContent::Text(text.into().into())
392    }
393
394    /// Helper constructor to make creating tool result image content easier.
395    pub fn image(
396        data: impl Into<String>,
397        format: Option<ContentFormat>,
398        media_type: Option<ImageMediaType>,
399        detail: Option<ImageDetail>,
400    ) -> Self {
401        ToolResultContent::Image(Image {
402            data: data.into(),
403            format,
404            media_type,
405            detail,
406        })
407    }
408}
409
410/// Trait for converting between MIME types and media types.
411pub trait MimeType {
412    fn from_mime_type(mime_type: &str) -> Option<Self>
413    where
414        Self: Sized;
415    fn to_mime_type(&self) -> &'static str;
416}
417
418impl MimeType for MediaType {
419    fn from_mime_type(mime_type: &str) -> Option<Self> {
420        ImageMediaType::from_mime_type(mime_type)
421            .map(MediaType::Image)
422            .or_else(|| {
423                DocumentMediaType::from_mime_type(mime_type)
424                    .map(MediaType::Document)
425                    .or_else(|| AudioMediaType::from_mime_type(mime_type).map(MediaType::Audio))
426            })
427    }
428
429    fn to_mime_type(&self) -> &'static str {
430        match self {
431            MediaType::Image(media_type) => media_type.to_mime_type(),
432            MediaType::Audio(media_type) => media_type.to_mime_type(),
433            MediaType::Document(media_type) => media_type.to_mime_type(),
434        }
435    }
436}
437
438impl MimeType for ImageMediaType {
439    fn from_mime_type(mime_type: &str) -> Option<Self> {
440        match mime_type {
441            "image/jpeg" => Some(ImageMediaType::JPEG),
442            "image/png" => Some(ImageMediaType::PNG),
443            "image/gif" => Some(ImageMediaType::GIF),
444            "image/webp" => Some(ImageMediaType::WEBP),
445            "image/heic" => Some(ImageMediaType::HEIC),
446            "image/heif" => Some(ImageMediaType::HEIF),
447            "image/svg+xml" => Some(ImageMediaType::SVG),
448            _ => None,
449        }
450    }
451
452    fn to_mime_type(&self) -> &'static str {
453        match self {
454            ImageMediaType::JPEG => "image/jpeg",
455            ImageMediaType::PNG => "image/png",
456            ImageMediaType::GIF => "image/gif",
457            ImageMediaType::WEBP => "image/webp",
458            ImageMediaType::HEIC => "image/heic",
459            ImageMediaType::HEIF => "image/heif",
460            ImageMediaType::SVG => "image/svg+xml",
461        }
462    }
463}
464
465impl MimeType for DocumentMediaType {
466    fn from_mime_type(mime_type: &str) -> Option<Self> {
467        match mime_type {
468            "application/pdf" => Some(DocumentMediaType::PDF),
469            "text/plain" => Some(DocumentMediaType::TXT),
470            "text/rtf" => Some(DocumentMediaType::RTF),
471            "text/html" => Some(DocumentMediaType::HTML),
472            "text/css" => Some(DocumentMediaType::CSS),
473            "text/md" | "text/markdown" => Some(DocumentMediaType::MARKDOWN),
474            "text/csv" => Some(DocumentMediaType::CSV),
475            "text/xml" => Some(DocumentMediaType::XML),
476            "application/x-javascript" | "text/x-javascript" => Some(DocumentMediaType::Javascript),
477            "application/x-python" | "text/x-python" => Some(DocumentMediaType::Python),
478            _ => None,
479        }
480    }
481
482    fn to_mime_type(&self) -> &'static str {
483        match self {
484            DocumentMediaType::PDF => "application/pdf",
485            DocumentMediaType::TXT => "text/plain",
486            DocumentMediaType::RTF => "text/rtf",
487            DocumentMediaType::HTML => "text/html",
488            DocumentMediaType::CSS => "text/css",
489            DocumentMediaType::MARKDOWN => "text/markdown",
490            DocumentMediaType::CSV => "text/csv",
491            DocumentMediaType::XML => "text/xml",
492            DocumentMediaType::Javascript => "application/x-javascript",
493            DocumentMediaType::Python => "application/x-python",
494        }
495    }
496}
497
498impl MimeType for AudioMediaType {
499    fn from_mime_type(mime_type: &str) -> Option<Self> {
500        match mime_type {
501            "audio/wav" => Some(AudioMediaType::WAV),
502            "audio/mp3" => Some(AudioMediaType::MP3),
503            "audio/aiff" => Some(AudioMediaType::AIFF),
504            "audio/aac" => Some(AudioMediaType::AAC),
505            "audio/ogg" => Some(AudioMediaType::OGG),
506            "audio/flac" => Some(AudioMediaType::FLAC),
507            _ => None,
508        }
509    }
510
511    fn to_mime_type(&self) -> &'static str {
512        match self {
513            AudioMediaType::WAV => "audio/wav",
514            AudioMediaType::MP3 => "audio/mp3",
515            AudioMediaType::AIFF => "audio/aiff",
516            AudioMediaType::AAC => "audio/aac",
517            AudioMediaType::OGG => "audio/ogg",
518            AudioMediaType::FLAC => "audio/flac",
519        }
520    }
521}
522
523impl std::str::FromStr for ImageDetail {
524    type Err = ();
525
526    fn from_str(s: &str) -> Result<Self, Self::Err> {
527        match s.to_lowercase().as_str() {
528            "low" => Ok(ImageDetail::Low),
529            "high" => Ok(ImageDetail::High),
530            "auto" => Ok(ImageDetail::Auto),
531            _ => Err(()),
532        }
533    }
534}
535
536// ================================================================
537// FromStr, From<String>, and From<&str> impls
538// ================================================================
539
540impl From<String> for Text {
541    fn from(text: String) -> Self {
542        Text { text }
543    }
544}
545
546impl From<&String> for Text {
547    fn from(text: &String) -> Self {
548        text.to_owned().into()
549    }
550}
551
552impl From<&str> for Text {
553    fn from(text: &str) -> Self {
554        text.to_owned().into()
555    }
556}
557
558impl FromStr for Text {
559    type Err = Infallible;
560
561    fn from_str(s: &str) -> Result<Self, Self::Err> {
562        Ok(s.into())
563    }
564}
565
566impl From<String> for Message {
567    fn from(text: String) -> Self {
568        Message::User {
569            content: OneOrMany::one(UserContent::Text(text.into())),
570        }
571    }
572}
573
574impl From<&str> for Message {
575    fn from(text: &str) -> Self {
576        Message::User {
577            content: OneOrMany::one(UserContent::Text(text.into())),
578        }
579    }
580}
581
582impl From<&String> for Message {
583    fn from(text: &String) -> Self {
584        Message::User {
585            content: OneOrMany::one(UserContent::Text(text.into())),
586        }
587    }
588}
589
590impl From<Text> for Message {
591    fn from(text: Text) -> Self {
592        Message::User {
593            content: OneOrMany::one(UserContent::Text(text)),
594        }
595    }
596}
597
598impl From<Image> for Message {
599    fn from(image: Image) -> Self {
600        Message::User {
601            content: OneOrMany::one(UserContent::Image(image)),
602        }
603    }
604}
605
606impl From<Audio> for Message {
607    fn from(audio: Audio) -> Self {
608        Message::User {
609            content: OneOrMany::one(UserContent::Audio(audio)),
610        }
611    }
612}
613
614impl From<Document> for Message {
615    fn from(document: Document) -> Self {
616        Message::User {
617            content: OneOrMany::one(UserContent::Document(document)),
618        }
619    }
620}
621
622impl From<String> for ToolResultContent {
623    fn from(text: String) -> Self {
624        ToolResultContent::text(text)
625    }
626}
627
628impl From<String> for AssistantContent {
629    fn from(text: String) -> Self {
630        AssistantContent::text(text)
631    }
632}
633
634impl From<String> for UserContent {
635    fn from(text: String) -> Self {
636        UserContent::text(text)
637    }
638}
639
640impl From<AssistantContent> for Message {
641    fn from(content: AssistantContent) -> Self {
642        Message::Assistant {
643            id: None,
644            content: OneOrMany::one(content),
645        }
646    }
647}
648
649impl From<UserContent> for Message {
650    fn from(content: UserContent) -> Self {
651        Message::User {
652            content: OneOrMany::one(content),
653        }
654    }
655}
656
657impl From<OneOrMany<AssistantContent>> for Message {
658    fn from(content: OneOrMany<AssistantContent>) -> Self {
659        Message::Assistant { id: None, content }
660    }
661}
662
663impl From<OneOrMany<UserContent>> for Message {
664    fn from(content: OneOrMany<UserContent>) -> Self {
665        Message::User { content }
666    }
667}
668
669impl From<ToolCall> for Message {
670    fn from(tool_call: ToolCall) -> Self {
671        Message::Assistant {
672            id: None,
673            content: OneOrMany::one(AssistantContent::ToolCall(tool_call)),
674        }
675    }
676}
677
678impl From<ToolResult> for Message {
679    fn from(tool_result: ToolResult) -> Self {
680        Message::User {
681            content: OneOrMany::one(UserContent::ToolResult(tool_result)),
682        }
683    }
684}
685
686impl From<ToolResultContent> for Message {
687    fn from(tool_result_content: ToolResultContent) -> Self {
688        Message::User {
689            content: OneOrMany::one(UserContent::ToolResult(ToolResult {
690                id: String::new(),
691                call_id: None,
692                content: OneOrMany::one(tool_result_content),
693            })),
694        }
695    }
696}
697
698// ================================================================
699// Error types
700// ================================================================
701
702/// Error type to represent issues with converting messages to and from specific provider messages.
703#[derive(Debug, Error)]
704pub enum MessageError {
705    #[error("Message conversion error: {0}")]
706    ConversionError(String),
707}
708
709impl From<MessageError> for CompletionError {
710    fn from(error: MessageError) -> Self {
711        CompletionError::RequestError(error.into())
712    }
713}