claude_agent/client/messages/
request.rs

1//! Message request and response types.
2
3use serde::{Deserialize, Serialize};
4
5use super::config::{
6    DEFAULT_MAX_TOKENS, EffortLevel, MAX_TOKENS_128K, MIN_MAX_TOKENS, OutputConfig, OutputFormat,
7    ThinkingConfig, TokenValidationError, ToolChoice,
8};
9use super::context::ContextManagement;
10use super::types::{ApiTool, RequestMetadata};
11use crate::types::{Message, SystemPrompt, ToolDefinition, WebFetchTool, WebSearchTool};
12
13#[derive(Debug, Clone, Serialize)]
14pub struct CreateMessageRequest {
15    pub model: String,
16    pub max_tokens: u32,
17    pub messages: Vec<Message>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub system: Option<SystemPrompt>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub tools: Option<Vec<ApiTool>>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tool_choice: Option<ToolChoice>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub stream: Option<bool>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stop_sequences: Option<Vec<String>>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub temperature: Option<f32>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub top_p: Option<f32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub top_k: Option<u32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub metadata: Option<RequestMetadata>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub thinking: Option<ThinkingConfig>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub output_format: Option<OutputFormat>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub context_management: Option<ContextManagement>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub output_config: Option<OutputConfig>,
44}
45
46impl CreateMessageRequest {
47    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
48        Self {
49            model: model.into(),
50            max_tokens: DEFAULT_MAX_TOKENS,
51            messages,
52            system: None,
53            tools: None,
54            tool_choice: None,
55            stream: None,
56            stop_sequences: None,
57            temperature: None,
58            top_p: None,
59            top_k: None,
60            metadata: None,
61            thinking: None,
62            output_format: None,
63            context_management: None,
64            output_config: None,
65        }
66    }
67
68    pub fn validate(&self) -> Result<(), TokenValidationError> {
69        if self.max_tokens < MIN_MAX_TOKENS {
70            return Err(TokenValidationError::MaxTokensTooLow {
71                min: MIN_MAX_TOKENS,
72                actual: self.max_tokens,
73            });
74        }
75        if self.max_tokens > MAX_TOKENS_128K {
76            return Err(TokenValidationError::MaxTokensTooHigh {
77                max: MAX_TOKENS_128K,
78                actual: self.max_tokens,
79            });
80        }
81        if let Some(thinking) = &self.thinking {
82            thinking.validate_against_max_tokens(self.max_tokens)?;
83        }
84        Ok(())
85    }
86
87    pub fn requires_128k_beta(&self) -> bool {
88        self.max_tokens > DEFAULT_MAX_TOKENS
89    }
90
91    pub fn with_metadata(mut self, metadata: RequestMetadata) -> Self {
92        self.metadata = Some(metadata);
93        self
94    }
95
96    pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
97        self.system = Some(system.into());
98        self
99    }
100
101    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
102        let api_tools: Vec<ApiTool> = tools.into_iter().map(ApiTool::Custom).collect();
103        self.tools = Some(api_tools);
104        self
105    }
106
107    pub fn with_web_search(mut self, config: WebSearchTool) -> Self {
108        let mut tools = self.tools.unwrap_or_default();
109        tools.push(ApiTool::WebSearch(config));
110        self.tools = Some(tools);
111        self
112    }
113
114    pub fn with_web_fetch(mut self, config: WebFetchTool) -> Self {
115        let mut tools = self.tools.unwrap_or_default();
116        tools.push(ApiTool::WebFetch(config));
117        self.tools = Some(tools);
118        self
119    }
120
121    pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
122        self.tools = Some(tools);
123        self
124    }
125
126    pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
127        self.tool_choice = Some(choice);
128        self
129    }
130
131    pub fn with_tool_choice_auto(mut self) -> Self {
132        self.tool_choice = Some(ToolChoice::Auto);
133        self
134    }
135
136    pub fn with_tool_choice_any(mut self) -> Self {
137        self.tool_choice = Some(ToolChoice::Any);
138        self
139    }
140
141    pub fn with_tool_choice_none(mut self) -> Self {
142        self.tool_choice = Some(ToolChoice::None);
143        self
144    }
145
146    pub fn with_required_tool(mut self, name: impl Into<String>) -> Self {
147        self.tool_choice = Some(ToolChoice::tool(name));
148        self
149    }
150
151    pub fn with_stream(mut self) -> Self {
152        self.stream = Some(true);
153        self
154    }
155
156    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
157        self.max_tokens = max_tokens;
158        self
159    }
160
161    pub fn with_model(mut self, model: impl Into<String>) -> Self {
162        self.model = model.into();
163        self
164    }
165
166    pub fn with_temperature(mut self, temperature: f32) -> Self {
167        self.temperature = Some(temperature);
168        self
169    }
170
171    pub fn with_top_p(mut self, top_p: f32) -> Self {
172        self.top_p = Some(top_p);
173        self
174    }
175
176    pub fn with_top_k(mut self, top_k: u32) -> Self {
177        self.top_k = Some(top_k);
178        self
179    }
180
181    pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
182        self.stop_sequences = Some(sequences);
183        self
184    }
185
186    pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
187        self.thinking = Some(config);
188        self
189    }
190
191    pub fn with_extended_thinking(mut self, budget_tokens: u32) -> Self {
192        self.thinking = Some(ThinkingConfig::enabled(budget_tokens));
193        self
194    }
195
196    pub fn with_output_format(mut self, format: OutputFormat) -> Self {
197        self.output_format = Some(format);
198        self
199    }
200
201    pub fn with_json_schema(mut self, schema: serde_json::Value) -> Self {
202        self.output_format = Some(OutputFormat::json_schema(schema));
203        self
204    }
205
206    pub fn with_context_management(mut self, management: ContextManagement) -> Self {
207        self.context_management = Some(management);
208        self
209    }
210
211    pub fn with_effort(mut self, level: EffortLevel) -> Self {
212        self.output_config = Some(OutputConfig::with_effort(level));
213        self
214    }
215
216    pub fn with_output_config(mut self, config: OutputConfig) -> Self {
217        self.output_config = Some(config);
218        self
219    }
220}
221
222impl From<String> for SystemPrompt {
223    fn from(s: String) -> Self {
224        SystemPrompt::Text(s)
225    }
226}
227
228impl From<&str> for SystemPrompt {
229    fn from(s: &str) -> Self {
230        SystemPrompt::Text(s.to_string())
231    }
232}
233
234#[derive(Debug, Clone, Serialize)]
235pub struct CountTokensRequest {
236    pub model: String,
237    pub messages: Vec<Message>,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub system: Option<SystemPrompt>,
240    #[serde(skip_serializing_if = "Option::is_none")]
241    pub tools: Option<Vec<ApiTool>>,
242    #[serde(skip_serializing_if = "Option::is_none")]
243    pub thinking: Option<ThinkingConfig>,
244}
245
246impl CountTokensRequest {
247    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
248        Self {
249            model: model.into(),
250            messages,
251            system: None,
252            tools: None,
253            thinking: None,
254        }
255    }
256
257    pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
258        self.system = Some(system.into());
259        self
260    }
261
262    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
263        self.tools = Some(tools.into_iter().map(ApiTool::Custom).collect());
264        self
265    }
266
267    pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
268        self.tools = Some(tools);
269        self
270    }
271
272    pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
273        self.thinking = Some(config);
274        self
275    }
276
277    pub fn from_message_request(request: &CreateMessageRequest) -> Self {
278        Self {
279            model: request.model.clone(),
280            messages: request.messages.clone(),
281            system: request.system.clone(),
282            tools: request.tools.clone(),
283            thinking: request.thinking.clone(),
284        }
285    }
286}
287
288#[derive(Debug, Clone, Deserialize)]
289pub struct CountTokensResponse {
290    pub input_tokens: u32,
291    #[serde(default, skip_serializing_if = "Option::is_none")]
292    pub context_management: Option<CountTokensContextManagement>,
293}
294
295#[derive(Debug, Clone, Default, Deserialize)]
296pub struct CountTokensContextManagement {
297    #[serde(default)]
298    pub original_input_tokens: Option<u32>,
299}
300
301#[cfg(test)]
302mod tests {
303    use super::super::config::MIN_THINKING_BUDGET;
304    use super::*;
305
306    #[test]
307    fn test_create_request_default_max_tokens() {
308        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")]);
309        assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
310    }
311
312    #[test]
313    fn test_create_request() {
314        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
315            .with_max_tokens(1000)
316            .with_temperature(0.7);
317
318        assert_eq!(request.model, "claude-sonnet-4-5");
319        assert_eq!(request.max_tokens, 1000);
320        assert_eq!(request.temperature, Some(0.7));
321    }
322
323    #[test]
324    fn test_request_validate_valid() {
325        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
326            .with_max_tokens(4000)
327            .with_extended_thinking(2000);
328        assert!(request.validate().is_ok());
329    }
330
331    #[test]
332    fn test_request_validate_max_tokens_too_high() {
333        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
334            .with_max_tokens(MAX_TOKENS_128K + 1);
335        let err = request.validate().unwrap_err();
336        assert!(matches!(err, TokenValidationError::MaxTokensTooHigh { .. }));
337    }
338
339    #[test]
340    fn test_request_validate_thinking_auto_clamp() {
341        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
342            .with_extended_thinking(500);
343        assert_eq!(
344            request.thinking.as_ref().unwrap().budget(),
345            Some(MIN_THINKING_BUDGET)
346        );
347        assert!(request.validate().is_ok());
348    }
349
350    #[test]
351    fn test_request_validate_thinking_exceeds_max() {
352        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
353            .with_max_tokens(2000)
354            .with_extended_thinking(MIN_THINKING_BUDGET);
355        assert!(request.validate().is_ok());
356
357        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
358            .with_max_tokens(MIN_THINKING_BUDGET)
359            .with_extended_thinking(MIN_THINKING_BUDGET);
360        let err = request.validate().unwrap_err();
361        assert!(matches!(
362            err,
363            TokenValidationError::ThinkingBudgetExceedsMaxTokens { .. }
364        ));
365    }
366
367    #[test]
368    fn test_request_requires_128k_beta() {
369        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")]);
370        assert!(!request.requires_128k_beta());
371
372        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
373            .with_max_tokens(DEFAULT_MAX_TOKENS + 1);
374        assert!(request.requires_128k_beta());
375
376        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
377            .with_max_tokens(MAX_TOKENS_128K);
378        assert!(request.requires_128k_beta());
379    }
380
381    #[test]
382    fn test_count_tokens_request() {
383        let request = CountTokensRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
384            .with_system("You are a helpful assistant");
385
386        assert_eq!(request.model, "claude-sonnet-4-5");
387        assert!(request.system.is_some());
388    }
389
390    #[test]
391    fn test_count_tokens_from_message_request() {
392        let msg_request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
393            .with_system("System prompt")
394            .with_extended_thinking(10000);
395
396        let count_request = CountTokensRequest::from_message_request(&msg_request);
397
398        assert_eq!(count_request.model, msg_request.model);
399        assert_eq!(count_request.messages.len(), msg_request.messages.len());
400        assert!(count_request.system.is_some());
401        assert!(count_request.thinking.is_some());
402    }
403
404    #[test]
405    fn test_request_with_effort() {
406        let request = CreateMessageRequest::new("claude-opus-4-5", vec![Message::user("Hi")])
407            .with_effort(EffortLevel::Medium);
408        assert!(request.output_config.is_some());
409        assert_eq!(
410            request.output_config.unwrap().effort,
411            Some(EffortLevel::Medium)
412        );
413    }
414
415    #[test]
416    fn test_request_with_context_management() {
417        let mgmt = ContextManagement::new().with_edit(ContextManagement::clear_thinking(2));
418        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
419            .with_context_management(mgmt);
420        assert!(request.context_management.is_some());
421    }
422
423    #[test]
424    fn test_request_with_tool_choice() {
425        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
426            .with_tool_choice_any();
427        assert_eq!(request.tool_choice, Some(ToolChoice::Any));
428
429        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
430            .with_required_tool("Grep");
431        assert_eq!(request.tool_choice, Some(ToolChoice::tool("Grep")));
432    }
433}