mod sse;
mod types;
use async_stream::stream;
use futures::Stream;
use crate::client::error::LlmError;
use crate::client::http::HttpClient;
use crate::client::models::{Message, MessageOptions, StreamEvent};
use crate::client::traits::{LlmProvider, StreamMsgFuture};
use std::future::Future;
use std::pin::Pin;
pub struct AnthropicProvider {
pub api_key: String,
pub model: String,
}
impl AnthropicProvider {
pub fn new(api_key: String, model: String) -> Self {
Self { api_key, model }
}
}
impl LlmProvider for AnthropicProvider {
fn send_msg(
&self,
client: &HttpClient,
messages: &[Message],
options: &MessageOptions,
) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
let client = client.clone();
let api_key = self.api_key.clone();
let model = self.model.clone();
let messages = messages.to_vec();
let options = options.clone();
Box::pin(async move {
let body = types::build_request_body(&messages, &options, &model)?;
let headers = types::get_request_headers(&api_key);
let headers_ref: Vec<(&str, &str)> =
headers.iter().map(|(k, v)| (*k, v.as_str())).collect();
let response = client
.post(types::get_api_url(), &headers_ref, &body)
.await?;
types::parse_response(&response)
})
}
fn send_msg_stream(
&self,
client: &HttpClient,
messages: &[Message],
options: &MessageOptions,
) -> StreamMsgFuture {
let client = client.clone();
let api_key = self.api_key.clone();
let model = self.model.clone();
let messages = messages.to_vec();
let options = options.clone();
Box::pin(async move {
let body = types::build_streaming_request_body(&messages, &options, &model)?;
let headers = types::get_request_headers(&api_key);
let headers_ref: Vec<(&str, &str)> =
headers.iter().map(|(k, v)| (*k, v.as_str())).collect();
let byte_stream = client
.post_stream(types::get_api_url(), &headers_ref, &body)
.await?;
use futures::StreamExt;
let event_stream = stream! {
let mut buffer = String::new();
let mut byte_stream = byte_stream;
while let Some(chunk_result) = byte_stream.next().await {
match chunk_result {
Ok(bytes) => {
if let Ok(text) = std::str::from_utf8(&bytes) {
buffer.push_str(text);
} else {
yield Err(LlmError::new("SSE_DECODE_ERROR", "Invalid UTF-8 in stream"));
break;
}
let (events, remaining) = sse::parse_sse_chunk(&buffer);
buffer = remaining;
for sse_event in events {
match sse::parse_stream_event(&sse_event) {
Ok(Some(stream_event)) => yield Ok(stream_event),
Ok(None) => {} Err(e) => {
yield Err(e);
return;
}
}
}
}
Err(e) => {
yield Err(e);
break;
}
}
}
};
Ok(Box::pin(event_stream)
as Pin<
Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>,
>)
})
}
}