Skip to main content

claude_agent/client/messages/
request.rs

1//! Message request and response types.
2
3use serde::{Deserialize, Serialize};
4
5use super::config::{
6    DEFAULT_MAX_TOKENS, EffortLevel, MAX_TOKENS_128K, MIN_MAX_TOKENS, OutputConfig, OutputFormat,
7    ThinkingConfig, TokenValidationError, ToolChoice,
8};
9use super::context::ContextManagement;
10use super::types::{ApiTool, RequestMetadata};
11use crate::types::{
12    Message, SystemPrompt, ToolDefinition, ToolSearchTool, WebFetchTool, WebSearchTool,
13};
14
15#[derive(Debug, Clone, Serialize)]
16pub struct CreateMessageRequest {
17    pub model: String,
18    pub max_tokens: u32,
19    pub messages: Vec<Message>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub system: Option<SystemPrompt>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tools: Option<Vec<ApiTool>>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub tool_choice: Option<ToolChoice>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stream: Option<bool>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub stop_sequences: Option<Vec<String>>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub temperature: Option<f32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub top_p: Option<f32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub top_k: Option<u32>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub metadata: Option<RequestMetadata>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub thinking: Option<ThinkingConfig>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub output_format: Option<OutputFormat>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub context_management: Option<ContextManagement>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub output_config: Option<OutputConfig>,
46}
47
48impl CreateMessageRequest {
49    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
50        Self {
51            model: model.into(),
52            max_tokens: DEFAULT_MAX_TOKENS,
53            messages,
54            system: None,
55            tools: None,
56            tool_choice: None,
57            stream: None,
58            stop_sequences: None,
59            temperature: None,
60            top_p: None,
61            top_k: None,
62            metadata: None,
63            thinking: None,
64            output_format: None,
65            context_management: None,
66            output_config: None,
67        }
68    }
69
70    pub fn validate(&self) -> Result<(), TokenValidationError> {
71        if self.max_tokens < MIN_MAX_TOKENS {
72            return Err(TokenValidationError::MaxTokensTooLow {
73                min: MIN_MAX_TOKENS,
74                actual: self.max_tokens,
75            });
76        }
77        if self.max_tokens > MAX_TOKENS_128K {
78            return Err(TokenValidationError::MaxTokensTooHigh {
79                max: MAX_TOKENS_128K,
80                actual: self.max_tokens,
81            });
82        }
83        if let Some(thinking) = &self.thinking {
84            thinking.validate_against_max_tokens(self.max_tokens)?;
85        }
86        Ok(())
87    }
88
89    pub fn requires_128k_beta(&self) -> bool {
90        self.max_tokens > DEFAULT_MAX_TOKENS
91    }
92
93    pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
94        self.system = Some(system.into());
95        self
96    }
97
98    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
99        let api_tools: Vec<ApiTool> = tools.into_iter().map(ApiTool::Custom).collect();
100        self.tools = Some(api_tools);
101        self
102    }
103
104    pub fn with_web_search(mut self, config: WebSearchTool) -> Self {
105        let mut tools = self.tools.unwrap_or_default();
106        tools.push(ApiTool::WebSearch(config));
107        self.tools = Some(tools);
108        self
109    }
110
111    pub fn with_web_fetch(mut self, config: WebFetchTool) -> Self {
112        let mut tools = self.tools.unwrap_or_default();
113        tools.push(ApiTool::WebFetch(config));
114        self.tools = Some(tools);
115        self
116    }
117
118    pub fn with_tool_search(mut self, config: ToolSearchTool) -> Self {
119        let mut tools = self.tools.unwrap_or_default();
120        tools.push(ApiTool::ToolSearch(config));
121        self.tools = Some(tools);
122        self
123    }
124
125    pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
126        self.tools = Some(tools);
127        self
128    }
129
130    pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
131        self.tool_choice = Some(choice);
132        self
133    }
134
135    pub fn with_tool_choice_auto(mut self) -> Self {
136        self.tool_choice = Some(ToolChoice::Auto);
137        self
138    }
139
140    pub fn with_tool_choice_any(mut self) -> Self {
141        self.tool_choice = Some(ToolChoice::Any);
142        self
143    }
144
145    pub fn with_tool_choice_none(mut self) -> Self {
146        self.tool_choice = Some(ToolChoice::None);
147        self
148    }
149
150    pub fn with_required_tool(mut self, name: impl Into<String>) -> Self {
151        self.tool_choice = Some(ToolChoice::tool(name));
152        self
153    }
154
155    pub fn with_stream(mut self) -> Self {
156        self.stream = Some(true);
157        self
158    }
159
160    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
161        self.max_tokens = max_tokens;
162        self
163    }
164
165    pub fn with_model(mut self, model: impl Into<String>) -> Self {
166        self.model = model.into();
167        self
168    }
169
170    pub fn with_temperature(mut self, temperature: f32) -> Self {
171        self.temperature = Some(temperature);
172        self
173    }
174
175    pub fn with_top_p(mut self, top_p: f32) -> Self {
176        self.top_p = Some(top_p);
177        self
178    }
179
180    pub fn with_top_k(mut self, top_k: u32) -> Self {
181        self.top_k = Some(top_k);
182        self
183    }
184
185    pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
186        self.stop_sequences = Some(sequences);
187        self
188    }
189
190    pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
191        self.thinking = Some(config);
192        self
193    }
194
195    pub fn with_extended_thinking(mut self, budget_tokens: u32) -> Self {
196        self.thinking = Some(ThinkingConfig::enabled(budget_tokens));
197        self
198    }
199
200    pub fn with_output_format(mut self, format: OutputFormat) -> Self {
201        self.output_format = Some(format);
202        self
203    }
204
205    /// Set JSON schema for structured output.
206    ///
207    /// Automatically transforms the schema for strict mode compatibility:
208    /// - Adds `additionalProperties: false` to all objects
209    /// - Removes unsupported constraints (minimum, maximum, minLength, maxLength, etc.)
210    /// - Ensures `required` fields are present
211    pub fn with_json_schema(mut self, schema: serde_json::Value) -> Self {
212        let strict_schema = crate::client::schema::transform_for_strict(schema);
213        self.output_format = Some(OutputFormat::json_schema(strict_schema));
214        self
215    }
216
217    pub fn with_context_management(mut self, management: ContextManagement) -> Self {
218        self.context_management = Some(management);
219        self
220    }
221
222    pub fn with_effort(mut self, level: EffortLevel) -> Self {
223        self.output_config = Some(OutputConfig::with_effort(level));
224        self
225    }
226
227    pub fn with_output_config(mut self, config: OutputConfig) -> Self {
228        self.output_config = Some(config);
229        self
230    }
231}
232
233impl From<String> for SystemPrompt {
234    fn from(s: String) -> Self {
235        SystemPrompt::Text(s)
236    }
237}
238
239impl From<&str> for SystemPrompt {
240    fn from(s: &str) -> Self {
241        SystemPrompt::Text(s.to_string())
242    }
243}
244
245#[derive(Debug, Clone, Serialize)]
246pub struct CountTokensRequest {
247    pub model: String,
248    pub messages: Vec<Message>,
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub system: Option<SystemPrompt>,
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub tools: Option<Vec<ApiTool>>,
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub thinking: Option<ThinkingConfig>,
255}
256
257impl CountTokensRequest {
258    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
259        Self {
260            model: model.into(),
261            messages,
262            system: None,
263            tools: None,
264            thinking: None,
265        }
266    }
267
268    pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
269        self.system = Some(system.into());
270        self
271    }
272
273    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
274        self.tools = Some(tools.into_iter().map(ApiTool::Custom).collect());
275        self
276    }
277
278    pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
279        self.tools = Some(tools);
280        self
281    }
282
283    pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
284        self.thinking = Some(config);
285        self
286    }
287
288    pub fn from_message_request(request: &CreateMessageRequest) -> Self {
289        Self {
290            model: request.model.clone(),
291            messages: request.messages.clone(),
292            system: request.system.clone(),
293            tools: request.tools.clone(),
294            thinking: request.thinking.clone(),
295        }
296    }
297}
298
299#[derive(Debug, Clone, Deserialize)]
300pub struct CountTokensResponse {
301    pub input_tokens: u32,
302    #[serde(default, skip_serializing_if = "Option::is_none")]
303    pub context_management: Option<CountTokensContextManagement>,
304}
305
306#[derive(Debug, Clone, Default, Deserialize)]
307pub struct CountTokensContextManagement {
308    #[serde(default)]
309    pub original_input_tokens: Option<u32>,
310}
311
312#[cfg(test)]
313mod tests {
314    use super::super::config::MIN_THINKING_BUDGET;
315    use super::*;
316
317    #[test]
318    fn test_create_request_default_max_tokens() {
319        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")]);
320        assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
321    }
322
323    #[test]
324    fn test_create_request() {
325        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
326            .with_max_tokens(1000)
327            .with_temperature(0.7);
328
329        assert_eq!(request.model, "claude-sonnet-4-5");
330        assert_eq!(request.max_tokens, 1000);
331        assert_eq!(request.temperature, Some(0.7));
332    }
333
334    #[test]
335    fn test_request_validate_valid() {
336        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
337            .with_max_tokens(4000)
338            .with_extended_thinking(2000);
339        assert!(request.validate().is_ok());
340    }
341
342    #[test]
343    fn test_request_validate_max_tokens_too_high() {
344        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
345            .with_max_tokens(MAX_TOKENS_128K + 1);
346        let err = request.validate().unwrap_err();
347        assert!(matches!(err, TokenValidationError::MaxTokensTooHigh { .. }));
348    }
349
350    #[test]
351    fn test_request_validate_thinking_auto_clamp() {
352        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
353            .with_extended_thinking(500);
354        assert_eq!(
355            request.thinking.as_ref().unwrap().budget(),
356            Some(MIN_THINKING_BUDGET)
357        );
358        assert!(request.validate().is_ok());
359    }
360
361    #[test]
362    fn test_request_validate_thinking_exceeds_max() {
363        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
364            .with_max_tokens(2000)
365            .with_extended_thinking(MIN_THINKING_BUDGET);
366        assert!(request.validate().is_ok());
367
368        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
369            .with_max_tokens(MIN_THINKING_BUDGET)
370            .with_extended_thinking(MIN_THINKING_BUDGET);
371        let err = request.validate().unwrap_err();
372        assert!(matches!(
373            err,
374            TokenValidationError::ThinkingBudgetExceedsMaxTokens { .. }
375        ));
376    }
377
378    #[test]
379    fn test_request_requires_128k_beta() {
380        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")]);
381        assert!(!request.requires_128k_beta());
382
383        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
384            .with_max_tokens(DEFAULT_MAX_TOKENS + 1);
385        assert!(request.requires_128k_beta());
386
387        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
388            .with_max_tokens(MAX_TOKENS_128K);
389        assert!(request.requires_128k_beta());
390    }
391
392    #[test]
393    fn test_count_tokens_request() {
394        let request = CountTokensRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
395            .with_system("You are a helpful assistant");
396
397        assert_eq!(request.model, "claude-sonnet-4-5");
398        assert!(request.system.is_some());
399    }
400
401    #[test]
402    fn test_count_tokens_from_message_request() {
403        let msg_request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
404            .with_system("System prompt")
405            .with_extended_thinking(10000);
406
407        let count_request = CountTokensRequest::from_message_request(&msg_request);
408
409        assert_eq!(count_request.model, msg_request.model);
410        assert_eq!(count_request.messages.len(), msg_request.messages.len());
411        assert!(count_request.system.is_some());
412        assert!(count_request.thinking.is_some());
413    }
414
415    #[test]
416    fn test_request_with_effort() {
417        let request = CreateMessageRequest::new("claude-opus-4-5", vec![Message::user("Hi")])
418            .with_effort(EffortLevel::Medium);
419        assert!(request.output_config.is_some());
420        assert_eq!(
421            request.output_config.unwrap().effort,
422            Some(EffortLevel::Medium)
423        );
424    }
425
426    #[test]
427    fn test_request_with_context_management() {
428        let mgmt = ContextManagement::new().with_edit(ContextManagement::clear_thinking(2));
429        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
430            .with_context_management(mgmt);
431        assert!(request.context_management.is_some());
432    }
433
434    #[test]
435    fn test_request_with_tool_choice() {
436        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
437            .with_tool_choice_any();
438        assert_eq!(request.tool_choice, Some(ToolChoice::Any));
439
440        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
441            .with_required_tool("Grep");
442        assert_eq!(request.tool_choice, Some(ToolChoice::tool("Grep")));
443    }
444}