claude_agent/client/messages/
request.rs1use serde::{Deserialize, Serialize};
4
5use super::config::{
6 DEFAULT_MAX_TOKENS, EffortLevel, MAX_TOKENS_128K, MIN_MAX_TOKENS, OutputConfig, OutputFormat,
7 ThinkingConfig, TokenValidationError, ToolChoice,
8};
9use super::context::ContextManagement;
10use super::types::{ApiTool, RequestMetadata};
11use crate::types::{Message, SystemPrompt, ToolDefinition, WebFetchTool, WebSearchTool};
12
13#[derive(Debug, Clone, Serialize)]
14pub struct CreateMessageRequest {
15 pub model: String,
16 pub max_tokens: u32,
17 pub messages: Vec<Message>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub system: Option<SystemPrompt>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub tools: Option<Vec<ApiTool>>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub tool_choice: Option<ToolChoice>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub stream: Option<bool>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stop_sequences: Option<Vec<String>>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub temperature: Option<f32>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub top_p: Option<f32>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub top_k: Option<u32>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub metadata: Option<RequestMetadata>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub thinking: Option<ThinkingConfig>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub output_format: Option<OutputFormat>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub context_management: Option<ContextManagement>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub output_config: Option<OutputConfig>,
44}
45
46impl CreateMessageRequest {
47 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
48 Self {
49 model: model.into(),
50 max_tokens: DEFAULT_MAX_TOKENS,
51 messages,
52 system: None,
53 tools: None,
54 tool_choice: None,
55 stream: None,
56 stop_sequences: None,
57 temperature: None,
58 top_p: None,
59 top_k: None,
60 metadata: None,
61 thinking: None,
62 output_format: None,
63 context_management: None,
64 output_config: None,
65 }
66 }
67
68 pub fn validate(&self) -> Result<(), TokenValidationError> {
69 if self.max_tokens < MIN_MAX_TOKENS {
70 return Err(TokenValidationError::MaxTokensTooLow {
71 min: MIN_MAX_TOKENS,
72 actual: self.max_tokens,
73 });
74 }
75 if self.max_tokens > MAX_TOKENS_128K {
76 return Err(TokenValidationError::MaxTokensTooHigh {
77 max: MAX_TOKENS_128K,
78 actual: self.max_tokens,
79 });
80 }
81 if let Some(thinking) = &self.thinking {
82 thinking.validate_against_max_tokens(self.max_tokens)?;
83 }
84 Ok(())
85 }
86
87 pub fn requires_128k_beta(&self) -> bool {
88 self.max_tokens > DEFAULT_MAX_TOKENS
89 }
90
91 pub fn with_metadata(mut self, metadata: RequestMetadata) -> Self {
92 self.metadata = Some(metadata);
93 self
94 }
95
96 pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
97 self.system = Some(system.into());
98 self
99 }
100
101 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
102 let api_tools: Vec<ApiTool> = tools.into_iter().map(ApiTool::Custom).collect();
103 self.tools = Some(api_tools);
104 self
105 }
106
107 pub fn with_web_search(mut self, config: WebSearchTool) -> Self {
108 let mut tools = self.tools.unwrap_or_default();
109 tools.push(ApiTool::WebSearch(config));
110 self.tools = Some(tools);
111 self
112 }
113
114 pub fn with_web_fetch(mut self, config: WebFetchTool) -> Self {
115 let mut tools = self.tools.unwrap_or_default();
116 tools.push(ApiTool::WebFetch(config));
117 self.tools = Some(tools);
118 self
119 }
120
121 pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
122 self.tools = Some(tools);
123 self
124 }
125
126 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
127 self.tool_choice = Some(choice);
128 self
129 }
130
131 pub fn with_tool_choice_auto(mut self) -> Self {
132 self.tool_choice = Some(ToolChoice::Auto);
133 self
134 }
135
136 pub fn with_tool_choice_any(mut self) -> Self {
137 self.tool_choice = Some(ToolChoice::Any);
138 self
139 }
140
141 pub fn with_tool_choice_none(mut self) -> Self {
142 self.tool_choice = Some(ToolChoice::None);
143 self
144 }
145
146 pub fn with_required_tool(mut self, name: impl Into<String>) -> Self {
147 self.tool_choice = Some(ToolChoice::tool(name));
148 self
149 }
150
151 pub fn with_stream(mut self) -> Self {
152 self.stream = Some(true);
153 self
154 }
155
156 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
157 self.max_tokens = max_tokens;
158 self
159 }
160
161 pub fn with_model(mut self, model: impl Into<String>) -> Self {
162 self.model = model.into();
163 self
164 }
165
166 pub fn with_temperature(mut self, temperature: f32) -> Self {
167 self.temperature = Some(temperature);
168 self
169 }
170
171 pub fn with_top_p(mut self, top_p: f32) -> Self {
172 self.top_p = Some(top_p);
173 self
174 }
175
176 pub fn with_top_k(mut self, top_k: u32) -> Self {
177 self.top_k = Some(top_k);
178 self
179 }
180
181 pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
182 self.stop_sequences = Some(sequences);
183 self
184 }
185
186 pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
187 self.thinking = Some(config);
188 self
189 }
190
191 pub fn with_extended_thinking(mut self, budget_tokens: u32) -> Self {
192 self.thinking = Some(ThinkingConfig::enabled(budget_tokens));
193 self
194 }
195
196 pub fn with_output_format(mut self, format: OutputFormat) -> Self {
197 self.output_format = Some(format);
198 self
199 }
200
201 pub fn with_json_schema(mut self, schema: serde_json::Value) -> Self {
202 self.output_format = Some(OutputFormat::json_schema(schema));
203 self
204 }
205
206 pub fn with_context_management(mut self, management: ContextManagement) -> Self {
207 self.context_management = Some(management);
208 self
209 }
210
211 pub fn with_effort(mut self, level: EffortLevel) -> Self {
212 self.output_config = Some(OutputConfig::with_effort(level));
213 self
214 }
215
216 pub fn with_output_config(mut self, config: OutputConfig) -> Self {
217 self.output_config = Some(config);
218 self
219 }
220}
221
222impl From<String> for SystemPrompt {
223 fn from(s: String) -> Self {
224 SystemPrompt::Text(s)
225 }
226}
227
228impl From<&str> for SystemPrompt {
229 fn from(s: &str) -> Self {
230 SystemPrompt::Text(s.to_string())
231 }
232}
233
234#[derive(Debug, Clone, Serialize)]
235pub struct CountTokensRequest {
236 pub model: String,
237 pub messages: Vec<Message>,
238 #[serde(skip_serializing_if = "Option::is_none")]
239 pub system: Option<SystemPrompt>,
240 #[serde(skip_serializing_if = "Option::is_none")]
241 pub tools: Option<Vec<ApiTool>>,
242 #[serde(skip_serializing_if = "Option::is_none")]
243 pub thinking: Option<ThinkingConfig>,
244}
245
246impl CountTokensRequest {
247 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
248 Self {
249 model: model.into(),
250 messages,
251 system: None,
252 tools: None,
253 thinking: None,
254 }
255 }
256
257 pub fn with_system(mut self, system: impl Into<SystemPrompt>) -> Self {
258 self.system = Some(system.into());
259 self
260 }
261
262 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
263 self.tools = Some(tools.into_iter().map(ApiTool::Custom).collect());
264 self
265 }
266
267 pub fn with_api_tools(mut self, tools: Vec<ApiTool>) -> Self {
268 self.tools = Some(tools);
269 self
270 }
271
272 pub fn with_thinking(mut self, config: ThinkingConfig) -> Self {
273 self.thinking = Some(config);
274 self
275 }
276
277 pub fn from_message_request(request: &CreateMessageRequest) -> Self {
278 Self {
279 model: request.model.clone(),
280 messages: request.messages.clone(),
281 system: request.system.clone(),
282 tools: request.tools.clone(),
283 thinking: request.thinking.clone(),
284 }
285 }
286}
287
288#[derive(Debug, Clone, Deserialize)]
289pub struct CountTokensResponse {
290 pub input_tokens: u32,
291 #[serde(default, skip_serializing_if = "Option::is_none")]
292 pub context_management: Option<CountTokensContextManagement>,
293}
294
295#[derive(Debug, Clone, Default, Deserialize)]
296pub struct CountTokensContextManagement {
297 #[serde(default)]
298 pub original_input_tokens: Option<u32>,
299}
300
301#[cfg(test)]
302mod tests {
303 use super::super::config::MIN_THINKING_BUDGET;
304 use super::*;
305
306 #[test]
307 fn test_create_request_default_max_tokens() {
308 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")]);
309 assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
310 }
311
312 #[test]
313 fn test_create_request() {
314 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
315 .with_max_tokens(1000)
316 .with_temperature(0.7);
317
318 assert_eq!(request.model, "claude-sonnet-4-5");
319 assert_eq!(request.max_tokens, 1000);
320 assert_eq!(request.temperature, Some(0.7));
321 }
322
323 #[test]
324 fn test_request_validate_valid() {
325 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
326 .with_max_tokens(4000)
327 .with_extended_thinking(2000);
328 assert!(request.validate().is_ok());
329 }
330
331 #[test]
332 fn test_request_validate_max_tokens_too_high() {
333 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
334 .with_max_tokens(MAX_TOKENS_128K + 1);
335 let err = request.validate().unwrap_err();
336 assert!(matches!(err, TokenValidationError::MaxTokensTooHigh { .. }));
337 }
338
339 #[test]
340 fn test_request_validate_thinking_auto_clamp() {
341 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
342 .with_extended_thinking(500);
343 assert_eq!(
344 request.thinking.as_ref().unwrap().budget(),
345 Some(MIN_THINKING_BUDGET)
346 );
347 assert!(request.validate().is_ok());
348 }
349
350 #[test]
351 fn test_request_validate_thinking_exceeds_max() {
352 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
353 .with_max_tokens(2000)
354 .with_extended_thinking(MIN_THINKING_BUDGET);
355 assert!(request.validate().is_ok());
356
357 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
358 .with_max_tokens(MIN_THINKING_BUDGET)
359 .with_extended_thinking(MIN_THINKING_BUDGET);
360 let err = request.validate().unwrap_err();
361 assert!(matches!(
362 err,
363 TokenValidationError::ThinkingBudgetExceedsMaxTokens { .. }
364 ));
365 }
366
367 #[test]
368 fn test_request_requires_128k_beta() {
369 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")]);
370 assert!(!request.requires_128k_beta());
371
372 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
373 .with_max_tokens(DEFAULT_MAX_TOKENS + 1);
374 assert!(request.requires_128k_beta());
375
376 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
377 .with_max_tokens(MAX_TOKENS_128K);
378 assert!(request.requires_128k_beta());
379 }
380
381 #[test]
382 fn test_count_tokens_request() {
383 let request = CountTokensRequest::new("claude-sonnet-4-5", vec![Message::user("Hello")])
384 .with_system("You are a helpful assistant");
385
386 assert_eq!(request.model, "claude-sonnet-4-5");
387 assert!(request.system.is_some());
388 }
389
390 #[test]
391 fn test_count_tokens_from_message_request() {
392 let msg_request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
393 .with_system("System prompt")
394 .with_extended_thinking(10000);
395
396 let count_request = CountTokensRequest::from_message_request(&msg_request);
397
398 assert_eq!(count_request.model, msg_request.model);
399 assert_eq!(count_request.messages.len(), msg_request.messages.len());
400 assert!(count_request.system.is_some());
401 assert!(count_request.thinking.is_some());
402 }
403
404 #[test]
405 fn test_request_with_effort() {
406 let request = CreateMessageRequest::new("claude-opus-4-5", vec![Message::user("Hi")])
407 .with_effort(EffortLevel::Medium);
408 assert!(request.output_config.is_some());
409 assert_eq!(
410 request.output_config.unwrap().effort,
411 Some(EffortLevel::Medium)
412 );
413 }
414
415 #[test]
416 fn test_request_with_context_management() {
417 let mgmt = ContextManagement::new().with_edit(ContextManagement::clear_thinking(2));
418 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
419 .with_context_management(mgmt);
420 assert!(request.context_management.is_some());
421 }
422
423 #[test]
424 fn test_request_with_tool_choice() {
425 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
426 .with_tool_choice_any();
427 assert_eq!(request.tool_choice, Some(ToolChoice::Any));
428
429 let request = CreateMessageRequest::new("claude-sonnet-4-5", vec![Message::user("Hi")])
430 .with_required_tool("Grep");
431 assert_eq!(request.tool_choice, Some(ToolChoice::tool("Grep")));
432 }
433}