1use core::fmt;
2use serde::{Deserialize, Serialize};
3use std::fmt::Display;
4
5use crate::models::claude::ClaudeModel;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
8pub struct Message {
9 pub role: Role,
10 pub content: Vec<Content>,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16 User,
17 Assistant,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
21pub struct Content {
22 pub text: String,
23 #[serde(rename = "type")]
24 pub content_type: ContentType,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28#[serde(untagged, rename_all = "lowercase")]
29pub enum System {
30 Text(String),
31 Structured(SystemPrompt),
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
35pub struct SystemPrompt {
36 pub text: String,
37 #[serde(rename = "type")]
38 pub content_type: ContentType,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub cache_control: Option<CacheControl>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
44pub struct CacheControl {
45 #[serde(rename = "type")]
46 pub cache_type: CacheType,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50#[serde(rename_all = "lowercase")]
51pub enum CacheType {
52 Ephemeral,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct MessageRequest {
57 pub model: ClaudeModel,
59
60 pub max_tokens: u32,
66
67 pub messages: Vec<Message>,
69
70 pub metadata: Option<MessageMetadata>,
72
73 pub stop_sequences: Option<Vec<String>>,
75
76 pub stream: bool,
78
79 #[serde(skip_serializing_if = "Option::is_none")]
83 pub system: Option<System>,
84
85 #[serde(skip_serializing_if = "Option::is_none")]
92 pub temperature: Option<f32>,
93
94 #[serde(skip_serializing_if = "Option::is_none")]
98 pub top_k: Option<i8>,
99
100 #[serde(skip_serializing_if = "Option::is_none")]
104 pub top_p: Option<i8>,
105}
106
107impl MessageRequest {
108 pub fn new(model: ClaudeModel, max_tokens: u32, messages: Vec<Message>) -> Self {
109 Self {
110 model,
111 max_tokens,
112 messages,
113 ..Default::default()
114 }
115 }
116
117 pub fn with_metadata(mut self, metadata: MessageMetadata) -> Self {
118 self.metadata = Some(metadata);
119 self
120 }
121
122 pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
123 self.stop_sequences = Some(stop_sequences);
124 self
125 }
126
127 pub fn with_stream(mut self, stream: bool) -> Self {
128 self.stream = stream;
129 self
130 }
131
132 pub fn with_system(mut self, system: System) -> Self {
133 self.system = Some(system);
134 self
135 }
136
137 pub fn with_temperature(mut self, temperature: f32) -> Self {
138 self.temperature = Some(temperature);
139 self
140 }
141
142 pub fn with_top_k(mut self, top_k: i8) -> Self {
143 self.top_k = Some(top_k);
144 self
145 }
146
147 pub fn with_top_p(mut self, top_p: i8) -> Self {
148 self.top_p = Some(top_p);
149 self
150 }
151}
152
153impl Default for MessageRequest {
154 fn default() -> Self {
155 Self {
156 model: ClaudeModel::Claude35Sonnet,
157 max_tokens: 1000,
158 messages: Vec::new(),
159 metadata: None,
160 stop_sequences: None,
161 stream: false,
162 system: None,
163 temperature: None,
164 top_k: None,
165 top_p: None,
166 }
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
171pub struct MessageMetadata {
172 pub user_id: Option<String>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
176pub struct MessageResponse {
177 pub id: String,
178 #[serde(rename = "type")]
179 pub message_type: MessageType,
180 pub role: RoleResponse,
181 pub content: Vec<Content>,
182 pub model: ClaudeModel,
183 pub stop_reason: Option<StopReason>,
184 pub stop_sequence: Option<String>,
185 pub usage: TokenUsage,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
189#[serde(rename_all = "lowercase")]
190pub enum MessageType {
191 Message,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
195#[serde(rename_all = "lowercase")]
196pub enum RoleResponse {
197 Assistant,
198}
199
200impl RoleResponse {
201 pub fn as_str(&self) -> &'static str {
202 match self {
203 Self::Assistant => "assistant",
204 }
205 }
206}
207
208impl Display for RoleResponse {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 write!(f, "{:?}", self)
211 }
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
215#[serde(rename_all = "snake_case")]
216pub enum StopReason {
217 EndTurn,
218 MaxTokens,
219 StopSequence,
220 ToolUse,
221}
222
223impl fmt::Display for StopReason {
224 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225 match self {
226 Self::EndTurn => write!(f, "end_turn"),
227 Self::MaxTokens => write!(f, "max_tokens"),
228 Self::StopSequence => write!(f, "stop_sequence"),
229 Self::ToolUse => write!(f, "tool_use"),
230 }
231 }
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
235#[serde(rename_all = "lowercase")]
236pub enum ContentType {
237 Text,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
241pub struct TokenUsage {
242 pub input_tokens: u32,
243 pub output_tokens: u32,
244 pub cache_creation_input_tokens: Option<u32>,
245 pub cache_read_input_tokens: Option<u32>,
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use pretty_assertions::assert_eq;
252
253 #[test]
254 fn should_set_metadata() {
255 let request = MessageRequest::default();
256 assert_eq!(request.metadata, None);
257
258 let metadata = MessageMetadata {
259 user_id: Some("user-id".to_string()),
260 };
261 let request = request.with_metadata(metadata.clone());
262 assert_eq!(request.metadata, Some(metadata));
263 }
264
265 #[test]
266 fn should_set_stop_sequences() {
267 let request = MessageRequest::default();
268 assert_eq!(request.stop_sequences, None);
269
270 let stop_sequences: Vec<String> = vec!["foo".to_string(), "bar".to_string()];
271 let request = request.with_stop_sequences(stop_sequences.clone());
272 assert_eq!(request.stop_sequences, Some(stop_sequences));
273 }
274
275 #[test]
276 fn should_set_stream() {
277 let request = MessageRequest::default();
278 assert_eq!(request.stream, false);
279
280 let request = request.with_stream(true);
281 assert_eq!(request.stream, true);
282
283 let request = request.with_stream(false);
284 assert_eq!(request.stream, false);
285 }
286
287 #[test]
288 fn should_set_system() {
289 let request = MessageRequest::default();
290 assert_eq!(request.system, None);
291
292 let system = System::Structured(SystemPrompt {
293 text: "You are an experienced software engineer".into(),
294 content_type: ContentType::Text,
295 cache_control: Some(CacheControl {
296 cache_type: CacheType::Ephemeral,
297 }),
298 });
299 let request = request.with_system(system.clone());
300 assert_eq!(request.system, Some(system));
301 }
302
303 #[test]
304 fn should_set_temperature() {
305 let request = MessageRequest::default();
306 assert_eq!(request.temperature, None);
307
308 let temperature: f32 = 0.9;
309 let request = request.with_temperature(temperature);
310 assert_eq!(request.temperature, Some(temperature));
311 }
312
313 #[test]
314 fn should_set_top_k() {
315 let request = MessageRequest::default();
316 assert_eq!(request.top_k, None);
317
318 let top_k: i8 = 1;
319 let request = request.with_top_k(top_k);
320 assert_eq!(request.top_k, Some(top_k));
321 }
322
323 #[test]
324 fn should_set_top_p() {
325 let request = MessageRequest::default();
326 assert_eq!(request.top_p, None);
327
328 let top_p: i8 = 1;
329 let request = request.with_top_p(top_p);
330 assert_eq!(request.top_p, Some(top_p));
331 }
332
333 #[test]
334 fn should_serialize_message() {
335 let message = Message {
336 role: Role::User,
337 content: vec![Content {
338 content_type: ContentType::Text,
339 text: "Hello World".to_string(),
340 }],
341 };
342 assert_eq!(
343 serde_json::to_value(&message).unwrap(),
344 serde_json::json!({
345 "role": "user",
346 "content": [{
347 "type": "text",
348 "text": "Hello World"
349 }],
350 })
351 );
352
353 let message = Message {
354 role: Role::Assistant,
355 content: vec![Content {
356 content_type: ContentType::Text,
357 text: "Hello World".to_string(),
358 }],
359 };
360 assert_eq!(
361 serde_json::to_value(&message).unwrap(),
362 serde_json::json!({
363 "role": "assistant",
364 "content": [{
365 "type": "text",
366 "text": "Hello World"
367 }],
368 })
369 );
370 }
371
372 #[test]
373 fn should_deserialize_message() {
374 let json = serde_json::json!({
375 "role": "user",
376 "content": [{
377 "type": "text",
378 "text": "Hello World",
379 }]
380 });
381 let message: Message = serde_json::from_value(json).unwrap();
382 assert_eq!(message.role, Role::User);
383 assert_eq!(
384 message.content,
385 vec![Content {
386 content_type: ContentType::Text,
387 text: "Hello World".to_string(),
388 }]
389 );
390
391 let json = serde_json::json!({
392 "role": "assistant",
393 "content": [{
394 "type": "text",
395 "text": "Hello World",
396 }]
397 });
398 let message: Message = serde_json::from_value(json).unwrap();
399 assert_eq!(message.role, Role::Assistant);
400 assert_eq!(
401 message.content,
402 vec![Content {
403 content_type: ContentType::Text,
404 text: "Hello World".to_string(),
405 }]
406 );
407 }
408}