claude_agent/client/messages/
request.rs

1//! Message request and response types.
2
3use serde::{Deserialize, Serialize};
4
5use super::config::{
6    DEFAULT_MAX_TOKENS, EffortLevel, MAX_TOKENS_128K, MIN_MAX_TOKENS, OutputConfig, OutputFormat,
7    ThinkingConfig, TokenValidationError, ToolChoice,
8};
9use super::context::ContextManagement;
10use super::types::{ApiTool, RequestMetadata};
11use crate::types::{
12    Message, SystemPrompt, ToolDefinition, ToolSearchTool, WebFetchTool, WebSearchTool,
13};
14
15#[derive(Debug, Clone, Serialize)]
16pub struct CreateMessageRequest {
17    pub model: String,
18    pub max_tokens: u32,
19    pub messages: Vec<Message>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub system: Option<SystemPrompt>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tools: Option<Vec<ApiTool>>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub tool_choice: Option<ToolChoice>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stream: Option<bool>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub stop_sequences: Option<Vec<String>>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub temperature: Option<f32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub top_p: Option<f32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub top_k: Option<u32>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub metadata: Option<RequestMetadata>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub thinking: Option<ThinkingConfig>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub output_format: Option<OutputFormat>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub context_management: Option<ContextManagement>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub output_config: Option<OutputConfig>,
46}
47
48impl CreateMessageRequest {
49    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
50        Self {
51            model: model.into(),
52            max_tokens: DEFAULT_MAX_TOKENS,
53            messages,
54            system: None,
55            tools: None,
56            tool_choice: None,
57            stream: None,
58            stop_sequences: None,
59            temperature: None,
60            top_p: None,
61            top_k: None,
62            metadata: None,
63            thinking: None,
64            output_format: None,
65            context_management: None,
66            output_config: None,
67        }
68    }
69
70    pub fn validate(&self) -> Result<(), TokenValidationError> {
71        if self.max_tokens < MIN_MAX_TOKENS {
72            return Err(TokenValidationError::MaxTokensTooLow {
73                min: MIN_MAX_TOKENS,
74                actual: self.max_tokens,
75            });
76        }
77        if self.max_tokens > MAX_TOKENS_128K {
78            return Err(TokenValidationError::MaxTokensTooHigh {
79                max: MAX_TOKENS_128K,
80                actual: self.max_tokens,
81            });
82        }
83        if let Some(thinking) = &self.thinking {
84            thinking.validate_against_max_tokens(self.max_tokens)?;
85        }
86        Ok(())
87    }
88
89    pub fn requires_128k_beta(&self) -> bool {
90        self.max_tokens > DEFAULT_MAX_TOKENS
91    }
92
93    pub fn with_metadata(mut self, metadata: RequestMetadata) -> Self {
94        self.metadata = Some(metadata);
95        self
96    }
97
98    pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
99        self.system = Some(system.into());
100        self
101    }
102
103    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
104        let api_tools: Vec<ApiTool> = tools.into_iter().map(ApiTool::Custom).collect();
105        self.tools = Some(api_tools);
106        self
107    }
108
109    pub fn with_web_search(mut self, config: WebSearchTool) -> Self {
110        let mut tools = self.tools.unwrap_or_default();
111        tools.push(ApiTool::WebSearch(config));
112        self.tools = Some(tools);
113        self
114    }
115
116    pub fn with_web_fetch(mut self, config: WebFetchTool) -> Self {
117        let mut tools = self.tools.unwrap_or_default();
118        tools.push(ApiTool::WebFetch(config));
119        self.tools = Some(tools);
120        self
121    }
122
123    pub fn with_tool_search(mut self, config: ToolSearchTool) -> Self {
124        let mut tools = self.tools.unwrap_or_default();
125        tools.push(ApiTool::ToolSearch(config));
126        self.tools = Some(tools);
127        self
128    }
129
130    pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
131        self.tools = Some(tools);
132        self
133    }
134
135    pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
136        self.tool_choice = Some(choice);
137        self
138    }
139
140    pub fn with_tool_choice_auto(mut self) -> Self {
141        self.tool_choice = Some(ToolChoice::Auto);
142        self
143    }
144
145    pub fn with_tool_choice_any(mut self) -> Self {
146        self.tool_choice = Some(ToolChoice::Any);
147        self
148    }
149
150    pub fn with_tool_choice_none(mut self) -> Self {
151        self.tool_choice = Some(ToolChoice::None);
152        self
153    }
154
155    pub fn with_required_tool(mut self, name: impl Into<String>) -> Self {
156        self.tool_choice = Some(ToolChoice::tool(name));
157        self
158    }
159
160    pub fn with_stream(mut self) -> Self {
161        self.stream = Some(true);
162        self
163    }
164
165    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
166        self.max_tokens = max_tokens;
167        self
168    }
169
170    pub fn with_model(mut self, model: impl Into<String>) -> Self {
171        self.model = model.into();
172        self
173    }
174
175    pub fn with_temperature(mut self, temperature: f32) -> Self {
176        self.temperature = Some(temperature);
177        self
178    }
179
180    pub fn with_top_p(mut self, top_p: f32) -> Self {
181        self.top_p = Some(top_p);
182        self
183    }
184
185    pub fn with_top_k(mut self, top_k: u32) -> Self {
186        self.top_k = Some(top_k);
187        self
188    }
189
190    pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
191        self.stop_sequences = Some(sequences);
192        self
193    }
194
195    pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
196        self.thinking = Some(config);
197        self
198    }
199
200    pub fn with_extended_thinking(mut self, budget_tokens: u32) -> Self {
201        self.thinking = Some(ThinkingConfig::enabled(budget_tokens));
202        self
203    }
204
205    pub fn with_output_format(mut self, format: OutputFormat) -> Self {
206        self.output_format = Some(format);
207        self
208    }
209
210    pub fn with_json_schema(mut self, schema: serde_json::Value) -> Self {
211        self.output_format = Some(OutputFormat::json_schema(schema));
212        self
213    }
214
215    pub fn with_context_management(mut self, management: ContextManagement) -> Self {
216        self.context_management = Some(management);
217        self
218    }
219
220    pub fn with_effort(mut self, level: EffortLevel) -> Self {
221        self.output_config = Some(OutputConfig::with_effort(level));
222        self
223    }
224
225    pub fn with_output_config(mut self, config: OutputConfig) -> Self {
226        self.output_config = Some(config);
227        self
228    }
229}
230
231impl From<String> for SystemPrompt {
232    fn from(s: String) -> Self {
233        SystemPrompt::Text(s)
234    }
235}
236
237impl From<&str> for SystemPrompt {
238    fn from(s: &str) -> Self {
239        SystemPrompt::Text(s.to_string())
240    }
241}
242
243#[derive(Debug, Clone, Serialize)]
244pub struct CountTokensRequest {
245    pub model: String,
246    pub messages: Vec<Message>,
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub system: Option<SystemPrompt>,
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub tools: Option<Vec<ApiTool>>,
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub thinking: Option<ThinkingConfig>,
253}
254
255impl CountTokensRequest {
256    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
257        Self {
258            model: model.into(),
259            messages,
260            system: None,
261            tools: None,
262            thinking: None,
263        }
264    }
265
266    pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
267        self.system = Some(system.into());
268        self
269    }
270
271    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
272        self.tools = Some(tools.into_iter().map(ApiTool::Custom).collect());
273        self
274    }
275
276    pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
277        self.tools = Some(tools);
278        self
279    }
280
281    pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
282        self.thinking = Some(config);
283        self
284    }
285
286    pub fn from_message_request(request: &CreateMessageRequest) -> Self {
287        Self {
288            model: request.model.clone(),
289            messages: request.messages.clone(),
290            system: request.system.clone(),
291            tools: request.tools.clone(),
292            thinking: request.thinking.clone(),
293        }
294    }
295}
296
297#[derive(Debug, Clone, Deserialize)]
298pub struct CountTokensResponse {
299    pub input_tokens: u32,
300    #[serde(default, skip_serializing_if = "Option::is_none")]
301    pub context_management: Option<CountTokensContextManagement>,
302}
303
304#[derive(Debug, Clone, Default, Deserialize)]
305pub struct CountTokensContextManagement {
306    #[serde(default)]
307    pub original_input_tokens: Option<u32>,
308}
309
310#[cfg(test)]
311mod tests {
312    use super::super::config::MIN_THINKING_BUDGET;
313    use super::*;
314
315    #[test]
316    fn test_create_request_default_max_tokens() {
317        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")]);
318        assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
319    }
320
321    #[test]
322    fn test_create_request() {
323        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
324            .with_max_tokens(1000)
325            .with_temperature(0.7);
326
327        assert_eq!(request.model, "claude-sonnet-4-5");
328        assert_eq!(request.max_tokens, 1000);
329        assert_eq!(request.temperature, Some(0.7));
330    }
331
332    #[test]
333    fn test_request_validate_valid() {
334        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
335            .with_max_tokens(4000)
336            .with_extended_thinking(2000);
337        assert!(request.validate().is_ok());
338    }
339
340    #[test]
341    fn test_request_validate_max_tokens_too_high() {
342        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
343            .with_max_tokens(MAX_TOKENS_128K + 1);
344        let err = request.validate().unwrap_err();
345        assert!(matches!(err, TokenValidationError::MaxTokensTooHigh { .. }));
346    }
347
348    #[test]
349    fn test_request_validate_thinking_auto_clamp() {
350        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
351            .with_extended_thinking(500);
352        assert_eq!(
353            request.thinking.as_ref().unwrap().budget(),
354            Some(MIN_THINKING_BUDGET)
355        );
356        assert!(request.validate().is_ok());
357    }
358
359    #[test]
360    fn test_request_validate_thinking_exceeds_max() {
361        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
362            .with_max_tokens(2000)
363            .with_extended_thinking(MIN_THINKING_BUDGET);
364        assert!(request.validate().is_ok());
365
366        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
367            .with_max_tokens(MIN_THINKING_BUDGET)
368            .with_extended_thinking(MIN_THINKING_BUDGET);
369        let err = request.validate().unwrap_err();
370        assert!(matches!(
371            err,
372            TokenValidationError::ThinkingBudgetExceedsMaxTokens { .. }
373        ));
374    }
375
376    #[test]
377    fn test_request_requires_128k_beta() {
378        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")]);
379        assert!(!request.requires_128k_beta());
380
381        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
382            .with_max_tokens(DEFAULT_MAX_TOKENS + 1);
383        assert!(request.requires_128k_beta());
384
385        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
386            .with_max_tokens(MAX_TOKENS_128K);
387        assert!(request.requires_128k_beta());
388    }
389
390    #[test]
391    fn test_count_tokens_request() {
392        let request = CountTokensRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
393            .with_system("You are a helpful assistant");
394
395        assert_eq!(request.model, "claude-sonnet-4-5");
396        assert!(request.system.is_some());
397    }
398
399    #[test]
400    fn test_count_tokens_from_message_request() {
401        let msg_request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
402            .with_system("System prompt")
403            .with_extended_thinking(10000);
404
405        let count_request = CountTokensRequest::from_message_request(&msg_request);
406
407        assert_eq!(count_request.model, msg_request.model);
408        assert_eq!(count_request.messages.len(), msg_request.messages.len());
409        assert!(count_request.system.is_some());
410        assert!(count_request.thinking.is_some());
411    }
412
413    #[test]
414    fn test_request_with_effort() {
415        let request = CreateMessageRequest::new("claude-opus-4-5", vec![Message::user("Hi")])
416            .with_effort(EffortLevel::Medium);
417        assert!(request.output_config.is_some());
418        assert_eq!(
419            request.output_config.unwrap().effort,
420            Some(EffortLevel::Medium)
421        );
422    }
423
424    #[test]
425    fn test_request_with_context_management() {
426        let mgmt = ContextManagement::new().with_edit(ContextManagement::clear_thinking(2));
427        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
428            .with_context_management(mgmt);
429        assert!(request.context_management.is_some());
430    }
431
432    #[test]
433    fn test_request_with_tool_choice() {
434        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
435            .with_tool_choice_any();
436        assert_eq!(request.tool_choice, Some(ToolChoice::Any));
437
438        let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
439            .with_required_tool("Grep");
440        assert_eq!(request.tool_choice, Some(ToolChoice::tool("Grep")));
441    }
442}