1use std::{convert::Infallible, str::FromStr};
2
3use crate::OneOrMany;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use super::CompletionError;
8
9#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
24#[serde(tag = "role", rename_all = "lowercase")]
25pub enum Message {
26    User { content: OneOrMany<UserContent> },
28
29    Assistant {
31        content: OneOrMany<AssistantContent>,
32    },
33}
34
35#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
39#[serde(tag = "type", rename_all = "lowercase")]
40pub enum UserContent {
41    Text(Text),
42    ToolResult(ToolResult),
43    Image(Image),
44    Audio(Audio),
45    Document(Document),
46}
47
48#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
50#[serde(untagged)]
51pub enum AssistantContent {
52    Text(Text),
53    ToolCall(ToolCall),
54}
55
56#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
58pub struct ToolResult {
59    pub id: String,
60    pub content: OneOrMany<ToolResultContent>,
61}
62
63#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
65pub enum ToolResultContent {
66    Text(Text),
67    Image(Image),
68}
69
70#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
72pub struct ToolCall {
73    pub id: String,
74    pub function: ToolFunction,
75}
76
77#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
79pub struct ToolFunction {
80    pub name: String,
81    pub arguments: serde_json::Value,
82}
83
84#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
90pub struct Text {
91    pub text: String,
92}
93
94#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
96pub struct Image {
97    pub data: String,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub format: Option<ContentFormat>,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub media_type: Option<ImageMediaType>,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub detail: Option<ImageDetail>,
104}
105
106#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
108pub struct Audio {
109    pub data: String,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub format: Option<ContentFormat>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub media_type: Option<AudioMediaType>,
114}
115
116#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
118pub struct Document {
119    pub data: String,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub format: Option<ContentFormat>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub media_type: Option<DocumentMediaType>,
124}
125
126#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
128#[serde(rename_all = "lowercase")]
129pub enum ContentFormat {
130    #[default]
131    Base64,
132    String,
133}
134
135#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
137pub enum MediaType {
138    Image(ImageMediaType),
139    Audio(AudioMediaType),
140    Document(DocumentMediaType),
141}
142
143#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
146#[serde(rename_all = "lowercase")]
147pub enum ImageMediaType {
148    JPEG,
149    PNG,
150    GIF,
151    WEBP,
152    HEIC,
153    HEIF,
154    SVG,
155}
156
157#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
161#[serde(rename_all = "lowercase")]
162pub enum DocumentMediaType {
163    PDF,
164    TXT,
165    RTF,
166    HTML,
167    CSS,
168    MARKDOWN,
169    CSV,
170    XML,
171    Javascript,
172    Python,
173}
174
175#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
178#[serde(rename_all = "lowercase")]
179pub enum AudioMediaType {
180    WAV,
181    MP3,
182    AIFF,
183    AAC,
184    OGG,
185    FLAC,
186}
187
188#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
190#[serde(rename_all = "lowercase")]
191pub enum ImageDetail {
192    Low,
193    High,
194    #[default]
195    Auto,
196}
197
198impl Message {
203    pub(crate) fn rag_text(&self) -> Option<String> {
206        match self {
207            Message::User { content } => {
208                for item in content.iter() {
209                    if let UserContent::Text(Text { text }) = item {
210                        return Some(text.clone());
211                    }
212                }
213                None
214            }
215            _ => None,
216        }
217    }
218
219    pub fn user(text: impl Into<String>) -> Self {
221        Message::User {
222            content: OneOrMany::one(UserContent::text(text)),
223        }
224    }
225
226    pub fn assistant(text: impl Into<String>) -> Self {
228        Message::Assistant {
229            content: OneOrMany::one(AssistantContent::text(text)),
230        }
231    }
232}
233
234impl UserContent {
235    pub fn text(text: impl Into<String>) -> Self {
237        UserContent::Text(text.into().into())
238    }
239
240    pub fn image(
242        data: impl Into<String>,
243        format: Option<ContentFormat>,
244        media_type: Option<ImageMediaType>,
245        detail: Option<ImageDetail>,
246    ) -> Self {
247        UserContent::Image(Image {
248            data: data.into(),
249            format,
250            media_type,
251            detail,
252        })
253    }
254
255    pub fn audio(
257        data: impl Into<String>,
258        format: Option<ContentFormat>,
259        media_type: Option<AudioMediaType>,
260    ) -> Self {
261        UserContent::Audio(Audio {
262            data: data.into(),
263            format,
264            media_type,
265        })
266    }
267
268    pub fn document(
270        data: impl Into<String>,
271        format: Option<ContentFormat>,
272        media_type: Option<DocumentMediaType>,
273    ) -> Self {
274        UserContent::Document(Document {
275            data: data.into(),
276            format,
277            media_type,
278        })
279    }
280
281    pub fn tool_result(id: impl Into<String>, content: OneOrMany<ToolResultContent>) -> Self {
283        UserContent::ToolResult(ToolResult {
284            id: id.into(),
285            content,
286        })
287    }
288}
289
290impl AssistantContent {
291    pub fn text(text: impl Into<String>) -> Self {
293        AssistantContent::Text(text.into().into())
294    }
295
296    pub fn tool_call(
298        id: impl Into<String>,
299        name: impl Into<String>,
300        arguments: serde_json::Value,
301    ) -> Self {
302        AssistantContent::ToolCall(ToolCall {
303            id: id.into(),
304            function: ToolFunction {
305                name: name.into(),
306                arguments,
307            },
308        })
309    }
310}
311
312impl ToolResultContent {
313    pub fn text(text: impl Into<String>) -> Self {
315        ToolResultContent::Text(text.into().into())
316    }
317
318    pub fn image(
320        data: impl Into<String>,
321        format: Option<ContentFormat>,
322        media_type: Option<ImageMediaType>,
323        detail: Option<ImageDetail>,
324    ) -> Self {
325        ToolResultContent::Image(Image {
326            data: data.into(),
327            format,
328            media_type,
329            detail,
330        })
331    }
332}
333
334pub trait MimeType {
336    fn from_mime_type(mime_type: &str) -> Option<Self>
337    where
338        Self: Sized;
339    fn to_mime_type(&self) -> &'static str;
340}
341
342impl MimeType for MediaType {
343    fn from_mime_type(mime_type: &str) -> Option<Self> {
344        ImageMediaType::from_mime_type(mime_type)
345            .map(MediaType::Image)
346            .or_else(|| {
347                DocumentMediaType::from_mime_type(mime_type)
348                    .map(MediaType::Document)
349                    .or_else(|| AudioMediaType::from_mime_type(mime_type).map(MediaType::Audio))
350            })
351    }
352
353    fn to_mime_type(&self) -> &'static str {
354        match self {
355            MediaType::Image(media_type) => media_type.to_mime_type(),
356            MediaType::Audio(media_type) => media_type.to_mime_type(),
357            MediaType::Document(media_type) => media_type.to_mime_type(),
358        }
359    }
360}
361
362impl MimeType for ImageMediaType {
363    fn from_mime_type(mime_type: &str) -> Option<Self> {
364        match mime_type {
365            "image/jpeg" => Some(ImageMediaType::JPEG),
366            "image/png" => Some(ImageMediaType::PNG),
367            "image/gif" => Some(ImageMediaType::GIF),
368            "image/webp" => Some(ImageMediaType::WEBP),
369            "image/heic" => Some(ImageMediaType::HEIC),
370            "image/heif" => Some(ImageMediaType::HEIF),
371            "image/svg+xml" => Some(ImageMediaType::SVG),
372            _ => None,
373        }
374    }
375
376    fn to_mime_type(&self) -> &'static str {
377        match self {
378            ImageMediaType::JPEG => "image/jpeg",
379            ImageMediaType::PNG => "image/png",
380            ImageMediaType::GIF => "image/gif",
381            ImageMediaType::WEBP => "image/webp",
382            ImageMediaType::HEIC => "image/heic",
383            ImageMediaType::HEIF => "image/heif",
384            ImageMediaType::SVG => "image/svg+xml",
385        }
386    }
387}
388
389impl MimeType for DocumentMediaType {
390    fn from_mime_type(mime_type: &str) -> Option<Self> {
391        match mime_type {
392            "application/pdf" => Some(DocumentMediaType::PDF),
393            "text/plain" => Some(DocumentMediaType::TXT),
394            "text/rtf" => Some(DocumentMediaType::RTF),
395            "text/html" => Some(DocumentMediaType::HTML),
396            "text/css" => Some(DocumentMediaType::CSS),
397            "text/md" | "text/markdown" => Some(DocumentMediaType::MARKDOWN),
398            "text/csv" => Some(DocumentMediaType::CSV),
399            "text/xml" => Some(DocumentMediaType::XML),
400            "application/x-javascript" | "text/x-javascript" => Some(DocumentMediaType::Javascript),
401            "application/x-python" | "text/x-python" => Some(DocumentMediaType::Python),
402            _ => None,
403        }
404    }
405
406    fn to_mime_type(&self) -> &'static str {
407        match self {
408            DocumentMediaType::PDF => "application/pdf",
409            DocumentMediaType::TXT => "text/plain",
410            DocumentMediaType::RTF => "text/rtf",
411            DocumentMediaType::HTML => "text/html",
412            DocumentMediaType::CSS => "text/css",
413            DocumentMediaType::MARKDOWN => "text/markdown",
414            DocumentMediaType::CSV => "text/csv",
415            DocumentMediaType::XML => "text/xml",
416            DocumentMediaType::Javascript => "application/x-javascript",
417            DocumentMediaType::Python => "application/x-python",
418        }
419    }
420}
421
422impl MimeType for AudioMediaType {
423    fn from_mime_type(mime_type: &str) -> Option<Self> {
424        match mime_type {
425            "audio/wav" => Some(AudioMediaType::WAV),
426            "audio/mp3" => Some(AudioMediaType::MP3),
427            "audio/aiff" => Some(AudioMediaType::AIFF),
428            "audio/aac" => Some(AudioMediaType::AAC),
429            "audio/ogg" => Some(AudioMediaType::OGG),
430            "audio/flac" => Some(AudioMediaType::FLAC),
431            _ => None,
432        }
433    }
434
435    fn to_mime_type(&self) -> &'static str {
436        match self {
437            AudioMediaType::WAV => "audio/wav",
438            AudioMediaType::MP3 => "audio/mp3",
439            AudioMediaType::AIFF => "audio/aiff",
440            AudioMediaType::AAC => "audio/aac",
441            AudioMediaType::OGG => "audio/ogg",
442            AudioMediaType::FLAC => "audio/flac",
443        }
444    }
445}
446
447impl std::str::FromStr for ImageDetail {
448    type Err = ();
449
450    fn from_str(s: &str) -> Result<Self, Self::Err> {
451        match s.to_lowercase().as_str() {
452            "low" => Ok(ImageDetail::Low),
453            "high" => Ok(ImageDetail::High),
454            "auto" => Ok(ImageDetail::Auto),
455            _ => Err(()),
456        }
457    }
458}
459
460impl From<String> for Text {
465    fn from(text: String) -> Self {
466        Text { text }
467    }
468}
469
470impl From<&str> for Text {
471    fn from(text: &str) -> Self {
472        text.to_owned().into()
473    }
474}
475
476impl FromStr for Text {
477    type Err = Infallible;
478
479    fn from_str(s: &str) -> Result<Self, Self::Err> {
480        Ok(s.into())
481    }
482}
483
484impl From<String> for Message {
485    fn from(text: String) -> Self {
486        Message::User {
487            content: OneOrMany::one(UserContent::Text(text.into())),
488        }
489    }
490}
491
492impl From<&str> for Message {
493    fn from(text: &str) -> Self {
494        Message::User {
495            content: OneOrMany::one(UserContent::Text(text.into())),
496        }
497    }
498}
499
500impl From<Text> for Message {
501    fn from(text: Text) -> Self {
502        Message::User {
503            content: OneOrMany::one(UserContent::Text(text)),
504        }
505    }
506}
507
508impl From<Image> for Message {
509    fn from(image: Image) -> Self {
510        Message::User {
511            content: OneOrMany::one(UserContent::Image(image)),
512        }
513    }
514}
515
516impl From<Audio> for Message {
517    fn from(audio: Audio) -> Self {
518        Message::User {
519            content: OneOrMany::one(UserContent::Audio(audio)),
520        }
521    }
522}
523
524impl From<Document> for Message {
525    fn from(document: Document) -> Self {
526        Message::User {
527            content: OneOrMany::one(UserContent::Document(document)),
528        }
529    }
530}
531
532impl From<String> for ToolResultContent {
533    fn from(text: String) -> Self {
534        ToolResultContent::text(text)
535    }
536}
537
538impl From<String> for AssistantContent {
539    fn from(text: String) -> Self {
540        AssistantContent::text(text)
541    }
542}
543
544impl From<String> for UserContent {
545    fn from(text: String) -> Self {
546        UserContent::text(text)
547    }
548}
549
550#[derive(Debug, Error)]
556pub enum MessageError {
557    #[error("Message conversion error: {0}")]
558    ConversionError(String),
559}
560
561impl From<MessageError> for CompletionError {
562    fn from(error: MessageError) -> Self {
563        CompletionError::RequestError(error.into())
564    }
565}