Skip to main content

claude_api/
conversation.rs

1//! Multi-turn conversation helper.
2//!
3//! [`Conversation`] holds the system prompt, message history, default
4//! request settings, and accumulated usage for a multi-turn exchange. Each
5//! call to [`Conversation::send`] runs one turn against the API and
6//! appends the assistant response to the history automatically.
7//!
8//! Optional auto-cache mode (set via [`Conversation::with_auto_cache`] or
9//! [`Conversation::with_cache_breakpoint_on_system`]) applies an ephemeral
10//! `cache_control` breakpoint to the system prompt and optionally the most
11//! recent user turn before each request, so cache hits stay high without
12//! the app needing to think about it.
13//!
14//! [`Conversation`] is `Serialize + Deserialize`, so a session can be
15//! persisted to disk and resumed later.
16//!
17//! Gated on the `conversation` feature.
18
19use serde::{Deserialize, Serialize};
20
21use crate::messages::cache::CacheControl;
22use crate::messages::content::{ContentBlock, KnownBlock};
23use crate::messages::input::{MessageContent, MessageInput, SystemPrompt};
24use crate::messages::mcp::McpServerConfig;
25use crate::messages::metadata::{MessageMetadata, RequestServiceTier};
26use crate::messages::request::CreateMessageRequest;
27use crate::messages::thinking::ThinkingConfig;
28use crate::messages::tools::{Tool, ToolChoice};
29use crate::types::{ModelId, Role, Usage};
30
31#[cfg(feature = "async")]
32use crate::client::Client;
33#[cfg(feature = "async")]
34use crate::error::Result;
35#[cfg(feature = "async")]
36use crate::messages::response::Message;
37
38/// Multi-turn conversation state plus per-request defaults.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[non_exhaustive]
41pub struct Conversation {
42    /// Model used for new turns (also recorded with each `UsageRecord`).
43    pub model: ModelId,
44    /// Maximum output tokens per turn.
45    pub max_tokens: u32,
46
47    /// Optional system prompt; survives across turns.
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub system: Option<SystemPrompt>,
50
51    /// Conversation history, oldest first.
52    #[serde(default)]
53    pub messages: Vec<MessageInput>,
54
55    /// Default sampling temperature.
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    pub temperature: Option<f32>,
58    /// Default nucleus sampling cutoff.
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub top_p: Option<f32>,
61    /// Default top-k cutoff.
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub top_k: Option<u32>,
64    /// Default stop sequences.
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    pub stop_sequences: Option<Vec<String>>,
67
68    /// Tools made available to every turn.
69    #[serde(default, skip_serializing_if = "Vec::is_empty")]
70    pub tools: Vec<Tool>,
71    /// Default tool-use policy.
72    #[serde(default, skip_serializing_if = "Option::is_none")]
73    pub tool_choice: Option<ToolChoice>,
74    /// Default extended-thinking config.
75    #[serde(default, skip_serializing_if = "Option::is_none")]
76    pub thinking: Option<ThinkingConfig>,
77    /// Default request metadata.
78    #[serde(default, skip_serializing_if = "Option::is_none")]
79    pub metadata: Option<MessageMetadata>,
80    /// Default request-side service tier.
81    #[serde(default, skip_serializing_if = "Option::is_none")]
82    pub service_tier: Option<RequestServiceTier>,
83    /// MCP servers exposed on every turn.
84    #[serde(default, skip_serializing_if = "Vec::is_empty")]
85    pub mcp_servers: Vec<McpServerConfig>,
86    /// Container ID for the code-execution built-in tool.
87    #[serde(default, skip_serializing_if = "Option::is_none")]
88    pub container: Option<String>,
89
90    /// Auto-cache configuration applied at request-build time.
91    #[serde(default)]
92    pub auto_cache: AutoCacheMode,
93
94    /// Optional context-compaction policy. When set, oldest user/assistant
95    /// roundtrips are dropped before each `send` once the estimated input
96    /// exceeds [`ContextCompactionPolicy::max_input_tokens`]. See
97    /// [`Self::compact_if_needed`].
98    #[serde(default, skip_serializing_if = "Option::is_none")]
99    pub compaction: Option<ContextCompactionPolicy>,
100
101    /// Per-turn `Usage` records, oldest first. Updated by [`Self::send`].
102    #[serde(default)]
103    pub usage_history: Vec<UsageRecord>,
104}
105
106/// Policy controlling when and how [`Conversation`] drops older turns to
107/// stay under a token budget.
108///
109/// v0.3 first cut implements **truncation**: oldest complete user→assistant
110/// roundtrips are dropped until either the estimated input is under
111/// [`Self::max_input_tokens`] or only [`Self::keep_recent_turns`] complete
112/// roundtrips remain. Tool-use / tool-result pairs are preserved as a unit
113/// (an assistant turn with `tool_use` blocks is never dropped without its
114/// matching `tool_result` user turn and follow-up assistant text).
115///
116/// Token estimation is a fast local heuristic (~4 chars/token); for exact
117/// counts use [`Conversation::estimate_input_tokens`] only as a hint, and
118/// configure `max_input_tokens` with some headroom.
119///
120/// Future work (v0.4): callback-based summarization that replaces a span
121/// of old turns with a single text summary instead of dropping them.
122#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
123#[non_exhaustive]
124pub struct ContextCompactionPolicy {
125    /// Compact when the estimated input would exceed this many tokens.
126    pub max_input_tokens: u32,
127    /// After compaction, keep at least this many complete roundtrips.
128    pub keep_recent_turns: usize,
129}
130
131impl Default for ContextCompactionPolicy {
132    fn default() -> Self {
133        Self {
134            // Generous default; ~50% of the 200k context window so users
135            // hit it before the model does.
136            max_input_tokens: 100_000,
137            keep_recent_turns: 4,
138        }
139    }
140}
141
142/// One turn's `Usage` paired with the model it ran on.
143#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
144#[non_exhaustive]
145pub struct UsageRecord {
146    /// Model that produced this usage record.
147    pub model: ModelId,
148    /// Usage as reported by the API.
149    pub usage: Usage,
150}
151
152/// Automatic cache-breakpoint placement for outgoing requests.
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
154#[serde(rename_all = "snake_case")]
155#[non_exhaustive]
156pub enum AutoCacheMode {
157    /// No automatic cache breakpoints. Default.
158    #[default]
159    Off,
160    /// Apply ephemeral `cache_control` to the last block of the system prompt.
161    System,
162    /// Apply ephemeral `cache_control` to the system prompt's last block AND
163    /// to the most recent user turn's last block.
164    SystemAndLastUser,
165}
166
167impl Conversation {
168    /// Begin a new conversation with the given model and per-turn `max_tokens`.
169    #[must_use]
170    pub fn new(model: impl Into<ModelId>, max_tokens: u32) -> Self {
171        Self {
172            model: model.into(),
173            max_tokens,
174            system: None,
175            messages: Vec::new(),
176            temperature: None,
177            top_p: None,
178            top_k: None,
179            stop_sequences: None,
180            tools: Vec::new(),
181            tool_choice: None,
182            thinking: None,
183            metadata: None,
184            service_tier: None,
185            mcp_servers: Vec::new(),
186            container: None,
187            auto_cache: AutoCacheMode::Off,
188            compaction: None,
189            usage_history: Vec::new(),
190        }
191    }
192
193    /// Attach a context-compaction policy. Without one, conversation
194    /// history grows unbounded.
195    #[must_use]
196    pub fn with_compaction(mut self, policy: ContextCompactionPolicy) -> Self {
197        self.compaction = Some(policy);
198        self
199    }
200
201    /// Set the system prompt.
202    #[must_use]
203    pub fn system(mut self, s: impl Into<SystemPrompt>) -> Self {
204        self.system = Some(s.into());
205        self
206    }
207
208    /// Shorthand for setting [`AutoCacheMode::System`] via
209    /// [`Self::with_auto_cache`].
210    #[must_use]
211    pub fn with_cache_breakpoint_on_system(self) -> Self {
212        self.with_auto_cache(AutoCacheMode::System)
213    }
214
215    /// Set the auto-cache mode. See [`AutoCacheMode`].
216    #[must_use]
217    pub fn with_auto_cache(mut self, mode: AutoCacheMode) -> Self {
218        self.auto_cache = mode;
219        self
220    }
221
222    /// Replace the tool list.
223    #[must_use]
224    pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
225        self.tools = tools;
226        self
227    }
228
229    /// Set the tool-use policy.
230    #[must_use]
231    pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
232        self.tool_choice = Some(choice);
233        self
234    }
235
236    /// Enable extended thinking.
237    #[must_use]
238    pub fn with_thinking(mut self, t: ThinkingConfig) -> Self {
239        self.thinking = Some(t);
240        self
241    }
242
243    /// Set the sampling temperature default.
244    #[must_use]
245    pub fn with_temperature(mut self, t: f32) -> Self {
246        self.temperature = Some(t);
247        self
248    }
249
250    /// Append a user-authored turn.
251    pub fn push_user(&mut self, content: impl Into<MessageContent>) {
252        self.messages.push(MessageInput::user(content));
253    }
254
255    /// Append an assistant-authored turn (typically used for prefill before
256    /// the first send).
257    pub fn push_assistant(&mut self, content: impl Into<MessageContent>) {
258        self.messages.push(MessageInput::assistant(content));
259    }
260
261    /// Remove and return the most recent message. Useful when aborting a
262    /// turn before sending.
263    pub fn pop(&mut self) -> Option<MessageInput> {
264        self.messages.pop()
265    }
266
267    /// Number of completed turns (request/response cycles via [`Self::send`]).
268    #[must_use]
269    pub fn turn_count(&self) -> usize {
270        self.usage_history.len()
271    }
272
273    /// Sum of every recorded `Usage` for this conversation.
274    #[must_use]
275    pub fn cumulative_usage(&self) -> Usage {
276        self.usage_history
277            .iter()
278            .fold(Usage::default(), |mut acc, r| {
279                acc.input_tokens = acc.input_tokens.saturating_add(r.usage.input_tokens);
280                acc.output_tokens = acc.output_tokens.saturating_add(r.usage.output_tokens);
281                acc.cache_creation_input_tokens = sum_opt(
282                    acc.cache_creation_input_tokens,
283                    r.usage.cache_creation_input_tokens,
284                );
285                acc.cache_read_input_tokens =
286                    sum_opt(acc.cache_read_input_tokens, r.usage.cache_read_input_tokens);
287                acc
288            })
289    }
290
291    /// Total cost in USD across all recorded turns, using the given pricing
292    /// table to look up rates for each turn's model.
293    #[cfg(feature = "pricing")]
294    #[cfg_attr(docsrs, doc(cfg(feature = "pricing")))]
295    #[must_use]
296    pub fn cost(&self, pricing: &crate::pricing::PricingTable) -> f64 {
297        self.usage_history
298            .iter()
299            .map(|r| pricing.cost(&r.model, &r.usage))
300            .sum()
301    }
302
303    /// Heuristic estimate of how many input tokens this conversation
304    /// would consume on the next request.
305    ///
306    /// Uses a fast local approximation (~4 characters per token), summed
307    /// across the system prompt, all messages, and tool definitions.
308    /// Adequate for compaction decisions; for exact billing-quality
309    /// numbers call `count_tokens` via the API.
310    #[must_use]
311    pub fn estimate_input_tokens(&self) -> u32 {
312        let mut total = 0u32;
313        if let Some(s) = &self.system {
314            total = total.saturating_add(estimate_system_tokens(s));
315        }
316        for msg in &self.messages {
317            total = total.saturating_add(estimate_message_tokens(msg));
318        }
319        // Each tool's schema serialized to JSON.
320        for tool in &self.tools {
321            if let Ok(s) = serde_json::to_string(tool) {
322                total = total.saturating_add(estimate_text_tokens(&s));
323            }
324        }
325        total
326    }
327
328    /// Number of complete user→assistant roundtrips in the history.
329    /// A "complete" roundtrip ends with an Assistant turn that has no
330    /// outstanding `tool_use` blocks and is not the most recent message.
331    #[must_use]
332    pub fn complete_roundtrip_count(&self) -> usize {
333        let last_idx = self.messages.len().saturating_sub(1);
334        self.messages
335            .iter()
336            .enumerate()
337            .filter(|(i, m)| *i < last_idx && m.role == Role::Assistant && !message_has_tool_use(m))
338            .count()
339    }
340
341    /// If a [`ContextCompactionPolicy`] is set and the estimated input
342    /// exceeds the configured budget, drop oldest complete roundtrips
343    /// until either the estimate fits or `keep_recent_turns` remain.
344    ///
345    /// Tool-use / tool-result pairs are preserved as a unit. Returns
346    /// `true` if any messages were dropped.
347    pub fn compact_if_needed(&mut self) -> bool {
348        let Some(policy) = self.compaction.clone() else {
349            return false;
350        };
351
352        let initial = self.estimate_input_tokens();
353        if initial <= policy.max_input_tokens {
354            return false;
355        }
356
357        let initial_msg_count = self.messages.len();
358        loop {
359            if self.estimate_input_tokens() <= policy.max_input_tokens {
360                break;
361            }
362            if self.complete_roundtrip_count() <= policy.keep_recent_turns {
363                break;
364            }
365            if !self.drop_oldest_roundtrip() {
366                break;
367            }
368        }
369
370        let dropped = initial_msg_count - self.messages.len();
371        if dropped > 0 {
372            tracing::warn!(
373                initial_estimate = initial,
374                final_estimate = self.estimate_input_tokens(),
375                messages_dropped = dropped,
376                roundtrips_remaining = self.complete_roundtrip_count(),
377                "claude-api: context compaction applied",
378            );
379            true
380        } else {
381            false
382        }
383    }
384
385    /// Internal: drop everything from index 0 through the first
386    /// "end-of-roundtrip" assistant message (inclusive). Returns false
387    /// if there is no complete roundtrip to drop without breaking
388    /// tool-use/tool-result pair integrity.
389    fn drop_oldest_roundtrip(&mut self) -> bool {
390        let last_idx = self.messages.len().saturating_sub(1);
391        let drop_to = self.messages.iter().enumerate().position(|(i, m)| {
392            i < last_idx && m.role == Role::Assistant && !message_has_tool_use(m)
393        });
394        match drop_to {
395            Some(idx) => {
396                self.messages.drain(0..=idx);
397                true
398            }
399            None => false,
400        }
401    }
402
403    /// Build the [`CreateMessageRequest`] this conversation would send next,
404    /// including any auto-cache breakpoints. Pure -- does not touch state.
405    ///
406    /// # Panics
407    ///
408    /// Will not panic in practice: the conversation always carries `model`
409    /// and `max_tokens`, so the inner builder's `build()` always succeeds.
410    #[must_use]
411    pub fn build_request(&self) -> CreateMessageRequest {
412        let mut messages = self.messages.clone();
413        let mut system = self.system.clone();
414
415        match self.auto_cache {
416            AutoCacheMode::Off => {}
417            AutoCacheMode::System => {
418                cache_breakpoint_on_system(&mut system);
419            }
420            AutoCacheMode::SystemAndLastUser => {
421                cache_breakpoint_on_system(&mut system);
422                cache_breakpoint_on_last_user(&mut messages);
423            }
424        }
425
426        let mut builder = CreateMessageRequest::builder()
427            .model(self.model.clone())
428            .max_tokens(self.max_tokens)
429            .messages(messages);
430
431        if let Some(s) = system {
432            builder = builder.system(s);
433        }
434        if let Some(t) = self.temperature {
435            builder = builder.temperature(t);
436        }
437        if let Some(p) = self.top_p {
438            builder = builder.top_p(p);
439        }
440        if let Some(k) = self.top_k {
441            builder = builder.top_k(k);
442        }
443        if let Some(seqs) = &self.stop_sequences {
444            builder = builder.stop_sequences(seqs.clone());
445        }
446        if !self.tools.is_empty() {
447            builder = builder.tools(self.tools.clone());
448        }
449        if let Some(c) = self.tool_choice.clone() {
450            builder = builder.tool_choice(c);
451        }
452        if let Some(t) = self.thinking {
453            builder = builder.thinking(t);
454        }
455        if let Some(m) = self.metadata.clone() {
456            builder = builder.metadata(m);
457        }
458        if let Some(t) = self.service_tier {
459            builder = builder.service_tier(t);
460        }
461        if !self.mcp_servers.is_empty() {
462            builder = builder.mcp_servers(self.mcp_servers.clone());
463        }
464        if let Some(c) = self.container.clone() {
465            builder = builder.container(c);
466        }
467
468        builder
469            .build()
470            .expect("conversation::build_request always provides model + max_tokens")
471    }
472
473    /// Drive one turn against the API. Appends the assistant response to
474    /// the history and records the usage.
475    #[cfg(feature = "async")]
476    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
477    pub async fn send(&mut self, client: &Client) -> Result<Message> {
478        self.send_with_beta(client, &[]).await
479    }
480
481    /// Like [`Self::send`] but with per-request beta headers merged in.
482    #[cfg(feature = "async")]
483    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
484    pub async fn send_with_beta(&mut self, client: &Client, betas: &[&str]) -> Result<Message> {
485        self.compact_if_needed();
486        let request = self.build_request();
487        let response = client.messages().create_with_beta(request, betas).await?;
488        self.usage_history.push(UsageRecord {
489            model: self.model.clone(),
490            usage: response.usage.clone(),
491        });
492        // Append the assistant turn so subsequent sends see it.
493        self.messages
494            .push(MessageInput::assistant(response.content.clone()));
495        Ok(response)
496    }
497}
498
499// ---- Token estimation helpers -----------------------------------------------
500
501fn estimate_text_tokens(s: &str) -> u32 {
502    // Anthropic averages ~3.5-4 chars/token for English. Round up so we
503    // err on the conservative (over-estimating) side; better to compact
504    // a turn early than to overshoot the model's real budget.
505    let chars = u32::try_from(s.chars().count()).unwrap_or(u32::MAX);
506    chars.div_ceil(4)
507}
508
509fn estimate_system_tokens(s: &SystemPrompt) -> u32 {
510    match s {
511        SystemPrompt::Text(t) => estimate_text_tokens(t),
512        SystemPrompt::Blocks(blocks) => blocks.iter().map(estimate_block_tokens).sum(),
513    }
514}
515
516fn estimate_message_tokens(msg: &MessageInput) -> u32 {
517    // ~4 tokens of role overhead per message (heuristic; varies in practice).
518    let body = match &msg.content {
519        MessageContent::Text(s) => estimate_text_tokens(s),
520        MessageContent::Blocks(blocks) => blocks.iter().map(estimate_block_tokens).sum(),
521    };
522    body.saturating_add(4)
523}
524
525fn estimate_block_tokens(block: &ContentBlock) -> u32 {
526    use crate::messages::content::ToolResultContent;
527
528    match block {
529        ContentBlock::Known(KnownBlock::Text { text, .. }) => estimate_text_tokens(text),
530        ContentBlock::Known(KnownBlock::Thinking { thinking, .. }) => {
531            estimate_text_tokens(thinking)
532        }
533        ContentBlock::Known(KnownBlock::ToolUse { name, input, .. }) => {
534            // name + JSON-stringified input.
535            estimate_text_tokens(name).saturating_add(estimate_text_tokens(&input.to_string()))
536        }
537        ContentBlock::Known(KnownBlock::ServerToolUse { name, input, .. }) => {
538            estimate_text_tokens(name).saturating_add(estimate_text_tokens(&input.to_string()))
539        }
540        ContentBlock::Known(KnownBlock::ToolResult { content, .. }) => match content {
541            ToolResultContent::Text(s) => estimate_text_tokens(s),
542            ToolResultContent::Blocks(b) => b.iter().map(estimate_block_tokens).sum(),
543        },
544        // Images, documents, web_search results: significant per-asset cost
545        // not derivable from JSON length alone. Use a flat rough estimate so
546        // compaction kicks in even when the conversation is image-heavy.
547        ContentBlock::Known(KnownBlock::Image { .. }) => 1500,
548        ContentBlock::Known(KnownBlock::Document { .. }) => 2000,
549        ContentBlock::Known(KnownBlock::WebSearchToolResult { .. }) => 500,
550        ContentBlock::Known(KnownBlock::RedactedThinking { data, .. }) => {
551            estimate_text_tokens(data)
552        }
553        ContentBlock::Other(v) => estimate_text_tokens(&v.to_string()),
554    }
555}
556
557fn message_has_tool_use(msg: &MessageInput) -> bool {
558    match &msg.content {
559        MessageContent::Text(_) => false,
560        MessageContent::Blocks(blocks) => blocks.iter().any(|b| {
561            matches!(
562                b,
563                ContentBlock::Known(KnownBlock::ToolUse { .. } | KnownBlock::ServerToolUse { .. })
564            )
565        }),
566    }
567}
568
569fn sum_opt(a: Option<u32>, b: Option<u32>) -> Option<u32> {
570    match (a, b) {
571        (None, None) => None,
572        (Some(x), None) | (None, Some(x)) => Some(x),
573        (Some(x), Some(y)) => Some(x.saturating_add(y)),
574    }
575}
576
577fn cache_breakpoint_on_system(system: &mut Option<SystemPrompt>) {
578    let Some(s) = system.take() else { return };
579    let blocks = match s {
580        SystemPrompt::Text(text) => vec![ContentBlock::Known(KnownBlock::Text {
581            text,
582            cache_control: Some(CacheControl::ephemeral()),
583            citations: None,
584        })],
585        SystemPrompt::Blocks(mut blocks) => {
586            apply_cache_control_to_last_block(&mut blocks);
587            blocks
588        }
589    };
590    *system = Some(SystemPrompt::Blocks(blocks));
591}
592
593fn cache_breakpoint_on_last_user(messages: &mut [MessageInput]) {
594    let Some(idx) = messages.iter().rposition(|m| m.role == Role::User) else {
595        return;
596    };
597    let target = &mut messages[idx];
598    match &mut target.content {
599        MessageContent::Text(text) => {
600            target.content = MessageContent::Blocks(vec![ContentBlock::Known(KnownBlock::Text {
601                text: std::mem::take(text),
602                cache_control: Some(CacheControl::ephemeral()),
603                citations: None,
604            })]);
605        }
606        MessageContent::Blocks(blocks) => {
607            apply_cache_control_to_last_block(blocks);
608        }
609    }
610}
611
612fn apply_cache_control_to_last_block(blocks: &mut [ContentBlock]) {
613    let Some(last) = blocks.last_mut() else {
614        return;
615    };
616    // Collapsed `if let ... { match ... }` into a single nested pattern.
617    // Variants without a `cache_control` field (ToolUse, Thinking,
618    // RedactedThinking, ServerToolUse, WebSearchToolResult) and
619    // `ContentBlock::Other` simply don't match -- the cache hint is silently
620    // skipped, which is the right behavior for an auto-cache helper.
621    if let ContentBlock::Known(
622        KnownBlock::Text { cache_control, .. }
623        | KnownBlock::Image { cache_control, .. }
624        | KnownBlock::Document { cache_control, .. }
625        | KnownBlock::ToolResult { cache_control, .. },
626    ) = last
627    {
628        *cache_control = Some(CacheControl::ephemeral());
629    }
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use pretty_assertions::assert_eq;
636    use serde_json::json;
637
638    fn convo() -> Conversation {
639        Conversation::new(ModelId::SONNET_4_6, 256)
640    }
641
642    // ---- basic state + serde -----------------------------------------------
643
644    #[test]
645    fn new_starts_empty() {
646        let c = convo();
647        assert!(c.messages.is_empty());
648        assert!(c.usage_history.is_empty());
649        assert_eq!(c.turn_count(), 0);
650    }
651
652    #[test]
653    fn push_appends_to_history() {
654        let mut c = convo();
655        c.push_user("hi");
656        c.push_assistant("hello");
657        c.push_user("how are you?");
658        assert_eq!(c.messages.len(), 3);
659        assert_eq!(c.messages[0].role, Role::User);
660        assert_eq!(c.messages[1].role, Role::Assistant);
661    }
662
663    #[test]
664    fn pop_removes_last() {
665        let mut c = convo();
666        c.push_user("first");
667        c.push_user("second");
668        let popped = c.pop().unwrap();
669        let MessageContent::Text(t) = popped.content else {
670            panic!("expected Text content");
671        };
672        assert_eq!(t, "second");
673        assert_eq!(c.messages.len(), 1);
674    }
675
676    #[test]
677    fn cumulative_usage_sums_across_turns() {
678        let mut c = convo();
679        c.usage_history.push(UsageRecord {
680            model: ModelId::SONNET_4_6,
681            usage: Usage {
682                input_tokens: 100,
683                output_tokens: 50,
684                cache_creation_input_tokens: Some(20),
685                cache_read_input_tokens: Some(30),
686                ..Usage::default()
687            },
688        });
689        c.usage_history.push(UsageRecord {
690            model: ModelId::SONNET_4_6,
691            usage: Usage {
692                input_tokens: 200,
693                output_tokens: 80,
694                cache_read_input_tokens: Some(70),
695                ..Usage::default()
696            },
697        });
698        let total = c.cumulative_usage();
699        assert_eq!(total.input_tokens, 300);
700        assert_eq!(total.output_tokens, 130);
701        assert_eq!(total.cache_creation_input_tokens, Some(20));
702        assert_eq!(total.cache_read_input_tokens, Some(100));
703    }
704
705    #[test]
706    fn serde_round_trip_preserves_state() {
707        let mut original = Conversation::new(ModelId::OPUS_4_7, 512)
708            .system("be concise")
709            .with_cache_breakpoint_on_system()
710            .with_temperature(0.5);
711        original.push_user("hi");
712        original.push_assistant("hello");
713        original.usage_history.push(UsageRecord {
714            model: ModelId::OPUS_4_7,
715            usage: Usage {
716                input_tokens: 5,
717                output_tokens: 3,
718                ..Usage::default()
719            },
720        });
721
722        let json = serde_json::to_string(&original).unwrap();
723        let parsed: Conversation = serde_json::from_str(&json).unwrap();
724
725        assert_eq!(parsed.model, ModelId::OPUS_4_7);
726        assert_eq!(parsed.max_tokens, 512);
727        assert_eq!(parsed.auto_cache, AutoCacheMode::System);
728        assert_eq!(parsed.temperature, Some(0.5));
729        assert_eq!(parsed.messages.len(), 2);
730        assert_eq!(parsed.usage_history.len(), 1);
731        assert_eq!(parsed.turn_count(), 1);
732    }
733
734    // ---- request building --------------------------------------------------
735
736    #[test]
737    fn build_request_includes_basic_fields() {
738        let mut c = convo().system("be concise").with_temperature(0.25);
739        c.push_user("hello");
740        let req = c.build_request();
741        let v = serde_json::to_value(&req).unwrap();
742        assert_eq!(v["model"], "claude-sonnet-4-6");
743        assert_eq!(v["max_tokens"], 256);
744        assert_eq!(v["system"], "be concise");
745        assert_eq!(v["temperature"], 0.25);
746        assert_eq!(v["messages"][0]["role"], "user");
747    }
748
749    #[test]
750    fn build_request_with_auto_cache_system() {
751        let mut c = convo()
752            .system("you are concise")
753            .with_cache_breakpoint_on_system();
754        c.push_user("hi");
755        let v = serde_json::to_value(c.build_request()).unwrap();
756        assert_eq!(
757            v["system"],
758            json!([{
759                "type": "text",
760                "text": "you are concise",
761                "cache_control": {"type": "ephemeral"}
762            }])
763        );
764        // Last user message should NOT be cached in this mode.
765        assert_eq!(v["messages"][0]["content"], "hi");
766    }
767
768    #[test]
769    fn build_request_with_auto_cache_system_and_last_user() {
770        let mut c = convo()
771            .system("you are concise")
772            .with_auto_cache(AutoCacheMode::SystemAndLastUser);
773        c.push_user("first");
774        c.push_assistant("response");
775        c.push_user("follow-up");
776        let v = serde_json::to_value(c.build_request()).unwrap();
777
778        // System cached
779        assert_eq!(v["system"][0]["cache_control"]["type"], "ephemeral");
780
781        // Last user (index 2) cached as a single text block
782        let msgs = v["messages"].as_array().unwrap();
783        assert_eq!(msgs.len(), 3);
784        assert_eq!(msgs[2]["role"], "user");
785        assert_eq!(msgs[2]["content"][0]["type"], "text");
786        assert_eq!(msgs[2]["content"][0]["text"], "follow-up");
787        assert_eq!(msgs[2]["content"][0]["cache_control"]["type"], "ephemeral");
788
789        // Earlier user message (index 0) untouched.
790        assert_eq!(msgs[0]["content"], "first");
791    }
792
793    #[test]
794    fn build_request_auto_cache_off_does_nothing() {
795        let mut c = convo().system("plain");
796        c.push_user("hi");
797        let v = serde_json::to_value(c.build_request()).unwrap();
798        // System remains a plain string.
799        assert_eq!(v["system"], "plain");
800        // User message remains a plain string.
801        assert_eq!(v["messages"][0]["content"], "hi");
802    }
803
804    #[test]
805    fn build_request_does_not_mutate_self() {
806        let mut c = convo().system("orig").with_cache_breakpoint_on_system();
807        c.push_user("hi");
808        let _ = c.build_request();
809        // After build, the conversation's stored system is still the plain
810        // text -- auto-cache is applied at request-build time, not stored.
811        let Some(SystemPrompt::Text(t)) = &c.system else {
812            panic!("system should still be Text, got {:?}", c.system);
813        };
814        assert_eq!(t, "orig");
815        let MessageContent::Text(t) = &c.messages[0].content else {
816            panic!(
817                "user content should still be Text, got {:?}",
818                c.messages[0].content
819            );
820        };
821        assert_eq!(t, "hi");
822    }
823
824    // ---- pricing integration -----------------------------------------------
825
826    // ---- compaction --------------------------------------------------------
827
828    #[test]
829    fn estimate_input_tokens_grows_with_message_size() {
830        let mut c = convo();
831        c.push_user("hi");
832        let small = c.estimate_input_tokens();
833
834        let mut c2 = convo();
835        c2.push_user("a".repeat(1000));
836        let large = c2.estimate_input_tokens();
837
838        assert!(large > small * 10, "{large} should dwarf {small}");
839    }
840
841    #[test]
842    fn compact_if_needed_no_op_without_policy() {
843        let mut c = convo();
844        for i in 0..10 {
845            c.push_user(format!("user {i}"));
846            c.push_assistant(format!("assistant {i}"));
847        }
848        let before = c.messages.len();
849        assert!(!c.compact_if_needed());
850        assert_eq!(c.messages.len(), before);
851    }
852
853    #[test]
854    fn compact_if_needed_no_op_when_under_threshold() {
855        let mut c = convo().with_compaction(ContextCompactionPolicy {
856            max_input_tokens: 100_000, // huge threshold
857            keep_recent_turns: 1,
858        });
859        c.push_user("short");
860        c.push_assistant("short");
861        assert!(!c.compact_if_needed());
862        assert_eq!(c.messages.len(), 2);
863    }
864
865    #[test]
866    fn compact_if_needed_drops_oldest_roundtrips_above_threshold() {
867        // Tight budget so compaction must fire. Each turn is ~25 tokens of
868        // text + 4 tokens of role overhead.
869        let mut c = convo().with_compaction(ContextCompactionPolicy {
870            max_input_tokens: 60,
871            keep_recent_turns: 1,
872        });
873        for i in 0..6 {
874            c.push_user(format!(
875                "this is user message number {i} with reasonable length"
876            ));
877            c.push_assistant(format!(
878                "this is assistant response number {i} with similar length"
879            ));
880        }
881        // Add a trailing user (the "next question") so we have a partial roundtrip.
882        c.push_user("current question");
883
884        let before_count = c.messages.len();
885        assert!(c.compact_if_needed(), "should have compacted");
886        assert!(
887            c.messages.len() < before_count,
888            "expected drop; got {} -> {}",
889            before_count,
890            c.messages.len()
891        );
892        // Most recent messages preserved.
893        let MessageContent::Text(last_user) = &c.messages.last().unwrap().content else {
894            panic!("expected text");
895        };
896        assert_eq!(last_user, "current question");
897    }
898
899    #[test]
900    fn compact_if_needed_respects_keep_recent_turns() {
901        // Even if over threshold, we must keep at least N complete roundtrips.
902        let mut c = convo().with_compaction(ContextCompactionPolicy {
903            max_input_tokens: 1, // impossibly tight
904            keep_recent_turns: 2,
905        });
906        for i in 0..5 {
907            c.push_user(format!("u{i}"));
908            c.push_assistant(format!("a{i}"));
909        }
910        c.push_user("trailing");
911
912        c.compact_if_needed();
913        // Should have exactly 2 complete roundtrips remaining + the trailing user.
914        assert_eq!(c.complete_roundtrip_count(), 2);
915        let MessageContent::Text(last) = &c.messages.last().unwrap().content else {
916            panic!("expected text");
917        };
918        assert_eq!(last, "trailing");
919    }
920
921    #[test]
922    fn compact_if_needed_preserves_tool_use_tool_result_pairs() {
923        use crate::messages::content::{ContentBlock, KnownBlock, ToolResultContent};
924        use serde_json::json;
925
926        let mut c = convo().with_compaction(ContextCompactionPolicy {
927            max_input_tokens: 30,
928            keep_recent_turns: 0, // free to drop everything droppable
929        });
930
931        // Roundtrip 1: simple
932        c.push_user("first user".repeat(20)); // padded to push estimate up
933        c.push_assistant("first answer".repeat(20));
934
935        // Roundtrip 2: tool sequence
936        c.push_user("second user".repeat(20));
937        c.messages.push(MessageInput::assistant(vec![
938            ContentBlock::text("calling tool"),
939            ContentBlock::Known(KnownBlock::ToolUse {
940                id: "toolu_1".into(),
941                name: "fn".into(),
942                input: json!({}),
943            }),
944        ]));
945        c.messages.push(MessageInput::user(vec![ContentBlock::Known(
946            KnownBlock::ToolResult {
947                tool_use_id: "toolu_1".into(),
948                content: ToolResultContent::Text("result".into()),
949                is_error: None,
950                cache_control: None,
951            },
952        )]));
953        c.push_assistant("here is the answer".repeat(20));
954
955        // Trailing user.
956        c.push_user("final");
957
958        c.compact_if_needed();
959
960        // After compaction, no tool_use should be left without its tool_result.
961        for (i, m) in c.messages.iter().enumerate() {
962            if message_has_tool_use(m) {
963                assert!(
964                    i + 1 < c.messages.len(),
965                    "tool_use at index {i} must be followed by a tool_result"
966                );
967                let next = &c.messages[i + 1];
968                let MessageContent::Blocks(blocks) = &next.content else {
969                    panic!("expected blocks");
970                };
971                assert!(
972                    blocks
973                        .iter()
974                        .any(|b| matches!(b, ContentBlock::Known(KnownBlock::ToolResult { .. }))),
975                    "next message after tool_use must contain tool_result"
976                );
977            }
978        }
979    }
980
981    #[test]
982    fn drop_oldest_roundtrip_returns_false_when_only_partial_remains() {
983        let mut c = convo();
984        c.push_user("only user, no assistant yet");
985        // No complete roundtrip; can't drop.
986        assert!(!c.drop_oldest_roundtrip());
987        assert_eq!(c.messages.len(), 1);
988    }
989
990    #[test]
991    fn complete_roundtrip_count_excludes_trailing_partial() {
992        let mut c = convo();
993        c.push_user("u1");
994        c.push_assistant("a1");
995        c.push_user("u2");
996        c.push_assistant("a2");
997        c.push_user("u3"); // trailing partial
998        assert_eq!(c.complete_roundtrip_count(), 2);
999    }
1000
1001    #[test]
1002    fn complete_roundtrip_count_skips_assistant_with_tool_use() {
1003        use crate::messages::content::{ContentBlock, KnownBlock};
1004        use serde_json::json;
1005
1006        let mut c = convo();
1007        c.push_user("u1");
1008        c.messages
1009            .push(MessageInput::assistant(vec![ContentBlock::Known(
1010                KnownBlock::ToolUse {
1011                    id: "t".into(),
1012                    name: "fn".into(),
1013                    input: json!({}),
1014                },
1015            )]));
1016        // The assistant turn has tool_use; not the end of a roundtrip.
1017        // Without a follow-up, complete count is 0.
1018        assert_eq!(c.complete_roundtrip_count(), 0);
1019    }
1020
1021    #[cfg(feature = "pricing")]
1022    #[test]
1023    fn cost_uses_pricing_table_per_turn_model() {
1024        let pricing = crate::pricing::PricingTable::default();
1025        let mut c = convo();
1026        c.usage_history.push(UsageRecord {
1027            model: ModelId::SONNET_4_6,
1028            usage: Usage {
1029                input_tokens: 1_000_000,
1030                ..Usage::default()
1031            },
1032        });
1033        c.usage_history.push(UsageRecord {
1034            model: ModelId::HAIKU_4_5,
1035            usage: Usage {
1036                input_tokens: 1_000_000,
1037                ..Usage::default()
1038            },
1039        });
1040        // Sonnet 4.6 = $3/MTok input, Haiku 4.5 = $1/MTok input -> $4.0
1041        let total = c.cost(&pricing);
1042        assert!((total - 4.0).abs() < 1e-9, "expected $4.00, got ${total}");
1043    }
1044
1045    #[test]
1046    fn cost_routes_through_cache_creation_and_read_pricing() {
1047        // Regression test: verify Conversation::cost picks up the
1048        // separate cache_creation / cache_read pricing fields. A
1049        // cache-heavy turn that drops these would under-report cost by
1050        // up to ~90% (cache reads are 0.1x input rate).
1051        use crate::types::CacheCreationBreakdown;
1052        let pricing = crate::pricing::PricingTable::default();
1053        let mut c = convo();
1054        c.usage_history.push(UsageRecord {
1055            model: ModelId::SONNET_4_6,
1056            usage: Usage {
1057                input_tokens: 0,
1058                output_tokens: 0,
1059                cache_creation: Some(CacheCreationBreakdown {
1060                    ephemeral_5m_input_tokens: 1_000_000,
1061                    ephemeral_1h_input_tokens: 1_000_000,
1062                }),
1063                cache_read_input_tokens: Some(1_000_000),
1064                ..Usage::default()
1065            },
1066        });
1067
1068        // Sonnet 4.6 input rate = $3/MTok. Cache rates derived:
1069        //   5m create = 1.25x = $3.75/MTok -> $3.75
1070        //   1h create = 2.0x  = $6.00/MTok -> $6.00
1071        //   read     = 0.1x  = $0.30/MTok -> $0.30
1072        // Sum = $10.05.
1073        let total = c.cost(&pricing);
1074        assert!(
1075            (total - 10.05).abs() < 1e-9,
1076            "expected $10.05 from cache pricing, got ${total} \
1077             -- if this dropped to ~$0 the cache fields aren't being read",
1078        );
1079    }
1080
1081    #[test]
1082    fn cost_routes_through_server_tool_use_charges() {
1083        // Regression test: web_search_requests should bill per-request,
1084        // not get silently dropped. Pairs with the cache test above.
1085        use crate::types::ServerToolUseUsage;
1086        let pricing = crate::pricing::PricingTable::default();
1087        let mut c = convo();
1088        c.usage_history.push(UsageRecord {
1089            model: ModelId::SONNET_4_6,
1090            usage: Usage {
1091                input_tokens: 0,
1092                output_tokens: 0,
1093                server_tool_use: Some(ServerToolUseUsage {
1094                    web_search_requests: 5,
1095                }),
1096                ..Usage::default()
1097            },
1098        });
1099        // Default web_search rate = $0.01/request -> $0.05.
1100        let total = c.cost(&pricing);
1101        assert!(
1102            (total - 0.05).abs() < 1e-9,
1103            "expected $0.05 from 5 web searches, got ${total}",
1104        );
1105    }
1106}
1107
1108#[cfg(all(test, feature = "async"))]
1109mod api_tests {
1110    use super::*;
1111    use pretty_assertions::assert_eq;
1112    use serde_json::json;
1113    use wiremock::matchers::{body_partial_json, method, path};
1114    use wiremock::{Mock, MockServer, ResponseTemplate};
1115
1116    fn client_for(mock: &MockServer) -> Client {
1117        Client::builder()
1118            .api_key("sk-ant-test")
1119            .base_url(mock.uri())
1120            .build()
1121            .unwrap()
1122    }
1123
1124    fn fake_response(text: &str, input: u32, output: u32) -> serde_json::Value {
1125        json!({
1126            "id": "msg_x",
1127            "type": "message",
1128            "role": "assistant",
1129            "content": [{"type": "text", "text": text}],
1130            "model": "claude-sonnet-4-6",
1131            "stop_reason": "end_turn",
1132            "usage": {"input_tokens": input, "output_tokens": output}
1133        })
1134    }
1135
1136    #[tokio::test]
1137    async fn send_appends_assistant_turn_and_records_usage() {
1138        let mock = MockServer::start().await;
1139        Mock::given(method("POST"))
1140            .and(path("/v1/messages"))
1141            .respond_with(ResponseTemplate::new(200).set_body_json(fake_response("hi back", 5, 2)))
1142            .mount(&mock)
1143            .await;
1144
1145        let client = client_for(&mock);
1146        let mut c = Conversation::new(ModelId::SONNET_4_6, 64);
1147        c.push_user("hi");
1148
1149        let r = c.send(&client).await.unwrap();
1150        assert_eq!(r.id, "msg_x");
1151
1152        // History now has user + assistant.
1153        assert_eq!(c.messages.len(), 2);
1154        assert_eq!(c.messages[1].role, Role::Assistant);
1155
1156        // Usage was recorded with the conversation's model.
1157        assert_eq!(c.turn_count(), 1);
1158        assert_eq!(c.usage_history[0].model, ModelId::SONNET_4_6);
1159        assert_eq!(c.usage_history[0].usage.input_tokens, 5);
1160        assert_eq!(c.usage_history[0].usage.output_tokens, 2);
1161    }
1162
1163    #[tokio::test]
1164    async fn second_send_includes_first_assistant_turn_in_history() {
1165        let mock = MockServer::start().await;
1166        // First call -- any user prompt OK.
1167        Mock::given(method("POST"))
1168            .and(path("/v1/messages"))
1169            .respond_with(ResponseTemplate::new(200).set_body_json(fake_response("first", 5, 3)))
1170            .up_to_n_times(1)
1171            .mount(&mock)
1172            .await;
1173        // Second call must contain the first assistant turn AND the new user turn.
1174        Mock::given(method("POST"))
1175            .and(path("/v1/messages"))
1176            .and(body_partial_json(json!({
1177                "messages": [
1178                    {"role": "user", "content": "hi"},
1179                    {"role": "assistant", "content": [{"type": "text", "text": "first"}]},
1180                    {"role": "user", "content": "again"}
1181                ]
1182            })))
1183            .respond_with(ResponseTemplate::new(200).set_body_json(fake_response("second", 8, 4)))
1184            .mount(&mock)
1185            .await;
1186
1187        let client = client_for(&mock);
1188        let mut c = Conversation::new(ModelId::SONNET_4_6, 64);
1189        c.push_user("hi");
1190        let _ = c.send(&client).await.unwrap();
1191        c.push_user("again");
1192        let _ = c.send(&client).await.unwrap();
1193
1194        assert_eq!(c.turn_count(), 2);
1195        let total = c.cumulative_usage();
1196        assert_eq!(total.input_tokens, 13);
1197        assert_eq!(total.output_tokens, 7);
1198    }
1199
1200    #[tokio::test]
1201    async fn auto_cache_system_sends_cache_control_in_request_body() {
1202        let mock = MockServer::start().await;
1203        Mock::given(method("POST"))
1204            .and(path("/v1/messages"))
1205            .and(body_partial_json(json!({
1206                "system": [{
1207                    "type": "text",
1208                    "text": "be concise",
1209                    "cache_control": {"type": "ephemeral"}
1210                }]
1211            })))
1212            .respond_with(ResponseTemplate::new(200).set_body_json(fake_response("ok", 3, 1)))
1213            .mount(&mock)
1214            .await;
1215
1216        let client = client_for(&mock);
1217        let mut c = Conversation::new(ModelId::SONNET_4_6, 32)
1218            .system("be concise")
1219            .with_cache_breakpoint_on_system();
1220        c.push_user("hello");
1221        let _ = c.send(&client).await.unwrap();
1222    }
1223}