ai_lib_rust/types/
message.rs1use serde::{Deserialize, Serialize};
4use base64::Engine as _;
5use std::path::Path;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Message {
10 pub role: MessageRole,
11 pub content: MessageContent,
12}
13
14impl Message {
15 pub fn system(text: impl Into<String>) -> Self {
16 Self {
17 role: MessageRole::System,
18 content: MessageContent::Text(text.into()),
19 }
20 }
21
22 pub fn user(text: impl Into<String>) -> Self {
23 Self {
24 role: MessageRole::User,
25 content: MessageContent::Text(text.into()),
26 }
27 }
28
29 pub fn assistant(text: impl Into<String>) -> Self {
30 Self {
31 role: MessageRole::Assistant,
32 content: MessageContent::Text(text.into()),
33 }
34 }
35
36 pub fn with_content(role: MessageRole, content: MessageContent) -> Self {
37 Self { role, content }
38 }
39
40 pub fn contains_image(&self) -> bool {
41 match &self.content {
42 MessageContent::Text(_) => false,
43 MessageContent::Blocks(bs) => bs.iter().any(|b| matches!(b, ContentBlock::Image { .. })),
44 }
45 }
46
47 pub fn contains_audio(&self) -> bool {
48 match &self.content {
49 MessageContent::Text(_) => false,
50 MessageContent::Blocks(bs) => bs.iter().any(|b| matches!(b, ContentBlock::Audio { .. })),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "lowercase")]
58pub enum MessageRole {
59 System,
60 User,
61 Assistant,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(untagged)]
67pub enum MessageContent {
68 Text(String),
69 Blocks(Vec<ContentBlock>),
70}
71
72impl MessageContent {
73 pub fn text(text: impl Into<String>) -> Self {
74 MessageContent::Text(text.into())
75 }
76
77 pub fn blocks(blocks: Vec<ContentBlock>) -> Self {
78 MessageContent::Blocks(blocks)
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(tag = "type")]
85pub enum ContentBlock {
86 #[serde(rename = "text")]
87 Text { text: String },
88 #[serde(rename = "image")]
89 Image { source: ImageSource },
90 #[serde(rename = "audio")]
91 Audio { source: AudioSource },
92 #[serde(rename = "tool_use")]
93 ToolUse {
94 id: String,
95 name: String,
96 input: serde_json::Value,
97 },
98 #[serde(rename = "tool_result")]
99 ToolResult {
100 tool_use_id: String,
101 content: serde_json::Value,
102 },
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ImageSource {
107 #[serde(rename = "type")]
108 pub source_type: String,
109 #[serde(default, skip_serializing_if = "Option::is_none")]
110 pub media_type: Option<String>,
111 pub data: String, }
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct AudioSource {
116 #[serde(rename = "type")]
117 pub source_type: String,
118 #[serde(default, skip_serializing_if = "Option::is_none")]
119 pub media_type: Option<String>,
120 pub data: String, }
122
123impl ContentBlock {
124 pub fn text(text: impl Into<String>) -> Self {
125 ContentBlock::Text { text: text.into() }
126 }
127
128 pub fn image_base64(data: String, media_type: Option<String>) -> Self {
129 ContentBlock::Image {
130 source: ImageSource {
131 source_type: "base64".to_string(),
132 media_type,
133 data,
134 },
135 }
136 }
137
138 pub fn audio_base64(data: String, media_type: Option<String>) -> Self {
139 ContentBlock::Audio {
140 source: AudioSource {
141 source_type: "base64".to_string(),
142 media_type,
143 data,
144 },
145 }
146 }
147
148 pub fn image_from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
149 let path = path.as_ref();
150 let bytes = std::fs::read(path)?;
151 let media_type = guess_media_type(path);
152 let data = base64::engine::general_purpose::STANDARD.encode(bytes);
153 Ok(Self::image_base64(data, media_type))
154 }
155
156 pub fn audio_from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
157 let path = path.as_ref();
158 let bytes = std::fs::read(path)?;
159 let media_type = guess_media_type(path);
160 let data = base64::engine::general_purpose::STANDARD.encode(bytes);
161 Ok(Self::audio_base64(data, media_type))
162 }
163}
164
165fn guess_media_type(path: &Path) -> Option<String> {
166 let ext = path
167 .extension()
168 .and_then(|s| s.to_str())
169 .unwrap_or("")
170 .to_lowercase();
171 let mt = match ext.as_str() {
172 "png" => "image/png",
173 "jpg" | "jpeg" => "image/jpeg",
174 "webp" => "image/webp",
175 "gif" => "image/gif",
176 "mp3" => "audio/mpeg",
177 "wav" => "audio/wav",
178 "ogg" => "audio/ogg",
179 "m4a" => "audio/mp4",
180 _ => return None,
181 };
182 Some(mt.to_string())
183}