rig/completion/
message.rs

1use std::{convert::Infallible, str::FromStr};
2
3use crate::OneOrMany;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use super::CompletionError;
8
9// ================================================================
10// Message models
11// ================================================================
12
13/// A message represents a run of input (user) and output (assistant).
14/// Each message type (based on it's `role`) can contain a atleast one bit of content such as text,
15///  images, audio, documents, or tool related information. While each message type can contain
16///  multiple content, most often, you'll only see one content type per message
17///  (an image w/ a description, etc).
18///
19/// Each provider is responsible with converting the generic message into it's provider specific
20///  type using `From` or `TryFrom` traits. Since not every provider supports every feature, the
21///  conversion can be lossy (providing an image might be discarded for a non-image supporting
22///  provider) though the message being converted back and forth should always be the same.
23#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
24#[serde(tag = "role", rename_all = "lowercase")]
25pub enum Message {
26    /// User message containing one or more content types defined by `UserContent`.
27    User { content: OneOrMany<UserContent> },
28
29    /// Assistant message containing one or more content types defined by `AssistantContent`.
30    Assistant {
31        content: OneOrMany<AssistantContent>,
32    },
33}
34
35/// Describes the content of a message, which can be text, a tool result, an image, audio, or
36///  a document. Dependent on provider supporting the content type. Multimedia content is generally
37///  base64 (defined by it's format) encoded but additionally supports urls (for some providers).
38#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
39#[serde(tag = "type", rename_all = "lowercase")]
40pub enum UserContent {
41    Text(Text),
42    ToolResult(ToolResult),
43    Image(Image),
44    Audio(Audio),
45    Document(Document),
46}
47
48/// Describes responses from a provider which is either text or a tool call.
49#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
50#[serde(untagged)]
51pub enum AssistantContent {
52    Text(Text),
53    ToolCall(ToolCall),
54}
55
56/// Tool result content containing information about a tool call and it's resulting content.
57#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
58pub struct ToolResult {
59    pub id: String,
60    pub content: OneOrMany<ToolResultContent>,
61}
62
63/// Describes the content of a tool result, which can be text or an image.
64#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
65pub enum ToolResultContent {
66    Text(Text),
67    Image(Image),
68}
69
70/// Describes a tool call with an id and function to call, generally produced by a provider.
71#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
72pub struct ToolCall {
73    pub id: String,
74    pub function: ToolFunction,
75}
76
77/// Describes a tool function to call with a name and arguments, generally produced by a provider.
78#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
79pub struct ToolFunction {
80    pub name: String,
81    pub arguments: serde_json::Value,
82}
83
84// ================================================================
85// Base content models
86// ================================================================
87
88/// Basic text content.
89#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
90pub struct Text {
91    pub text: String,
92}
93
94/// Image content containing image data and metadata about it.
95#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
96pub struct Image {
97    pub data: String,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub format: Option<ContentFormat>,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub media_type: Option<ImageMediaType>,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub detail: Option<ImageDetail>,
104}
105
106/// Audio content containing audio data and metadata about it.
107#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
108pub struct Audio {
109    pub data: String,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub format: Option<ContentFormat>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub media_type: Option<AudioMediaType>,
114}
115
116/// Document content containing document data and metadata about it.
117#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
118pub struct Document {
119    pub data: String,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub format: Option<ContentFormat>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub media_type: Option<DocumentMediaType>,
124}
125
126/// Describes the format of the content, which can be base64 or string.
127#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
128#[serde(rename_all = "lowercase")]
129pub enum ContentFormat {
130    #[default]
131    Base64,
132    String,
133}
134
135/// Helper enum that tracks the media type of the content.
136#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
137pub enum MediaType {
138    Image(ImageMediaType),
139    Audio(AudioMediaType),
140    Document(DocumentMediaType),
141}
142
143/// Describes the image media type of the content. Not every provider supports every media type.
144/// Convertible to and from MIME type strings.
145#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
146#[serde(rename_all = "lowercase")]
147pub enum ImageMediaType {
148    JPEG,
149    PNG,
150    GIF,
151    WEBP,
152    HEIC,
153    HEIF,
154    SVG,
155}
156
157/// Describes the document media type of the content. Not every provider supports every media type.
158/// Includes also programming languages as document types for providers who support code running.
159/// Convertible to and from MIME type strings.
160#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
161#[serde(rename_all = "lowercase")]
162pub enum DocumentMediaType {
163    PDF,
164    TXT,
165    RTF,
166    HTML,
167    CSS,
168    MARKDOWN,
169    CSV,
170    XML,
171    Javascript,
172    Python,
173}
174
175/// Describes the audio media type of the content. Not every provider supports every media type.
176/// Convertible to and from MIME type strings.
177#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
178#[serde(rename_all = "lowercase")]
179pub enum AudioMediaType {
180    WAV,
181    MP3,
182    AIFF,
183    AAC,
184    OGG,
185    FLAC,
186}
187
188/// Describes the detail of the image content, which can be low, high, or auto (open-ai specific).
189#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
190#[serde(rename_all = "lowercase")]
191pub enum ImageDetail {
192    Low,
193    High,
194    #[default]
195    Auto,
196}
197
198// ================================================================
199// Impl. for message models
200// ================================================================
201
202impl Message {
203    /// This helper method is primarily used to extract the first string prompt from a `Message`.
204    /// Since `Message` might have more than just text content, we need to find the first text.
205    pub(crate) fn rag_text(&self) -> Option<String> {
206        match self {
207            Message::User { content } => {
208                for item in content.iter() {
209                    if let UserContent::Text(Text { text }) = item {
210                        return Some(text.clone());
211                    }
212                }
213                None
214            }
215            _ => None,
216        }
217    }
218
219    /// Helper constructor to make creating user messages easier.
220    pub fn user(text: impl Into<String>) -> Self {
221        Message::User {
222            content: OneOrMany::one(UserContent::text(text)),
223        }
224    }
225
226    /// Helper constructor to make creating assistant messages easier.
227    pub fn assistant(text: impl Into<String>) -> Self {
228        Message::Assistant {
229            content: OneOrMany::one(AssistantContent::text(text)),
230        }
231    }
232}
233
234impl UserContent {
235    /// Helper constructor to make creating user text content easier.
236    pub fn text(text: impl Into<String>) -> Self {
237        UserContent::Text(text.into().into())
238    }
239
240    /// Helper constructor to make creating user image content easier.
241    pub fn image(
242        data: impl Into<String>,
243        format: Option<ContentFormat>,
244        media_type: Option<ImageMediaType>,
245        detail: Option<ImageDetail>,
246    ) -> Self {
247        UserContent::Image(Image {
248            data: data.into(),
249            format,
250            media_type,
251            detail,
252        })
253    }
254
255    /// Helper constructor to make creating user audio content easier.
256    pub fn audio(
257        data: impl Into<String>,
258        format: Option<ContentFormat>,
259        media_type: Option<AudioMediaType>,
260    ) -> Self {
261        UserContent::Audio(Audio {
262            data: data.into(),
263            format,
264            media_type,
265        })
266    }
267
268    /// Helper constructor to make creating user document content easier.
269    pub fn document(
270        data: impl Into<String>,
271        format: Option<ContentFormat>,
272        media_type: Option<DocumentMediaType>,
273    ) -> Self {
274        UserContent::Document(Document {
275            data: data.into(),
276            format,
277            media_type,
278        })
279    }
280
281    /// Helper constructor to make creating user tool result content easier.
282    pub fn tool_result(id: impl Into<String>, content: OneOrMany<ToolResultContent>) -> Self {
283        UserContent::ToolResult(ToolResult {
284            id: id.into(),
285            content,
286        })
287    }
288}
289
290impl AssistantContent {
291    /// Helper constructor to make creating assistant text content easier.
292    pub fn text(text: impl Into<String>) -> Self {
293        AssistantContent::Text(text.into().into())
294    }
295
296    /// Helper constructor to make creating assistant tool call content easier.
297    pub fn tool_call(
298        id: impl Into<String>,
299        name: impl Into<String>,
300        arguments: serde_json::Value,
301    ) -> Self {
302        AssistantContent::ToolCall(ToolCall {
303            id: id.into(),
304            function: ToolFunction {
305                name: name.into(),
306                arguments,
307            },
308        })
309    }
310}
311
312impl ToolResultContent {
313    /// Helper constructor to make creating tool result text content easier.
314    pub fn text(text: impl Into<String>) -> Self {
315        ToolResultContent::Text(text.into().into())
316    }
317
318    /// Helper constructor to make creating tool result image content easier.
319    pub fn image(
320        data: impl Into<String>,
321        format: Option<ContentFormat>,
322        media_type: Option<ImageMediaType>,
323        detail: Option<ImageDetail>,
324    ) -> Self {
325        ToolResultContent::Image(Image {
326            data: data.into(),
327            format,
328            media_type,
329            detail,
330        })
331    }
332}
333
334/// Trait for converting between MIME types and media types.
335pub trait MimeType {
336    fn from_mime_type(mime_type: &str) -> Option<Self>
337    where
338        Self: Sized;
339    fn to_mime_type(&self) -> &'static str;
340}
341
342impl MimeType for MediaType {
343    fn from_mime_type(mime_type: &str) -> Option<Self> {
344        ImageMediaType::from_mime_type(mime_type)
345            .map(MediaType::Image)
346            .or_else(|| {
347                DocumentMediaType::from_mime_type(mime_type)
348                    .map(MediaType::Document)
349                    .or_else(|| AudioMediaType::from_mime_type(mime_type).map(MediaType::Audio))
350            })
351    }
352
353    fn to_mime_type(&self) -> &'static str {
354        match self {
355            MediaType::Image(media_type) => media_type.to_mime_type(),
356            MediaType::Audio(media_type) => media_type.to_mime_type(),
357            MediaType::Document(media_type) => media_type.to_mime_type(),
358        }
359    }
360}
361
362impl MimeType for ImageMediaType {
363    fn from_mime_type(mime_type: &str) -> Option<Self> {
364        match mime_type {
365            "image/jpeg" => Some(ImageMediaType::JPEG),
366            "image/png" => Some(ImageMediaType::PNG),
367            "image/gif" => Some(ImageMediaType::GIF),
368            "image/webp" => Some(ImageMediaType::WEBP),
369            "image/heic" => Some(ImageMediaType::HEIC),
370            "image/heif" => Some(ImageMediaType::HEIF),
371            "image/svg+xml" => Some(ImageMediaType::SVG),
372            _ => None,
373        }
374    }
375
376    fn to_mime_type(&self) -> &'static str {
377        match self {
378            ImageMediaType::JPEG => "image/jpeg",
379            ImageMediaType::PNG => "image/png",
380            ImageMediaType::GIF => "image/gif",
381            ImageMediaType::WEBP => "image/webp",
382            ImageMediaType::HEIC => "image/heic",
383            ImageMediaType::HEIF => "image/heif",
384            ImageMediaType::SVG => "image/svg+xml",
385        }
386    }
387}
388
389impl MimeType for DocumentMediaType {
390    fn from_mime_type(mime_type: &str) -> Option<Self> {
391        match mime_type {
392            "application/pdf" => Some(DocumentMediaType::PDF),
393            "text/plain" => Some(DocumentMediaType::TXT),
394            "text/rtf" => Some(DocumentMediaType::RTF),
395            "text/html" => Some(DocumentMediaType::HTML),
396            "text/css" => Some(DocumentMediaType::CSS),
397            "text/md" | "text/markdown" => Some(DocumentMediaType::MARKDOWN),
398            "text/csv" => Some(DocumentMediaType::CSV),
399            "text/xml" => Some(DocumentMediaType::XML),
400            "application/x-javascript" | "text/x-javascript" => Some(DocumentMediaType::Javascript),
401            "application/x-python" | "text/x-python" => Some(DocumentMediaType::Python),
402            _ => None,
403        }
404    }
405
406    fn to_mime_type(&self) -> &'static str {
407        match self {
408            DocumentMediaType::PDF => "application/pdf",
409            DocumentMediaType::TXT => "text/plain",
410            DocumentMediaType::RTF => "text/rtf",
411            DocumentMediaType::HTML => "text/html",
412            DocumentMediaType::CSS => "text/css",
413            DocumentMediaType::MARKDOWN => "text/markdown",
414            DocumentMediaType::CSV => "text/csv",
415            DocumentMediaType::XML => "text/xml",
416            DocumentMediaType::Javascript => "application/x-javascript",
417            DocumentMediaType::Python => "application/x-python",
418        }
419    }
420}
421
422impl MimeType for AudioMediaType {
423    fn from_mime_type(mime_type: &str) -> Option<Self> {
424        match mime_type {
425            "audio/wav" => Some(AudioMediaType::WAV),
426            "audio/mp3" => Some(AudioMediaType::MP3),
427            "audio/aiff" => Some(AudioMediaType::AIFF),
428            "audio/aac" => Some(AudioMediaType::AAC),
429            "audio/ogg" => Some(AudioMediaType::OGG),
430            "audio/flac" => Some(AudioMediaType::FLAC),
431            _ => None,
432        }
433    }
434
435    fn to_mime_type(&self) -> &'static str {
436        match self {
437            AudioMediaType::WAV => "audio/wav",
438            AudioMediaType::MP3 => "audio/mp3",
439            AudioMediaType::AIFF => "audio/aiff",
440            AudioMediaType::AAC => "audio/aac",
441            AudioMediaType::OGG => "audio/ogg",
442            AudioMediaType::FLAC => "audio/flac",
443        }
444    }
445}
446
447impl std::str::FromStr for ImageDetail {
448    type Err = ();
449
450    fn from_str(s: &str) -> Result<Self, Self::Err> {
451        match s.to_lowercase().as_str() {
452            "low" => Ok(ImageDetail::Low),
453            "high" => Ok(ImageDetail::High),
454            "auto" => Ok(ImageDetail::Auto),
455            _ => Err(()),
456        }
457    }
458}
459
460// ================================================================
461// FromStr, From<String>, and From<&str> impls
462// ================================================================
463
464impl From<String> for Text {
465    fn from(text: String) -> Self {
466        Text { text }
467    }
468}
469
470impl From<&str> for Text {
471    fn from(text: &str) -> Self {
472        text.to_owned().into()
473    }
474}
475
476impl FromStr for Text {
477    type Err = Infallible;
478
479    fn from_str(s: &str) -> Result<Self, Self::Err> {
480        Ok(s.into())
481    }
482}
483
484impl From<String> for Message {
485    fn from(text: String) -> Self {
486        Message::User {
487            content: OneOrMany::one(UserContent::Text(text.into())),
488        }
489    }
490}
491
492impl From<&str> for Message {
493    fn from(text: &str) -> Self {
494        Message::User {
495            content: OneOrMany::one(UserContent::Text(text.into())),
496        }
497    }
498}
499
500impl From<Text> for Message {
501    fn from(text: Text) -> Self {
502        Message::User {
503            content: OneOrMany::one(UserContent::Text(text)),
504        }
505    }
506}
507
508impl From<Image> for Message {
509    fn from(image: Image) -> Self {
510        Message::User {
511            content: OneOrMany::one(UserContent::Image(image)),
512        }
513    }
514}
515
516impl From<Audio> for Message {
517    fn from(audio: Audio) -> Self {
518        Message::User {
519            content: OneOrMany::one(UserContent::Audio(audio)),
520        }
521    }
522}
523
524impl From<Document> for Message {
525    fn from(document: Document) -> Self {
526        Message::User {
527            content: OneOrMany::one(UserContent::Document(document)),
528        }
529    }
530}
531
532impl From<String> for ToolResultContent {
533    fn from(text: String) -> Self {
534        ToolResultContent::text(text)
535    }
536}
537
538impl From<String> for AssistantContent {
539    fn from(text: String) -> Self {
540        AssistantContent::text(text)
541    }
542}
543
544impl From<String> for UserContent {
545    fn from(text: String) -> Self {
546        UserContent::text(text)
547    }
548}
549
550// ================================================================
551// Error types
552// ================================================================
553
554/// Error type to represent issues with converting messages to and from specific provider messages.
555#[derive(Debug, Error)]
556pub enum MessageError {
557    #[error("Message conversion error: {0}")]
558    ConversionError(String),
559}
560
561impl From<MessageError> for CompletionError {
562    fn from(error: MessageError) -> Self {
563        CompletionError::RequestError(error.into())
564    }
565}