claude_agent/client/messages/
request.rs1use serde::{Deserialize, Serialize};
4
5use super::config::{
6 DEFAULT_MAX_TOKENS, EffortLevel, MAX_TOKENS_128K, MIN_MAX_TOKENS, OutputConfig, OutputFormat,
7 ThinkingConfig, TokenValidationError, ToolChoice,
8};
9use super::context::ContextManagement;
10use super::types::{ApiTool, RequestMetadata};
11use crate::types::{
12 Message, SystemPrompt, ToolDefinition, ToolSearchTool, WebFetchTool, WebSearchTool,
13};
14
15#[derive(Debug, Clone, Serialize)]
16pub struct CreateMessageRequest {
17 pub model: String,
18 pub max_tokens: u32,
19 pub messages: Vec<Message>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub system: Option<SystemPrompt>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub tools: Option<Vec<ApiTool>>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub tool_choice: Option<ToolChoice>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stream: Option<bool>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub stop_sequences: Option<Vec<String>>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub temperature: Option<f32>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub top_p: Option<f32>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub top_k: Option<u32>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub metadata: Option<RequestMetadata>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub thinking: Option<ThinkingConfig>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub output_format: Option<OutputFormat>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub context_management: Option<ContextManagement>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub output_config: Option<OutputConfig>,
46}
47
48impl CreateMessageRequest {
49 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
50 Self {
51 model: model.into(),
52 max_tokens: DEFAULT_MAX_TOKENS,
53 messages,
54 system: None,
55 tools: None,
56 tool_choice: None,
57 stream: None,
58 stop_sequences: None,
59 temperature: None,
60 top_p: None,
61 top_k: None,
62 metadata: None,
63 thinking: None,
64 output_format: None,
65 context_management: None,
66 output_config: None,
67 }
68 }
69
70 pub fn validate(&self) -> Result<(), TokenValidationError> {
71 if self.max_tokens < MIN_MAX_TOKENS {
72 return Err(TokenValidationError::MaxTokensTooLow {
73 min: MIN_MAX_TOKENS,
74 actual: self.max_tokens,
75 });
76 }
77 if self.max_tokens > MAX_TOKENS_128K {
78 return Err(TokenValidationError::MaxTokensTooHigh {
79 max: MAX_TOKENS_128K,
80 actual: self.max_tokens,
81 });
82 }
83 if let Some(thinking) = &self.thinking {
84 thinking.validate_against_max_tokens(self.max_tokens)?;
85 }
86 Ok(())
87 }
88
89 pub fn requires_128k_beta(&self) -> bool {
90 self.max_tokens > DEFAULT_MAX_TOKENS
91 }
92
93 pub fn with_metadata(mut self, metadata: RequestMetadata) -> Self {
94 self.metadata = Some(metadata);
95 self
96 }
97
98 pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
99 self.system = Some(system.into());
100 self
101 }
102
103 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
104 let api_tools: Vec<ApiTool> = tools.into_iter().map(ApiTool::Custom).collect();
105 self.tools = Some(api_tools);
106 self
107 }
108
109 pub fn with_web_search(mut self, config: WebSearchTool) -> Self {
110 let mut tools = self.tools.unwrap_or_default();
111 tools.push(ApiTool::WebSearch(config));
112 self.tools = Some(tools);
113 self
114 }
115
116 pub fn with_web_fetch(mut self, config: WebFetchTool) -> Self {
117 let mut tools = self.tools.unwrap_or_default();
118 tools.push(ApiTool::WebFetch(config));
119 self.tools = Some(tools);
120 self
121 }
122
123 pub fn with_tool_search(mut self, config: ToolSearchTool) -> Self {
124 let mut tools = self.tools.unwrap_or_default();
125 tools.push(ApiTool::ToolSearch(config));
126 self.tools = Some(tools);
127 self
128 }
129
130 pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
131 self.tools = Some(tools);
132 self
133 }
134
135 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
136 self.tool_choice = Some(choice);
137 self
138 }
139
140 pub fn with_tool_choice_auto(mut self) -> Self {
141 self.tool_choice = Some(ToolChoice::Auto);
142 self
143 }
144
145 pub fn with_tool_choice_any(mut self) -> Self {
146 self.tool_choice = Some(ToolChoice::Any);
147 self
148 }
149
150 pub fn with_tool_choice_none(mut self) -> Self {
151 self.tool_choice = Some(ToolChoice::None);
152 self
153 }
154
155 pub fn with_required_tool(mut self, name: impl Into<String>) -> Self {
156 self.tool_choice = Some(ToolChoice::tool(name));
157 self
158 }
159
160 pub fn with_stream(mut self) -> Self {
161 self.stream = Some(true);
162 self
163 }
164
165 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
166 self.max_tokens = max_tokens;
167 self
168 }
169
170 pub fn with_model(mut self, model: impl Into<String>) -> Self {
171 self.model = model.into();
172 self
173 }
174
175 pub fn with_temperature(mut self, temperature: f32) -> Self {
176 self.temperature = Some(temperature);
177 self
178 }
179
180 pub fn with_top_p(mut self, top_p: f32) -> Self {
181 self.top_p = Some(top_p);
182 self
183 }
184
185 pub fn with_top_k(mut self, top_k: u32) -> Self {
186 self.top_k = Some(top_k);
187 self
188 }
189
190 pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
191 self.stop_sequences = Some(sequences);
192 self
193 }
194
195 pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
196 self.thinking = Some(config);
197 self
198 }
199
200 pub fn with_extended_thinking(mut self, budget_tokens: u32) -> Self {
201 self.thinking = Some(ThinkingConfig::enabled(budget_tokens));
202 self
203 }
204
205 pub fn with_output_format(mut self, format: OutputFormat) -> Self {
206 self.output_format = Some(format);
207 self
208 }
209
210 pub fn with_json_schema(mut self, schema: serde_json::Value) -> Self {
217 let strict_schema = crate::client::schema::transform_for_strict(schema);
218 self.output_format = Some(OutputFormat::json_schema(strict_schema));
219 self
220 }
221
222 pub fn with_context_management(mut self, management: ContextManagement) -> Self {
223 self.context_management = Some(management);
224 self
225 }
226
227 pub fn with_effort(mut self, level: EffortLevel) -> Self {
228 self.output_config = Some(OutputConfig::with_effort(level));
229 self
230 }
231
232 pub fn with_output_config(mut self, config: OutputConfig) -> Self {
233 self.output_config = Some(config);
234 self
235 }
236}
237
238impl From<String> for SystemPrompt {
239 fn from(s: String) -> Self {
240 SystemPrompt::Text(s)
241 }
242}
243
244impl From<&str> for SystemPrompt {
245 fn from(s: &str) -> Self {
246 SystemPrompt::Text(s.to_string())
247 }
248}
249
250#[derive(Debug, Clone, Serialize)]
251pub struct CountTokensRequest {
252 pub model: String,
253 pub messages: Vec<Message>,
254 #[serde(skip_serializing_if = "Option::is_none")]
255 pub system: Option<SystemPrompt>,
256 #[serde(skip_serializing_if = "Option::is_none")]
257 pub tools: Option<Vec<ApiTool>>,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 pub thinking: Option<ThinkingConfig>,
260}
261
262impl CountTokensRequest {
263 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
264 Self {
265 model: model.into(),
266 messages,
267 system: None,
268 tools: None,
269 thinking: None,
270 }
271 }
272
273 pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
274 self.system = Some(system.into());
275 self
276 }
277
278 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
279 self.tools = Some(tools.into_iter().map(ApiTool::Custom).collect());
280 self
281 }
282
283 pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
284 self.tools = Some(tools);
285 self
286 }
287
288 pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
289 self.thinking = Some(config);
290 self
291 }
292
293 pub fn from_message_request(request: &CreateMessageRequest) -> Self {
294 Self {
295 model: request.model.clone(),
296 messages: request.messages.clone(),
297 system: request.system.clone(),
298 tools: request.tools.clone(),
299 thinking: request.thinking.clone(),
300 }
301 }
302}
303
304#[derive(Debug, Clone, Deserialize)]
305pub struct CountTokensResponse {
306 pub input_tokens: u32,
307 #[serde(default, skip_serializing_if = "Option::is_none")]
308 pub context_management: Option<CountTokensContextManagement>,
309}
310
311#[derive(Debug, Clone, Default, Deserialize)]
312pub struct CountTokensContextManagement {
313 #[serde(default)]
314 pub original_input_tokens: Option<u32>,
315}
316
317#[cfg(test)]
318mod tests {
319 use super::super::config::MIN_THINKING_BUDGET;
320 use super::*;
321
322 #[test]
323 fn test_create_request_default_max_tokens() {
324 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")]);
325 assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
326 }
327
328 #[test]
329 fn test_create_request() {
330 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
331 .with_max_tokens(1000)
332 .with_temperature(0.7);
333
334 assert_eq!(request.model, "claude-sonnet-4-5");
335 assert_eq!(request.max_tokens, 1000);
336 assert_eq!(request.temperature, Some(0.7));
337 }
338
339 #[test]
340 fn test_request_validate_valid() {
341 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
342 .with_max_tokens(4000)
343 .with_extended_thinking(2000);
344 assert!(request.validate().is_ok());
345 }
346
347 #[test]
348 fn test_request_validate_max_tokens_too_high() {
349 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
350 .with_max_tokens(MAX_TOKENS_128K + 1);
351 let err = request.validate().unwrap_err();
352 assert!(matches!(err, TokenValidationError::MaxTokensTooHigh { .. }));
353 }
354
355 #[test]
356 fn test_request_validate_thinking_auto_clamp() {
357 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
358 .with_extended_thinking(500);
359 assert_eq!(
360 request.thinking.as_ref().unwrap().budget(),
361 Some(MIN_THINKING_BUDGET)
362 );
363 assert!(request.validate().is_ok());
364 }
365
366 #[test]
367 fn test_request_validate_thinking_exceeds_max() {
368 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
369 .with_max_tokens(2000)
370 .with_extended_thinking(MIN_THINKING_BUDGET);
371 assert!(request.validate().is_ok());
372
373 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
374 .with_max_tokens(MIN_THINKING_BUDGET)
375 .with_extended_thinking(MIN_THINKING_BUDGET);
376 let err = request.validate().unwrap_err();
377 assert!(matches!(
378 err,
379 TokenValidationError::ThinkingBudgetExceedsMaxTokens { .. }
380 ));
381 }
382
383 #[test]
384 fn test_request_requires_128k_beta() {
385 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")]);
386 assert!(!request.requires_128k_beta());
387
388 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
389 .with_max_tokens(DEFAULT_MAX_TOKENS + 1);
390 assert!(request.requires_128k_beta());
391
392 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
393 .with_max_tokens(MAX_TOKENS_128K);
394 assert!(request.requires_128k_beta());
395 }
396
397 #[test]
398 fn test_count_tokens_request() {
399 let request = CountTokensRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
400 .with_system("You are a helpful assistant");
401
402 assert_eq!(request.model, "claude-sonnet-4-5");
403 assert!(request.system.is_some());
404 }
405
406 #[test]
407 fn test_count_tokens_from_message_request() {
408 let msg_request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
409 .with_system("System prompt")
410 .with_extended_thinking(10000);
411
412 let count_request = CountTokensRequest::from_message_request(&msg_request);
413
414 assert_eq!(count_request.model, msg_request.model);
415 assert_eq!(count_request.messages.len(), msg_request.messages.len());
416 assert!(count_request.system.is_some());
417 assert!(count_request.thinking.is_some());
418 }
419
420 #[test]
421 fn test_request_with_effort() {
422 let request = CreateMessageRequest::new("claude-opus-4-5", vec![Message::user("Hi")])
423 .with_effort(EffortLevel::Medium);
424 assert!(request.output_config.is_some());
425 assert_eq!(
426 request.output_config.unwrap().effort,
427 Some(EffortLevel::Medium)
428 );
429 }
430
431 #[test]
432 fn test_request_with_context_management() {
433 let mgmt = ContextManagement::new().with_edit(ContextManagement::clear_thinking(2));
434 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
435 .with_context_management(mgmt);
436 assert!(request.context_management.is_some());
437 }
438
439 #[test]
440 fn test_request_with_tool_choice() {
441 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
442 .with_tool_choice_any();
443 assert_eq!(request.tool_choice, Some(ToolChoice::Any));
444
445 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
446 .with_required_tool("Grep");
447 assert_eq!(request.tool_choice, Some(ToolChoice::tool("Grep")));
448 }
449}