claude_agent/client/messages/
request.rs

1//! Message request and response types.
2
3use serde::{Deserialize, Serialize};
4
5use super::config::{
6    DEFAULT_MAX_TOKENS, EffortLevel, MAX_TOKENS_128K, MIN_MAX_TOKENS, OutputConfig, OutputFormat,
7    ThinkingConfig, TokenValidationError, ToolChoice,
8};
9use super::context::ContextManagement;
10use super::types::{ApiTool, RequestMetadata};
11use crate::types::{
12    Message, SystemPrompt, ToolDefinition, ToolSearchTool, WebFetchTool, WebSearchTool,
13};
14
15#[derive(Debug, Clone, Serialize)]
16pub struct CreateMessageRequest {
17    pub model: String,
18    pub max_tokens: u32,
19    pub messages: Vec<Message>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub system: Option<SystemPrompt>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tools: Option<Vec<ApiTool>>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub tool_choice: Option<ToolChoice>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stream: Option<bool>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub stop_sequences: Option<Vec<String>>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub temperature: Option<f32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub top_p: Option<f32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub top_k: Option<u32>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub metadata: Option<RequestMetadata>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub thinking: Option<ThinkingConfig>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub output_format: Option<OutputFormat>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub context_management: Option<ContextManagement>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub output_config: Option<OutputConfig>,
46}
47
48impl CreateMessageRequest {
49    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
50        Self {
51            model: model.into(),
52            max_tokens: DEFAULT_MAX_TOKENS,
53            messages,
54            system: None,
55            tools: None,
56            tool_choice: None,
57            stream: None,
58            stop_sequences: None,
59            temperature: None,
60            top_p: None,
61            top_k: None,
62            metadata: None,
63            thinking: None,
64            output_format: None,
65            context_management: None,
66            output_config: None,
67        }
68    }
69
70    pub fn validate(&self) -> Result<(), TokenValidationError> {
71        if self.max_tokens < MIN_MAX_TOKENS {
72            return Err(TokenValidationError::MaxTokensTooLow {
73                min: MIN_MAX_TOKENS,
74                actual: self.max_tokens,
75            });
76        }
77        if self.max_tokens > MAX_TOKENS_128K {
78            return Err(TokenValidationError::MaxTokensTooHigh {
79                max: MAX_TOKENS_128K,
80                actual: self.max_tokens,
81            });
82        }
83        if let Some(thinking) = &self.thinking {
84            thinking.validate_against_max_tokens(self.max_tokens)?;
85        }
86        Ok(())
87    }
88
89    pub fn requires_128k_beta(&self) -> bool {
90        self.max_tokens > DEFAULT_MAX_TOKENS
91    }
92
93    pub fn with_metadata(mut self, metadata: RequestMetadata) -> Self {
94        self.metadata = Some(metadata);
95        self
96    }
97
98    pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
99        self.system = Some(system.into());
100        self
101    }
102
103    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
104        let api_tools: Vec<ApiTool> = tools.into_iter().map(ApiTool::Custom).collect();
105        self.tools = Some(api_tools);
106        self
107    }
108
109    pub fn with_web_search(mut self, config: WebSearchTool) -> Self {
110        let mut tools = self.tools.unwrap_or_default();
111        tools.push(ApiTool::WebSearch(config));
112        self.tools = Some(tools);
113        self
114    }
115
116    pub fn with_web_fetch(mut self, config: WebFetchTool) -> Self {
117        let mut tools = self.tools.unwrap_or_default();
118        tools.push(ApiTool::WebFetch(config));
119        self.tools = Some(tools);
120        self
121    }
122
123    pub fn with_tool_search(mut self, config: ToolSearchTool) -> Self {
124        let mut tools = self.tools.unwrap_or_default();
125        tools.push(ApiTool::ToolSearch(config));
126        self.tools = Some(tools);
127        self
128    }
129
130    pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
131        self.tools = Some(tools);
132        self
133    }
134
135    pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
136        self.tool_choice = Some(choice);
137        self
138    }
139
140    pub fn with_tool_choice_auto(mut self) -> Self {
141        self.tool_choice = Some(ToolChoice::Auto);
142        self
143    }
144
145    pub fn with_tool_choice_any(mut self) -> Self {
146        self.tool_choice = Some(ToolChoice::Any);
147        self
148    }
149
150    pub fn with_tool_choice_none(mut self) -> Self {
151        self.tool_choice = Some(ToolChoice::None);
152        self
153    }
154
155    pub fn with_required_tool(mut self, name: impl Into<String>) -> Self {
156        self.tool_choice = Some(ToolChoice::tool(name));
157        self
158    }
159
160    pub fn with_stream(mut self) -> Self {
161        self.stream = Some(true);
162        self
163    }
164
165    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
166        self.max_tokens = max_tokens;
167        self
168    }
169
170    pub fn with_model(mut self, model: impl Into<String>) -> Self {
171        self.model = model.into();
172        self
173    }
174
175    pub fn with_temperature(mut self, temperature: f32) -> Self {
176        self.temperature = Some(temperature);
177        self
178    }
179
180    pub fn with_top_p(mut self, top_p: f32) -> Self {
181        self.top_p = Some(top_p);
182        self
183    }
184
185    pub fn with_top_k(mut self, top_k: u32) -> Self {
186        self.top_k = Some(top_k);
187        self
188    }
189
190    pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
191        self.stop_sequences = Some(sequences);
192        self
193    }
194
195    pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
196        self.thinking = Some(config);
197        self
198    }
199
200    pub fn with_extended_thinking(mut self, budget_tokens: u32) -> Self {
201        self.thinking = Some(ThinkingConfig::enabled(budget_tokens));
202        self
203    }
204
205    pub fn with_output_format(mut self, format: OutputFormat) -> Self {
206        self.output_format = Some(format);
207        self
208    }
209
210    /// Set JSON schema for structured output.
211    ///
212    /// Automatically transforms the schema for strict mode compatibility:
213    /// - Adds `additionalProperties: false` to all objects
214    /// - Removes unsupported constraints (minimum, maximum, minLength, maxLength, etc.)
215    /// - Ensures `required` fields are present
216    pub fn with_json_schema(mut self, schema: serde_json::Value) -> Self {
217        let strict_schema = crate::client::schema::transform_for_strict(schema);
218        self.output_format = Some(OutputFormat::json_schema(strict_schema));
219        self
220    }
221
222    pub fn with_context_management(mut self, management: ContextManagement) -> Self {
223        self.context_management = Some(management);
224        self
225    }
226
227    pub fn with_effort(mut self, level: EffortLevel) -> Self {
228        self.output_config = Some(OutputConfig::with_effort(level));
229        self
230    }
231
232    pub fn with_output_config(mut self, config: OutputConfig) -> Self {
233        self.output_config = Some(config);
234        self
235    }
236}
237
238impl From<String> for SystemPrompt {
239    fn from(s: String) -> Self {
240        SystemPrompt::Text(s)
241    }
242}
243
244impl From<&str> for SystemPrompt {
245    fn from(s: &str) -> Self {
246        SystemPrompt::Text(s.to_string())
247    }
248}
249
250#[derive(Debug, Clone, Serialize)]
251pub struct CountTokensRequest {
252    pub model: String,
253    pub messages: Vec<Message>,
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub system: Option<SystemPrompt>,
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub tools: Option<Vec<ApiTool>>,
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub thinking: Option<ThinkingConfig>,
260}
261
262impl CountTokensRequest {
263    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
264        Self {
265            model: model.into(),
266            messages,
267            system: None,
268            tools: None,
269            thinking: None,
270        }
271    }
272
273    pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
274        self.system = Some(system.into());
275        self
276    }
277
278    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
279        self.tools = Some(tools.into_iter().map(ApiTool::Custom).collect());
280        self
281    }
282
283    pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
284        self.tools = Some(tools);
285        self
286    }
287
288    pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
289        self.thinking = Some(config);
290        self
291    }
292
293    pub fn from_message_request(request: &CreateMessageRequest) -> Self {
294        Self {
295            model: request.model.clone(),
296            messages: request.messages.clone(),
297            system: request.system.clone(),
298            tools: request.tools.clone(),
299            thinking: request.thinking.clone(),
300        }
301    }
302}
303
304#[derive(Debug, Clone, Deserialize)]
305pub struct CountTokensResponse {
306    pub input_tokens: u32,
307    #[serde(default, skip_serializing_if = "Option::is_none")]
308    pub context_management: Option<CountTokensContextManagement>,
309}
310
311#[derive(Debug, Clone, Default, Deserialize)]
312pub struct CountTokensContextManagement {
313    #[serde(default)]
314    pub original_input_tokens: Option<u32>,
315}
316
317#[cfg(test)]
318mod tests {
319    use super::super::config::MIN_THINKING_BUDGET;
320    use super::*;
321
322    #[test]
323    fn test_create_request_default_max_tokens() {
324        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")]);
325        assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
326    }
327
328    #[test]
329    fn test_create_request() {
330        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
331            .with_max_tokens(1000)
332            .with_temperature(0.7);
333
334        assert_eq!(request.model, "claude-sonnet-4-5");
335        assert_eq!(request.max_tokens, 1000);
336        assert_eq!(request.temperature, Some(0.7));
337    }
338
339    #[test]
340    fn test_request_validate_valid() {
341        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
342            .with_max_tokens(4000)
343            .with_extended_thinking(2000);
344        assert!(request.validate().is_ok());
345    }
346
347    #[test]
348    fn test_request_validate_max_tokens_too_high() {
349        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
350            .with_max_tokens(MAX_TOKENS_128K + 1);
351        let err = request.validate().unwrap_err();
352        assert!(matches!(err, TokenValidationError::MaxTokensTooHigh { .. }));
353    }
354
355    #[test]
356    fn test_request_validate_thinking_auto_clamp() {
357        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
358            .with_extended_thinking(500);
359        assert_eq!(
360            request.thinking.as_ref().unwrap().budget(),
361            Some(MIN_THINKING_BUDGET)
362        );
363        assert!(request.validate().is_ok());
364    }
365
366    #[test]
367    fn test_request_validate_thinking_exceeds_max() {
368        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
369            .with_max_tokens(2000)
370            .with_extended_thinking(MIN_THINKING_BUDGET);
371        assert!(request.validate().is_ok());
372
373        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
374            .with_max_tokens(MIN_THINKING_BUDGET)
375            .with_extended_thinking(MIN_THINKING_BUDGET);
376        let err = request.validate().unwrap_err();
377        assert!(matches!(
378            err,
379            TokenValidationError::ThinkingBudgetExceedsMaxTokens { .. }
380        ));
381    }
382
383    #[test]
384    fn test_request_requires_128k_beta() {
385        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")]);
386        assert!(!request.requires_128k_beta());
387
388        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
389            .with_max_tokens(DEFAULT_MAX_TOKENS + 1);
390        assert!(request.requires_128k_beta());
391
392        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
393            .with_max_tokens(MAX_TOKENS_128K);
394        assert!(request.requires_128k_beta());
395    }
396
397    #[test]
398    fn test_count_tokens_request() {
399        let request = CountTokensRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
400            .with_system("You are a helpful assistant");
401
402        assert_eq!(request.model, "claude-sonnet-4-5");
403        assert!(request.system.is_some());
404    }
405
406    #[test]
407    fn test_count_tokens_from_message_request() {
408        let msg_request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
409            .with_system("System prompt")
410            .with_extended_thinking(10000);
411
412        let count_request = CountTokensRequest::from_message_request(&msg_request);
413
414        assert_eq!(count_request.model, msg_request.model);
415        assert_eq!(count_request.messages.len(), msg_request.messages.len());
416        assert!(count_request.system.is_some());
417        assert!(count_request.thinking.is_some());
418    }
419
420    #[test]
421    fn test_request_with_effort() {
422        let request = CreateMessageRequest::new("claude-opus-4-5", vec![Message::user("Hi")])
423            .with_effort(EffortLevel::Medium);
424        assert!(request.output_config.is_some());
425        assert_eq!(
426            request.output_config.unwrap().effort,
427            Some(EffortLevel::Medium)
428        );
429    }
430
431    #[test]
432    fn test_request_with_context_management() {
433        let mgmt = ContextManagement::new().with_edit(ContextManagement::clear_thinking(2));
434        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
435            .with_context_management(mgmt);
436        assert!(request.context_management.is_some());
437    }
438
439    #[test]
440    fn test_request_with_tool_choice() {
441        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
442            .with_tool_choice_any();
443        assert_eq!(request.tool_choice, Some(ToolChoice::Any));
444
445        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
446            .with_required_tool("Grep");
447        assert_eq!(request.tool_choice, Some(ToolChoice::tool("Grep")));
448    }
449}