claude_agent/client/messages/
request.rs1use serde::{Deserialize, Serialize};
4
5use super::config::{
6 DEFAULT_MAX_TOKENS, EffortLevel, MAX_TOKENS_128K, MIN_MAX_TOKENS, OutputConfig, OutputFormat,
7 ThinkingConfig, TokenValidationError, ToolChoice,
8};
9use super::context::ContextManagement;
10use super::types::{ApiTool, RequestMetadata};
11use crate::types::{
12 Message, SystemPrompt, ToolDefinition, ToolSearchTool, WebFetchTool, WebSearchTool,
13};
14
15#[derive(Debug, Clone, Serialize)]
16pub struct CreateMessageRequest {
17 pub model: String,
18 pub max_tokens: u32,
19 pub messages: Vec<Message>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub system: Option<SystemPrompt>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub tools: Option<Vec<ApiTool>>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub tool_choice: Option<ToolChoice>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stream: Option<bool>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub stop_sequences: Option<Vec<String>>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub temperature: Option<f32>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub top_p: Option<f32>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub top_k: Option<u32>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub metadata: Option<RequestMetadata>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub thinking: Option<ThinkingConfig>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub output_format: Option<OutputFormat>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub context_management: Option<ContextManagement>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub output_config: Option<OutputConfig>,
46}
47
48impl CreateMessageRequest {
49 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
50 Self {
51 model: model.into(),
52 max_tokens: DEFAULT_MAX_TOKENS,
53 messages,
54 system: None,
55 tools: None,
56 tool_choice: None,
57 stream: None,
58 stop_sequences: None,
59 temperature: None,
60 top_p: None,
61 top_k: None,
62 metadata: None,
63 thinking: None,
64 output_format: None,
65 context_management: None,
66 output_config: None,
67 }
68 }
69
70 pub fn validate(&self) -> Result<(), TokenValidationError> {
71 if self.max_tokens < MIN_MAX_TOKENS {
72 return Err(TokenValidationError::MaxTokensTooLow {
73 min: MIN_MAX_TOKENS,
74 actual: self.max_tokens,
75 });
76 }
77 if self.max_tokens > MAX_TOKENS_128K {
78 return Err(TokenValidationError::MaxTokensTooHigh {
79 max: MAX_TOKENS_128K,
80 actual: self.max_tokens,
81 });
82 }
83 if let Some(thinking) = &self.thinking {
84 thinking.validate_against_max_tokens(self.max_tokens)?;
85 }
86 Ok(())
87 }
88
89 pub fn requires_128k_beta(&self) -> bool {
90 self.max_tokens > DEFAULT_MAX_TOKENS
91 }
92
93 pub fn with_metadata(mut self, metadata: RequestMetadata) -> Self {
94 self.metadata = Some(metadata);
95 self
96 }
97
98 pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
99 self.system = Some(system.into());
100 self
101 }
102
103 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
104 let api_tools: Vec<ApiTool> = tools.into_iter().map(ApiTool::Custom).collect();
105 self.tools = Some(api_tools);
106 self
107 }
108
109 pub fn with_web_search(mut self, config: WebSearchTool) -> Self {
110 let mut tools = self.tools.unwrap_or_default();
111 tools.push(ApiTool::WebSearch(config));
112 self.tools = Some(tools);
113 self
114 }
115
116 pub fn with_web_fetch(mut self, config: WebFetchTool) -> Self {
117 let mut tools = self.tools.unwrap_or_default();
118 tools.push(ApiTool::WebFetch(config));
119 self.tools = Some(tools);
120 self
121 }
122
123 pub fn with_tool_search(mut self, config: ToolSearchTool) -> Self {
124 let mut tools = self.tools.unwrap_or_default();
125 tools.push(ApiTool::ToolSearch(config));
126 self.tools = Some(tools);
127 self
128 }
129
130 pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
131 self.tools = Some(tools);
132 self
133 }
134
135 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
136 self.tool_choice = Some(choice);
137 self
138 }
139
140 pub fn with_tool_choice_auto(mut self) -> Self {
141 self.tool_choice = Some(ToolChoice::Auto);
142 self
143 }
144
145 pub fn with_tool_choice_any(mut self) -> Self {
146 self.tool_choice = Some(ToolChoice::Any);
147 self
148 }
149
150 pub fn with_tool_choice_none(mut self) -> Self {
151 self.tool_choice = Some(ToolChoice::None);
152 self
153 }
154
155 pub fn with_required_tool(mut self, name: impl Into<String>) -> Self {
156 self.tool_choice = Some(ToolChoice::tool(name));
157 self
158 }
159
160 pub fn with_stream(mut self) -> Self {
161 self.stream = Some(true);
162 self
163 }
164
165 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
166 self.max_tokens = max_tokens;
167 self
168 }
169
170 pub fn with_model(mut self, model: impl Into<String>) -> Self {
171 self.model = model.into();
172 self
173 }
174
175 pub fn with_temperature(mut self, temperature: f32) -> Self {
176 self.temperature = Some(temperature);
177 self
178 }
179
180 pub fn with_top_p(mut self, top_p: f32) -> Self {
181 self.top_p = Some(top_p);
182 self
183 }
184
185 pub fn with_top_k(mut self, top_k: u32) -> Self {
186 self.top_k = Some(top_k);
187 self
188 }
189
190 pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
191 self.stop_sequences = Some(sequences);
192 self
193 }
194
195 pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
196 self.thinking = Some(config);
197 self
198 }
199
200 pub fn with_extended_thinking(mut self, budget_tokens: u32) -> Self {
201 self.thinking = Some(ThinkingConfig::enabled(budget_tokens));
202 self
203 }
204
205 pub fn with_output_format(mut self, format: OutputFormat) -> Self {
206 self.output_format = Some(format);
207 self
208 }
209
210 pub fn with_json_schema(mut self, schema: serde_json::Value) -> Self {
211 self.output_format = Some(OutputFormat::json_schema(schema));
212 self
213 }
214
215 pub fn with_context_management(mut self, management: ContextManagement) -> Self {
216 self.context_management = Some(management);
217 self
218 }
219
220 pub fn with_effort(mut self, level: EffortLevel) -> Self {
221 self.output_config = Some(OutputConfig::with_effort(level));
222 self
223 }
224
225 pub fn with_output_config(mut self, config: OutputConfig) -> Self {
226 self.output_config = Some(config);
227 self
228 }
229}
230
231impl From<String> for SystemPrompt {
232 fn from(s: String) -> Self {
233 SystemPrompt::Text(s)
234 }
235}
236
237impl From<&str> for SystemPrompt {
238 fn from(s: &str) -> Self {
239 SystemPrompt::Text(s.to_string())
240 }
241}
242
243#[derive(Debug, Clone, Serialize)]
244pub struct CountTokensRequest {
245 pub model: String,
246 pub messages: Vec<Message>,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub system: Option<SystemPrompt>,
249 #[serde(skip_serializing_if = "Option::is_none")]
250 pub tools: Option<Vec<ApiTool>>,
251 #[serde(skip_serializing_if = "Option::is_none")]
252 pub thinking: Option<ThinkingConfig>,
253}
254
255impl CountTokensRequest {
256 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
257 Self {
258 model: model.into(),
259 messages,
260 system: None,
261 tools: None,
262 thinking: None,
263 }
264 }
265
266 pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
267 self.system = Some(system.into());
268 self
269 }
270
271 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
272 self.tools = Some(tools.into_iter().map(ApiTool::Custom).collect());
273 self
274 }
275
276 pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
277 self.tools = Some(tools);
278 self
279 }
280
281 pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
282 self.thinking = Some(config);
283 self
284 }
285
286 pub fn from_message_request(request: &CreateMessageRequest) -> Self {
287 Self {
288 model: request.model.clone(),
289 messages: request.messages.clone(),
290 system: request.system.clone(),
291 tools: request.tools.clone(),
292 thinking: request.thinking.clone(),
293 }
294 }
295}
296
297#[derive(Debug, Clone, Deserialize)]
298pub struct CountTokensResponse {
299 pub input_tokens: u32,
300 #[serde(default, skip_serializing_if = "Option::is_none")]
301 pub context_management: Option<CountTokensContextManagement>,
302}
303
304#[derive(Debug, Clone, Default, Deserialize)]
305pub struct CountTokensContextManagement {
306 #[serde(default)]
307 pub original_input_tokens: Option<u32>,
308}
309
310#[cfg(test)]
311mod tests {
312 use super::super::config::MIN_THINKING_BUDGET;
313 use super::*;
314
315 #[test]
316 fn test_create_request_default_max_tokens() {
317 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")]);
318 assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
319 }
320
321 #[test]
322 fn test_create_request() {
323 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
324 .with_max_tokens(1000)
325 .with_temperature(0.7);
326
327 assert_eq!(request.model, "claude-sonnet-4-5");
328 assert_eq!(request.max_tokens, 1000);
329 assert_eq!(request.temperature, Some(0.7));
330 }
331
332 #[test]
333 fn test_request_validate_valid() {
334 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
335 .with_max_tokens(4000)
336 .with_extended_thinking(2000);
337 assert!(request.validate().is_ok());
338 }
339
340 #[test]
341 fn test_request_validate_max_tokens_too_high() {
342 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
343 .with_max_tokens(MAX_TOKENS_128K + 1);
344 let err = request.validate().unwrap_err();
345 assert!(matches!(err, TokenValidationError::MaxTokensTooHigh { .. }));
346 }
347
348 #[test]
349 fn test_request_validate_thinking_auto_clamp() {
350 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
351 .with_extended_thinking(500);
352 assert_eq!(
353 request.thinking.as_ref().unwrap().budget(),
354 Some(MIN_THINKING_BUDGET)
355 );
356 assert!(request.validate().is_ok());
357 }
358
359 #[test]
360 fn test_request_validate_thinking_exceeds_max() {
361 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
362 .with_max_tokens(2000)
363 .with_extended_thinking(MIN_THINKING_BUDGET);
364 assert!(request.validate().is_ok());
365
366 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
367 .with_max_tokens(MIN_THINKING_BUDGET)
368 .with_extended_thinking(MIN_THINKING_BUDGET);
369 let err = request.validate().unwrap_err();
370 assert!(matches!(
371 err,
372 TokenValidationError::ThinkingBudgetExceedsMaxTokens { .. }
373 ));
374 }
375
376 #[test]
377 fn test_request_requires_128k_beta() {
378 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")]);
379 assert!(!request.requires_128k_beta());
380
381 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
382 .with_max_tokens(DEFAULT_MAX_TOKENS + 1);
383 assert!(request.requires_128k_beta());
384
385 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
386 .with_max_tokens(MAX_TOKENS_128K);
387 assert!(request.requires_128k_beta());
388 }
389
390 #[test]
391 fn test_count_tokens_request() {
392 let request = CountTokensRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
393 .with_system("You are a helpful assistant");
394
395 assert_eq!(request.model, "claude-sonnet-4-5");
396 assert!(request.system.is_some());
397 }
398
399 #[test]
400 fn test_count_tokens_from_message_request() {
401 let msg_request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
402 .with_system("System prompt")
403 .with_extended_thinking(10000);
404
405 let count_request = CountTokensRequest::from_message_request(&msg_request);
406
407 assert_eq!(count_request.model, msg_request.model);
408 assert_eq!(count_request.messages.len(), msg_request.messages.len());
409 assert!(count_request.system.is_some());
410 assert!(count_request.thinking.is_some());
411 }
412
413 #[test]
414 fn test_request_with_effort() {
415 let request = CreateMessageRequest::new("claude-opus-4-5", vec![Message::user("Hi")])
416 .with_effort(EffortLevel::Medium);
417 assert!(request.output_config.is_some());
418 assert_eq!(
419 request.output_config.unwrap().effort,
420 Some(EffortLevel::Medium)
421 );
422 }
423
424 #[test]
425 fn test_request_with_context_management() {
426 let mgmt = ContextManagement::new().with_edit(ContextManagement::clear_thinking(2));
427 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
428 .with_context_management(mgmt);
429 assert!(request.context_management.is_some());
430 }
431
432 #[test]
433 fn test_request_with_tool_choice() {
434 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
435 .with_tool_choice_any();
436 assert_eq!(request.tool_choice, Some(ToolChoice::Any));
437
438 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
439 .with_required_tool("Grep");
440 assert_eq!(request.tool_choice, Some(ToolChoice::tool("Grep")));
441 }
442}