1use serde::{Deserialize, Serialize};
2
3use super::{BaseContentBlock, ContentBlock, Message, MessageContent, Role};
4
5#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
6pub struct Request {
7 pub model: String,
8 pub messages: Vec<Message>,
9 #[serde(skip_serializing_if = "Option::is_none")]
10 pub system: Option<System>,
11 pub max_tokens: u32,
12 #[serde(skip_serializing_if = "Option::is_none")]
13 pub metadata: Option<serde_json::Value>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub stop_sequences: Option<Vec<String>>,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub stream: Option<bool>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub temperature: Option<f32>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub top_p: Option<f32>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub top_k: Option<u32>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub tools: Option<Vec<Tool>>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub thinking: Option<Thinking>,
28}
29
30#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
31pub struct Tool {
32 pub name: String,
33 pub description: Option<String>,
34 pub input_schema: String,
35}
36
37#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
38#[serde(tag = "type")]
39pub enum Thinking {
40 #[serde(rename = "enabled")]
41 Enabled { budget_tokens: u32 },
42}
43
44#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
45#[serde(untagged)]
46pub enum System {
47 Text(String),
48 Blocks(Vec<SystemMessage>),
49}
50
51#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
52pub struct SystemMessage {
53 pub r#type: SystemMessageType,
54 pub text: String,
55 pub cache_control: Option<CacheControl>,
56}
57
58#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
59#[serde(rename_all = "lowercase")]
60pub enum SystemMessageType {
61 #[default]
62 Text,
63}
64
65#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
66pub struct CacheControl {
67 pub r#type: CacheControlType,
68}
69
70#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
71#[serde(rename_all = "lowercase")]
72pub enum CacheControlType {
73 #[default]
74 Ephemeral,
75}
76
77pub fn process_messages(messages: &[Message]) -> Vec<Message> {
89 let mut filtered = Vec::with_capacity(messages.len());
90 if messages.is_empty() {
91 return filtered;
92 }
93
94 let mut prev_message: Option<Message> = None;
95 for message in messages {
96 if message.is_all_empty() {
98 continue;
99 }
100 if let Some(prev_msg) = prev_message.as_ref() {
101 if prev_msg.role == message.role {
102 let mut combined_message = prev_msg.clone();
103 match (&mut combined_message.content, &message.content) {
104 (MessageContent::Text(prev), MessageContent::Text(curr)) => {
105 prev.push('\n');
106 prev.push_str(curr);
107 }
108 (MessageContent::Blocks(prev), MessageContent::Blocks(curr)) => {
109 prev.retain(|b| !b.is_empty());
110 let curr_clone: Vec<_> =
111 curr.clone().into_iter().filter(|v| !v.is_empty()).collect();
112 prev.extend(curr_clone);
113 }
114 (MessageContent::Blocks(prev), MessageContent::Text(curr)) => {
115 prev.retain(|v| !v.is_empty());
116 prev.push(ContentBlock::Base(BaseContentBlock::Text {
117 text: curr.clone(),
118 }));
119 }
120 (MessageContent::Text(prev), MessageContent::Blocks(curr)) => {
121 let mut blocks =
122 vec![ContentBlock::Base(BaseContentBlock::Text { text: prev.clone() })];
123 let curr_clone: Vec<_> =
124 curr.clone().into_iter().filter(|v| !v.is_empty()).collect();
125 blocks.extend(curr_clone);
126 combined_message.content = MessageContent::Blocks(blocks);
127 }
128 }
129 filtered.pop();
130 filtered.push(combined_message.clone());
131 prev_message = Some(combined_message);
132 continue;
133 }
134 }
135 filtered.push(message.clone());
136 prev_message = Some(message.clone());
137 }
138
139 if let Some(first) = messages.first() {
141 if first.role == Role::Assistant {
142 filtered.insert(
143 0,
144 Message {
145 role: Role::User,
146 content: MessageContent::Text("Starting the conversation...".to_string()),
147 },
148 );
149 }
150 }
151
152 if let Some(last) = filtered.last_mut() {
156 if last.role == Role::Assistant {
157 match &mut last.content {
158 MessageContent::Text(text) => {
159 *text = text.trim_end().to_string();
160 }
161 MessageContent::Blocks(blocks) => {
162 for block in blocks {
163 if let ContentBlock::Base(BaseContentBlock::Text { text }) = block {
164 *text = text.trim_end().to_string();
165 }
166 }
167 }
168 }
169 }
170 }
171
172 filtered
173}
174
175#[cfg(test)]
176mod tests {
177 use crate::messages::{
178 ContentBlock, ImageSource, MessageContent, RequestOnlyContentBlock, Role,
179 ToolUseContentBlock,
180 };
181
182 use super::*;
183 #[test]
184 fn serde() {
185 let tests = vec![
186 (
187 "simple",
188 r#"{
189 "model": "claude-3-opus-20240229",
190 "max_tokens": 1024,
191 "messages": [
192 {"role": "user", "content": "Hello, world"}
193 ]
194 }"#,
195 Request {
196 model: "claude-3-opus-20240229".to_string(),
197 max_tokens: 1024,
198 messages: vec![Message {
199 role: Role::User,
200 content: MessageContent::Text("Hello, world".to_string()),
201 }],
202 ..Default::default()
203 },
204 ),
205 (
206 "with thinking enabled",
207 r#"{
208 "model": "claude-3-opus-20240229",
209 "max_tokens": 1024,
210 "thinking": {
211 "type": "enabled",
212 "budget_tokens": 5000
213 },
214 "messages": [
215 {"role": "user", "content": "Solve this complex math problem step by step"}
216 ]
217 }"#,
218 Request {
219 model: "claude-3-opus-20240229".to_string(),
220 max_tokens: 1024,
221 thinking: Some(Thinking::Enabled {
222 budget_tokens: 5000,
223 }),
224 messages: vec![Message {
225 role: Role::User,
226 content: MessageContent::Text(
227 "Solve this complex math problem step by step".to_string(),
228 ),
229 }],
230 ..Default::default()
231 },
232 ),
233 (
234 "system messages",
235 r#"{
236 "model": "claude-3-opus-20240229",
237 "max_tokens": 1024,
238 "system":[
239 {"type":"text","text":"You are a helpful assistant."},
240 {"type":"text","text":"You are a really helpful assistant."}
241 ],
242 "messages": [
243 {"role": "user", "content": "Hello, world"}
244 ]
245 }"#,
246 Request {
247 model: "claude-3-opus-20240229".to_string(),
248 max_tokens: 1024,
249 system: Some(System::Blocks(vec![
250 SystemMessage {
251 r#type: SystemMessageType::Text,
252 text: "You are a helpful assistant.".to_string(),
253 cache_control: None,
254 },
255 SystemMessage {
256 r#type: SystemMessageType::Text,
257 text: "You are a really helpful assistant.".to_string(),
258 cache_control: None,
259 },
260 ])),
261 messages: vec![Message {
262 role: Role::User,
263 content: MessageContent::Text("Hello, world".to_string()),
264 }],
265 ..Default::default()
266 },
267 ),
268 (
269 "system message with cache control",
270 r#"{
271 "model": "claude-3-opus-20240229",
272 "max_tokens": 1024,
273 "system":[
274 {"type":"text","text":"You are a helpful assistant."},
275 {"type":"text","text":"You are a really helpful assistant.", "cache_control": {"type":"ephemeral"}}
276 ],
277 "messages": [
278 {"role": "user", "content": "Hello, world"}
279 ]
280 }"#,
281 Request {
282 model: "claude-3-opus-20240229".to_string(),
283 max_tokens: 1024,
284 system: Some(System::Blocks(vec![
285 SystemMessage {
286 r#type: SystemMessageType::Text,
287 text: "You are a helpful assistant.".to_string(),
288 cache_control: None,
289 },
290 SystemMessage {
291 r#type: SystemMessageType::Text,
292 text: "You are a really helpful assistant.".to_string(),
293 cache_control: Some(CacheControl {
294 r#type: CacheControlType::Ephemeral,
295 }),
296 },
297 ])),
298 messages: vec![Message {
299 role: Role::User,
300 content: MessageContent::Text("Hello, world".to_string()),
301 }],
302 ..Default::default()
303 },
304 ),
305 (
306 "system message string",
307 r#"{
308 "model": "claude-3-opus-20240229",
309 "max_tokens": 1024,
310 "system":"You are a helpful assistant.",
311 "messages": [
312 {"role": "user", "content": "Hello, world"}
313 ]
314 }"#,
315 Request {
316 model: "claude-3-opus-20240229".to_string(),
317 max_tokens: 1024,
318 system: Some(System::Text("You are a helpful assistant.".to_string())),
319 messages: vec![Message {
320 role: Role::User,
321 content: MessageContent::Text("Hello, world".to_string()),
322 }],
323 ..Default::default()
324 },
325 ),
326 (
327 "multiple conversation",
328 r#"{
329 "model": "claude-3-opus-20240229",
330 "max_tokens": 1024,
331 "messages": [
332 {"role": "user", "content": "Hello there."},
333 {"role": "assistant", "content": "Hi, I'm Claude. How can I help you?"},
334 {"role": "user", "content": "Can you explain LLMs in plain English?"}
335 ]
336 }"#,
337 Request {
338 model: "claude-3-opus-20240229".to_string(),
339 max_tokens: 1024,
340 messages: vec![
341 Message {
342 role: Role::User,
343 content: MessageContent::Text("Hello there.".to_string()),
344 },
345 Message {
346 role: Role::Assistant,
347 content: MessageContent::Text(
348 "Hi, I'm Claude. How can I help you?".to_string(),
349 ),
350 },
351 Message {
352 role: Role::User,
353 content: MessageContent::Text(
354 "Can you explain LLMs in plain English?".to_string(),
355 ),
356 },
357 ],
358 ..Default::default()
359 },
360 ),
361 (
362 "image content",
363 r#"{
364 "model": "claude-3-opus-20240229",
365 "max_tokens": 1024,
366 "messages": [
367 {"role": "user", "content": [
368 {
369 "type": "image",
370 "source": {
371 "type": "base64",
372 "media_type": "image/jpeg",
373 "data": "/9j/4AAQSkZJRg..."
374 }
375 },
376 {"type": "text", "text": "What is in this image?"}
377 ]}
378 ]
379 }"#,
380 Request {
381 model: "claude-3-opus-20240229".to_string(),
382 max_tokens: 1024,
383 messages: vec![Message {
384 role: Role::User,
385 content: MessageContent::Blocks(vec![
386 ContentBlock::RequestOnly(RequestOnlyContentBlock::Image {
387 source: ImageSource::Base64 {
388 media_type: "image/jpeg".to_string(),
389 data: "/9j/4AAQSkZJRg...".to_string(),
390 },
391 }),
392 ContentBlock::Base(BaseContentBlock::Text {
393 text: "What is in this image?".to_string(),
394 }),
395 ]),
396 }],
397 ..Default::default()
398 },
399 ),
400 ];
401 for (name, json, expected) in tests {
402 let actual: Request = serde_json::from_str(json).unwrap();
404 assert_eq!(actual, expected, "deserialize test failed: {}", name);
405 let serialized = serde_json::to_string(&expected).unwrap();
407 let actual: Request = serde_json::from_str(&serialized).unwrap();
408 assert_eq!(actual, expected, "serialize test failed: {}", name);
409 }
410 }
411
412 #[test]
413 fn process() {
414 let tests = vec![
415 (
416 "[(assistant, text)]",
417 vec![Message {
418 role: Role::Assistant,
419 content: MessageContent::Text("hi".to_string()),
420 }],
421 vec![
422 Message {
423 role: Role::User,
424 content: MessageContent::Text("Starting the conversation...".to_string()),
425 },
426 Message {
427 role: Role::Assistant,
428 content: MessageContent::Text("hi".to_string()),
429 },
430 ],
431 ),
432 (
433 "[(assistant, blocks)]",
434 vec![Message {
435 role: Role::Assistant,
436 content: MessageContent::Blocks(vec![
437 ContentBlock::Base(BaseContentBlock::Text {
438 text: "hi".to_string(),
439 }),
440 ]),
441 }],
442 vec![
443 Message {
444 role: Role::User,
445 content: MessageContent::Text("Starting the conversation...".to_string()),
446 },
447 Message {
448 role: Role::Assistant,
449 content: MessageContent::Blocks(vec![
450 ContentBlock::Base(BaseContentBlock::Text {
451 text: "hi".to_string(),
452 }),
453 ]),
454 },
455 ],
456 ),
457 (
458 "[(assistant, blocks)]-2",
459 vec![Message {
460 role: Role::Assistant,
461 content: MessageContent::Blocks(vec![
462 ContentBlock::Base(BaseContentBlock::Text {
463 text: "hi".to_string(),
464 }),
465 ContentBlock::RequestOnly(RequestOnlyContentBlock::Image {
466 source: ImageSource::Base64 {
467 media_type: "img/png".to_string(),
468 data: "abcs".to_string(),
469 },
470 }),
471 ]),
472 }],
473 vec![
474 Message {
475 role: Role::User,
476 content: MessageContent::Text("Starting the conversation...".to_string()),
477 },
478 Message {
479 role: Role::Assistant,
480 content: MessageContent::Blocks(vec![
481 ContentBlock::Base(BaseContentBlock::Text {
482 text: "hi".to_string(),
483 }),
484 ContentBlock::RequestOnly(RequestOnlyContentBlock::Image {
485 source: ImageSource::Base64 {
486 media_type: "img/png".to_string(),
487 data: "abcs".to_string(),
488 },
489 }),
490 ]),
491 },
492 ],
493 ),
494 (
495 "[(assistant, text), (user, text)]",
496 vec![
497 Message {
498 role: Role::Assistant,
499 content: MessageContent::Text("hi".to_string()),
500 },
501 Message {
502 role: Role::User,
503 content: MessageContent::Text("hi".to_string()),
504 },
505 ],
506 vec![
507 Message {
508 role: Role::User,
509 content: MessageContent::Text("Starting the conversation...".to_string()),
510 },
511 Message {
512 role: Role::Assistant,
513 content: MessageContent::Text("hi".to_string()),
514 },
515 Message {
516 role: Role::User,
517 content: MessageContent::Text("hi".to_string()),
518 },
519 ],
520 ),
521 (
522 "[(assistant, text), (user, blocks)]",
523 vec![
524 Message {
525 role: Role::Assistant,
526 content: MessageContent::Text("hi".to_string()),
527 },
528 Message {
529 role: Role::User,
530 content: MessageContent::Blocks(vec![
531 ContentBlock::Base(BaseContentBlock::Text {
532 text: "hi".to_string(),
533 }),
534 ContentBlock::RequestOnly(RequestOnlyContentBlock::Image {
535 source: ImageSource::Base64 {
536 media_type: "img/png".to_string(),
537 data: "abcs".to_string(),
538 },
539 }),
540 ]),
541 },
542 ],
543 vec![
544 Message {
545 role: Role::User,
546 content: MessageContent::Text("Starting the conversation...".to_string()),
547 },
548 Message {
549 role: Role::Assistant,
550 content: MessageContent::Text("hi".to_string()),
551 },
552 Message {
553 role: Role::User,
554 content: MessageContent::Blocks(vec![
555 ContentBlock::Base(BaseContentBlock::Text {
556 text: "hi".to_string(),
557 }),
558 ContentBlock::RequestOnly(RequestOnlyContentBlock::Image {
559 source: ImageSource::Base64 {
560 media_type: "img/png".to_string(),
561 data: "abcs".to_string(),
562 },
563 }),
564 ]),
565 },
566 ],
567 ),
568 (
569 "[(user, text), (user, text)]",
570 vec![
571 Message {
572 role: Role::User,
573 content: MessageContent::Text("Hi,".to_string()),
574 },
575 Message {
576 role: Role::User,
577 content: MessageContent::Text("how are you".to_string()),
578 },
579 ],
580 vec![Message {
581 role: Role::User,
582 content: MessageContent::Text("Hi,\nhow are you".to_string()),
583 }],
584 ),
585 (
586 "[(user, text), (user, blocks)]",
587 vec![
588 Message {
589 role: Role::User,
590 content: MessageContent::Text("Hi,".to_string()),
591 },
592 Message {
593 role: Role::User,
594 content: MessageContent::Blocks(vec![
595 ContentBlock::Base(BaseContentBlock::Text {
596 text: "how are you".to_string(),
597 }),
598 ]),
599 },
600 ],
601 vec![Message {
602 role: Role::User,
603 content: MessageContent::Blocks(vec![
604 ContentBlock::Base(BaseContentBlock::Text {
605 text: "Hi,".to_string(),
606 }),
607 ContentBlock::Base(BaseContentBlock::Text {
608 text: "how are you".to_string(),
609 }),
610 ]),
611 }],
612 ),
613 (
614 "[(user, blocks), (user, text)]",
615 vec![
616 Message {
617 role: Role::User,
618 content: MessageContent::Blocks(vec![
619 ContentBlock::Base(BaseContentBlock::Text {
620 text: "how are you".to_string(),
621 }),
622 ]),
623 },
624 Message {
625 role: Role::User,
626 content: MessageContent::Text("Hi,".to_string()),
627 },
628 ],
629 vec![Message {
630 role: Role::User,
631 content: MessageContent::Blocks(vec![
632 ContentBlock::Base(BaseContentBlock::Text {
633 text: "how are you".to_string(),
634 }),
635 ContentBlock::Base(BaseContentBlock::Text {
636 text: "Hi,".to_string(),
637 }),
638 ]),
639 }],
640 ),
641 (
642 "[(assistant, text), (assistant, text)]",
643 vec![
644 Message {
645 role: Role::Assistant,
646 content: MessageContent::Text("Hi,".to_string()),
647 },
648 Message {
649 role: Role::Assistant,
650 content: MessageContent::Text("how are you".to_string()),
651 },
652 ],
653 vec![
654 Message {
655 role: Role::User,
656 content: MessageContent::Text("Starting the conversation...".to_string()),
657 },
658 Message {
659 role: Role::Assistant,
660 content: MessageContent::Text("Hi,\nhow are you".to_string()),
661 },
662 ],
663 ),
664 (
665 "[(assistant, blocks), (assistant, text)]",
667 vec![
668 Message {
669 role: Role::Assistant,
670 content: MessageContent::Blocks(vec![
671 ContentBlock::Base(BaseContentBlock::Text {
672 text: "how are you".to_string(),
673 }),
674 ]),
675 },
676 Message {
677 role: Role::Assistant,
678 content: MessageContent::Text("Hi,".to_string()),
679 },
680 ],
681 vec![
682 Message {
683 role: Role::User,
684 content: MessageContent::Text("Starting the conversation...".to_string()),
685 },
686 Message {
687 role: Role::Assistant,
688 content: MessageContent::Blocks(vec![
689 ContentBlock::Base(BaseContentBlock::Text {
690 text: "how are you".to_string(),
691 }),
692 ContentBlock::Base(BaseContentBlock::Text {
693 text: "Hi,".to_string(),
694 }),
695 ]),
696 },
697 ],
698 ),
699 (
700 "[(user, blocks), (user, text), (user, blocks)]",
701 vec![
702 Message {
703 role: Role::User,
704 content: MessageContent::Blocks(vec![
705 ContentBlock::Base(BaseContentBlock::Text {
706 text: "how are you".to_string(),
707 }),
708 ]),
709 },
710 Message {
711 role: Role::User,
712 content: MessageContent::Text("Hi,".to_string()),
713 },
714 Message {
715 role: Role::User,
716 content: MessageContent::Blocks(vec![
717 ContentBlock::Base(BaseContentBlock::Text {
718 text: "who are you".to_string(),
719 }),
720 ]),
721 },
722 ],
723 vec![Message {
724 role: Role::User,
725 content: MessageContent::Blocks(vec![
726 ContentBlock::Base(BaseContentBlock::Text {
727 text: "how are you".to_string(),
728 }),
729 ContentBlock::Base(BaseContentBlock::Text {
730 text: "Hi,".to_string(),
731 }),
732 ContentBlock::Base(BaseContentBlock::Text {
733 text: "who are you".to_string(),
734 }),
735 ]),
736 }],
737 ),
738 (
739 "[(user, blocks), (user, text), (user, blocks), (user, text)]",
740 vec![
741 Message {
742 role: Role::User,
743 content: MessageContent::Blocks(vec![
744 ContentBlock::Base(BaseContentBlock::Text {
745 text: "how are you".to_string(),
746 }),
747 ]),
748 },
749 Message {
750 role: Role::User,
751 content: MessageContent::Text("Hi,".to_string()),
752 },
753 Message {
754 role: Role::User,
755 content: MessageContent::Blocks(vec![
756 ContentBlock::Base(BaseContentBlock::Text {
757 text: "who are you".to_string(),
758 }),
759 ]),
760 },
761 Message {
762 role: Role::User,
763 content: MessageContent::Text("ho".to_string()),
764 },
765 ],
766 vec![Message {
767 role: Role::User,
768 content: MessageContent::Blocks(vec![
769 ContentBlock::Base(BaseContentBlock::Text {
770 text: "how are you".to_string(),
771 }),
772 ContentBlock::Base(BaseContentBlock::Text {
773 text: "Hi,".to_string(),
774 }),
775 ContentBlock::Base(BaseContentBlock::Text {
776 text: "who are you".to_string(),
777 }),
778 ContentBlock::Base(BaseContentBlock::Text {
779 text: "ho".to_string(),
780 }),
781 ]),
782 }],
783 ),
784 (
785 "[(user, blocks), (user, text), (user, blocks), (assistant, text)]",
786 vec![
787 Message {
788 role: Role::User,
789 content: MessageContent::Blocks(vec![
790 ContentBlock::Base(BaseContentBlock::Text {
791 text: "how are you".to_string(),
792 }),
793 ]),
794 },
795 Message {
796 role: Role::User,
797 content: MessageContent::Text("Hi,".to_string()),
798 },
799 Message {
800 role: Role::User,
801 content: MessageContent::Blocks(vec![
802 ContentBlock::Base(BaseContentBlock::Text {
803 text: "who are you".to_string(),
804 }),
805 ]),
806 },
807 Message {
808 role: Role::Assistant,
809 content: MessageContent::Text("ho".to_string()),
810 },
811 ],
812 vec![
813 Message {
814 role: Role::User,
815 content: MessageContent::Blocks(vec![
816 ContentBlock::Base(BaseContentBlock::Text {
817 text: "how are you".to_string(),
818 }),
819 ContentBlock::Base(BaseContentBlock::Text {
820 text: "Hi,".to_string(),
821 }),
822 ContentBlock::Base(BaseContentBlock::Text {
823 text: "who are you".to_string(),
824 }),
825 ]),
826 },
827 Message {
828 role: Role::Assistant,
829 content: MessageContent::Text("ho".to_string()),
830 },
831 ],
832 ),
833 (
834 "[(user, blocks), (user, text), (assistant, blocks), (user, text)]",
835 vec![
836 Message {
837 role: Role::User,
838 content: MessageContent::Blocks(vec![
839 ContentBlock::Base(BaseContentBlock::Text {
840 text: "how are you".to_string(),
841 }),
842 ]),
843 },
844 Message {
845 role: Role::User,
846 content: MessageContent::Text("Hi,".to_string()),
847 },
848 Message {
849 role: Role::Assistant,
850 content: MessageContent::Blocks(vec![
851 ContentBlock::Base(BaseContentBlock::Text {
852 text: "who are you".to_string(),
853 }),
854 ]),
855 },
856 Message {
857 role: Role::User,
858 content: MessageContent::Text("ho".to_string()),
859 },
860 ],
861 vec![
862 Message {
863 role: Role::User,
864 content: MessageContent::Blocks(vec![
865 ContentBlock::Base(BaseContentBlock::Text {
866 text: "how are you".to_string(),
867 }),
868 ContentBlock::Base(BaseContentBlock::Text {
869 text: "Hi,".to_string(),
870 }),
871 ]),
872 },
873 Message {
874 role: Role::Assistant,
875 content: MessageContent::Blocks(vec![
876 ContentBlock::Base(BaseContentBlock::Text {
877 text: "who are you".to_string(),
878 }),
879 ]),
880 },
881 Message {
882 role: Role::User,
883 content: MessageContent::Text("ho".to_string()),
884 },
885 ],
886 ),
887 (
888 "[(user, blocks), (assistant, text), (user, blocks), (user, text)]",
889 vec![
890 Message {
891 role: Role::User,
892 content: MessageContent::Blocks(vec![
893 ContentBlock::Base(BaseContentBlock::Text {
894 text: "how are you".to_string(),
895 }),
896 ]),
897 },
898 Message {
899 role: Role::Assistant,
900 content: MessageContent::Text("Hi,".to_string()),
901 },
902 Message {
903 role: Role::User,
904 content: MessageContent::Blocks(vec![
905 ContentBlock::Base(BaseContentBlock::Text {
906 text: "who are you".to_string(),
907 }),
908 ]),
909 },
910 Message {
911 role: Role::User,
912 content: MessageContent::Text("ho".to_string()),
913 },
914 ],
915 vec![
916 Message {
917 role: Role::User,
918 content: MessageContent::Blocks(vec![
919 ContentBlock::Base(BaseContentBlock::Text {
920 text: "how are you".to_string(),
921 }),
922 ]),
923 },
924 Message {
925 role: Role::Assistant,
926 content: MessageContent::Text("Hi,".to_string()),
927 },
928 Message {
929 role: Role::User,
930 content: MessageContent::Blocks(vec![
931 ContentBlock::Base(BaseContentBlock::Text {
932 text: "who are you".to_string(),
933 }),
934 ContentBlock::Base(BaseContentBlock::Text {
935 text: "ho".to_string(),
936 }),
937 ]),
938 },
939 ],
940 ),
941 (
942 "[(assistant, blocks), (user, text), (user, blocks), (user, text)]",
943 vec![
944 Message {
945 role: Role::Assistant,
946 content: MessageContent::Blocks(vec![
947 ContentBlock::Base(BaseContentBlock::Text {
948 text: "how are you".to_string(),
949 }),
950 ]),
951 },
952 Message {
953 role: Role::User,
954 content: MessageContent::Text("Hi,".to_string()),
955 },
956 Message {
957 role: Role::User,
958 content: MessageContent::Blocks(vec![
959 ContentBlock::Base(BaseContentBlock::Text {
960 text: "who are you".to_string(),
961 }),
962 ]),
963 },
964 Message {
965 role: Role::User,
966 content: MessageContent::Text("ho".to_string()),
967 },
968 ],
969 vec![
970 Message {
971 role: Role::User,
972 content: MessageContent::Text("Starting the conversation...".to_string()),
973 },
974 Message {
975 role: Role::Assistant,
976 content: MessageContent::Blocks(vec![
977 ContentBlock::Base(BaseContentBlock::Text {
978 text: "how are you".to_string(),
979 }),
980 ]),
981 },
982 Message {
983 role: Role::User,
984 content: MessageContent::Blocks(vec![
985 ContentBlock::Base(BaseContentBlock::Text {
986 text: "Hi,".to_string(),
987 }),
988 ContentBlock::Base(BaseContentBlock::Text {
989 text: "who are you".to_string(),
990 }),
991 ContentBlock::Base(BaseContentBlock::Text {
992 text: "ho".to_string(),
993 }),
994 ]),
995 },
996 ],
997 ),
998 (
999 "[(user, text), empty, (assistant, text with trailing space), empty]",
1000 vec![
1001 Message {
1002 role: Role::User,
1003 content: MessageContent::Text("hi".to_string()),
1004 },
1005 Message {
1006 role: Role::User,
1007 content: MessageContent::Text("".to_string()),
1008 },
1009 Message {
1010 role: Role::Assistant,
1011 content: MessageContent::Text("hi ".to_string()),
1012 },
1013 Message {
1014 role: Role::User,
1015 content: MessageContent::Text("".to_string()),
1016 },
1017 ],
1018 vec![
1019 Message {
1020 role: Role::User,
1021 content: MessageContent::Text("hi".to_string()),
1022 },
1023 Message {
1024 role: Role::Assistant,
1025 content: MessageContent::Text("hi".to_string()),
1026 },
1027 ],
1028 ),
1029 (
1030 "last one",
1031 vec![
1032 Message {
1033 role: Role::Assistant,
1034 content: MessageContent::Text("hi".to_string()),
1035 },
1036 Message {
1037 role: Role::User,
1038 content: MessageContent::Text("".to_string()),
1039 },
1040 Message {
1041 role: Role::User,
1042 content: MessageContent::Blocks(vec![
1043 ContentBlock::Base(BaseContentBlock::Text {
1044 text: " ".to_string(),
1045 }),
1046 ]),
1047 },
1048 Message {
1049 role: Role::User,
1050 content: MessageContent::Blocks(vec![
1051 ContentBlock::Base(BaseContentBlock::Text {
1052 text: " ".to_string(),
1053 }),
1054 ContentBlock::Base(BaseContentBlock::Text {
1055 text: "hi".to_string(),
1056 }),
1057 ]),
1058 },
1059 Message {
1060 role: Role::User,
1061 content: MessageContent::Text("how are you".to_string()),
1062 },
1063 Message {
1064 role: Role::Assistant,
1065 content: MessageContent::Text("hi ".to_string()),
1066 },
1067 Message {
1068 role: Role::Assistant,
1069 content: MessageContent::Blocks(vec![
1070 ContentBlock::Base(BaseContentBlock::Text {
1071 text: "who are you ".to_string(),
1072 }),
1073 ]),
1074 },
1075 Message {
1076 role: Role::User,
1077 content: MessageContent::Text(" ".to_string()),
1078 },
1079 Message {
1080 role: Role::User,
1081 content: MessageContent::Text("".to_string()),
1082 },
1083 ],
1084 vec![
1085 Message {
1086 role: Role::User,
1087 content: MessageContent::Text("Starting the conversation...".to_string()),
1088 },
1089 Message {
1090 role: Role::Assistant,
1091 content: MessageContent::Text("hi".to_string()),
1092 },
1093 Message {
1094 role: Role::User,
1095 content: MessageContent::Blocks(vec![
1096 ContentBlock::Base(BaseContentBlock::Text {
1097 text: "hi".to_string(),
1098 }),
1099 ContentBlock::Base(BaseContentBlock::Text {
1100 text: "how are you".to_string(),
1101 }),
1102 ]),
1103 },
1104 Message {
1105 role: Role::Assistant,
1106 content: MessageContent::Blocks(vec![
1107 ContentBlock::Base(BaseContentBlock::Text {
1108 text: "hi".to_string(),
1109 }),
1110 ContentBlock::Base(BaseContentBlock::Text {
1111 text: "who are you".to_string(),
1112 }),
1113 ]),
1114 },
1115 ],
1116 ),
1117 ];
1118 for (name, messages, expected) in tests {
1119 let got = process_messages(&messages);
1120 assert_eq!(got, expected, "test failed: {}", name);
1121 }
1122 }
1123
1124 #[test]
1125 fn tool_use() {
1126 let tests = vec![
1127 (
1128 "simple tool",
1129 r#"{
1130 "model": "claude-3-opus-20240229",
1131 "max_tokens": 1024,
1132 "tools": [{
1133 "name": "get_weather",
1134 "description": "Get the current weather in a given location",
1135 "input_schema": "{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"unit\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The unit of temperature, either \\\"celsius\\\" or \\\"fahrenheit\\\"\"}},\"required\":[\"location\"]}"
1136 }],
1137 "messages": [{"role": "user", "content": "What is the weather like in San Francisco?"}]
1138 }"#,
1139 Request {
1140 model: "claude-3-opus-20240229".to_string(),
1141 max_tokens: 1024,
1142 messages: vec![Message {
1143 role: Role::User,
1144 content: MessageContent::Text(
1145 "What is the weather like in San Francisco?".to_string(),
1146 ),
1147 }],
1148 tools: Some(vec![Tool {
1149 name: "get_weather".to_string(),
1150 description: Some(
1151 "Get the current weather in a given location".to_string(),
1152 ),
1153 input_schema: "{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"unit\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The unit of temperature, either \\\"celsius\\\" or \\\"fahrenheit\\\"\"}},\"required\":[\"location\"]}".to_string(),
1154 }]),
1155 ..Default::default()
1156 },
1157 ),
1158 (
1159 "sequencial",
1160 r#"{
1161 "model": "claude-3-opus-20240229",
1162 "max_tokens": 1024,
1163 "tools": [
1164 {
1165 "name": "get_weather",
1166 "description": "Get the current weather in a given location",
1167 "input_schema": "{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"unit\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The unit of temperature, either \\\"celsius\\\" or \\\"fahrenheit\\\"\"}},\"required\":[\"location\"]}"
1168 }
1169 ],
1170 "messages": [
1171 {
1172 "role": "user",
1173 "content": "What is the weather like in San Francisco?"
1174 },
1175 {
1176 "role": "assistant",
1177 "content": [
1178 {
1179 "type": "text",
1180 "text": "<thinking>I need to use get_weather, and the user wants SF, which is likely San Francisco, CA.</thinking>"
1181 },
1182 {
1183 "type": "tool_use",
1184 "id": "toolu_01A09q90qw90lq917835lq9",
1185 "name": "get_weather",
1186 "input": {
1187 "location": "San Francisco, CA",
1188 "unit": "celsius"
1189 }
1190 }
1191 ]
1192 },
1193 {
1194 "role": "user",
1195 "content": [
1196 {
1197 "type": "tool_result",
1198 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1199 "content": "15 degrees"
1200 }
1201 ]
1202 }
1203 ]
1204 }"#,
1205 Request {
1206 model: "claude-3-opus-20240229".to_string(),
1207 max_tokens: 1024,
1208 tools: Some(vec![Tool {
1209 name: "get_weather".to_string(),
1210 description: Some(
1211 "Get the current weather in a given location".to_string(),
1212 ),
1213 input_schema: "{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"unit\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The unit of temperature, either \\\"celsius\\\" or \\\"fahrenheit\\\"\"}},\"required\":[\"location\"]}".to_string(),
1214 }]),
1215 messages: vec![
1216 Message {
1217 role: Role::User,
1218 content: MessageContent::Text(
1219 "What is the weather like in San Francisco?".to_string(),
1220 ),
1221 },
1222 Message {
1223 role: Role::Assistant,
1224 content: MessageContent::Blocks(vec![
1225 ContentBlock::Base(BaseContentBlock::Text {
1226 text: "<thinking>I need to use get_weather, and the user wants SF, which is likely San Francisco, CA.</thinking>".to_string(),
1227 }),
1228 ContentBlock::Base(BaseContentBlock::ToolUse(ToolUseContentBlock {
1229 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1230 name: "get_weather".to_string(),
1231 input: serde_json::json!({
1232 "location": "San Francisco, CA",
1233 "unit": "celsius"
1234 }),
1235 })),
1236 ]),
1237 },
1238 Message {
1239 role: Role::User,
1240 content: MessageContent::Blocks(vec![
1241 ContentBlock::RequestOnly(RequestOnlyContentBlock::ToolResult {
1242 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1243 content: "15 degrees".to_string(),
1244 }),
1245 ]),
1246 },
1247 ],
1248 ..Default::default()
1249 },
1250 ),
1251 ];
1252 for (name, json, expected) in tests {
1253 let actual: Request = serde_json::from_str(json).unwrap();
1255 assert_eq!(actual, expected, "deserialize test failed: {}", name);
1256 let serialized = serde_json::to_string(&expected).unwrap();
1258 let actual: Request = serde_json::from_str(&serialized).unwrap();
1259 assert_eq!(actual, expected, "serialize test failed: {}", name);
1260 }
1261 }
1262}