rig/completion/
message.rs

1use std::{convert::Infallible, str::FromStr};
2
3use crate::OneOrMany;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use super::CompletionError;
8
9// ================================================================
10// Message models
11// ================================================================
12
13/// A message represents a run of input (user) and output (assistant).
14/// Each message type (based on it's `role`) can contain a atleast one bit of content such as text,
15///  images, audio, documents, or tool related information. While each message type can contain
16///  multiple content, most often, you'll only see one content type per message
17///  (an image w/ a description, etc).
18///
19/// Each provider is responsible with converting the generic message into it's provider specific
20///  type using `From` or `TryFrom` traits. Since not every provider supports every feature, the
21///  conversion can be lossy (providing an image might be discarded for a non-image supporting
22///  provider) though the message being converted back and forth should always be the same.
23#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
24#[serde(tag = "role", rename_all = "lowercase")]
25pub enum Message {
26    /// User message containing one or more content types defined by `UserContent`.
27    User { content: OneOrMany<UserContent> },
28
29    /// Assistant message containing one or more content types defined by `AssistantContent`.
30    Assistant {
31        id: Option<String>,
32        content: OneOrMany<AssistantContent>,
33    },
34}
35
36/// Describes the content of a message, which can be text, a tool result, an image, audio, or
37///  a document. Dependent on provider supporting the content type. Multimedia content is generally
38///  base64 (defined by it's format) encoded but additionally supports urls (for some providers).
39#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
40#[serde(tag = "type", rename_all = "lowercase")]
41pub enum UserContent {
42    Text(Text),
43    ToolResult(ToolResult),
44    Image(Image),
45    Audio(Audio),
46    Document(Document),
47}
48
49/// Describes responses from a provider which is either text or a tool call.
50#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
51#[serde(untagged)]
52pub enum AssistantContent {
53    Text(Text),
54    ToolCall(ToolCall),
55}
56
57/// Tool result content containing information about a tool call and it's resulting content.
58#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
59pub struct ToolResult {
60    pub id: String,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub call_id: Option<String>,
63    pub content: OneOrMany<ToolResultContent>,
64}
65
66/// Describes the content of a tool result, which can be text or an image.
67#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
68pub enum ToolResultContent {
69    Text(Text),
70    Image(Image),
71}
72
73/// Describes a tool call with an id and function to call, generally produced by a provider.
74#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
75pub struct ToolCall {
76    pub id: String,
77    pub call_id: Option<String>,
78    pub function: ToolFunction,
79}
80
81/// Describes a tool function to call with a name and arguments, generally produced by a provider.
82#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
83pub struct ToolFunction {
84    pub name: String,
85    pub arguments: serde_json::Value,
86}
87
88// ================================================================
89// Base content models
90// ================================================================
91
92/// Basic text content.
93#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
94pub struct Text {
95    pub text: String,
96}
97
98/// Image content containing image data and metadata about it.
99#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
100pub struct Image {
101    pub data: String,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub format: Option<ContentFormat>,
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub media_type: Option<ImageMediaType>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub detail: Option<ImageDetail>,
108}
109
110/// Audio content containing audio data and metadata about it.
111#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
112pub struct Audio {
113    pub data: String,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub format: Option<ContentFormat>,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub media_type: Option<AudioMediaType>,
118}
119
120/// Document content containing document data and metadata about it.
121#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
122pub struct Document {
123    pub data: String,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub format: Option<ContentFormat>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub media_type: Option<DocumentMediaType>,
128}
129
130/// Describes the format of the content, which can be base64 or string.
131#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
132#[serde(rename_all = "lowercase")]
133pub enum ContentFormat {
134    #[default]
135    Base64,
136    String,
137}
138
139/// Helper enum that tracks the media type of the content.
140#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
141pub enum MediaType {
142    Image(ImageMediaType),
143    Audio(AudioMediaType),
144    Document(DocumentMediaType),
145}
146
147/// Describes the image media type of the content. Not every provider supports every media type.
148/// Convertible to and from MIME type strings.
149#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
150#[serde(rename_all = "lowercase")]
151pub enum ImageMediaType {
152    JPEG,
153    PNG,
154    GIF,
155    WEBP,
156    HEIC,
157    HEIF,
158    SVG,
159}
160
161/// Describes the document media type of the content. Not every provider supports every media type.
162/// Includes also programming languages as document types for providers who support code running.
163/// Convertible to and from MIME type strings.
164#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
165#[serde(rename_all = "lowercase")]
166pub enum DocumentMediaType {
167    PDF,
168    TXT,
169    RTF,
170    HTML,
171    CSS,
172    MARKDOWN,
173    CSV,
174    XML,
175    Javascript,
176    Python,
177}
178
179/// Describes the audio media type of the content. Not every provider supports every media type.
180/// Convertible to and from MIME type strings.
181#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
182#[serde(rename_all = "lowercase")]
183pub enum AudioMediaType {
184    WAV,
185    MP3,
186    AIFF,
187    AAC,
188    OGG,
189    FLAC,
190}
191
192/// Describes the detail of the image content, which can be low, high, or auto (open-ai specific).
193#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
194#[serde(rename_all = "lowercase")]
195pub enum ImageDetail {
196    Low,
197    High,
198    #[default]
199    Auto,
200}
201
202// ================================================================
203// Impl. for message models
204// ================================================================
205
206impl Message {
207    /// This helper method is primarily used to extract the first string prompt from a `Message`.
208    /// Since `Message` might have more than just text content, we need to find the first text.
209    pub(crate) fn rag_text(&self) -> Option<String> {
210        match self {
211            Message::User { content } => {
212                for item in content.iter() {
213                    if let UserContent::Text(Text { text }) = item {
214                        return Some(text.clone());
215                    }
216                }
217                None
218            }
219            _ => None,
220        }
221    }
222
223    /// Helper constructor to make creating user messages easier.
224    pub fn user(text: impl Into<String>) -> Self {
225        Message::User {
226            content: OneOrMany::one(UserContent::text(text)),
227        }
228    }
229
230    /// Helper constructor to make creating assistant messages easier.
231    pub fn assistant(text: impl Into<String>) -> Self {
232        Message::Assistant {
233            id: None,
234            content: OneOrMany::one(AssistantContent::text(text)),
235        }
236    }
237
238    /// Helper constructor to make creating assistant messages easier.
239    pub fn assistant_with_id(id: String, text: impl Into<String>) -> Self {
240        Message::Assistant {
241            id: Some(id),
242            content: OneOrMany::one(AssistantContent::text(text)),
243        }
244    }
245
246    /// Helper constructor to make creating tool result messages easier.
247    pub fn tool_result(id: impl Into<String>, content: impl Into<String>) -> Self {
248        Message::User {
249            content: OneOrMany::one(UserContent::ToolResult(ToolResult {
250                id: id.into(),
251                call_id: None,
252                content: OneOrMany::one(ToolResultContent::text(content)),
253            })),
254        }
255    }
256
257    pub fn tool_result_with_call_id(
258        id: impl Into<String>,
259        call_id: Option<String>,
260        content: impl Into<String>,
261    ) -> Self {
262        Message::User {
263            content: OneOrMany::one(UserContent::ToolResult(ToolResult {
264                id: id.into(),
265                call_id,
266                content: OneOrMany::one(ToolResultContent::text(content)),
267            })),
268        }
269    }
270}
271
272impl UserContent {
273    /// Helper constructor to make creating user text content easier.
274    pub fn text(text: impl Into<String>) -> Self {
275        UserContent::Text(text.into().into())
276    }
277
278    /// Helper constructor to make creating user image content easier.
279    pub fn image(
280        data: impl Into<String>,
281        format: Option<ContentFormat>,
282        media_type: Option<ImageMediaType>,
283        detail: Option<ImageDetail>,
284    ) -> Self {
285        UserContent::Image(Image {
286            data: data.into(),
287            format,
288            media_type,
289            detail,
290        })
291    }
292
293    /// Helper constructor to make creating user audio content easier.
294    pub fn audio(
295        data: impl Into<String>,
296        format: Option<ContentFormat>,
297        media_type: Option<AudioMediaType>,
298    ) -> Self {
299        UserContent::Audio(Audio {
300            data: data.into(),
301            format,
302            media_type,
303        })
304    }
305
306    /// Helper constructor to make creating user document content easier.
307    pub fn document(
308        data: impl Into<String>,
309        format: Option<ContentFormat>,
310        media_type: Option<DocumentMediaType>,
311    ) -> Self {
312        UserContent::Document(Document {
313            data: data.into(),
314            format,
315            media_type,
316        })
317    }
318
319    /// Helper constructor to make creating user tool result content easier.
320    pub fn tool_result(id: impl Into<String>, content: OneOrMany<ToolResultContent>) -> Self {
321        UserContent::ToolResult(ToolResult {
322            id: id.into(),
323            call_id: None,
324            content,
325        })
326    }
327
328    /// Helper constructor to make creating user tool result content easier.
329    pub fn tool_result_with_call_id(
330        id: impl Into<String>,
331        call_id: String,
332        content: OneOrMany<ToolResultContent>,
333    ) -> Self {
334        UserContent::ToolResult(ToolResult {
335            id: id.into(),
336            call_id: Some(call_id),
337            content,
338        })
339    }
340}
341
342impl AssistantContent {
343    /// Helper constructor to make creating assistant text content easier.
344    pub fn text(text: impl Into<String>) -> Self {
345        AssistantContent::Text(text.into().into())
346    }
347
348    /// Helper constructor to make creating assistant tool call content easier.
349    pub fn tool_call(
350        id: impl Into<String>,
351        name: impl Into<String>,
352        arguments: serde_json::Value,
353    ) -> Self {
354        AssistantContent::ToolCall(ToolCall {
355            id: id.into(),
356            call_id: None,
357            function: ToolFunction {
358                name: name.into(),
359                arguments,
360            },
361        })
362    }
363
364    pub fn tool_call_with_call_id(
365        id: impl Into<String>,
366        call_id: String,
367        name: impl Into<String>,
368        arguments: serde_json::Value,
369    ) -> Self {
370        AssistantContent::ToolCall(ToolCall {
371            id: id.into(),
372            call_id: Some(call_id),
373            function: ToolFunction {
374                name: name.into(),
375                arguments,
376            },
377        })
378    }
379}
380
381impl ToolResultContent {
382    /// Helper constructor to make creating tool result text content easier.
383    pub fn text(text: impl Into<String>) -> Self {
384        ToolResultContent::Text(text.into().into())
385    }
386
387    /// Helper constructor to make creating tool result image content easier.
388    pub fn image(
389        data: impl Into<String>,
390        format: Option<ContentFormat>,
391        media_type: Option<ImageMediaType>,
392        detail: Option<ImageDetail>,
393    ) -> Self {
394        ToolResultContent::Image(Image {
395            data: data.into(),
396            format,
397            media_type,
398            detail,
399        })
400    }
401}
402
403/// Trait for converting between MIME types and media types.
404pub trait MimeType {
405    fn from_mime_type(mime_type: &str) -> Option<Self>
406    where
407        Self: Sized;
408    fn to_mime_type(&self) -> &'static str;
409}
410
411impl MimeType for MediaType {
412    fn from_mime_type(mime_type: &str) -> Option<Self> {
413        ImageMediaType::from_mime_type(mime_type)
414            .map(MediaType::Image)
415            .or_else(|| {
416                DocumentMediaType::from_mime_type(mime_type)
417                    .map(MediaType::Document)
418                    .or_else(|| AudioMediaType::from_mime_type(mime_type).map(MediaType::Audio))
419            })
420    }
421
422    fn to_mime_type(&self) -> &'static str {
423        match self {
424            MediaType::Image(media_type) => media_type.to_mime_type(),
425            MediaType::Audio(media_type) => media_type.to_mime_type(),
426            MediaType::Document(media_type) => media_type.to_mime_type(),
427        }
428    }
429}
430
431impl MimeType for ImageMediaType {
432    fn from_mime_type(mime_type: &str) -> Option<Self> {
433        match mime_type {
434            "image/jpeg" => Some(ImageMediaType::JPEG),
435            "image/png" => Some(ImageMediaType::PNG),
436            "image/gif" => Some(ImageMediaType::GIF),
437            "image/webp" => Some(ImageMediaType::WEBP),
438            "image/heic" => Some(ImageMediaType::HEIC),
439            "image/heif" => Some(ImageMediaType::HEIF),
440            "image/svg+xml" => Some(ImageMediaType::SVG),
441            _ => None,
442        }
443    }
444
445    fn to_mime_type(&self) -> &'static str {
446        match self {
447            ImageMediaType::JPEG => "image/jpeg",
448            ImageMediaType::PNG => "image/png",
449            ImageMediaType::GIF => "image/gif",
450            ImageMediaType::WEBP => "image/webp",
451            ImageMediaType::HEIC => "image/heic",
452            ImageMediaType::HEIF => "image/heif",
453            ImageMediaType::SVG => "image/svg+xml",
454        }
455    }
456}
457
458impl MimeType for DocumentMediaType {
459    fn from_mime_type(mime_type: &str) -> Option<Self> {
460        match mime_type {
461            "application/pdf" => Some(DocumentMediaType::PDF),
462            "text/plain" => Some(DocumentMediaType::TXT),
463            "text/rtf" => Some(DocumentMediaType::RTF),
464            "text/html" => Some(DocumentMediaType::HTML),
465            "text/css" => Some(DocumentMediaType::CSS),
466            "text/md" | "text/markdown" => Some(DocumentMediaType::MARKDOWN),
467            "text/csv" => Some(DocumentMediaType::CSV),
468            "text/xml" => Some(DocumentMediaType::XML),
469            "application/x-javascript" | "text/x-javascript" => Some(DocumentMediaType::Javascript),
470            "application/x-python" | "text/x-python" => Some(DocumentMediaType::Python),
471            _ => None,
472        }
473    }
474
475    fn to_mime_type(&self) -> &'static str {
476        match self {
477            DocumentMediaType::PDF => "application/pdf",
478            DocumentMediaType::TXT => "text/plain",
479            DocumentMediaType::RTF => "text/rtf",
480            DocumentMediaType::HTML => "text/html",
481            DocumentMediaType::CSS => "text/css",
482            DocumentMediaType::MARKDOWN => "text/markdown",
483            DocumentMediaType::CSV => "text/csv",
484            DocumentMediaType::XML => "text/xml",
485            DocumentMediaType::Javascript => "application/x-javascript",
486            DocumentMediaType::Python => "application/x-python",
487        }
488    }
489}
490
491impl MimeType for AudioMediaType {
492    fn from_mime_type(mime_type: &str) -> Option<Self> {
493        match mime_type {
494            "audio/wav" => Some(AudioMediaType::WAV),
495            "audio/mp3" => Some(AudioMediaType::MP3),
496            "audio/aiff" => Some(AudioMediaType::AIFF),
497            "audio/aac" => Some(AudioMediaType::AAC),
498            "audio/ogg" => Some(AudioMediaType::OGG),
499            "audio/flac" => Some(AudioMediaType::FLAC),
500            _ => None,
501        }
502    }
503
504    fn to_mime_type(&self) -> &'static str {
505        match self {
506            AudioMediaType::WAV => "audio/wav",
507            AudioMediaType::MP3 => "audio/mp3",
508            AudioMediaType::AIFF => "audio/aiff",
509            AudioMediaType::AAC => "audio/aac",
510            AudioMediaType::OGG => "audio/ogg",
511            AudioMediaType::FLAC => "audio/flac",
512        }
513    }
514}
515
516impl std::str::FromStr for ImageDetail {
517    type Err = ();
518
519    fn from_str(s: &str) -> Result<Self, Self::Err> {
520        match s.to_lowercase().as_str() {
521            "low" => Ok(ImageDetail::Low),
522            "high" => Ok(ImageDetail::High),
523            "auto" => Ok(ImageDetail::Auto),
524            _ => Err(()),
525        }
526    }
527}
528
529// ================================================================
530// FromStr, From<String>, and From<&str> impls
531// ================================================================
532
533impl From<String> for Text {
534    fn from(text: String) -> Self {
535        Text { text }
536    }
537}
538
539impl From<&String> for Text {
540    fn from(text: &String) -> Self {
541        text.to_owned().into()
542    }
543}
544
545impl From<&str> for Text {
546    fn from(text: &str) -> Self {
547        text.to_owned().into()
548    }
549}
550
551impl FromStr for Text {
552    type Err = Infallible;
553
554    fn from_str(s: &str) -> Result<Self, Self::Err> {
555        Ok(s.into())
556    }
557}
558
559impl From<String> for Message {
560    fn from(text: String) -> Self {
561        Message::User {
562            content: OneOrMany::one(UserContent::Text(text.into())),
563        }
564    }
565}
566
567impl From<&str> for Message {
568    fn from(text: &str) -> Self {
569        Message::User {
570            content: OneOrMany::one(UserContent::Text(text.into())),
571        }
572    }
573}
574
575impl From<&String> for Message {
576    fn from(text: &String) -> Self {
577        Message::User {
578            content: OneOrMany::one(UserContent::Text(text.into())),
579        }
580    }
581}
582
583impl From<Text> for Message {
584    fn from(text: Text) -> Self {
585        Message::User {
586            content: OneOrMany::one(UserContent::Text(text)),
587        }
588    }
589}
590
591impl From<Image> for Message {
592    fn from(image: Image) -> Self {
593        Message::User {
594            content: OneOrMany::one(UserContent::Image(image)),
595        }
596    }
597}
598
599impl From<Audio> for Message {
600    fn from(audio: Audio) -> Self {
601        Message::User {
602            content: OneOrMany::one(UserContent::Audio(audio)),
603        }
604    }
605}
606
607impl From<Document> for Message {
608    fn from(document: Document) -> Self {
609        Message::User {
610            content: OneOrMany::one(UserContent::Document(document)),
611        }
612    }
613}
614
615impl From<String> for ToolResultContent {
616    fn from(text: String) -> Self {
617        ToolResultContent::text(text)
618    }
619}
620
621impl From<String> for AssistantContent {
622    fn from(text: String) -> Self {
623        AssistantContent::text(text)
624    }
625}
626
627impl From<String> for UserContent {
628    fn from(text: String) -> Self {
629        UserContent::text(text)
630    }
631}
632
633impl From<AssistantContent> for Message {
634    fn from(content: AssistantContent) -> Self {
635        Message::Assistant {
636            id: None,
637            content: OneOrMany::one(content),
638        }
639    }
640}
641
642impl From<UserContent> for Message {
643    fn from(content: UserContent) -> Self {
644        Message::User {
645            content: OneOrMany::one(content),
646        }
647    }
648}
649
650impl From<OneOrMany<AssistantContent>> for Message {
651    fn from(content: OneOrMany<AssistantContent>) -> Self {
652        Message::Assistant { id: None, content }
653    }
654}
655
656impl From<OneOrMany<UserContent>> for Message {
657    fn from(content: OneOrMany<UserContent>) -> Self {
658        Message::User { content }
659    }
660}
661
662impl From<ToolCall> for Message {
663    fn from(tool_call: ToolCall) -> Self {
664        Message::Assistant {
665            id: None,
666            content: OneOrMany::one(AssistantContent::ToolCall(tool_call)),
667        }
668    }
669}
670
671impl From<ToolResult> for Message {
672    fn from(tool_result: ToolResult) -> Self {
673        Message::User {
674            content: OneOrMany::one(UserContent::ToolResult(tool_result)),
675        }
676    }
677}
678
679impl From<ToolResultContent> for Message {
680    fn from(tool_result_content: ToolResultContent) -> Self {
681        Message::User {
682            content: OneOrMany::one(UserContent::ToolResult(ToolResult {
683                id: String::new(),
684                call_id: None,
685                content: OneOrMany::one(tool_result_content),
686            })),
687        }
688    }
689}
690
691// ================================================================
692// Error types
693// ================================================================
694
695/// Error type to represent issues with converting messages to and from specific provider messages.
696#[derive(Debug, Error)]
697pub enum MessageError {
698    #[error("Message conversion error: {0}")]
699    ConversionError(String),
700}
701
702impl From<MessageError> for CompletionError {
703    fn from(error: MessageError) -> Self {
704        CompletionError::RequestError(error.into())
705    }
706}