1use std::borrow::Cow;
2
3use serde::{Deserialize, Serialize};
4
5use super::{BaseContentBlock, ContentBlock, Message, MessageContent, Role};
6
7#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
8pub struct Request {
9 pub model: String,
10 pub messages: Vec<Message>,
11 #[serde(skip_serializing_if = "Option::is_none")]
12 pub system: Option<System>,
13 pub max_tokens: u32,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub metadata: Option<serde_json::Value>,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub stop_sequences: Option<Vec<String>>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub stream: Option<bool>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub temperature: Option<f32>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub top_p: Option<f32>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub top_k: Option<u32>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub tools: Option<Vec<Tool>>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub thinking: Option<Thinking>,
30}
31
32#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
33pub struct Tool {
34 pub name: Cow<'static, str>,
35 pub description: Option<Cow<'static, str>>,
36 pub input_schema: serde_json::Value,
37}
38
39#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
40#[serde(tag = "type")]
41pub enum Thinking {
42 #[serde(rename = "enabled")]
43 Enabled { budget_tokens: u32 },
44}
45
46#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
47#[serde(untagged)]
48pub enum System {
49 Text(String),
50 Blocks(Vec<SystemMessage>),
51}
52
53#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
54pub struct SystemMessage {
55 pub r#type: SystemMessageType,
56 pub text: String,
57 pub cache_control: Option<CacheControl>,
58}
59
60#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
61#[serde(rename_all = "lowercase")]
62pub enum SystemMessageType {
63 #[default]
64 Text,
65}
66
67#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
68pub struct CacheControl {
69 pub r#type: CacheControlType,
70}
71
72#[derive(Debug, Deserialize, Clone, Default, PartialEq, Serialize)]
73#[serde(rename_all = "lowercase")]
74pub enum CacheControlType {
75 #[default]
76 Ephemeral,
77}
78
79pub fn process_messages(messages: &[Message]) -> Vec<Message> {
91 let mut filtered = Vec::with_capacity(messages.len());
92 if messages.is_empty() {
93 return filtered;
94 }
95
96 let mut prev_message: Option<Message> = None;
97 for message in messages {
98 if message.is_all_empty() {
100 continue;
101 }
102 if let Some(prev_msg) = prev_message.as_ref() {
103 if prev_msg.role == message.role {
104 let mut combined_message = prev_msg.clone();
105 match (&mut combined_message.content, &message.content) {
106 (MessageContent::Text(prev), MessageContent::Text(curr)) => {
107 prev.push('\n');
108 prev.push_str(curr);
109 }
110 (MessageContent::Blocks(prev), MessageContent::Blocks(curr)) => {
111 prev.retain(|b| !b.is_empty());
112 let curr_clone: Vec<_> =
113 curr.clone().into_iter().filter(|v| !v.is_empty()).collect();
114 prev.extend(curr_clone);
115 }
116 (MessageContent::Blocks(prev), MessageContent::Text(curr)) => {
117 prev.retain(|v| !v.is_empty());
118 prev.push(ContentBlock::Base(BaseContentBlock::Text {
119 text: curr.clone(),
120 }));
121 }
122 (MessageContent::Text(prev), MessageContent::Blocks(curr)) => {
123 let mut blocks =
124 vec![ContentBlock::Base(BaseContentBlock::Text { text: prev.clone() })];
125 let curr_clone: Vec<_> =
126 curr.clone().into_iter().filter(|v| !v.is_empty()).collect();
127 blocks.extend(curr_clone);
128 combined_message.content = MessageContent::Blocks(blocks);
129 }
130 }
131 filtered.pop();
132 filtered.push(combined_message.clone());
133 prev_message = Some(combined_message);
134 continue;
135 }
136 }
137 filtered.push(message.clone());
138 prev_message = Some(message.clone());
139 }
140
141 if let Some(first) = messages.first() {
143 if first.role == Role::Assistant {
144 filtered.insert(
145 0,
146 Message {
147 role: Role::User,
148 content: MessageContent::Text("Starting the conversation...".to_string()),
149 },
150 );
151 }
152 }
153
154 if let Some(last) = filtered.last_mut() {
158 if last.role == Role::Assistant {
159 match &mut last.content {
160 MessageContent::Text(text) => {
161 *text = text.trim_end().to_string();
162 }
163 MessageContent::Blocks(blocks) => {
164 for block in blocks {
165 if let ContentBlock::Base(BaseContentBlock::Text { text }) = block {
166 *text = text.trim_end().to_string();
167 }
168 }
169 }
170 }
171 }
172 }
173
174 filtered
175}
176
177#[cfg(test)]
178mod tests {
179 use crate::messages::{
180 ContentBlock, ImageSource, MessageContent, RequestOnlyContentBlock, Role,
181 ToolUseContentBlock,
182 };
183
184 use super::*;
185 #[test]
186 fn serde() {
187 let tests = vec![
188 (
189 "simple",
190 r#"{
191 "model": "claude-3-opus-20240229",
192 "max_tokens": 1024,
193 "messages": [
194 {"role": "user", "content": "Hello, world"}
195 ]
196 }"#,
197 Request {
198 model: "claude-3-opus-20240229".to_string(),
199 max_tokens: 1024,
200 messages: vec![Message {
201 role: Role::User,
202 content: MessageContent::Text("Hello, world".to_string()),
203 }],
204 ..Default::default()
205 },
206 ),
207 (
208 "with thinking enabled",
209 r#"{
210 "model": "claude-3-opus-20240229",
211 "max_tokens": 1024,
212 "thinking": {
213 "type": "enabled",
214 "budget_tokens": 5000
215 },
216 "messages": [
217 {"role": "user", "content": "Solve this complex math problem step by step"}
218 ]
219 }"#,
220 Request {
221 model: "claude-3-opus-20240229".to_string(),
222 max_tokens: 1024,
223 thinking: Some(Thinking::Enabled {
224 budget_tokens: 5000,
225 }),
226 messages: vec![Message {
227 role: Role::User,
228 content: MessageContent::Text(
229 "Solve this complex math problem step by step".to_string(),
230 ),
231 }],
232 ..Default::default()
233 },
234 ),
235 (
236 "system messages",
237 r#"{
238 "model": "claude-3-opus-20240229",
239 "max_tokens": 1024,
240 "system":[
241 {"type":"text","text":"You are a helpful assistant."},
242 {"type":"text","text":"You are a really helpful assistant."}
243 ],
244 "messages": [
245 {"role": "user", "content": "Hello, world"}
246 ]
247 }"#,
248 Request {
249 model: "claude-3-opus-20240229".to_string(),
250 max_tokens: 1024,
251 system: Some(System::Blocks(vec![
252 SystemMessage {
253 r#type: SystemMessageType::Text,
254 text: "You are a helpful assistant.".to_string(),
255 cache_control: None,
256 },
257 SystemMessage {
258 r#type: SystemMessageType::Text,
259 text: "You are a really helpful assistant.".to_string(),
260 cache_control: None,
261 },
262 ])),
263 messages: vec![Message {
264 role: Role::User,
265 content: MessageContent::Text("Hello, world".to_string()),
266 }],
267 ..Default::default()
268 },
269 ),
270 (
271 "system message with cache control",
272 r#"{
273 "model": "claude-3-opus-20240229",
274 "max_tokens": 1024,
275 "system":[
276 {"type":"text","text":"You are a helpful assistant."},
277 {"type":"text","text":"You are a really helpful assistant.", "cache_control": {"type":"ephemeral"}}
278 ],
279 "messages": [
280 {"role": "user", "content": "Hello, world"}
281 ]
282 }"#,
283 Request {
284 model: "claude-3-opus-20240229".to_string(),
285 max_tokens: 1024,
286 system: Some(System::Blocks(vec![
287 SystemMessage {
288 r#type: SystemMessageType::Text,
289 text: "You are a helpful assistant.".to_string(),
290 cache_control: None,
291 },
292 SystemMessage {
293 r#type: SystemMessageType::Text,
294 text: "You are a really helpful assistant.".to_string(),
295 cache_control: Some(CacheControl {
296 r#type: CacheControlType::Ephemeral,
297 }),
298 },
299 ])),
300 messages: vec![Message {
301 role: Role::User,
302 content: MessageContent::Text("Hello, world".to_string()),
303 }],
304 ..Default::default()
305 },
306 ),
307 (
308 "system message string",
309 r#"{
310 "model": "claude-3-opus-20240229",
311 "max_tokens": 1024,
312 "system":"You are a helpful assistant.",
313 "messages": [
314 {"role": "user", "content": "Hello, world"}
315 ]
316 }"#,
317 Request {
318 model: "claude-3-opus-20240229".to_string(),
319 max_tokens: 1024,
320 system: Some(System::Text("You are a helpful assistant.".to_string())),
321 messages: vec![Message {
322 role: Role::User,
323 content: MessageContent::Text("Hello, world".to_string()),
324 }],
325 ..Default::default()
326 },
327 ),
328 (
329 "multiple conversation",
330 r#"{
331 "model": "claude-3-opus-20240229",
332 "max_tokens": 1024,
333 "messages": [
334 {"role": "user", "content": "Hello there."},
335 {"role": "assistant", "content": "Hi, I'm Claude. How can I help you?"},
336 {"role": "user", "content": "Can you explain LLMs in plain English?"}
337 ]
338 }"#,
339 Request {
340 model: "claude-3-opus-20240229".to_string(),
341 max_tokens: 1024,
342 messages: vec![
343 Message {
344 role: Role::User,
345 content: MessageContent::Text("Hello there.".to_string()),
346 },
347 Message {
348 role: Role::Assistant,
349 content: MessageContent::Text(
350 "Hi, I'm Claude. How can I help you?".to_string(),
351 ),
352 },
353 Message {
354 role: Role::User,
355 content: MessageContent::Text(
356 "Can you explain LLMs in plain English?".to_string(),
357 ),
358 },
359 ],
360 ..Default::default()
361 },
362 ),
363 (
364 "image content",
365 r#"{
366 "model": "claude-3-opus-20240229",
367 "max_tokens": 1024,
368 "messages": [
369 {"role": "user", "content": [
370 {
371 "type": "image",
372 "source": {
373 "type": "base64",
374 "media_type": "image/jpeg",
375 "data": "/9j/4AAQSkZJRg..."
376 }
377 },
378 {"type": "text", "text": "What is in this image?"}
379 ]}
380 ]
381 }"#,
382 Request {
383 model: "claude-3-opus-20240229".to_string(),
384 max_tokens: 1024,
385 messages: vec![Message {
386 role: Role::User,
387 content: MessageContent::Blocks(vec![
388 ContentBlock::RequestOnly(RequestOnlyContentBlock::Image {
389 source: ImageSource::Base64 {
390 media_type: "image/jpeg".to_string(),
391 data: "/9j/4AAQSkZJRg...".to_string(),
392 },
393 }),
394 ContentBlock::Base(BaseContentBlock::Text {
395 text: "What is in this image?".to_string(),
396 }),
397 ]),
398 }],
399 ..Default::default()
400 },
401 ),
402 ];
403 for (name, json, expected) in tests {
404 let actual: Request = serde_json::from_str(json).unwrap();
406 assert_eq!(actual, expected, "deserialize test failed: {}", name);
407 let serialized = serde_json::to_string(&expected).unwrap();
409 let actual: Request = serde_json::from_str(&serialized).unwrap();
410 assert_eq!(actual, expected, "serialize test failed: {}", name);
411 }
412 }
413
414 #[test]
415 fn process() {
416 let tests = vec![
417 (
418 "[(assistant, text)]",
419 vec![Message {
420 role: Role::Assistant,
421 content: MessageContent::Text("hi".to_string()),
422 }],
423 vec![
424 Message {
425 role: Role::User,
426 content: MessageContent::Text("Starting the conversation...".to_string()),
427 },
428 Message {
429 role: Role::Assistant,
430 content: MessageContent::Text("hi".to_string()),
431 },
432 ],
433 ),
434 (
435 "[(assistant, blocks)]",
436 vec![Message {
437 role: Role::Assistant,
438 content: MessageContent::Blocks(vec![
439 ContentBlock::Base(BaseContentBlock::Text {
440 text: "hi".to_string(),
441 }),
442 ]),
443 }],
444 vec![
445 Message {
446 role: Role::User,
447 content: MessageContent::Text("Starting the conversation...".to_string()),
448 },
449 Message {
450 role: Role::Assistant,
451 content: MessageContent::Blocks(vec![
452 ContentBlock::Base(BaseContentBlock::Text {
453 text: "hi".to_string(),
454 }),
455 ]),
456 },
457 ],
458 ),
459 (
460 "[(assistant, blocks)]-2",
461 vec![Message {
462 role: Role::Assistant,
463 content: MessageContent::Blocks(vec![
464 ContentBlock::Base(BaseContentBlock::Text {
465 text: "hi".to_string(),
466 }),
467 ContentBlock::RequestOnly(RequestOnlyContentBlock::Image {
468 source: ImageSource::Base64 {
469 media_type: "img/png".to_string(),
470 data: "abcs".to_string(),
471 },
472 }),
473 ]),
474 }],
475 vec![
476 Message {
477 role: Role::User,
478 content: MessageContent::Text("Starting the conversation...".to_string()),
479 },
480 Message {
481 role: Role::Assistant,
482 content: MessageContent::Blocks(vec![
483 ContentBlock::Base(BaseContentBlock::Text {
484 text: "hi".to_string(),
485 }),
486 ContentBlock::RequestOnly(RequestOnlyContentBlock::Image {
487 source: ImageSource::Base64 {
488 media_type: "img/png".to_string(),
489 data: "abcs".to_string(),
490 },
491 }),
492 ]),
493 },
494 ],
495 ),
496 (
497 "[(assistant, text), (user, text)]",
498 vec![
499 Message {
500 role: Role::Assistant,
501 content: MessageContent::Text("hi".to_string()),
502 },
503 Message {
504 role: Role::User,
505 content: MessageContent::Text("hi".to_string()),
506 },
507 ],
508 vec![
509 Message {
510 role: Role::User,
511 content: MessageContent::Text("Starting the conversation...".to_string()),
512 },
513 Message {
514 role: Role::Assistant,
515 content: MessageContent::Text("hi".to_string()),
516 },
517 Message {
518 role: Role::User,
519 content: MessageContent::Text("hi".to_string()),
520 },
521 ],
522 ),
523 (
524 "[(assistant, text), (user, blocks)]",
525 vec![
526 Message {
527 role: Role::Assistant,
528 content: MessageContent::Text("hi".to_string()),
529 },
530 Message {
531 role: Role::User,
532 content: MessageContent::Blocks(vec![
533 ContentBlock::Base(BaseContentBlock::Text {
534 text: "hi".to_string(),
535 }),
536 ContentBlock::RequestOnly(RequestOnlyContentBlock::Image {
537 source: ImageSource::Base64 {
538 media_type: "img/png".to_string(),
539 data: "abcs".to_string(),
540 },
541 }),
542 ]),
543 },
544 ],
545 vec![
546 Message {
547 role: Role::User,
548 content: MessageContent::Text("Starting the conversation...".to_string()),
549 },
550 Message {
551 role: Role::Assistant,
552 content: MessageContent::Text("hi".to_string()),
553 },
554 Message {
555 role: Role::User,
556 content: MessageContent::Blocks(vec![
557 ContentBlock::Base(BaseContentBlock::Text {
558 text: "hi".to_string(),
559 }),
560 ContentBlock::RequestOnly(RequestOnlyContentBlock::Image {
561 source: ImageSource::Base64 {
562 media_type: "img/png".to_string(),
563 data: "abcs".to_string(),
564 },
565 }),
566 ]),
567 },
568 ],
569 ),
570 (
571 "[(user, text), (user, text)]",
572 vec![
573 Message {
574 role: Role::User,
575 content: MessageContent::Text("Hi,".to_string()),
576 },
577 Message {
578 role: Role::User,
579 content: MessageContent::Text("how are you".to_string()),
580 },
581 ],
582 vec![Message {
583 role: Role::User,
584 content: MessageContent::Text("Hi,\nhow are you".to_string()),
585 }],
586 ),
587 (
588 "[(user, text), (user, blocks)]",
589 vec![
590 Message {
591 role: Role::User,
592 content: MessageContent::Text("Hi,".to_string()),
593 },
594 Message {
595 role: Role::User,
596 content: MessageContent::Blocks(vec![
597 ContentBlock::Base(BaseContentBlock::Text {
598 text: "how are you".to_string(),
599 }),
600 ]),
601 },
602 ],
603 vec![Message {
604 role: Role::User,
605 content: MessageContent::Blocks(vec![
606 ContentBlock::Base(BaseContentBlock::Text {
607 text: "Hi,".to_string(),
608 }),
609 ContentBlock::Base(BaseContentBlock::Text {
610 text: "how are you".to_string(),
611 }),
612 ]),
613 }],
614 ),
615 (
616 "[(user, blocks), (user, text)]",
617 vec![
618 Message {
619 role: Role::User,
620 content: MessageContent::Blocks(vec![
621 ContentBlock::Base(BaseContentBlock::Text {
622 text: "how are you".to_string(),
623 }),
624 ]),
625 },
626 Message {
627 role: Role::User,
628 content: MessageContent::Text("Hi,".to_string()),
629 },
630 ],
631 vec![Message {
632 role: Role::User,
633 content: MessageContent::Blocks(vec![
634 ContentBlock::Base(BaseContentBlock::Text {
635 text: "how are you".to_string(),
636 }),
637 ContentBlock::Base(BaseContentBlock::Text {
638 text: "Hi,".to_string(),
639 }),
640 ]),
641 }],
642 ),
643 (
644 "[(assistant, text), (assistant, text)]",
645 vec![
646 Message {
647 role: Role::Assistant,
648 content: MessageContent::Text("Hi,".to_string()),
649 },
650 Message {
651 role: Role::Assistant,
652 content: MessageContent::Text("how are you".to_string()),
653 },
654 ],
655 vec![
656 Message {
657 role: Role::User,
658 content: MessageContent::Text("Starting the conversation...".to_string()),
659 },
660 Message {
661 role: Role::Assistant,
662 content: MessageContent::Text("Hi,\nhow are you".to_string()),
663 },
664 ],
665 ),
666 (
667 "[(assistant, blocks), (assistant, text)]",
669 vec![
670 Message {
671 role: Role::Assistant,
672 content: MessageContent::Blocks(vec![
673 ContentBlock::Base(BaseContentBlock::Text {
674 text: "how are you".to_string(),
675 }),
676 ]),
677 },
678 Message {
679 role: Role::Assistant,
680 content: MessageContent::Text("Hi,".to_string()),
681 },
682 ],
683 vec![
684 Message {
685 role: Role::User,
686 content: MessageContent::Text("Starting the conversation...".to_string()),
687 },
688 Message {
689 role: Role::Assistant,
690 content: MessageContent::Blocks(vec![
691 ContentBlock::Base(BaseContentBlock::Text {
692 text: "how are you".to_string(),
693 }),
694 ContentBlock::Base(BaseContentBlock::Text {
695 text: "Hi,".to_string(),
696 }),
697 ]),
698 },
699 ],
700 ),
701 (
702 "[(user, blocks), (user, text), (user, blocks)]",
703 vec![
704 Message {
705 role: Role::User,
706 content: MessageContent::Blocks(vec![
707 ContentBlock::Base(BaseContentBlock::Text {
708 text: "how are you".to_string(),
709 }),
710 ]),
711 },
712 Message {
713 role: Role::User,
714 content: MessageContent::Text("Hi,".to_string()),
715 },
716 Message {
717 role: Role::User,
718 content: MessageContent::Blocks(vec![
719 ContentBlock::Base(BaseContentBlock::Text {
720 text: "who are you".to_string(),
721 }),
722 ]),
723 },
724 ],
725 vec![Message {
726 role: Role::User,
727 content: MessageContent::Blocks(vec![
728 ContentBlock::Base(BaseContentBlock::Text {
729 text: "how are you".to_string(),
730 }),
731 ContentBlock::Base(BaseContentBlock::Text {
732 text: "Hi,".to_string(),
733 }),
734 ContentBlock::Base(BaseContentBlock::Text {
735 text: "who are you".to_string(),
736 }),
737 ]),
738 }],
739 ),
740 (
741 "[(user, blocks), (user, text), (user, blocks), (user, text)]",
742 vec![
743 Message {
744 role: Role::User,
745 content: MessageContent::Blocks(vec![
746 ContentBlock::Base(BaseContentBlock::Text {
747 text: "how are you".to_string(),
748 }),
749 ]),
750 },
751 Message {
752 role: Role::User,
753 content: MessageContent::Text("Hi,".to_string()),
754 },
755 Message {
756 role: Role::User,
757 content: MessageContent::Blocks(vec![
758 ContentBlock::Base(BaseContentBlock::Text {
759 text: "who are you".to_string(),
760 }),
761 ]),
762 },
763 Message {
764 role: Role::User,
765 content: MessageContent::Text("ho".to_string()),
766 },
767 ],
768 vec![Message {
769 role: Role::User,
770 content: MessageContent::Blocks(vec![
771 ContentBlock::Base(BaseContentBlock::Text {
772 text: "how are you".to_string(),
773 }),
774 ContentBlock::Base(BaseContentBlock::Text {
775 text: "Hi,".to_string(),
776 }),
777 ContentBlock::Base(BaseContentBlock::Text {
778 text: "who are you".to_string(),
779 }),
780 ContentBlock::Base(BaseContentBlock::Text {
781 text: "ho".to_string(),
782 }),
783 ]),
784 }],
785 ),
786 (
787 "[(user, blocks), (user, text), (user, blocks), (assistant, text)]",
788 vec![
789 Message {
790 role: Role::User,
791 content: MessageContent::Blocks(vec![
792 ContentBlock::Base(BaseContentBlock::Text {
793 text: "how are you".to_string(),
794 }),
795 ]),
796 },
797 Message {
798 role: Role::User,
799 content: MessageContent::Text("Hi,".to_string()),
800 },
801 Message {
802 role: Role::User,
803 content: MessageContent::Blocks(vec![
804 ContentBlock::Base(BaseContentBlock::Text {
805 text: "who are you".to_string(),
806 }),
807 ]),
808 },
809 Message {
810 role: Role::Assistant,
811 content: MessageContent::Text("ho".to_string()),
812 },
813 ],
814 vec![
815 Message {
816 role: Role::User,
817 content: MessageContent::Blocks(vec![
818 ContentBlock::Base(BaseContentBlock::Text {
819 text: "how are you".to_string(),
820 }),
821 ContentBlock::Base(BaseContentBlock::Text {
822 text: "Hi,".to_string(),
823 }),
824 ContentBlock::Base(BaseContentBlock::Text {
825 text: "who are you".to_string(),
826 }),
827 ]),
828 },
829 Message {
830 role: Role::Assistant,
831 content: MessageContent::Text("ho".to_string()),
832 },
833 ],
834 ),
835 (
836 "[(user, blocks), (user, text), (assistant, blocks), (user, text)]",
837 vec![
838 Message {
839 role: Role::User,
840 content: MessageContent::Blocks(vec![
841 ContentBlock::Base(BaseContentBlock::Text {
842 text: "how are you".to_string(),
843 }),
844 ]),
845 },
846 Message {
847 role: Role::User,
848 content: MessageContent::Text("Hi,".to_string()),
849 },
850 Message {
851 role: Role::Assistant,
852 content: MessageContent::Blocks(vec![
853 ContentBlock::Base(BaseContentBlock::Text {
854 text: "who are you".to_string(),
855 }),
856 ]),
857 },
858 Message {
859 role: Role::User,
860 content: MessageContent::Text("ho".to_string()),
861 },
862 ],
863 vec![
864 Message {
865 role: Role::User,
866 content: MessageContent::Blocks(vec![
867 ContentBlock::Base(BaseContentBlock::Text {
868 text: "how are you".to_string(),
869 }),
870 ContentBlock::Base(BaseContentBlock::Text {
871 text: "Hi,".to_string(),
872 }),
873 ]),
874 },
875 Message {
876 role: Role::Assistant,
877 content: MessageContent::Blocks(vec![
878 ContentBlock::Base(BaseContentBlock::Text {
879 text: "who are you".to_string(),
880 }),
881 ]),
882 },
883 Message {
884 role: Role::User,
885 content: MessageContent::Text("ho".to_string()),
886 },
887 ],
888 ),
889 (
890 "[(user, blocks), (assistant, text), (user, blocks), (user, text)]",
891 vec![
892 Message {
893 role: Role::User,
894 content: MessageContent::Blocks(vec![
895 ContentBlock::Base(BaseContentBlock::Text {
896 text: "how are you".to_string(),
897 }),
898 ]),
899 },
900 Message {
901 role: Role::Assistant,
902 content: MessageContent::Text("Hi,".to_string()),
903 },
904 Message {
905 role: Role::User,
906 content: MessageContent::Blocks(vec![
907 ContentBlock::Base(BaseContentBlock::Text {
908 text: "who are you".to_string(),
909 }),
910 ]),
911 },
912 Message {
913 role: Role::User,
914 content: MessageContent::Text("ho".to_string()),
915 },
916 ],
917 vec![
918 Message {
919 role: Role::User,
920 content: MessageContent::Blocks(vec![
921 ContentBlock::Base(BaseContentBlock::Text {
922 text: "how are you".to_string(),
923 }),
924 ]),
925 },
926 Message {
927 role: Role::Assistant,
928 content: MessageContent::Text("Hi,".to_string()),
929 },
930 Message {
931 role: Role::User,
932 content: MessageContent::Blocks(vec![
933 ContentBlock::Base(BaseContentBlock::Text {
934 text: "who are you".to_string(),
935 }),
936 ContentBlock::Base(BaseContentBlock::Text {
937 text: "ho".to_string(),
938 }),
939 ]),
940 },
941 ],
942 ),
943 (
944 "[(assistant, blocks), (user, text), (user, blocks), (user, text)]",
945 vec![
946 Message {
947 role: Role::Assistant,
948 content: MessageContent::Blocks(vec![
949 ContentBlock::Base(BaseContentBlock::Text {
950 text: "how are you".to_string(),
951 }),
952 ]),
953 },
954 Message {
955 role: Role::User,
956 content: MessageContent::Text("Hi,".to_string()),
957 },
958 Message {
959 role: Role::User,
960 content: MessageContent::Blocks(vec![
961 ContentBlock::Base(BaseContentBlock::Text {
962 text: "who are you".to_string(),
963 }),
964 ]),
965 },
966 Message {
967 role: Role::User,
968 content: MessageContent::Text("ho".to_string()),
969 },
970 ],
971 vec![
972 Message {
973 role: Role::User,
974 content: MessageContent::Text("Starting the conversation...".to_string()),
975 },
976 Message {
977 role: Role::Assistant,
978 content: MessageContent::Blocks(vec![
979 ContentBlock::Base(BaseContentBlock::Text {
980 text: "how are you".to_string(),
981 }),
982 ]),
983 },
984 Message {
985 role: Role::User,
986 content: MessageContent::Blocks(vec![
987 ContentBlock::Base(BaseContentBlock::Text {
988 text: "Hi,".to_string(),
989 }),
990 ContentBlock::Base(BaseContentBlock::Text {
991 text: "who are you".to_string(),
992 }),
993 ContentBlock::Base(BaseContentBlock::Text {
994 text: "ho".to_string(),
995 }),
996 ]),
997 },
998 ],
999 ),
1000 (
1001 "[(user, text), empty, (assistant, text with trailing space), empty]",
1002 vec![
1003 Message {
1004 role: Role::User,
1005 content: MessageContent::Text("hi".to_string()),
1006 },
1007 Message {
1008 role: Role::User,
1009 content: MessageContent::Text("".to_string()),
1010 },
1011 Message {
1012 role: Role::Assistant,
1013 content: MessageContent::Text("hi ".to_string()),
1014 },
1015 Message {
1016 role: Role::User,
1017 content: MessageContent::Text("".to_string()),
1018 },
1019 ],
1020 vec![
1021 Message {
1022 role: Role::User,
1023 content: MessageContent::Text("hi".to_string()),
1024 },
1025 Message {
1026 role: Role::Assistant,
1027 content: MessageContent::Text("hi".to_string()),
1028 },
1029 ],
1030 ),
1031 (
1032 "last one",
1033 vec![
1034 Message {
1035 role: Role::Assistant,
1036 content: MessageContent::Text("hi".to_string()),
1037 },
1038 Message {
1039 role: Role::User,
1040 content: MessageContent::Text("".to_string()),
1041 },
1042 Message {
1043 role: Role::User,
1044 content: MessageContent::Blocks(vec![
1045 ContentBlock::Base(BaseContentBlock::Text {
1046 text: " ".to_string(),
1047 }),
1048 ]),
1049 },
1050 Message {
1051 role: Role::User,
1052 content: MessageContent::Blocks(vec![
1053 ContentBlock::Base(BaseContentBlock::Text {
1054 text: " ".to_string(),
1055 }),
1056 ContentBlock::Base(BaseContentBlock::Text {
1057 text: "hi".to_string(),
1058 }),
1059 ]),
1060 },
1061 Message {
1062 role: Role::User,
1063 content: MessageContent::Text("how are you".to_string()),
1064 },
1065 Message {
1066 role: Role::Assistant,
1067 content: MessageContent::Text("hi ".to_string()),
1068 },
1069 Message {
1070 role: Role::Assistant,
1071 content: MessageContent::Blocks(vec![
1072 ContentBlock::Base(BaseContentBlock::Text {
1073 text: "who are you ".to_string(),
1074 }),
1075 ]),
1076 },
1077 Message {
1078 role: Role::User,
1079 content: MessageContent::Text(" ".to_string()),
1080 },
1081 Message {
1082 role: Role::User,
1083 content: MessageContent::Text("".to_string()),
1084 },
1085 ],
1086 vec![
1087 Message {
1088 role: Role::User,
1089 content: MessageContent::Text("Starting the conversation...".to_string()),
1090 },
1091 Message {
1092 role: Role::Assistant,
1093 content: MessageContent::Text("hi".to_string()),
1094 },
1095 Message {
1096 role: Role::User,
1097 content: MessageContent::Blocks(vec![
1098 ContentBlock::Base(BaseContentBlock::Text {
1099 text: "hi".to_string(),
1100 }),
1101 ContentBlock::Base(BaseContentBlock::Text {
1102 text: "how are you".to_string(),
1103 }),
1104 ]),
1105 },
1106 Message {
1107 role: Role::Assistant,
1108 content: MessageContent::Blocks(vec![
1109 ContentBlock::Base(BaseContentBlock::Text {
1110 text: "hi".to_string(),
1111 }),
1112 ContentBlock::Base(BaseContentBlock::Text {
1113 text: "who are you".to_string(),
1114 }),
1115 ]),
1116 },
1117 ],
1118 ),
1119 ];
1120 for (name, messages, expected) in tests {
1121 let got = process_messages(&messages);
1122 assert_eq!(got, expected, "test failed: {}", name);
1123 }
1124 }
1125
1126 #[test]
1127 fn tool_use() {
1128 let tests = vec![
1129 (
1130 "simple tool",
1131 r#"{
1132 "model": "claude-3-opus-20240229",
1133 "max_tokens": 1024,
1134 "tools": [{
1135 "name": "get_weather",
1136 "description": "Get the current weather in a given location",
1137 "input_schema": {
1138 "type": "object",
1139 "properties": {
1140 "location": {
1141 "type": "string",
1142 "description": "The city and state, e.g. San Francisco, CA"
1143 },
1144 "unit": {
1145 "type": "string",
1146 "enum": ["celsius", "fahrenheit"],
1147 "description": "The unit of temperature, either \"celsius\" or \"fahrenheit\""
1148 }
1149 },
1150 "required": ["location"]
1151 }
1152 }],
1153 "messages": [{"role": "user", "content": "What is the weather like in San Francisco?"}]
1154 }"#,
1155 Request {
1156 model: "claude-3-opus-20240229".to_string(),
1157 max_tokens: 1024,
1158 messages: vec![Message {
1159 role: Role::User,
1160 content: MessageContent::Text(
1161 "What is the weather like in San Francisco?".to_string(),
1162 ),
1163 }],
1164 tools: Some(vec![Tool {
1165 name: "get_weather".into(),
1166 description: Some(
1167 "Get the current weather in a given location".into(),
1168 ),
1169 input_schema: serde_json::json!({
1170 "type": "object",
1171 "properties": {
1172 "location": {
1173 "type": "string",
1174 "description": "The city and state, e.g. San Francisco, CA"
1175 },
1176 "unit": {
1177 "type": "string",
1178 "enum": ["celsius", "fahrenheit"],
1179 "description": "The unit of temperature, either \"celsius\" or \"fahrenheit\""
1180 }
1181 },
1182 "required": ["location"]
1183 }),
1184 }]),
1185 ..Default::default()
1186 },
1187 ),
1188 (
1189 "sequencial",
1190 r#"{
1191 "model": "claude-3-opus-20240229",
1192 "max_tokens": 1024,
1193 "tools": [
1194 {
1195 "name": "get_weather",
1196 "description": "Get the current weather in a given location",
1197 "input_schema": {
1198 "type": "object",
1199 "properties": {
1200 "location": {
1201 "type": "string",
1202 "description": "The city and state, e.g. San Francisco, CA"
1203 },
1204 "unit": {
1205 "type": "string",
1206 "enum": ["celsius", "fahrenheit"],
1207 "description": "The unit of temperature, either \"celsius\" or \"fahrenheit\""
1208 }
1209 },
1210 "required": ["location"]
1211 }
1212 }
1213 ],
1214 "messages": [
1215 {
1216 "role": "user",
1217 "content": "What is the weather like in San Francisco?"
1218 },
1219 {
1220 "role": "assistant",
1221 "content": [
1222 {
1223 "type": "text",
1224 "text": "<thinking>I need to use get_weather, and the user wants SF, which is likely San Francisco, CA.</thinking>"
1225 },
1226 {
1227 "type": "tool_use",
1228 "id": "toolu_01A09q90qw90lq917835lq9",
1229 "name": "get_weather",
1230 "input": {
1231 "location": "San Francisco, CA",
1232 "unit": "celsius"
1233 }
1234 }
1235 ]
1236 },
1237 {
1238 "role": "user",
1239 "content": [
1240 {
1241 "type": "tool_result",
1242 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
1243 "content": "15 degrees"
1244 }
1245 ]
1246 }
1247 ]
1248 }"#,
1249 Request {
1250 model: "claude-3-opus-20240229".to_string(),
1251 max_tokens: 1024,
1252 tools: Some(vec![Tool {
1253 name: "get_weather".into(),
1254 description: Some(
1255 "Get the current weather in a given location".into(),
1256 ),
1257 input_schema: serde_json::json!({
1258 "type": "object",
1259 "properties": {
1260 "location": {
1261 "type": "string",
1262 "description": "The city and state, e.g. San Francisco, CA"
1263 },
1264 "unit": {
1265 "type": "string",
1266 "enum": ["celsius", "fahrenheit"],
1267 "description": "The unit of temperature, either \"celsius\" or \"fahrenheit\""
1268 }
1269 },
1270 "required": ["location"]
1271 }),
1272 }]),
1273 messages: vec![
1274 Message {
1275 role: Role::User,
1276 content: MessageContent::Text(
1277 "What is the weather like in San Francisco?".to_string(),
1278 ),
1279 },
1280 Message {
1281 role: Role::Assistant,
1282 content: MessageContent::Blocks(vec![
1283 ContentBlock::Base(BaseContentBlock::Text {
1284 text: "<thinking>I need to use get_weather, and the user wants SF, which is likely San Francisco, CA.</thinking>".to_string(),
1285 }),
1286 ContentBlock::Base(BaseContentBlock::ToolUse(ToolUseContentBlock {
1287 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1288 name: "get_weather".to_string(),
1289 input: serde_json::json!({
1290 "location": "San Francisco, CA",
1291 "unit": "celsius"
1292 }),
1293 })),
1294 ]),
1295 },
1296 Message {
1297 role: Role::User,
1298 content: MessageContent::Blocks(vec![
1299 ContentBlock::RequestOnly(RequestOnlyContentBlock::ToolResult {
1300 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1301 content: "15 degrees".to_string(),
1302 }),
1303 ]),
1304 },
1305 ],
1306 ..Default::default()
1307 },
1308 ),
1309 ];
1310 for (name, json, expected) in tests {
1311 let actual: Request = serde_json::from_str(json).unwrap();
1313 assert_eq!(actual, expected, "deserialize test failed: {}", name);
1314 let serialized = serde_json::to_string(&expected).unwrap();
1316 let actual: Request = serde_json::from_str(&serialized).unwrap();
1317 assert_eq!(actual, expected, "serialize test failed: {}", name);
1318 }
1319 }
1320}