claude_agent/client/messages/
request.rs1use serde::{Deserialize, Serialize};
4
5use super::config::{
6 DEFAULT_MAX_TOKENS, EffortLevel, MAX_TOKENS_128K, MIN_MAX_TOKENS, OutputConfig, OutputFormat,
7 ThinkingConfig, TokenValidationError, ToolChoice,
8};
9use super::context::ContextManagement;
10use super::types::{ApiTool, RequestMetadata};
11use crate::types::{
12 Message, SystemPrompt, ToolDefinition, ToolSearchTool, WebFetchTool, WebSearchTool,
13};
14
15#[derive(Debug, Clone, Serialize)]
16pub struct CreateMessageRequest {
17 pub model: String,
18 pub max_tokens: u32,
19 pub messages: Vec<Message>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub system: Option<SystemPrompt>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub tools: Option<Vec<ApiTool>>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub tool_choice: Option<ToolChoice>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stream: Option<bool>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub stop_sequences: Option<Vec<String>>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub temperature: Option<f32>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub top_p: Option<f32>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub top_k: Option<u32>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub metadata: Option<RequestMetadata>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub thinking: Option<ThinkingConfig>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub output_format: Option<OutputFormat>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub context_management: Option<ContextManagement>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub output_config: Option<OutputConfig>,
46}
47
48impl CreateMessageRequest {
49 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
50 Self {
51 model: model.into(),
52 max_tokens: DEFAULT_MAX_TOKENS,
53 messages,
54 system: None,
55 tools: None,
56 tool_choice: None,
57 stream: None,
58 stop_sequences: None,
59 temperature: None,
60 top_p: None,
61 top_k: None,
62 metadata: None,
63 thinking: None,
64 output_format: None,
65 context_management: None,
66 output_config: None,
67 }
68 }
69
70 pub fn validate(&self) -> Result<(), TokenValidationError> {
71 if self.max_tokens < MIN_MAX_TOKENS {
72 return Err(TokenValidationError::MaxTokensTooLow {
73 min: MIN_MAX_TOKENS,
74 actual: self.max_tokens,
75 });
76 }
77 if self.max_tokens > MAX_TOKENS_128K {
78 return Err(TokenValidationError::MaxTokensTooHigh {
79 max: MAX_TOKENS_128K,
80 actual: self.max_tokens,
81 });
82 }
83 if let Some(thinking) = &self.thinking {
84 thinking.validate_against_max_tokens(self.max_tokens)?;
85 }
86 Ok(())
87 }
88
89 pub fn requires_128k_beta(&self) -> bool {
90 self.max_tokens > DEFAULT_MAX_TOKENS
91 }
92
93 pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
94 self.system = Some(system.into());
95 self
96 }
97
98 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
99 let api_tools: Vec<ApiTool> = tools.into_iter().map(ApiTool::Custom).collect();
100 self.tools = Some(api_tools);
101 self
102 }
103
104 pub fn with_web_search(mut self, config: WebSearchTool) -> Self {
105 let mut tools = self.tools.unwrap_or_default();
106 tools.push(ApiTool::WebSearch(config));
107 self.tools = Some(tools);
108 self
109 }
110
111 pub fn with_web_fetch(mut self, config: WebFetchTool) -> Self {
112 let mut tools = self.tools.unwrap_or_default();
113 tools.push(ApiTool::WebFetch(config));
114 self.tools = Some(tools);
115 self
116 }
117
118 pub fn with_tool_search(mut self, config: ToolSearchTool) -> Self {
119 let mut tools = self.tools.unwrap_or_default();
120 tools.push(ApiTool::ToolSearch(config));
121 self.tools = Some(tools);
122 self
123 }
124
125 pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
126 self.tools = Some(tools);
127 self
128 }
129
130 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
131 self.tool_choice = Some(choice);
132 self
133 }
134
135 pub fn with_tool_choice_auto(mut self) -> Self {
136 self.tool_choice = Some(ToolChoice::Auto);
137 self
138 }
139
140 pub fn with_tool_choice_any(mut self) -> Self {
141 self.tool_choice = Some(ToolChoice::Any);
142 self
143 }
144
145 pub fn with_tool_choice_none(mut self) -> Self {
146 self.tool_choice = Some(ToolChoice::None);
147 self
148 }
149
150 pub fn with_required_tool(mut self, name: impl Into<String>) -> Self {
151 self.tool_choice = Some(ToolChoice::tool(name));
152 self
153 }
154
155 pub fn with_stream(mut self) -> Self {
156 self.stream = Some(true);
157 self
158 }
159
160 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
161 self.max_tokens = max_tokens;
162 self
163 }
164
165 pub fn with_model(mut self, model: impl Into<String>) -> Self {
166 self.model = model.into();
167 self
168 }
169
170 pub fn with_temperature(mut self, temperature: f32) -> Self {
171 self.temperature = Some(temperature);
172 self
173 }
174
175 pub fn with_top_p(mut self, top_p: f32) -> Self {
176 self.top_p = Some(top_p);
177 self
178 }
179
180 pub fn with_top_k(mut self, top_k: u32) -> Self {
181 self.top_k = Some(top_k);
182 self
183 }
184
185 pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
186 self.stop_sequences = Some(sequences);
187 self
188 }
189
190 pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
191 self.thinking = Some(config);
192 self
193 }
194
195 pub fn with_extended_thinking(mut self, budget_tokens: u32) -> Self {
196 self.thinking = Some(ThinkingConfig::enabled(budget_tokens));
197 self
198 }
199
200 pub fn with_output_format(mut self, format: OutputFormat) -> Self {
201 self.output_format = Some(format);
202 self
203 }
204
205 pub fn with_json_schema(mut self, schema: serde_json::Value) -> Self {
212 let strict_schema = crate::client::schema::transform_for_strict(schema);
213 self.output_format = Some(OutputFormat::json_schema(strict_schema));
214 self
215 }
216
217 pub fn with_context_management(mut self, management: ContextManagement) -> Self {
218 self.context_management = Some(management);
219 self
220 }
221
222 pub fn with_effort(mut self, level: EffortLevel) -> Self {
223 self.output_config = Some(OutputConfig::with_effort(level));
224 self
225 }
226
227 pub fn with_output_config(mut self, config: OutputConfig) -> Self {
228 self.output_config = Some(config);
229 self
230 }
231}
232
233impl From<String> for SystemPrompt {
234 fn from(s: String) -> Self {
235 SystemPrompt::Text(s)
236 }
237}
238
239impl From<&str> for SystemPrompt {
240 fn from(s: &str) -> Self {
241 SystemPrompt::Text(s.to_string())
242 }
243}
244
245#[derive(Debug, Clone, Serialize)]
246pub struct CountTokensRequest {
247 pub model: String,
248 pub messages: Vec<Message>,
249 #[serde(skip_serializing_if = "Option::is_none")]
250 pub system: Option<SystemPrompt>,
251 #[serde(skip_serializing_if = "Option::is_none")]
252 pub tools: Option<Vec<ApiTool>>,
253 #[serde(skip_serializing_if = "Option::is_none")]
254 pub thinking: Option<ThinkingConfig>,
255}
256
257impl CountTokensRequest {
258 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
259 Self {
260 model: model.into(),
261 messages,
262 system: None,
263 tools: None,
264 thinking: None,
265 }
266 }
267
268 pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
269 self.system = Some(system.into());
270 self
271 }
272
273 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
274 self.tools = Some(tools.into_iter().map(ApiTool::Custom).collect());
275 self
276 }
277
278 pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
279 self.tools = Some(tools);
280 self
281 }
282
283 pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
284 self.thinking = Some(config);
285 self
286 }
287
288 pub fn from_message_request(request: &CreateMessageRequest) -> Self {
289 Self {
290 model: request.model.clone(),
291 messages: request.messages.clone(),
292 system: request.system.clone(),
293 tools: request.tools.clone(),
294 thinking: request.thinking.clone(),
295 }
296 }
297}
298
299#[derive(Debug, Clone, Deserialize)]
300pub struct CountTokensResponse {
301 pub input_tokens: u32,
302 #[serde(default, skip_serializing_if = "Option::is_none")]
303 pub context_management: Option<CountTokensContextManagement>,
304}
305
306#[derive(Debug, Clone, Default, Deserialize)]
307pub struct CountTokensContextManagement {
308 #[serde(default)]
309 pub original_input_tokens: Option<u32>,
310}
311
312#[cfg(test)]
313mod tests {
314 use super::super::config::MIN_THINKING_BUDGET;
315 use super::*;
316
317 #[test]
318 fn test_create_request_default_max_tokens() {
319 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")]);
320 assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
321 }
322
323 #[test]
324 fn test_create_request() {
325 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
326 .with_max_tokens(1000)
327 .with_temperature(0.7);
328
329 assert_eq!(request.model, "claude-sonnet-4-5");
330 assert_eq!(request.max_tokens, 1000);
331 assert_eq!(request.temperature, Some(0.7));
332 }
333
334 #[test]
335 fn test_request_validate_valid() {
336 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
337 .with_max_tokens(4000)
338 .with_extended_thinking(2000);
339 assert!(request.validate().is_ok());
340 }
341
342 #[test]
343 fn test_request_validate_max_tokens_too_high() {
344 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
345 .with_max_tokens(MAX_TOKENS_128K + 1);
346 let err = request.validate().unwrap_err();
347 assert!(matches!(err, TokenValidationError::MaxTokensTooHigh { .. }));
348 }
349
350 #[test]
351 fn test_request_validate_thinking_auto_clamp() {
352 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
353 .with_extended_thinking(500);
354 assert_eq!(
355 request.thinking.as_ref().unwrap().budget(),
356 Some(MIN_THINKING_BUDGET)
357 );
358 assert!(request.validate().is_ok());
359 }
360
361 #[test]
362 fn test_request_validate_thinking_exceeds_max() {
363 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
364 .with_max_tokens(2000)
365 .with_extended_thinking(MIN_THINKING_BUDGET);
366 assert!(request.validate().is_ok());
367
368 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
369 .with_max_tokens(MIN_THINKING_BUDGET)
370 .with_extended_thinking(MIN_THINKING_BUDGET);
371 let err = request.validate().unwrap_err();
372 assert!(matches!(
373 err,
374 TokenValidationError::ThinkingBudgetExceedsMaxTokens { .. }
375 ));
376 }
377
378 #[test]
379 fn test_request_requires_128k_beta() {
380 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")]);
381 assert!(!request.requires_128k_beta());
382
383 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
384 .with_max_tokens(DEFAULT_MAX_TOKENS + 1);
385 assert!(request.requires_128k_beta());
386
387 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
388 .with_max_tokens(MAX_TOKENS_128K);
389 assert!(request.requires_128k_beta());
390 }
391
392 #[test]
393 fn test_count_tokens_request() {
394 let request = CountTokensRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
395 .with_system("You are a helpful assistant");
396
397 assert_eq!(request.model, "claude-sonnet-4-5");
398 assert!(request.system.is_some());
399 }
400
401 #[test]
402 fn test_count_tokens_from_message_request() {
403 let msg_request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
404 .with_system("System prompt")
405 .with_extended_thinking(10000);
406
407 let count_request = CountTokensRequest::from_message_request(&msg_request);
408
409 assert_eq!(count_request.model, msg_request.model);
410 assert_eq!(count_request.messages.len(), msg_request.messages.len());
411 assert!(count_request.system.is_some());
412 assert!(count_request.thinking.is_some());
413 }
414
415 #[test]
416 fn test_request_with_effort() {
417 let request = CreateMessageRequest::new("claude-opus-4-5", vec![Message::user("Hi")])
418 .with_effort(EffortLevel::Medium);
419 assert!(request.output_config.is_some());
420 assert_eq!(
421 request.output_config.unwrap().effort,
422 Some(EffortLevel::Medium)
423 );
424 }
425
426 #[test]
427 fn test_request_with_context_management() {
428 let mgmt = ContextManagement::new().with_edit(ContextManagement::clear_thinking(2));
429 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
430 .with_context_management(mgmt);
431 assert!(request.context_management.is_some());
432 }
433
434 #[test]
435 fn test_request_with_tool_choice() {
436 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
437 .with_tool_choice_any();
438 assert_eq!(request.tool_choice, Some(ToolChoice::Any));
439
440 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
441 .with_required_tool("Grep");
442 assert_eq!(request.tool_choice, Some(ToolChoice::tool("Grep")));
443 }
444}