1use serde::{Deserialize, Serialize};
4
5use super::ContentBlock;
6use super::document::DocumentBlock;
7use super::search::SearchResultBlock;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 User,
15 Assistant,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Message {
22 pub role: Role,
24 pub content: Vec<ContentBlock>,
26}
27
28impl Message {
29 pub fn user(text: impl Into<String>) -> Self {
30 Self {
31 role: Role::User,
32 content: vec![ContentBlock::text(text)],
33 }
34 }
35
36 pub fn assistant(text: impl Into<String>) -> Self {
37 Self {
38 role: Role::Assistant,
39 content: vec![ContentBlock::text(text)],
40 }
41 }
42
43 pub fn tool_results(results: Vec<super::ToolResultBlock>) -> Self {
44 Self {
45 role: Role::User,
46 content: results.into_iter().map(ContentBlock::ToolResult).collect(),
47 }
48 }
49
50 pub fn user_with_content(content: Vec<ContentBlock>) -> Self {
51 Self {
52 role: Role::User,
53 content,
54 }
55 }
56
57 pub fn user_with_document(text: impl Into<String>, doc: DocumentBlock) -> Self {
58 Self {
59 role: Role::User,
60 content: vec![ContentBlock::Document(doc), ContentBlock::text(text)],
61 }
62 }
63
64 pub fn user_with_documents(text: impl Into<String>, docs: Vec<DocumentBlock>) -> Self {
65 let mut content: Vec<ContentBlock> = docs.into_iter().map(ContentBlock::Document).collect();
66 content.push(ContentBlock::text(text));
67 Self {
68 role: Role::User,
69 content,
70 }
71 }
72
73 pub fn user_with_search_results(
74 text: impl Into<String>,
75 results: Vec<SearchResultBlock>,
76 ) -> Self {
77 let mut content: Vec<ContentBlock> = results
78 .into_iter()
79 .map(ContentBlock::SearchResult)
80 .collect();
81 content.push(ContentBlock::text(text));
82 Self {
83 role: Role::User,
84 content,
85 }
86 }
87
88 pub fn text(&self) -> String {
89 self.content
90 .iter()
91 .filter_map(|block| block.as_text())
92 .collect::<Vec<_>>()
93 .join("")
94 }
95
96 pub fn has_tool_use(&self) -> bool {
97 self.content
98 .iter()
99 .any(|block| matches!(block, ContentBlock::ToolUse { .. }))
100 }
101
102 pub fn tool_uses(&self) -> Vec<&super::ToolUseBlock> {
103 self.content
104 .iter()
105 .filter_map(|block| match block {
106 ContentBlock::ToolUse(tool_use) => Some(tool_use),
107 _ => None,
108 })
109 .collect()
110 }
111
112 pub fn documents(&self) -> Vec<&DocumentBlock> {
113 self.content
114 .iter()
115 .filter_map(|block| block.as_document())
116 .collect()
117 }
118
119 pub fn search_results(&self) -> Vec<&SearchResultBlock> {
120 self.content
121 .iter()
122 .filter_map(|block| block.as_search_result())
123 .collect()
124 }
125
126 pub fn with_cache_on_last_block(mut self) -> Self {
127 if let Some(last) = self.content.pop() {
128 self.content
129 .push(last.with_cache_control(CacheControl::ephemeral()));
130 }
131 self
132 }
133
134 pub fn set_cache_on_last_block(&mut self, cache: CacheControl) {
135 if let Some(last) = self.content.last_mut() {
136 last.set_cache_control(Some(cache));
137 }
138 }
139
140 pub fn clear_cache_control(&mut self) {
141 for block in &mut self.content {
142 block.set_cache_control(None);
143 }
144 }
145
146 pub fn has_cache_control(&self) -> bool {
147 self.content.iter().any(|b| b.is_cached())
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153#[serde(untagged)]
154pub enum SystemPrompt {
155 Text(String),
157 Blocks(Vec<SystemBlock>),
159}
160
161impl Default for SystemPrompt {
162 fn default() -> Self {
163 Self::Text(String::new())
164 }
165}
166
167impl SystemPrompt {
168 pub fn is_empty(&self) -> bool {
169 match self {
170 Self::Text(s) => s.is_empty(),
171 Self::Blocks(b) => b.is_empty(),
172 }
173 }
174
175 pub fn as_text(&self) -> String {
176 match self {
177 Self::Text(s) => s.clone(),
178 Self::Blocks(b) => b
179 .iter()
180 .map(|block| block.text.as_str())
181 .collect::<Vec<_>>()
182 .join("\n\n"),
183 }
184 }
185}
186
187impl std::fmt::Display for SystemPrompt {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 write!(f, "{}", self.as_text())
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct SystemBlock {
196 #[serde(rename = "type")]
198 pub block_type: String,
199 pub text: String,
201 #[serde(skip_serializing_if = "Option::is_none")]
203 pub cache_control: Option<CacheControl>,
204}
205
206impl SystemBlock {
207 pub fn cached(text: impl Into<String>) -> Self {
209 Self {
210 block_type: "text".to_string(),
211 text: text.into(),
212 cache_control: Some(CacheControl::ephemeral()),
213 }
214 }
215
216 pub fn cached_with_ttl(text: impl Into<String>, ttl: CacheTtl) -> Self {
222 Self {
223 block_type: "text".to_string(),
224 text: text.into(),
225 cache_control: Some(CacheControl::ephemeral().with_ttl(ttl)),
226 }
227 }
228
229 pub fn uncached(text: impl Into<String>) -> Self {
231 Self {
232 block_type: "text".to_string(),
233 text: text.into(),
234 cache_control: None,
235 }
236 }
237}
238
239#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
241pub struct CacheControl {
242 #[serde(rename = "type")]
243 pub cache_type: CacheType,
244 #[serde(skip_serializing_if = "Option::is_none")]
245 pub ttl: Option<CacheTtl>,
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
249#[serde(rename_all = "snake_case")]
250pub enum CacheType {
251 Ephemeral,
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq)]
255pub enum CacheTtl {
256 FiveMinutes,
257 OneHour,
258}
259
260impl Serialize for CacheTtl {
261 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
262 where
263 S: serde::Serializer,
264 {
265 match self {
266 CacheTtl::FiveMinutes => serializer.serialize_str("5m"),
267 CacheTtl::OneHour => serializer.serialize_str("1h"),
268 }
269 }
270}
271
272impl<'de> Deserialize<'de> for CacheTtl {
273 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
274 where
275 D: serde::Deserializer<'de>,
276 {
277 let s = String::deserialize(deserializer)?;
278 match s.as_str() {
279 "5m" => Ok(CacheTtl::FiveMinutes),
280 "1h" => Ok(CacheTtl::OneHour),
281 _ => Err(serde::de::Error::custom(format!("unknown TTL: {}", s))),
282 }
283 }
284}
285
286impl CacheControl {
287 pub fn ephemeral() -> Self {
288 Self {
289 cache_type: CacheType::Ephemeral,
290 ttl: None,
291 }
292 }
293
294 pub fn ephemeral_5m() -> Self {
295 Self {
296 cache_type: CacheType::Ephemeral,
297 ttl: Some(CacheTtl::FiveMinutes),
298 }
299 }
300
301 pub fn ephemeral_1h() -> Self {
302 Self {
303 cache_type: CacheType::Ephemeral,
304 ttl: Some(CacheTtl::OneHour),
305 }
306 }
307
308 pub fn with_ttl(mut self, ttl: CacheTtl) -> Self {
309 self.ttl = Some(ttl);
310 self
311 }
312}
313
314impl SystemPrompt {
315 pub fn text(prompt: impl Into<String>) -> Self {
317 Self::Text(prompt.into())
318 }
319
320 pub fn cached(prompt: impl Into<String>) -> Self {
322 Self::Blocks(vec![SystemBlock {
323 block_type: "text".to_string(),
324 text: prompt.into(),
325 cache_control: Some(CacheControl {
326 cache_type: CacheType::Ephemeral,
327 ttl: None,
328 }),
329 }])
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_user_message() {
339 let msg = Message::user("Hello");
340 assert_eq!(msg.role, Role::User);
341 assert_eq!(msg.text(), "Hello");
342 }
343
344 #[test]
345 fn test_assistant_message() {
346 let msg = Message::assistant("Hi there!");
347 assert_eq!(msg.role, Role::Assistant);
348 assert_eq!(msg.text(), "Hi there!");
349 }
350}