1use eventsource_client::{Client as EsClient, ClientBuilder, ReconnectOptions, SSE};
2use futures::stream::{Stream, TryStreamExt};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::Duration;
6
7use crate::error::Error;
8
9const MESSAGES_CREATE: &str = "/messages";
11
12#[derive(Debug, Serialize, Deserialize)]
13pub struct Usage {
14 pub input_tokens: Option<u32>,
15 pub output_tokens: Option<u32>,
16}
17
18#[derive(Debug, Serialize, Deserialize)]
19pub struct Content {
20 pub r#type: String,
22 pub text: Option<String>,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27pub struct Message {
28 pub role: Role,
29 pub content: String,
30}
31
32#[derive(Debug, Serialize, Deserialize, Clone)]
33#[serde(rename_all = "lowercase")]
34pub enum Role {
35 Assistant,
36 User,
37}
38
39#[derive(Debug, Serialize, Deserialize, Default)]
40pub struct MessageBody {
41 pub model: String,
44 pub messages: Vec<Message>,
46 pub max_tokens: u32,
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub metadata: Option<HashMap<String, String>>,
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub stop_sequences: Option<Vec<String>>,
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub stream: Option<bool>,
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub system: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub temperature: Option<f32>,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub top_k: Option<u32>,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub top_p: Option<f32>,
69}
70
71impl MessageBody {
72 #[must_use]
74 pub fn new(model: &str, messages: Vec<Message>, max_tokens: u32) -> Self {
75 Self {
76 model: model.into(),
77 messages,
78 max_tokens,
79 stream: Some(true),
80 ..Default::default()
81 }
82 }
83}
84
85#[derive(Debug, Serialize, Deserialize)]
86pub struct MessageResponse {
87 pub id: String,
89 pub r#type: String,
91 pub role: String,
93 pub content: Vec<Content>,
95 pub model: String,
97 pub stop_reason: Option<String>,
99 pub stop_sequence: Option<String>,
101 pub usage: Usage,
103}
104
105#[derive(Debug, Serialize, Deserialize)]
106struct MessageEventResponse {
107 pub id: String,
109 pub r#type: String,
111 pub role: String,
113 pub content: Vec<Content>,
115 pub model: String,
117 pub stop_reason: Option<String>,
119 pub stop_sequence: Option<String>,
121 pub usage: Usage,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
126struct Delta {
127 pub r#type: Option<String>,
129 pub text: Option<String>,
131 pub stop_reason: Option<String>,
132 pub end_turn: Option<String>,
133}
134
135#[derive(Debug, Serialize, Deserialize, Default)]
136#[serde(rename_all = "snake_case")]
137enum MessageEventType {
138 #[default]
139 Error,
140 MessageStart,
141 MessageDelta,
142 MessageStop,
143 Ping,
144 ContentBlockStart,
145 ContentBlockDelta,
146 ContentBlockStop,
147 Comment,
148}
149
150#[derive(Debug, Serialize, Deserialize, Default)]
151struct MessageEvent {
152 pub r#type: MessageEventType,
154 pub message: Option<MessageEventResponse>,
156 pub index: Option<i32>,
158 pub content_block: Option<Content>,
160 pub delta: Option<Delta>,
162 pub usage: Option<Usage>,
164 pub comment: Option<String>,
166}
167
168#[derive(Debug, Serialize, Deserialize, Clone)]
169pub struct Auth {
170 pub api_key: String,
171 pub version: Option<String>,
172}
173
174impl Auth {
175 #[must_use]
176 pub fn new(api_key: String, version: Option<String>) -> Self {
177 Self { api_key, version }
178 }
179
180 pub fn from_env() -> Result<Self, Error> {
181 let api_key = match std::env::var("ANTHROPIC_API_KEY") {
182 Ok(key) => key,
183 Err(_) => return Err(Error::AuthError("ANTHROPIC_API_KEY not found".to_string())),
184 };
185 let version = std::env::var("ANTHROPIC_API_VERSION").ok();
186 Ok(Self { api_key, version })
187 }
188}
189
190#[derive(Debug, Clone)]
191pub struct Client {
192 pub auth: Auth,
193 pub api_url: String,
194}
195
196impl Client {
197 pub fn new(auth: Auth, api_url: impl Into<String>) -> Self {
198 Self {
199 auth,
200 api_url: api_url.into(),
201 }
202 }
203}
204
205impl Client {
206 pub fn delta<'a>(
207 &'a self,
208 message_body: &'a MessageBody,
209 ) -> Result<impl Stream<Item = Result<String, Error>> + 'a, Error> {
210 log::debug!("message_body: {:#?}", message_body);
211
212 let request_body = match serde_json::to_value(message_body) {
213 Ok(body) => body,
214 Err(e) => return Err(Error::Serde(e)),
215 };
216 log::debug!("request_body: {:#?}", request_body);
217
218 let anthropic_version = self.auth.version.as_deref().unwrap_or("2023-06-01");
219
220 let client = ClientBuilder::for_url(&(self.api_url.clone() + MESSAGES_CREATE))?
221 .header("anthropic-version", anthropic_version)?
222 .header("content-type", "application/json")?
223 .header("x-api-key", &self.auth.api_key)?
224 .method("POST".into())
225 .body(request_body.to_string())
226 .reconnect(
227 ReconnectOptions::reconnect(true)
228 .retry_initial(false)
229 .delay(Duration::from_secs(1))
230 .backoff_factor(2)
231 .delay_max(Duration::from_secs(60))
232 .build(),
233 )
234 .build();
235
236 let stream = Box::pin(client.stream())
237 .map_err(Error::from)
238 .map_ok(|event| match event {
239 SSE::Connected(_) => String::default(),
240 SSE::Event(ev) => match serde_json::from_str::<MessageEvent>(&ev.data) {
241 Ok(ev) => {
242 if matches!(ev.r#type, MessageEventType::ContentBlockDelta) {
243 if let Some(delta) = ev.delta {
244 delta.text.map_or_else(String::default, |text| text)
245 } else {
246 String::default()
247 }
248 } else {
249 String::default()
250 }
251 }
252 Err(e) => {
253 log::error!("Error parsing event: {:#?}", ev);
254 log::error!("Error: {:#?}", e);
255 String::default()
256 }
257 },
258 SSE::Comment(comment) => {
259 log::debug!("Comment: {:#?}", comment);
260 String::default()
261 }
262 });
263
264 Ok(stream)
265 }
266}