llm_stream/
anthropic.rs

1use eventsource_client::{Client as EsClient, ClientBuilder, ReconnectOptions, SSE};
2use futures::stream::{Stream, TryStreamExt};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::Duration;
6
7use crate::error::Error;
8
9// Messages API
10const MESSAGES_CREATE: &str = "/messages";
11
12#[derive(Debug, Serialize, Deserialize)]
13pub struct Usage {
14    pub input_tokens: Option<u32>,
15    pub output_tokens: Option<u32>,
16}
17
18#[derive(Debug, Serialize, Deserialize)]
19pub struct Content {
20    /// Determines the content shape.
21    pub r#type: String,
22    /// Response content
23    pub text: Option<String>,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27pub struct Message {
28    pub role: Role,
29    pub content: String,
30}
31
32#[derive(Debug, Serialize, Deserialize, Clone)]
33#[serde(rename_all = "lowercase")]
34pub enum Role {
35    Assistant,
36    User,
37}
38
39#[derive(Debug, Serialize, Deserialize, Default)]
40pub struct MessageBody {
41    /// The model that will complete your prompt.
42    /// See this link for additional details and options: https://docs.anthropic.com/claude/docs/models-overview
43    pub model: String,
44    /// Input messages.
45    pub messages: Vec<Message>,
46    /// The maximum number of tokens to generate before stopping.
47    pub max_tokens: u32,
48    /// An object describing metadata about the request.
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub metadata: Option<HashMap<String, String>>,
51    /// Custom text sequences that will cause the model to stop generating.
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub stop_sequences: Option<Vec<String>>,
54    /// Whether to incrementally stream the response using server-sent events.
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub stream: Option<bool>,
57    /// System prompt
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub system: Option<String>,
60    /// Amount of randomness injected into the response.
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub temperature: Option<f32>,
63    /// Only sample from the top K options for each subsequent token.
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub top_k: Option<u32>,
66    /// Use nucleus sampling.
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub top_p: Option<f32>,
69}
70
71impl MessageBody {
72    /// Creates a new `MessageBody`
73    #[must_use]
74    pub fn new(model: &str, messages: Vec<Message>, max_tokens: u32) -> Self {
75        Self {
76            model: model.into(),
77            messages,
78            max_tokens,
79            stream: Some(true),
80            ..Default::default()
81        }
82    }
83}
84
85#[derive(Debug, Serialize, Deserialize)]
86pub struct MessageResponse {
87    /// Unique object identifier.
88    pub id: String,
89    /// Object type.
90    pub r#type: String,
91    /// Conversational role of the generated message.
92    pub role: String,
93    /// Content generated by the model.
94    pub content: Vec<Content>,
95    /// The model that handled the request.
96    pub model: String,
97    /// The reason that the model stopped.
98    pub stop_reason: Option<String>,
99    /// Which custom stop sequence was generated, if any.
100    pub stop_sequence: Option<String>,
101    /// Billing and rate-limit usage.
102    pub usage: Usage,
103}
104
105#[derive(Debug, Serialize, Deserialize)]
106struct MessageEventResponse {
107    /// Unique object identifier.
108    pub id: String,
109    /// Object type.
110    pub r#type: String,
111    /// Conversational role of the generated message.
112    pub role: String,
113    /// Content messages.
114    pub content: Vec<Content>,
115    /// The model that handled the request.
116    pub model: String,
117    /// The reason that the model stopped.
118    pub stop_reason: Option<String>,
119    /// Which custom stop sequence was generated, if any.
120    pub stop_sequence: Option<String>,
121    /// Billing and rate-limit usage.
122    pub usage: Usage,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
126struct Delta {
127    /// Determines the content shape.
128    pub r#type: Option<String>,
129    /// Response content
130    pub text: Option<String>,
131    pub stop_reason: Option<String>,
132    pub end_turn: Option<String>,
133}
134
135#[derive(Debug, Serialize, Deserialize, Default)]
136#[serde(rename_all = "snake_case")]
137enum MessageEventType {
138    #[default]
139    Error,
140    MessageStart,
141    MessageDelta,
142    MessageStop,
143    Ping,
144    ContentBlockStart,
145    ContentBlockDelta,
146    ContentBlockStop,
147    Comment,
148}
149
150#[derive(Debug, Serialize, Deserialize, Default)]
151struct MessageEvent {
152    /// Event type
153    pub r#type: MessageEventType,
154    /// Init message
155    pub message: Option<MessageEventResponse>,
156    /// Event index
157    pub index: Option<i32>,
158    /// Content block
159    pub content_block: Option<Content>,
160    /// Delta block
161    pub delta: Option<Delta>,
162    /// Usage
163    pub usage: Option<Usage>,
164    /// Comment
165    pub comment: Option<String>,
166}
167
168#[derive(Debug, Serialize, Deserialize, Clone)]
169pub struct Auth {
170    pub api_key: String,
171    pub version: Option<String>,
172}
173
174impl Auth {
175    #[must_use]
176    pub fn new(api_key: String, version: Option<String>) -> Self {
177        Self { api_key, version }
178    }
179
180    pub fn from_env() -> Result<Self, Error> {
181        let api_key = match std::env::var("ANTHROPIC_API_KEY") {
182            Ok(key) => key,
183            Err(_) => return Err(Error::AuthError("ANTHROPIC_API_KEY not found".to_string())),
184        };
185        let version = std::env::var("ANTHROPIC_API_VERSION").ok();
186        Ok(Self { api_key, version })
187    }
188}
189
190#[derive(Debug, Clone)]
191pub struct Client {
192    pub auth: Auth,
193    pub api_url: String,
194}
195
196impl Client {
197    pub fn new(auth: Auth, api_url: impl Into<String>) -> Self {
198        Self {
199            auth,
200            api_url: api_url.into(),
201        }
202    }
203}
204
205impl Client {
206    pub fn delta<'a>(
207        &'a self,
208        message_body: &'a MessageBody,
209    ) -> Result<impl Stream<Item = Result<String, Error>> + 'a, Error> {
210        log::debug!("message_body: {:#?}", message_body);
211
212        let request_body = match serde_json::to_value(message_body) {
213            Ok(body) => body,
214            Err(e) => return Err(Error::Serde(e)),
215        };
216        log::debug!("request_body: {:#?}", request_body);
217
218        let anthropic_version = self.auth.version.as_deref().unwrap_or("2023-06-01");
219
220        let client = ClientBuilder::for_url(&(self.api_url.clone() + MESSAGES_CREATE))?
221            .header("anthropic-version", anthropic_version)?
222            .header("content-type", "application/json")?
223            .header("x-api-key", &self.auth.api_key)?
224            .method("POST".into())
225            .body(request_body.to_string())
226            .reconnect(
227                ReconnectOptions::reconnect(true)
228                    .retry_initial(false)
229                    .delay(Duration::from_secs(1))
230                    .backoff_factor(2)
231                    .delay_max(Duration::from_secs(60))
232                    .build(),
233            )
234            .build();
235
236        let stream = Box::pin(client.stream())
237            .map_err(Error::from)
238            .map_ok(|event| match event {
239                SSE::Connected(_) => String::default(),
240                SSE::Event(ev) => match serde_json::from_str::<MessageEvent>(&ev.data) {
241                    Ok(ev) => {
242                        if matches!(ev.r#type, MessageEventType::ContentBlockDelta) {
243                            if let Some(delta) = ev.delta {
244                                delta.text.map_or_else(String::default, |text| text)
245                            } else {
246                                String::default()
247                            }
248                        } else {
249                            String::default()
250                        }
251                    }
252                    Err(e) => {
253                        log::error!("Error parsing event: {:#?}", ev);
254                        log::error!("Error: {:#?}", e);
255                        String::default()
256                    }
257                },
258                SSE::Comment(comment) => {
259                    log::debug!("Comment: {:#?}", comment);
260                    String::default()
261                }
262            });
263
264        Ok(stream)
265    }
266}