1use std::{convert::Infallible, str::FromStr};
2
3use crate::OneOrMany;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use super::CompletionError;
8
9#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
24#[serde(tag = "role", rename_all = "lowercase")]
25pub enum Message {
26 User { content: OneOrMany<UserContent> },
28
29 Assistant {
31 content: OneOrMany<AssistantContent>,
32 },
33}
34
35#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
39#[serde(tag = "type", rename_all = "lowercase")]
40pub enum UserContent {
41 Text(Text),
42 ToolResult(ToolResult),
43 Image(Image),
44 Audio(Audio),
45 Document(Document),
46}
47
48#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
50#[serde(untagged)]
51pub enum AssistantContent {
52 Text(Text),
53 ToolCall(ToolCall),
54}
55
56#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
58pub struct ToolResult {
59 pub id: String,
60 pub content: OneOrMany<ToolResultContent>,
61}
62
63#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
65pub enum ToolResultContent {
66 Text(Text),
67 Image(Image),
68}
69
70#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
72pub struct ToolCall {
73 pub id: String,
74 pub function: ToolFunction,
75}
76
77#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
79pub struct ToolFunction {
80 pub name: String,
81 pub arguments: serde_json::Value,
82}
83
84#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
90pub struct Text {
91 pub text: String,
92}
93
94#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
96pub struct Image {
97 pub data: String,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub format: Option<ContentFormat>,
100 #[serde(skip_serializing_if = "Option::is_none")]
101 pub media_type: Option<ImageMediaType>,
102 #[serde(skip_serializing_if = "Option::is_none")]
103 pub detail: Option<ImageDetail>,
104}
105
106#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
108pub struct Audio {
109 pub data: String,
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub format: Option<ContentFormat>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 pub media_type: Option<AudioMediaType>,
114}
115
116#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
118pub struct Document {
119 pub data: String,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 pub format: Option<ContentFormat>,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 pub media_type: Option<DocumentMediaType>,
124}
125
126#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
128#[serde(rename_all = "lowercase")]
129pub enum ContentFormat {
130 #[default]
131 Base64,
132 String,
133}
134
135#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
137pub enum MediaType {
138 Image(ImageMediaType),
139 Audio(AudioMediaType),
140 Document(DocumentMediaType),
141}
142
143#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
146#[serde(rename_all = "lowercase")]
147pub enum ImageMediaType {
148 JPEG,
149 PNG,
150 GIF,
151 WEBP,
152 HEIC,
153 HEIF,
154 SVG,
155}
156
157#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
161#[serde(rename_all = "lowercase")]
162pub enum DocumentMediaType {
163 PDF,
164 TXT,
165 RTF,
166 HTML,
167 CSS,
168 MARKDOWN,
169 CSV,
170 XML,
171 Javascript,
172 Python,
173}
174
175#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
178#[serde(rename_all = "lowercase")]
179pub enum AudioMediaType {
180 WAV,
181 MP3,
182 AIFF,
183 AAC,
184 OGG,
185 FLAC,
186}
187
188#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)]
190#[serde(rename_all = "lowercase")]
191pub enum ImageDetail {
192 Low,
193 High,
194 #[default]
195 Auto,
196}
197
198impl Message {
203 pub(crate) fn rag_text(&self) -> Option<String> {
206 match self {
207 Message::User { content } => {
208 for item in content.iter() {
209 if let UserContent::Text(Text { text }) = item {
210 return Some(text.clone());
211 }
212 }
213 None
214 }
215 _ => None,
216 }
217 }
218
219 pub fn user(text: impl Into<String>) -> Self {
221 Message::User {
222 content: OneOrMany::one(UserContent::text(text)),
223 }
224 }
225
226 pub fn assistant(text: impl Into<String>) -> Self {
228 Message::Assistant {
229 content: OneOrMany::one(AssistantContent::text(text)),
230 }
231 }
232}
233
234impl UserContent {
235 pub fn text(text: impl Into<String>) -> Self {
237 UserContent::Text(text.into().into())
238 }
239
240 pub fn image(
242 data: impl Into<String>,
243 format: Option<ContentFormat>,
244 media_type: Option<ImageMediaType>,
245 detail: Option<ImageDetail>,
246 ) -> Self {
247 UserContent::Image(Image {
248 data: data.into(),
249 format,
250 media_type,
251 detail,
252 })
253 }
254
255 pub fn audio(
257 data: impl Into<String>,
258 format: Option<ContentFormat>,
259 media_type: Option<AudioMediaType>,
260 ) -> Self {
261 UserContent::Audio(Audio {
262 data: data.into(),
263 format,
264 media_type,
265 })
266 }
267
268 pub fn document(
270 data: impl Into<String>,
271 format: Option<ContentFormat>,
272 media_type: Option<DocumentMediaType>,
273 ) -> Self {
274 UserContent::Document(Document {
275 data: data.into(),
276 format,
277 media_type,
278 })
279 }
280
281 pub fn tool_result(id: impl Into<String>, content: OneOrMany<ToolResultContent>) -> Self {
283 UserContent::ToolResult(ToolResult {
284 id: id.into(),
285 content,
286 })
287 }
288}
289
290impl AssistantContent {
291 pub fn text(text: impl Into<String>) -> Self {
293 AssistantContent::Text(text.into().into())
294 }
295
296 pub fn tool_call(
298 id: impl Into<String>,
299 name: impl Into<String>,
300 arguments: serde_json::Value,
301 ) -> Self {
302 AssistantContent::ToolCall(ToolCall {
303 id: id.into(),
304 function: ToolFunction {
305 name: name.into(),
306 arguments,
307 },
308 })
309 }
310}
311
312impl ToolResultContent {
313 pub fn text(text: impl Into<String>) -> Self {
315 ToolResultContent::Text(text.into().into())
316 }
317
318 pub fn image(
320 data: impl Into<String>,
321 format: Option<ContentFormat>,
322 media_type: Option<ImageMediaType>,
323 detail: Option<ImageDetail>,
324 ) -> Self {
325 ToolResultContent::Image(Image {
326 data: data.into(),
327 format,
328 media_type,
329 detail,
330 })
331 }
332}
333
334pub trait MimeType {
336 fn from_mime_type(mime_type: &str) -> Option<Self>
337 where
338 Self: Sized;
339 fn to_mime_type(&self) -> &'static str;
340}
341
342impl MimeType for MediaType {
343 fn from_mime_type(mime_type: &str) -> Option<Self> {
344 ImageMediaType::from_mime_type(mime_type)
345 .map(MediaType::Image)
346 .or_else(|| {
347 DocumentMediaType::from_mime_type(mime_type)
348 .map(MediaType::Document)
349 .or_else(|| AudioMediaType::from_mime_type(mime_type).map(MediaType::Audio))
350 })
351 }
352
353 fn to_mime_type(&self) -> &'static str {
354 match self {
355 MediaType::Image(media_type) => media_type.to_mime_type(),
356 MediaType::Audio(media_type) => media_type.to_mime_type(),
357 MediaType::Document(media_type) => media_type.to_mime_type(),
358 }
359 }
360}
361
362impl MimeType for ImageMediaType {
363 fn from_mime_type(mime_type: &str) -> Option<Self> {
364 match mime_type {
365 "image/jpeg" => Some(ImageMediaType::JPEG),
366 "image/png" => Some(ImageMediaType::PNG),
367 "image/gif" => Some(ImageMediaType::GIF),
368 "image/webp" => Some(ImageMediaType::WEBP),
369 "image/heic" => Some(ImageMediaType::HEIC),
370 "image/heif" => Some(ImageMediaType::HEIF),
371 "image/svg+xml" => Some(ImageMediaType::SVG),
372 _ => None,
373 }
374 }
375
376 fn to_mime_type(&self) -> &'static str {
377 match self {
378 ImageMediaType::JPEG => "image/jpeg",
379 ImageMediaType::PNG => "image/png",
380 ImageMediaType::GIF => "image/gif",
381 ImageMediaType::WEBP => "image/webp",
382 ImageMediaType::HEIC => "image/heic",
383 ImageMediaType::HEIF => "image/heif",
384 ImageMediaType::SVG => "image/svg+xml",
385 }
386 }
387}
388
389impl MimeType for DocumentMediaType {
390 fn from_mime_type(mime_type: &str) -> Option<Self> {
391 match mime_type {
392 "application/pdf" => Some(DocumentMediaType::PDF),
393 "text/plain" => Some(DocumentMediaType::TXT),
394 "text/rtf" => Some(DocumentMediaType::RTF),
395 "text/html" => Some(DocumentMediaType::HTML),
396 "text/css" => Some(DocumentMediaType::CSS),
397 "text/md" | "text/markdown" => Some(DocumentMediaType::MARKDOWN),
398 "text/csv" => Some(DocumentMediaType::CSV),
399 "text/xml" => Some(DocumentMediaType::XML),
400 "application/x-javascript" | "text/x-javascript" => Some(DocumentMediaType::Javascript),
401 "application/x-python" | "text/x-python" => Some(DocumentMediaType::Python),
402 _ => None,
403 }
404 }
405
406 fn to_mime_type(&self) -> &'static str {
407 match self {
408 DocumentMediaType::PDF => "application/pdf",
409 DocumentMediaType::TXT => "text/plain",
410 DocumentMediaType::RTF => "text/rtf",
411 DocumentMediaType::HTML => "text/html",
412 DocumentMediaType::CSS => "text/css",
413 DocumentMediaType::MARKDOWN => "text/markdown",
414 DocumentMediaType::CSV => "text/csv",
415 DocumentMediaType::XML => "text/xml",
416 DocumentMediaType::Javascript => "application/x-javascript",
417 DocumentMediaType::Python => "application/x-python",
418 }
419 }
420}
421
422impl MimeType for AudioMediaType {
423 fn from_mime_type(mime_type: &str) -> Option<Self> {
424 match mime_type {
425 "audio/wav" => Some(AudioMediaType::WAV),
426 "audio/mp3" => Some(AudioMediaType::MP3),
427 "audio/aiff" => Some(AudioMediaType::AIFF),
428 "audio/aac" => Some(AudioMediaType::AAC),
429 "audio/ogg" => Some(AudioMediaType::OGG),
430 "audio/flac" => Some(AudioMediaType::FLAC),
431 _ => None,
432 }
433 }
434
435 fn to_mime_type(&self) -> &'static str {
436 match self {
437 AudioMediaType::WAV => "audio/wav",
438 AudioMediaType::MP3 => "audio/mp3",
439 AudioMediaType::AIFF => "audio/aiff",
440 AudioMediaType::AAC => "audio/aac",
441 AudioMediaType::OGG => "audio/ogg",
442 AudioMediaType::FLAC => "audio/flac",
443 }
444 }
445}
446
447impl std::str::FromStr for ImageDetail {
448 type Err = ();
449
450 fn from_str(s: &str) -> Result<Self, Self::Err> {
451 match s.to_lowercase().as_str() {
452 "low" => Ok(ImageDetail::Low),
453 "high" => Ok(ImageDetail::High),
454 "auto" => Ok(ImageDetail::Auto),
455 _ => Err(()),
456 }
457 }
458}
459
460impl From<String> for Text {
465 fn from(text: String) -> Self {
466 Text { text }
467 }
468}
469
470impl From<&str> for Text {
471 fn from(text: &str) -> Self {
472 text.to_owned().into()
473 }
474}
475
476impl FromStr for Text {
477 type Err = Infallible;
478
479 fn from_str(s: &str) -> Result<Self, Self::Err> {
480 Ok(s.into())
481 }
482}
483
484impl From<String> for Message {
485 fn from(text: String) -> Self {
486 Message::User {
487 content: OneOrMany::one(UserContent::Text(text.into())),
488 }
489 }
490}
491
492impl From<&str> for Message {
493 fn from(text: &str) -> Self {
494 Message::User {
495 content: OneOrMany::one(UserContent::Text(text.into())),
496 }
497 }
498}
499
500impl From<Text> for Message {
501 fn from(text: Text) -> Self {
502 Message::User {
503 content: OneOrMany::one(UserContent::Text(text)),
504 }
505 }
506}
507
508impl From<Image> for Message {
509 fn from(image: Image) -> Self {
510 Message::User {
511 content: OneOrMany::one(UserContent::Image(image)),
512 }
513 }
514}
515
516impl From<Audio> for Message {
517 fn from(audio: Audio) -> Self {
518 Message::User {
519 content: OneOrMany::one(UserContent::Audio(audio)),
520 }
521 }
522}
523
524impl From<Document> for Message {
525 fn from(document: Document) -> Self {
526 Message::User {
527 content: OneOrMany::one(UserContent::Document(document)),
528 }
529 }
530}
531
532impl From<String> for ToolResultContent {
533 fn from(text: String) -> Self {
534 ToolResultContent::text(text)
535 }
536}
537
538impl From<String> for AssistantContent {
539 fn from(text: String) -> Self {
540 AssistantContent::text(text)
541 }
542}
543
544impl From<String> for UserContent {
545 fn from(text: String) -> Self {
546 UserContent::text(text)
547 }
548}
549
550#[derive(Debug, Error)]
556pub enum MessageError {
557 #[error("Message conversion error: {0}")]
558 ConversionError(String),
559}
560
561impl From<MessageError> for CompletionError {
562 fn from(error: MessageError) -> Self {
563 CompletionError::RequestError(error.into())
564 }
565}