anthropic_rs/completion/
message.rs

1use core::fmt;
2use serde::{Deserialize, Serialize};
3use std::fmt::Display;
4
5use crate::models::claude::ClaudeModel;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
8pub struct Message {
9    pub role: Role,
10    pub content: Vec<Content>,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16    User,
17    Assistant,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
21pub struct Content {
22    pub text: String,
23    #[serde(rename = "type")]
24    pub content_type: ContentType,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28#[serde(untagged, rename_all = "lowercase")]
29pub enum System {
30    Text(String),
31    Structured(SystemPrompt),
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
35pub struct SystemPrompt {
36    pub text: String,
37    #[serde(rename = "type")]
38    pub content_type: ContentType,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub cache_control: Option<CacheControl>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
44pub struct CacheControl {
45    #[serde(rename = "type")]
46    pub cache_type: CacheType,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50#[serde(rename_all = "lowercase")]
51pub enum CacheType {
52    Ephemeral,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct MessageRequest {
57    /// The model that will complete your prompt e.g. Claude 3.5 Sonnet
58    pub model: ClaudeModel,
59
60    /// The maximum number of tokens to generate before stopping.
61    ///
62    /// Defaults to 1000 tokens.
63    ///
64    /// Note that models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate.
65    pub max_tokens: u32,
66
67    /// Input messages.
68    pub messages: Vec<Message>,
69
70    /// An object describing metadata about the request.
71    pub metadata: Option<MessageMetadata>,
72
73    /// Custom text sequences that will cause the model to stop generating.
74    pub stop_sequences: Option<Vec<String>>,
75
76    /// Whether to incrementally stream the response using server-sent events.
77    pub stream: bool,
78
79    /// System prompt.
80    ///
81    /// A system prompt is a way of providing context and instructions to Claude, such as specifying a particular goal or role.
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub system: Option<System>,
84
85    /// Amount of randomness injected into the response.
86    ///
87    /// Defaults to 1.0. Ranges from 0.0 to 1.0.
88    ///
89    /// Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks.
90    /// Note that even with temperature of 0.0, the results will not be fully deterministic.
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub temperature: Option<f32>,
93
94    /// Only sample from the top K options for each subsequent token.
95    ///
96    /// Used to remove "long tail" low probability responses. Learn more technical details here.
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub top_k: Option<i8>,
99
100    /// Use nucleus sampling.
101    ///
102    /// In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both.
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub top_p: Option<i8>,
105}
106
107impl MessageRequest {
108    pub fn new(model: ClaudeModel, max_tokens: u32, messages: Vec<Message>) -> Self {
109        Self {
110            model,
111            max_tokens,
112            messages,
113            ..Default::default()
114        }
115    }
116
117    pub fn with_metadata(mut self, metadata: MessageMetadata) -> Self {
118        self.metadata = Some(metadata);
119        self
120    }
121
122    pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
123        self.stop_sequences = Some(stop_sequences);
124        self
125    }
126
127    pub fn with_stream(mut self, stream: bool) -> Self {
128        self.stream = stream;
129        self
130    }
131
132    pub fn with_system(mut self, system: System) -> Self {
133        self.system = Some(system);
134        self
135    }
136
137    pub fn with_temperature(mut self, temperature: f32) -> Self {
138        self.temperature = Some(temperature);
139        self
140    }
141
142    pub fn with_top_k(mut self, top_k: i8) -> Self {
143        self.top_k = Some(top_k);
144        self
145    }
146
147    pub fn with_top_p(mut self, top_p: i8) -> Self {
148        self.top_p = Some(top_p);
149        self
150    }
151}
152
153impl Default for MessageRequest {
154    fn default() -> Self {
155        Self {
156            model: ClaudeModel::Claude35Sonnet,
157            max_tokens: 1000,
158            messages: Vec::new(),
159            metadata: None,
160            stop_sequences: None,
161            stream: false,
162            system: None,
163            temperature: None,
164            top_k: None,
165            top_p: None,
166        }
167    }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
171pub struct MessageMetadata {
172    pub user_id: Option<String>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
176pub struct MessageResponse {
177    pub id: String,
178    #[serde(rename = "type")]
179    pub message_type: MessageType,
180    pub role: RoleResponse,
181    pub content: Vec<Content>,
182    pub model: ClaudeModel,
183    pub stop_reason: Option<StopReason>,
184    pub stop_sequence: Option<String>,
185    pub usage: TokenUsage,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
189#[serde(rename_all = "lowercase")]
190pub enum MessageType {
191    Message,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
195#[serde(rename_all = "lowercase")]
196pub enum RoleResponse {
197    Assistant,
198}
199
200impl RoleResponse {
201    pub fn as_str(&self) -> &'static str {
202        match self {
203            Self::Assistant => "assistant",
204        }
205    }
206}
207
208impl Display for RoleResponse {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        write!(f, "{:?}", self)
211    }
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
215#[serde(rename_all = "snake_case")]
216pub enum StopReason {
217    EndTurn,
218    MaxTokens,
219    StopSequence,
220    ToolUse,
221}
222
223impl fmt::Display for StopReason {
224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225        match self {
226            Self::EndTurn => write!(f, "end_turn"),
227            Self::MaxTokens => write!(f, "max_tokens"),
228            Self::StopSequence => write!(f, "stop_sequence"),
229            Self::ToolUse => write!(f, "tool_use"),
230        }
231    }
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
235#[serde(rename_all = "lowercase")]
236pub enum ContentType {
237    Text,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
241pub struct TokenUsage {
242    pub input_tokens: u32,
243    pub output_tokens: u32,
244    pub cache_creation_input_tokens: Option<u32>,
245    pub cache_read_input_tokens: Option<u32>,
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use pretty_assertions::assert_eq;
252
253    #[test]
254    fn should_set_metadata() {
255        let request = MessageRequest::default();
256        assert_eq!(request.metadata, None);
257
258        let metadata = MessageMetadata {
259            user_id: Some("user-id".to_string()),
260        };
261        let request = request.with_metadata(metadata.clone());
262        assert_eq!(request.metadata, Some(metadata));
263    }
264
265    #[test]
266    fn should_set_stop_sequences() {
267        let request = MessageRequest::default();
268        assert_eq!(request.stop_sequences, None);
269
270        let stop_sequences: Vec<String> = vec!["foo".to_string(), "bar".to_string()];
271        let request = request.with_stop_sequences(stop_sequences.clone());
272        assert_eq!(request.stop_sequences, Some(stop_sequences));
273    }
274
275    #[test]
276    fn should_set_stream() {
277        let request = MessageRequest::default();
278        assert_eq!(request.stream, false);
279
280        let request = request.with_stream(true);
281        assert_eq!(request.stream, true);
282
283        let request = request.with_stream(false);
284        assert_eq!(request.stream, false);
285    }
286
287    #[test]
288    fn should_set_system() {
289        let request = MessageRequest::default();
290        assert_eq!(request.system, None);
291
292        let system = System::Structured(SystemPrompt {
293            text: "You are an experienced software engineer".into(),
294            content_type: ContentType::Text,
295            cache_control: Some(CacheControl {
296                cache_type: CacheType::Ephemeral,
297            }),
298        });
299        let request = request.with_system(system.clone());
300        assert_eq!(request.system, Some(system));
301    }
302
303    #[test]
304    fn should_set_temperature() {
305        let request = MessageRequest::default();
306        assert_eq!(request.temperature, None);
307
308        let temperature: f32 = 0.9;
309        let request = request.with_temperature(temperature);
310        assert_eq!(request.temperature, Some(temperature));
311    }
312
313    #[test]
314    fn should_set_top_k() {
315        let request = MessageRequest::default();
316        assert_eq!(request.top_k, None);
317
318        let top_k: i8 = 1;
319        let request = request.with_top_k(top_k);
320        assert_eq!(request.top_k, Some(top_k));
321    }
322
323    #[test]
324    fn should_set_top_p() {
325        let request = MessageRequest::default();
326        assert_eq!(request.top_p, None);
327
328        let top_p: i8 = 1;
329        let request = request.with_top_p(top_p);
330        assert_eq!(request.top_p, Some(top_p));
331    }
332
333    #[test]
334    fn should_serialize_message() {
335        let message = Message {
336            role: Role::User,
337            content: vec![Content {
338                content_type: ContentType::Text,
339                text: "Hello World".to_string(),
340            }],
341        };
342        assert_eq!(
343            serde_json::to_value(&message).unwrap(),
344            serde_json::json!({
345                "role": "user",
346                "content": [{
347                    "type": "text",
348                    "text": "Hello World"
349                }],
350            })
351        );
352
353        let message = Message {
354            role: Role::Assistant,
355            content: vec![Content {
356                content_type: ContentType::Text,
357                text: "Hello World".to_string(),
358            }],
359        };
360        assert_eq!(
361            serde_json::to_value(&message).unwrap(),
362            serde_json::json!({
363                "role": "assistant",
364                "content": [{
365                    "type": "text",
366                    "text": "Hello World"
367                }],
368            })
369        );
370    }
371
372    #[test]
373    fn should_deserialize_message() {
374        let json = serde_json::json!({
375            "role": "user",
376            "content": [{
377                "type": "text",
378                "text": "Hello World",
379            }]
380        });
381        let message: Message = serde_json::from_value(json).unwrap();
382        assert_eq!(message.role, Role::User);
383        assert_eq!(
384            message.content,
385            vec![Content {
386                content_type: ContentType::Text,
387                text: "Hello World".to_string(),
388            }]
389        );
390
391        let json = serde_json::json!({
392            "role": "assistant",
393            "content": [{
394                "type": "text",
395                "text": "Hello World",
396            }]
397        });
398        let message: Message = serde_json::from_value(json).unwrap();
399        assert_eq!(message.role, Role::Assistant);
400        assert_eq!(
401            message.content,
402            vec![Content {
403                content_type: ContentType::Text,
404                text: "Hello World".to_string(),
405            }]
406        );
407    }
408}