1use serde::{Deserialize, Serialize};
4
5use super::ContentBlock;
6use super::document::DocumentBlock;
7use super::search::SearchResultBlock;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 User,
15 Assistant,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Message {
22 pub role: Role,
24 pub content: Vec<ContentBlock>,
26}
27
28impl Message {
29 pub fn user(text: impl Into<String>) -> Self {
30 Self {
31 role: Role::User,
32 content: vec![ContentBlock::text(text)],
33 }
34 }
35
36 pub fn assistant(text: impl Into<String>) -> Self {
37 Self {
38 role: Role::Assistant,
39 content: vec![ContentBlock::text(text)],
40 }
41 }
42
43 pub fn tool_results(results: Vec<super::ToolResultBlock>) -> Self {
44 Self {
45 role: Role::User,
46 content: results.into_iter().map(ContentBlock::ToolResult).collect(),
47 }
48 }
49
50 pub fn user_with_content(content: Vec<ContentBlock>) -> Self {
51 Self {
52 role: Role::User,
53 content,
54 }
55 }
56
57 pub fn user_with_document(text: impl Into<String>, doc: DocumentBlock) -> Self {
58 Self {
59 role: Role::User,
60 content: vec![ContentBlock::Document(doc), ContentBlock::text(text)],
61 }
62 }
63
64 pub fn user_with_documents(text: impl Into<String>, docs: Vec<DocumentBlock>) -> Self {
65 let mut content: Vec<ContentBlock> = docs.into_iter().map(ContentBlock::Document).collect();
66 content.push(ContentBlock::text(text));
67 Self {
68 role: Role::User,
69 content,
70 }
71 }
72
73 pub fn user_with_search_results(
74 text: impl Into<String>,
75 results: Vec<SearchResultBlock>,
76 ) -> Self {
77 let mut content: Vec<ContentBlock> = results
78 .into_iter()
79 .map(ContentBlock::SearchResult)
80 .collect();
81 content.push(ContentBlock::text(text));
82 Self {
83 role: Role::User,
84 content,
85 }
86 }
87
88 pub fn text(&self) -> String {
89 self.content
90 .iter()
91 .filter_map(|block| block.as_text())
92 .collect::<Vec<_>>()
93 .join("")
94 }
95
96 pub fn has_tool_use(&self) -> bool {
97 self.content
98 .iter()
99 .any(|block| matches!(block, ContentBlock::ToolUse { .. }))
100 }
101
102 pub fn tool_uses(&self) -> Vec<&super::ToolUseBlock> {
103 self.content
104 .iter()
105 .filter_map(|block| match block {
106 ContentBlock::ToolUse(tool_use) => Some(tool_use),
107 _ => None,
108 })
109 .collect()
110 }
111
112 pub fn documents(&self) -> Vec<&DocumentBlock> {
113 self.content
114 .iter()
115 .filter_map(|block| block.as_document())
116 .collect()
117 }
118
119 pub fn search_results(&self) -> Vec<&SearchResultBlock> {
120 self.content
121 .iter()
122 .filter_map(|block| block.as_search_result())
123 .collect()
124 }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129#[serde(untagged)]
130pub enum SystemPrompt {
131 Text(String),
133 Blocks(Vec<SystemBlock>),
135}
136
137impl Default for SystemPrompt {
138 fn default() -> Self {
139 Self::Text(String::new())
140 }
141}
142
143impl SystemPrompt {
144 pub fn is_empty(&self) -> bool {
145 match self {
146 Self::Text(s) => s.is_empty(),
147 Self::Blocks(b) => b.is_empty(),
148 }
149 }
150
151 pub fn as_text(&self) -> String {
152 match self {
153 Self::Text(s) => s.clone(),
154 Self::Blocks(b) => b
155 .iter()
156 .map(|block| block.text.as_str())
157 .collect::<Vec<_>>()
158 .join("\n\n"),
159 }
160 }
161}
162
163impl std::fmt::Display for SystemPrompt {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 write!(f, "{}", self.as_text())
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct SystemBlock {
172 #[serde(rename = "type")]
174 pub block_type: String,
175 pub text: String,
177 #[serde(skip_serializing_if = "Option::is_none")]
179 pub cache_control: Option<CacheControl>,
180}
181
182impl SystemBlock {
183 pub fn cached(text: impl Into<String>) -> Self {
185 Self {
186 block_type: "text".to_string(),
187 text: text.into(),
188 cache_control: Some(CacheControl::ephemeral()),
189 }
190 }
191
192 pub fn uncached(text: impl Into<String>) -> Self {
194 Self {
195 block_type: "text".to_string(),
196 text: text.into(),
197 cache_control: None,
198 }
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
204pub struct CacheControl {
205 #[serde(rename = "type")]
206 pub cache_type: CacheType,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub ttl: Option<CacheTtl>,
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
212#[serde(rename_all = "snake_case")]
213pub enum CacheType {
214 Ephemeral,
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq)]
218pub enum CacheTtl {
219 FiveMinutes,
220 OneHour,
221}
222
223impl Serialize for CacheTtl {
224 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
225 where
226 S: serde::Serializer,
227 {
228 match self {
229 CacheTtl::FiveMinutes => serializer.serialize_str("5m"),
230 CacheTtl::OneHour => serializer.serialize_str("1h"),
231 }
232 }
233}
234
235impl<'de> Deserialize<'de> for CacheTtl {
236 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
237 where
238 D: serde::Deserializer<'de>,
239 {
240 let s = String::deserialize(deserializer)?;
241 match s.as_str() {
242 "5m" => Ok(CacheTtl::FiveMinutes),
243 "1h" => Ok(CacheTtl::OneHour),
244 _ => Err(serde::de::Error::custom(format!("unknown TTL: {}", s))),
245 }
246 }
247}
248
249impl CacheControl {
250 pub fn ephemeral() -> Self {
251 Self {
252 cache_type: CacheType::Ephemeral,
253 ttl: None,
254 }
255 }
256
257 pub fn ephemeral_5m() -> Self {
258 Self {
259 cache_type: CacheType::Ephemeral,
260 ttl: Some(CacheTtl::FiveMinutes),
261 }
262 }
263
264 pub fn ephemeral_1h() -> Self {
265 Self {
266 cache_type: CacheType::Ephemeral,
267 ttl: Some(CacheTtl::OneHour),
268 }
269 }
270
271 pub fn with_ttl(mut self, ttl: CacheTtl) -> Self {
272 self.ttl = Some(ttl);
273 self
274 }
275}
276
277impl SystemPrompt {
278 pub fn text(prompt: impl Into<String>) -> Self {
280 Self::Text(prompt.into())
281 }
282
283 pub fn cached(prompt: impl Into<String>) -> Self {
285 Self::Blocks(vec![SystemBlock {
286 block_type: "text".to_string(),
287 text: prompt.into(),
288 cache_control: Some(CacheControl {
289 cache_type: CacheType::Ephemeral,
290 ttl: None,
291 }),
292 }])
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_user_message() {
302 let msg = Message::user("Hello");
303 assert_eq!(msg.role, Role::User);
304 assert_eq!(msg.text(), "Hello");
305 }
306
307 #[test]
308 fn test_assistant_message() {
309 let msg = Message::assistant("Hi there!");
310 assert_eq!(msg.role, Role::Assistant);
311 assert_eq!(msg.text(), "Hi there!");
312 }
313}