ragit_pdl/
message.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
use std::fmt;
use crate::image::ImageType;
use crate::role::Role;
use crate::util::encode_base64;

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Message {
    pub role: Role,
    pub content: Vec<MessageContent>,
}

impl Message {
    pub fn simple_message(role: Role, message: String) -> Self {
        Message {
            role,
            content: vec![MessageContent::String(message)],
        }
    }

    pub fn is_valid_system_prompt(&self) -> bool {
        self.role == Role::System
        && self.content.len() == 1
        && matches!(&self.content[0], MessageContent::String(_))
    }

    pub fn is_user_prompt(&self) -> bool {
        self.role == Role::User
    }

    pub fn is_assistant_prompt(&self) -> bool {
        self.role == Role::Assistant
    }
}

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum MessageContent {
    String(String),
    Image {
        image_type: ImageType,
        bytes: Vec<u8>,
    },
}

impl MessageContent {
    pub fn unwrap_str(&self) -> &str {
        match self {
            MessageContent::String(s) => s.as_str(),
            _ => panic!("{self:?} is not a string"),
        }
    }

    pub fn simple_message(s: String) -> Vec<Self> {
        vec![MessageContent::String(s)]
    }
}

impl fmt::Display for MessageContent {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
        match self {
            MessageContent::String(s) => write!(fmt, "{s}"),
            MessageContent::Image { image_type, bytes } => write!(
                fmt,
                "<|raw_media({}:{})|>",
                image_type.to_extension(),
                encode_base64(bytes),
            ),
        }
    }
}