1use serde::{Deserialize, Serialize};
20
21use crate::messages::cache::CacheControl;
22use crate::messages::content::{ContentBlock, KnownBlock};
23use crate::messages::input::{MessageContent, MessageInput, SystemPrompt};
24use crate::messages::mcp::McpServerConfig;
25use crate::messages::metadata::{MessageMetadata, RequestServiceTier};
26use crate::messages::request::CreateMessageRequest;
27use crate::messages::thinking::ThinkingConfig;
28use crate::messages::tools::{Tool, ToolChoice};
29use crate::types::{ModelId, Role, Usage};
30
31#[cfg(feature = "async")]
32use crate::client::Client;
33#[cfg(feature = "async")]
34use crate::error::Result;
35#[cfg(feature = "async")]
36use crate::messages::response::Message;
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40#[non_exhaustive]
41pub struct Conversation {
42 pub model: ModelId,
44 pub max_tokens: u32,
46
47 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub system: Option<SystemPrompt>,
50
51 #[serde(default)]
53 pub messages: Vec<MessageInput>,
54
55 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub temperature: Option<f32>,
58 #[serde(default, skip_serializing_if = "Option::is_none")]
60 pub top_p: Option<f32>,
61 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub top_k: Option<u32>,
64 #[serde(default, skip_serializing_if = "Option::is_none")]
66 pub stop_sequences: Option<Vec<String>>,
67
68 #[serde(default, skip_serializing_if = "Vec::is_empty")]
70 pub tools: Vec<Tool>,
71 #[serde(default, skip_serializing_if = "Option::is_none")]
73 pub tool_choice: Option<ToolChoice>,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
76 pub thinking: Option<ThinkingConfig>,
77 #[serde(default, skip_serializing_if = "Option::is_none")]
79 pub metadata: Option<MessageMetadata>,
80 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub service_tier: Option<RequestServiceTier>,
83 #[serde(default, skip_serializing_if = "Vec::is_empty")]
85 pub mcp_servers: Vec<McpServerConfig>,
86 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub container: Option<String>,
89
90 #[serde(default)]
92 pub auto_cache: AutoCacheMode,
93
94 #[serde(default, skip_serializing_if = "Option::is_none")]
99 pub compaction: Option<ContextCompactionPolicy>,
100
101 #[serde(default)]
103 pub usage_history: Vec<UsageRecord>,
104}
105
106#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
123#[non_exhaustive]
124pub struct ContextCompactionPolicy {
125 pub max_input_tokens: u32,
127 pub keep_recent_turns: usize,
129}
130
131impl Default for ContextCompactionPolicy {
132 fn default() -> Self {
133 Self {
134 max_input_tokens: 100_000,
137 keep_recent_turns: 4,
138 }
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
144#[non_exhaustive]
145pub struct UsageRecord {
146 pub model: ModelId,
148 pub usage: Usage,
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
154#[serde(rename_all = "snake_case")]
155#[non_exhaustive]
156pub enum AutoCacheMode {
157 #[default]
159 Off,
160 System,
162 SystemAndLastUser,
165}
166
167impl Conversation {
168 #[must_use]
170 pub fn new(model: impl Into<ModelId>, max_tokens: u32) -> Self {
171 Self {
172 model: model.into(),
173 max_tokens,
174 system: None,
175 messages: Vec::new(),
176 temperature: None,
177 top_p: None,
178 top_k: None,
179 stop_sequences: None,
180 tools: Vec::new(),
181 tool_choice: None,
182 thinking: None,
183 metadata: None,
184 service_tier: None,
185 mcp_servers: Vec::new(),
186 container: None,
187 auto_cache: AutoCacheMode::Off,
188 compaction: None,
189 usage_history: Vec::new(),
190 }
191 }
192
193 #[must_use]
196 pub fn with_compaction(mut self, policy: ContextCompactionPolicy) -> Self {
197 self.compaction = Some(policy);
198 self
199 }
200
201 #[must_use]
203 pub fn system(mut self, s: impl Into<SystemPrompt>) -> Self {
204 self.system = Some(s.into());
205 self
206 }
207
208 #[must_use]
211 pub fn with_cache_breakpoint_on_system(self) -> Self {
212 self.with_auto_cache(AutoCacheMode::System)
213 }
214
215 #[must_use]
217 pub fn with_auto_cache(mut self, mode: AutoCacheMode) -> Self {
218 self.auto_cache = mode;
219 self
220 }
221
222 #[must_use]
224 pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
225 self.tools = tools;
226 self
227 }
228
229 #[must_use]
231 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
232 self.tool_choice = Some(choice);
233 self
234 }
235
236 #[must_use]
238 pub fn with_thinking(mut self, t: ThinkingConfig) -> Self {
239 self.thinking = Some(t);
240 self
241 }
242
243 #[must_use]
245 pub fn with_temperature(mut self, t: f32) -> Self {
246 self.temperature = Some(t);
247 self
248 }
249
250 pub fn push_user(&mut self, content: impl Into<MessageContent>) {
252 self.messages.push(MessageInput::user(content));
253 }
254
255 pub fn push_assistant(&mut self, content: impl Into<MessageContent>) {
258 self.messages.push(MessageInput::assistant(content));
259 }
260
261 pub fn pop(&mut self) -> Option<MessageInput> {
264 self.messages.pop()
265 }
266
267 #[must_use]
269 pub fn turn_count(&self) -> usize {
270 self.usage_history.len()
271 }
272
273 #[must_use]
275 pub fn cumulative_usage(&self) -> Usage {
276 self.usage_history
277 .iter()
278 .fold(Usage::default(), |mut acc, r| {
279 acc.input_tokens = acc.input_tokens.saturating_add(r.usage.input_tokens);
280 acc.output_tokens = acc.output_tokens.saturating_add(r.usage.output_tokens);
281 acc.cache_creation_input_tokens = sum_opt(
282 acc.cache_creation_input_tokens,
283 r.usage.cache_creation_input_tokens,
284 );
285 acc.cache_read_input_tokens =
286 sum_opt(acc.cache_read_input_tokens, r.usage.cache_read_input_tokens);
287 acc
288 })
289 }
290
291 #[cfg(feature = "pricing")]
294 #[cfg_attr(docsrs, doc(cfg(feature = "pricing")))]
295 #[must_use]
296 pub fn cost(&self, pricing: &crate::pricing::PricingTable) -> f64 {
297 self.usage_history
298 .iter()
299 .map(|r| pricing.cost(&r.model, &r.usage))
300 .sum()
301 }
302
303 #[must_use]
311 pub fn estimate_input_tokens(&self) -> u32 {
312 let mut total = 0u32;
313 if let Some(s) = &self.system {
314 total = total.saturating_add(estimate_system_tokens(s));
315 }
316 for msg in &self.messages {
317 total = total.saturating_add(estimate_message_tokens(msg));
318 }
319 for tool in &self.tools {
321 if let Ok(s) = serde_json::to_string(tool) {
322 total = total.saturating_add(estimate_text_tokens(&s));
323 }
324 }
325 total
326 }
327
328 #[must_use]
332 pub fn complete_roundtrip_count(&self) -> usize {
333 let last_idx = self.messages.len().saturating_sub(1);
334 self.messages
335 .iter()
336 .enumerate()
337 .filter(|(i, m)| *i < last_idx && m.role == Role::Assistant && !message_has_tool_use(m))
338 .count()
339 }
340
341 pub fn compact_if_needed(&mut self) -> bool {
348 let Some(policy) = self.compaction.clone() else {
349 return false;
350 };
351
352 let initial = self.estimate_input_tokens();
353 if initial <= policy.max_input_tokens {
354 return false;
355 }
356
357 let initial_msg_count = self.messages.len();
358 loop {
359 if self.estimate_input_tokens() <= policy.max_input_tokens {
360 break;
361 }
362 if self.complete_roundtrip_count() <= policy.keep_recent_turns {
363 break;
364 }
365 if !self.drop_oldest_roundtrip() {
366 break;
367 }
368 }
369
370 let dropped = initial_msg_count - self.messages.len();
371 if dropped > 0 {
372 tracing::warn!(
373 initial_estimate = initial,
374 final_estimate = self.estimate_input_tokens(),
375 messages_dropped = dropped,
376 roundtrips_remaining = self.complete_roundtrip_count(),
377 "claude-api: context compaction applied",
378 );
379 true
380 } else {
381 false
382 }
383 }
384
385 fn drop_oldest_roundtrip(&mut self) -> bool {
390 let last_idx = self.messages.len().saturating_sub(1);
391 let drop_to = self.messages.iter().enumerate().position(|(i, m)| {
392 i < last_idx && m.role == Role::Assistant && !message_has_tool_use(m)
393 });
394 match drop_to {
395 Some(idx) => {
396 self.messages.drain(0..=idx);
397 true
398 }
399 None => false,
400 }
401 }
402
403 #[must_use]
411 pub fn build_request(&self) -> CreateMessageRequest {
412 let mut messages = self.messages.clone();
413 let mut system = self.system.clone();
414
415 match self.auto_cache {
416 AutoCacheMode::Off => {}
417 AutoCacheMode::System => {
418 cache_breakpoint_on_system(&mut system);
419 }
420 AutoCacheMode::SystemAndLastUser => {
421 cache_breakpoint_on_system(&mut system);
422 cache_breakpoint_on_last_user(&mut messages);
423 }
424 }
425
426 let mut builder = CreateMessageRequest::builder()
427 .model(self.model.clone())
428 .max_tokens(self.max_tokens)
429 .messages(messages);
430
431 if let Some(s) = system {
432 builder = builder.system(s);
433 }
434 if let Some(t) = self.temperature {
435 builder = builder.temperature(t);
436 }
437 if let Some(p) = self.top_p {
438 builder = builder.top_p(p);
439 }
440 if let Some(k) = self.top_k {
441 builder = builder.top_k(k);
442 }
443 if let Some(seqs) = &self.stop_sequences {
444 builder = builder.stop_sequences(seqs.clone());
445 }
446 if !self.tools.is_empty() {
447 builder = builder.tools(self.tools.clone());
448 }
449 if let Some(c) = self.tool_choice.clone() {
450 builder = builder.tool_choice(c);
451 }
452 if let Some(t) = self.thinking {
453 builder = builder.thinking(t);
454 }
455 if let Some(m) = self.metadata.clone() {
456 builder = builder.metadata(m);
457 }
458 if let Some(t) = self.service_tier {
459 builder = builder.service_tier(t);
460 }
461 if !self.mcp_servers.is_empty() {
462 builder = builder.mcp_servers(self.mcp_servers.clone());
463 }
464 if let Some(c) = self.container.clone() {
465 builder = builder.container(c);
466 }
467
468 builder
469 .build()
470 .expect("conversation::build_request always provides model + max_tokens")
471 }
472
473 #[cfg(feature = "async")]
476 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
477 pub async fn send(&mut self, client: &Client) -> Result<Message> {
478 self.send_with_beta(client, &[]).await
479 }
480
481 #[cfg(feature = "async")]
483 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
484 pub async fn send_with_beta(&mut self, client: &Client, betas: &[&str]) -> Result<Message> {
485 self.compact_if_needed();
486 let request = self.build_request();
487 let response = client.messages().create_with_beta(request, betas).await?;
488 self.usage_history.push(UsageRecord {
489 model: self.model.clone(),
490 usage: response.usage.clone(),
491 });
492 self.messages
494 .push(MessageInput::assistant(response.content.clone()));
495 Ok(response)
496 }
497}
498
499fn estimate_text_tokens(s: &str) -> u32 {
502 let chars = u32::try_from(s.chars().count()).unwrap_or(u32::MAX);
506 chars.div_ceil(4)
507}
508
509fn estimate_system_tokens(s: &SystemPrompt) -> u32 {
510 match s {
511 SystemPrompt::Text(t) => estimate_text_tokens(t),
512 SystemPrompt::Blocks(blocks) => blocks.iter().map(estimate_block_tokens).sum(),
513 }
514}
515
516fn estimate_message_tokens(msg: &MessageInput) -> u32 {
517 let body = match &msg.content {
519 MessageContent::Text(s) => estimate_text_tokens(s),
520 MessageContent::Blocks(blocks) => blocks.iter().map(estimate_block_tokens).sum(),
521 };
522 body.saturating_add(4)
523}
524
525fn estimate_block_tokens(block: &ContentBlock) -> u32 {
526 use crate::messages::content::ToolResultContent;
527
528 match block {
529 ContentBlock::Known(KnownBlock::Text { text, .. }) => estimate_text_tokens(text),
530 ContentBlock::Known(KnownBlock::Thinking { thinking, .. }) => {
531 estimate_text_tokens(thinking)
532 }
533 ContentBlock::Known(KnownBlock::ToolUse { name, input, .. }) => {
534 estimate_text_tokens(name).saturating_add(estimate_text_tokens(&input.to_string()))
536 }
537 ContentBlock::Known(KnownBlock::ServerToolUse { name, input, .. }) => {
538 estimate_text_tokens(name).saturating_add(estimate_text_tokens(&input.to_string()))
539 }
540 ContentBlock::Known(KnownBlock::ToolResult { content, .. }) => match content {
541 ToolResultContent::Text(s) => estimate_text_tokens(s),
542 ToolResultContent::Blocks(b) => b.iter().map(estimate_block_tokens).sum(),
543 },
544 ContentBlock::Known(KnownBlock::Image { .. }) => 1500,
548 ContentBlock::Known(KnownBlock::Document { .. }) => 2000,
549 ContentBlock::Known(KnownBlock::WebSearchToolResult { .. }) => 500,
550 ContentBlock::Known(KnownBlock::RedactedThinking { data, .. }) => {
551 estimate_text_tokens(data)
552 }
553 ContentBlock::Other(v) => estimate_text_tokens(&v.to_string()),
554 }
555}
556
557fn message_has_tool_use(msg: &MessageInput) -> bool {
558 match &msg.content {
559 MessageContent::Text(_) => false,
560 MessageContent::Blocks(blocks) => blocks.iter().any(|b| {
561 matches!(
562 b,
563 ContentBlock::Known(KnownBlock::ToolUse { .. } | KnownBlock::ServerToolUse { .. })
564 )
565 }),
566 }
567}
568
569fn sum_opt(a: Option<u32>, b: Option<u32>) -> Option<u32> {
570 match (a, b) {
571 (None, None) => None,
572 (Some(x), None) | (None, Some(x)) => Some(x),
573 (Some(x), Some(y)) => Some(x.saturating_add(y)),
574 }
575}
576
577fn cache_breakpoint_on_system(system: &mut Option<SystemPrompt>) {
578 let Some(s) = system.take() else { return };
579 let blocks = match s {
580 SystemPrompt::Text(text) => vec![ContentBlock::Known(KnownBlock::Text {
581 text,
582 cache_control: Some(CacheControl::ephemeral()),
583 citations: None,
584 })],
585 SystemPrompt::Blocks(mut blocks) => {
586 apply_cache_control_to_last_block(&mut blocks);
587 blocks
588 }
589 };
590 *system = Some(SystemPrompt::Blocks(blocks));
591}
592
593fn cache_breakpoint_on_last_user(messages: &mut [MessageInput]) {
594 let Some(idx) = messages.iter().rposition(|m| m.role == Role::User) else {
595 return;
596 };
597 let target = &mut messages[idx];
598 match &mut target.content {
599 MessageContent::Text(text) => {
600 target.content = MessageContent::Blocks(vec![ContentBlock::Known(KnownBlock::Text {
601 text: std::mem::take(text),
602 cache_control: Some(CacheControl::ephemeral()),
603 citations: None,
604 })]);
605 }
606 MessageContent::Blocks(blocks) => {
607 apply_cache_control_to_last_block(blocks);
608 }
609 }
610}
611
612fn apply_cache_control_to_last_block(blocks: &mut [ContentBlock]) {
613 let Some(last) = blocks.last_mut() else {
614 return;
615 };
616 if let ContentBlock::Known(
622 KnownBlock::Text { cache_control, .. }
623 | KnownBlock::Image { cache_control, .. }
624 | KnownBlock::Document { cache_control, .. }
625 | KnownBlock::ToolResult { cache_control, .. },
626 ) = last
627 {
628 *cache_control = Some(CacheControl::ephemeral());
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use pretty_assertions::assert_eq;
636 use serde_json::json;
637
638 fn convo() -> Conversation {
639 Conversation::new(ModelId::SONNET_4_6, 256)
640 }
641
642 #[test]
645 fn new_starts_empty() {
646 let c = convo();
647 assert!(c.messages.is_empty());
648 assert!(c.usage_history.is_empty());
649 assert_eq!(c.turn_count(), 0);
650 }
651
652 #[test]
653 fn push_appends_to_history() {
654 let mut c = convo();
655 c.push_user("hi");
656 c.push_assistant("hello");
657 c.push_user("how are you?");
658 assert_eq!(c.messages.len(), 3);
659 assert_eq!(c.messages[0].role, Role::User);
660 assert_eq!(c.messages[1].role, Role::Assistant);
661 }
662
663 #[test]
664 fn pop_removes_last() {
665 let mut c = convo();
666 c.push_user("first");
667 c.push_user("second");
668 let popped = c.pop().unwrap();
669 let MessageContent::Text(t) = popped.content else {
670 panic!("expected Text content");
671 };
672 assert_eq!(t, "second");
673 assert_eq!(c.messages.len(), 1);
674 }
675
676 #[test]
677 fn cumulative_usage_sums_across_turns() {
678 let mut c = convo();
679 c.usage_history.push(UsageRecord {
680 model: ModelId::SONNET_4_6,
681 usage: Usage {
682 input_tokens: 100,
683 output_tokens: 50,
684 cache_creation_input_tokens: Some(20),
685 cache_read_input_tokens: Some(30),
686 ..Usage::default()
687 },
688 });
689 c.usage_history.push(UsageRecord {
690 model: ModelId::SONNET_4_6,
691 usage: Usage {
692 input_tokens: 200,
693 output_tokens: 80,
694 cache_read_input_tokens: Some(70),
695 ..Usage::default()
696 },
697 });
698 let total = c.cumulative_usage();
699 assert_eq!(total.input_tokens, 300);
700 assert_eq!(total.output_tokens, 130);
701 assert_eq!(total.cache_creation_input_tokens, Some(20));
702 assert_eq!(total.cache_read_input_tokens, Some(100));
703 }
704
705 #[test]
706 fn serde_round_trip_preserves_state() {
707 let mut original = Conversation::new(ModelId::OPUS_4_7, 512)
708 .system("be concise")
709 .with_cache_breakpoint_on_system()
710 .with_temperature(0.5);
711 original.push_user("hi");
712 original.push_assistant("hello");
713 original.usage_history.push(UsageRecord {
714 model: ModelId::OPUS_4_7,
715 usage: Usage {
716 input_tokens: 5,
717 output_tokens: 3,
718 ..Usage::default()
719 },
720 });
721
722 let json = serde_json::to_string(&original).unwrap();
723 let parsed: Conversation = serde_json::from_str(&json).unwrap();
724
725 assert_eq!(parsed.model, ModelId::OPUS_4_7);
726 assert_eq!(parsed.max_tokens, 512);
727 assert_eq!(parsed.auto_cache, AutoCacheMode::System);
728 assert_eq!(parsed.temperature, Some(0.5));
729 assert_eq!(parsed.messages.len(), 2);
730 assert_eq!(parsed.usage_history.len(), 1);
731 assert_eq!(parsed.turn_count(), 1);
732 }
733
734 #[test]
737 fn build_request_includes_basic_fields() {
738 let mut c = convo().system("be concise").with_temperature(0.25);
739 c.push_user("hello");
740 let req = c.build_request();
741 let v = serde_json::to_value(&req).unwrap();
742 assert_eq!(v["model"], "claude-sonnet-4-6");
743 assert_eq!(v["max_tokens"], 256);
744 assert_eq!(v["system"], "be concise");
745 assert_eq!(v["temperature"], 0.25);
746 assert_eq!(v["messages"][0]["role"], "user");
747 }
748
749 #[test]
750 fn build_request_with_auto_cache_system() {
751 let mut c = convo()
752 .system("you are concise")
753 .with_cache_breakpoint_on_system();
754 c.push_user("hi");
755 let v = serde_json::to_value(c.build_request()).unwrap();
756 assert_eq!(
757 v["system"],
758 json!([{
759 "type": "text",
760 "text": "you are concise",
761 "cache_control": {"type": "ephemeral"}
762 }])
763 );
764 assert_eq!(v["messages"][0]["content"], "hi");
766 }
767
768 #[test]
769 fn build_request_with_auto_cache_system_and_last_user() {
770 let mut c = convo()
771 .system("you are concise")
772 .with_auto_cache(AutoCacheMode::SystemAndLastUser);
773 c.push_user("first");
774 c.push_assistant("response");
775 c.push_user("follow-up");
776 let v = serde_json::to_value(c.build_request()).unwrap();
777
778 assert_eq!(v["system"][0]["cache_control"]["type"], "ephemeral");
780
781 let msgs = v["messages"].as_array().unwrap();
783 assert_eq!(msgs.len(), 3);
784 assert_eq!(msgs[2]["role"], "user");
785 assert_eq!(msgs[2]["content"][0]["type"], "text");
786 assert_eq!(msgs[2]["content"][0]["text"], "follow-up");
787 assert_eq!(msgs[2]["content"][0]["cache_control"]["type"], "ephemeral");
788
789 assert_eq!(msgs[0]["content"], "first");
791 }
792
793 #[test]
794 fn build_request_auto_cache_off_does_nothing() {
795 let mut c = convo().system("plain");
796 c.push_user("hi");
797 let v = serde_json::to_value(c.build_request()).unwrap();
798 assert_eq!(v["system"], "plain");
800 assert_eq!(v["messages"][0]["content"], "hi");
802 }
803
804 #[test]
805 fn build_request_does_not_mutate_self() {
806 let mut c = convo().system("orig").with_cache_breakpoint_on_system();
807 c.push_user("hi");
808 let _ = c.build_request();
809 let Some(SystemPrompt::Text(t)) = &c.system else {
812 panic!("system should still be Text, got {:?}", c.system);
813 };
814 assert_eq!(t, "orig");
815 let MessageContent::Text(t) = &c.messages[0].content else {
816 panic!(
817 "user content should still be Text, got {:?}",
818 c.messages[0].content
819 );
820 };
821 assert_eq!(t, "hi");
822 }
823
824 #[test]
829 fn estimate_input_tokens_grows_with_message_size() {
830 let mut c = convo();
831 c.push_user("hi");
832 let small = c.estimate_input_tokens();
833
834 let mut c2 = convo();
835 c2.push_user("a".repeat(1000));
836 let large = c2.estimate_input_tokens();
837
838 assert!(large > small * 10, "{large} should dwarf {small}");
839 }
840
841 #[test]
842 fn compact_if_needed_no_op_without_policy() {
843 let mut c = convo();
844 for i in 0..10 {
845 c.push_user(format!("user {i}"));
846 c.push_assistant(format!("assistant {i}"));
847 }
848 let before = c.messages.len();
849 assert!(!c.compact_if_needed());
850 assert_eq!(c.messages.len(), before);
851 }
852
853 #[test]
854 fn compact_if_needed_no_op_when_under_threshold() {
855 let mut c = convo().with_compaction(ContextCompactionPolicy {
856 max_input_tokens: 100_000, keep_recent_turns: 1,
858 });
859 c.push_user("short");
860 c.push_assistant("short");
861 assert!(!c.compact_if_needed());
862 assert_eq!(c.messages.len(), 2);
863 }
864
865 #[test]
866 fn compact_if_needed_drops_oldest_roundtrips_above_threshold() {
867 let mut c = convo().with_compaction(ContextCompactionPolicy {
870 max_input_tokens: 60,
871 keep_recent_turns: 1,
872 });
873 for i in 0..6 {
874 c.push_user(format!(
875 "this is user message number {i} with reasonable length"
876 ));
877 c.push_assistant(format!(
878 "this is assistant response number {i} with similar length"
879 ));
880 }
881 c.push_user("current question");
883
884 let before_count = c.messages.len();
885 assert!(c.compact_if_needed(), "should have compacted");
886 assert!(
887 c.messages.len() < before_count,
888 "expected drop; got {} -> {}",
889 before_count,
890 c.messages.len()
891 );
892 let MessageContent::Text(last_user) = &c.messages.last().unwrap().content else {
894 panic!("expected text");
895 };
896 assert_eq!(last_user, "current question");
897 }
898
899 #[test]
900 fn compact_if_needed_respects_keep_recent_turns() {
901 let mut c = convo().with_compaction(ContextCompactionPolicy {
903 max_input_tokens: 1, keep_recent_turns: 2,
905 });
906 for i in 0..5 {
907 c.push_user(format!("u{i}"));
908 c.push_assistant(format!("a{i}"));
909 }
910 c.push_user("trailing");
911
912 c.compact_if_needed();
913 assert_eq!(c.complete_roundtrip_count(), 2);
915 let MessageContent::Text(last) = &c.messages.last().unwrap().content else {
916 panic!("expected text");
917 };
918 assert_eq!(last, "trailing");
919 }
920
921 #[test]
922 fn compact_if_needed_preserves_tool_use_tool_result_pairs() {
923 use crate::messages::content::{ContentBlock, KnownBlock, ToolResultContent};
924 use serde_json::json;
925
926 let mut c = convo().with_compaction(ContextCompactionPolicy {
927 max_input_tokens: 30,
928 keep_recent_turns: 0, });
930
931 c.push_user("first user".repeat(20)); c.push_assistant("first answer".repeat(20));
934
935 c.push_user("second user".repeat(20));
937 c.messages.push(MessageInput::assistant(vec![
938 ContentBlock::text("calling tool"),
939 ContentBlock::Known(KnownBlock::ToolUse {
940 id: "toolu_1".into(),
941 name: "fn".into(),
942 input: json!({}),
943 }),
944 ]));
945 c.messages.push(MessageInput::user(vec![ContentBlock::Known(
946 KnownBlock::ToolResult {
947 tool_use_id: "toolu_1".into(),
948 content: ToolResultContent::Text("result".into()),
949 is_error: None,
950 cache_control: None,
951 },
952 )]));
953 c.push_assistant("here is the answer".repeat(20));
954
955 c.push_user("final");
957
958 c.compact_if_needed();
959
960 for (i, m) in c.messages.iter().enumerate() {
962 if message_has_tool_use(m) {
963 assert!(
964 i + 1 < c.messages.len(),
965 "tool_use at index {i} must be followed by a tool_result"
966 );
967 let next = &c.messages[i + 1];
968 let MessageContent::Blocks(blocks) = &next.content else {
969 panic!("expected blocks");
970 };
971 assert!(
972 blocks
973 .iter()
974 .any(|b| matches!(b, ContentBlock::Known(KnownBlock::ToolResult { .. }))),
975 "next message after tool_use must contain tool_result"
976 );
977 }
978 }
979 }
980
981 #[test]
982 fn drop_oldest_roundtrip_returns_false_when_only_partial_remains() {
983 let mut c = convo();
984 c.push_user("only user, no assistant yet");
985 assert!(!c.drop_oldest_roundtrip());
987 assert_eq!(c.messages.len(), 1);
988 }
989
990 #[test]
991 fn complete_roundtrip_count_excludes_trailing_partial() {
992 let mut c = convo();
993 c.push_user("u1");
994 c.push_assistant("a1");
995 c.push_user("u2");
996 c.push_assistant("a2");
997 c.push_user("u3"); assert_eq!(c.complete_roundtrip_count(), 2);
999 }
1000
1001 #[test]
1002 fn complete_roundtrip_count_skips_assistant_with_tool_use() {
1003 use crate::messages::content::{ContentBlock, KnownBlock};
1004 use serde_json::json;
1005
1006 let mut c = convo();
1007 c.push_user("u1");
1008 c.messages
1009 .push(MessageInput::assistant(vec![ContentBlock::Known(
1010 KnownBlock::ToolUse {
1011 id: "t".into(),
1012 name: "fn".into(),
1013 input: json!({}),
1014 },
1015 )]));
1016 assert_eq!(c.complete_roundtrip_count(), 0);
1019 }
1020
1021 #[cfg(feature = "pricing")]
1022 #[test]
1023 fn cost_uses_pricing_table_per_turn_model() {
1024 let pricing = crate::pricing::PricingTable::default();
1025 let mut c = convo();
1026 c.usage_history.push(UsageRecord {
1027 model: ModelId::SONNET_4_6,
1028 usage: Usage {
1029 input_tokens: 1_000_000,
1030 ..Usage::default()
1031 },
1032 });
1033 c.usage_history.push(UsageRecord {
1034 model: ModelId::HAIKU_4_5,
1035 usage: Usage {
1036 input_tokens: 1_000_000,
1037 ..Usage::default()
1038 },
1039 });
1040 let total = c.cost(&pricing);
1042 assert!((total - 4.0).abs() < 1e-9, "expected $4.00, got ${total}");
1043 }
1044
1045 #[cfg(feature = "pricing")]
1046 #[test]
1047 fn cost_routes_through_cache_creation_and_read_pricing() {
1048 use crate::types::CacheCreationBreakdown;
1053 let pricing = crate::pricing::PricingTable::default();
1054 let mut c = convo();
1055 c.usage_history.push(UsageRecord {
1056 model: ModelId::SONNET_4_6,
1057 usage: Usage {
1058 input_tokens: 0,
1059 output_tokens: 0,
1060 cache_creation: Some(CacheCreationBreakdown {
1061 ephemeral_5m_input_tokens: 1_000_000,
1062 ephemeral_1h_input_tokens: 1_000_000,
1063 }),
1064 cache_read_input_tokens: Some(1_000_000),
1065 ..Usage::default()
1066 },
1067 });
1068
1069 let total = c.cost(&pricing);
1075 assert!(
1076 (total - 10.05).abs() < 1e-9,
1077 "expected $10.05 from cache pricing, got ${total} \
1078 -- if this dropped to ~$0 the cache fields aren't being read",
1079 );
1080 }
1081
1082 #[cfg(feature = "pricing")]
1083 #[test]
1084 fn cost_routes_through_server_tool_use_charges() {
1085 use crate::types::ServerToolUseUsage;
1088 let pricing = crate::pricing::PricingTable::default();
1089 let mut c = convo();
1090 c.usage_history.push(UsageRecord {
1091 model: ModelId::SONNET_4_6,
1092 usage: Usage {
1093 input_tokens: 0,
1094 output_tokens: 0,
1095 server_tool_use: Some(ServerToolUseUsage {
1096 web_search_requests: 5,
1097 }),
1098 ..Usage::default()
1099 },
1100 });
1101 let total = c.cost(&pricing);
1103 assert!(
1104 (total - 0.05).abs() < 1e-9,
1105 "expected $0.05 from 5 web searches, got ${total}",
1106 );
1107 }
1108}
1109
1110#[cfg(all(test, feature = "async"))]
1111mod api_tests {
1112 use super::*;
1113 use pretty_assertions::assert_eq;
1114 use serde_json::json;
1115 use wiremock::matchers::{body_partial_json, method, path};
1116 use wiremock::{Mock, MockServer, ResponseTemplate};
1117
1118 fn client_for(mock: &MockServer) -> Client {
1119 Client::builder()
1120 .api_key("sk-ant-test")
1121 .base_url(mock.uri())
1122 .build()
1123 .unwrap()
1124 }
1125
1126 fn fake_response(text: &str, input: u32, output: u32) -> serde_json::Value {
1127 json!({
1128 "id": "msg_x",
1129 "type": "message",
1130 "role": "assistant",
1131 "content": [{"type": "text", "text": text}],
1132 "model": "claude-sonnet-4-6",
1133 "stop_reason": "end_turn",
1134 "usage": {"input_tokens": input, "output_tokens": output}
1135 })
1136 }
1137
1138 #[tokio::test]
1139 async fn send_appends_assistant_turn_and_records_usage() {
1140 let mock = MockServer::start().await;
1141 Mock::given(method("POST"))
1142 .and(path("/v1/messages"))
1143 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response("hi back", 5, 2)))
1144 .mount(&mock)
1145 .await;
1146
1147 let client = client_for(&mock);
1148 let mut c = Conversation::new(ModelId::SONNET_4_6, 64);
1149 c.push_user("hi");
1150
1151 let r = c.send(&client).await.unwrap();
1152 assert_eq!(r.id, "msg_x");
1153
1154 assert_eq!(c.messages.len(), 2);
1156 assert_eq!(c.messages[1].role, Role::Assistant);
1157
1158 assert_eq!(c.turn_count(), 1);
1160 assert_eq!(c.usage_history[0].model, ModelId::SONNET_4_6);
1161 assert_eq!(c.usage_history[0].usage.input_tokens, 5);
1162 assert_eq!(c.usage_history[0].usage.output_tokens, 2);
1163 }
1164
1165 #[tokio::test]
1166 async fn second_send_includes_first_assistant_turn_in_history() {
1167 let mock = MockServer::start().await;
1168 Mock::given(method("POST"))
1170 .and(path("/v1/messages"))
1171 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response("first", 5, 3)))
1172 .up_to_n_times(1)
1173 .mount(&mock)
1174 .await;
1175 Mock::given(method("POST"))
1177 .and(path("/v1/messages"))
1178 .and(body_partial_json(json!({
1179 "messages": [
1180 {"role": "user", "content": "hi"},
1181 {"role": "assistant", "content": [{"type": "text", "text": "first"}]},
1182 {"role": "user", "content": "again"}
1183 ]
1184 })))
1185 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response("second", 8, 4)))
1186 .mount(&mock)
1187 .await;
1188
1189 let client = client_for(&mock);
1190 let mut c = Conversation::new(ModelId::SONNET_4_6, 64);
1191 c.push_user("hi");
1192 let _ = c.send(&client).await.unwrap();
1193 c.push_user("again");
1194 let _ = c.send(&client).await.unwrap();
1195
1196 assert_eq!(c.turn_count(), 2);
1197 let total = c.cumulative_usage();
1198 assert_eq!(total.input_tokens, 13);
1199 assert_eq!(total.output_tokens, 7);
1200 }
1201
1202 #[tokio::test]
1203 async fn auto_cache_system_sends_cache_control_in_request_body() {
1204 let mock = MockServer::start().await;
1205 Mock::given(method("POST"))
1206 .and(path("/v1/messages"))
1207 .and(body_partial_json(json!({
1208 "system": [{
1209 "type": "text",
1210 "text": "be concise",
1211 "cache_control": {"type": "ephemeral"}
1212 }]
1213 })))
1214 .respond_with(ResponseTemplate::new(200).set_body_json(fake_response("ok", 3, 1)))
1215 .mount(&mock)
1216 .await;
1217
1218 let client = client_for(&mock);
1219 let mut c = Conversation::new(ModelId::SONNET_4_6, 32)
1220 .system("be concise")
1221 .with_cache_breakpoint_on_system();
1222 c.push_user("hello");
1223 let _ = c.send(&client).await.unwrap();
1224 }
1225}