use std::sync::mpsc::Sender;
use futures::StreamExt;
use reqwest::Client;
use tokio_util::sync::CancellationToken;
use super::AiError;
use super::sse::{AnthropicEventParser, SseParser};
use crate::ai::ai_state::AiResponse;
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
const ANTHROPIC_VERSION: &str = "2023-06-01";
#[derive(Debug, Clone)]
pub struct AsyncAnthropicClient {
client: Client,
api_key: String,
model: String,
max_tokens: u32,
}
impl AsyncAnthropicClient {
pub fn new(api_key: String, model: String, max_tokens: u32) -> Self {
Self {
client: Client::new(),
api_key,
model,
max_tokens,
}
}
pub async fn stream_with_cancel(
&self,
prompt: &str,
request_id: u64,
cancel_token: CancellationToken,
response_tx: Sender<AiResponse>,
) -> Result<(), AiError> {
if cancel_token.is_cancelled() {
return Err(AiError::Cancelled);
}
let request_body = serde_json::json!({
"model": self.model,
"max_tokens": self.max_tokens,
"stream": true,
"messages": [
{
"role": "user",
"content": prompt
}
]
});
let body = serde_json::to_string(&request_body).map_err(|e| AiError::Parse {
provider: "Anthropic".to_string(),
message: e.to_string(),
})?;
let response = self
.client
.post(ANTHROPIC_API_URL)
.header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
.body(body)
.send()
.await
.map_err(|e| AiError::Network {
provider: "Anthropic".to_string(),
message: e.to_string(),
})?;
if !response.status().is_success() {
let code = response.status().as_u16();
let message = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(AiError::Api {
provider: "Anthropic".to_string(),
code,
message,
});
}
let mut stream = response.bytes_stream();
let mut sse_parser = SseParser::new(AnthropicEventParser);
loop {
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
return Err(AiError::Cancelled);
}
chunk = stream.next() => {
match chunk {
Some(Ok(bytes)) => {
for text in sse_parser.parse_chunk(&bytes) {
if response_tx
.send(AiResponse::Chunk {
text,
request_id,
})
.is_err()
{
return Ok(());
}
}
}
Some(Err(e)) => {
return Err(AiError::Network {
provider: "Anthropic".to_string(),
message: e.to_string(),
});
}
None => {
break;
}
}
}
}
}
Ok(())
}
}
#[cfg(test)]
#[path = "async_anthropic_tests.rs"]
mod async_anthropic_tests;