use crate::error::{Error, Result};
use crate::models::tool::ToolType;
use crate::types::chat::{
ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, ChatRole, Message,
MessageContent,
};
use crate::utils::{
retry::execute_with_retry_builder, retry::handle_response_json,
retry::operations::CHAT_COMPLETION, security::create_safe_error_message, validation,
};
use async_stream::try_stream;
use futures::stream::Stream;
use futures::StreamExt;
use futures::TryStreamExt;
use reqwest::Client;
use serde_json;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio_util::codec::{FramedRead, LinesCodec};
use tokio_util::io::StreamReader;
const MAX_LINE_LENGTH: usize = 64 * 1024; const MAX_TOTAL_CHUNKS: usize = 10_000;
pub struct ChatApi {
pub(crate) client: Client,
pub(crate) config: crate::client::ApiConfig,
}
impl ChatApi {
#[must_use = "returns an API client that should be used for chat operations"]
pub fn new(client: Client, config: &crate::client::ClientConfig) -> Result<Self> {
Ok(Self {
client,
config: config.to_api_config()?,
})
}
#[must_use = "returns the chat completion response that should be processed"]
pub async fn chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse> {
validation::validate_chat_request(&request)?;
validation::check_token_limits(&request)?;
let url = self
.config
.base_url
.join("chat/completions")
.map_err(|e| Error::ApiError {
code: 400,
message: format!("Invalid URL: {e}"),
metadata: None,
})?;
let response =
execute_with_retry_builder(&self.config.retry_config, CHAT_COMPLETION, || {
self.client
.post(url.clone())
.headers((*self.config.headers).clone())
.json(&request)
})
.await?;
let chat_response: ChatCompletionResponse =
handle_response_json::<ChatCompletionResponse>(response, CHAT_COMPLETION).await?;
for choice in &chat_response.choices {
if let Some(tool_calls) = &choice.message.tool_calls {
for tc in tool_calls {
if tc.kind != ToolType::Function {
return Err(Error::SchemaValidationError(format!(
"Invalid tool call kind: {}. Expected 'function'",
tc.kind
)));
}
}
}
}
Ok(chat_response)
}
#[must_use = "returns a stream that should be consumed to receive completion chunks"]
pub fn chat_completion_stream(
&self,
request: ChatCompletionRequest,
) -> Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send + '_>> {
let client = self.client.clone();
let headers = Arc::clone(&self.config.headers);
if let Err(e) = validation::validate_chat_request(&request) {
return Box::pin(futures::stream::once(async { Err(e) }));
}
if let Err(e) = validation::check_token_limits(&request) {
return Box::pin(futures::stream::once(async { Err(e) }));
}
let chunk_count = AtomicUsize::new(0);
let url = match self.config.base_url.join("chat/completions") {
Ok(url) => url,
Err(e) => {
return Box::pin(futures::stream::once(async move {
Err(Error::ApiError {
code: 400,
message: format!("Invalid URL: {e}"),
metadata: None,
})
}));
}
};
let mut req_body = match serde_json::to_value(&request) {
Ok(body) => body,
Err(e) => {
return Box::pin(futures::stream::once(async move {
Err(Error::ApiError {
code: 500,
message: format!("Request serialization error: {e}"),
metadata: None,
})
}));
}
};
req_body["stream"] = serde_json::Value::Bool(true);
let stream = try_stream! {
let response = client
.post(url)
.headers((*headers).clone())
.json(&req_body)
.send()
.await
.map_err(|e| {
Error::ApiError {
code: 500,
message: format!("Request failed: {e}"),
metadata: None,
}
})?;
let response = response.error_for_status().map_err(|e| {
Error::ApiError {
code: e.status().map(|s| s.as_u16()).unwrap_or(500),
message: e.to_string(),
metadata: None,
}
})?;
let byte_stream = response.bytes_stream().map_err(std::io::Error::other);
let stream_reader = StreamReader::new(byte_stream);
let mut lines = FramedRead::new(stream_reader, LinesCodec::new_with_max_length(MAX_LINE_LENGTH));
while let Some(line_result) = lines.next().await {
let line = line_result.map_err(|e| Error::StreamingError(format!("Failed to read stream line: {e}")))?;
if line.trim().is_empty() {
continue;
}
let current_chunk = chunk_count.fetch_add(1, Ordering::Relaxed) + 1;
if current_chunk > MAX_TOTAL_CHUNKS {
Err(Error::StreamingError(format!(
"Too many chunks: {current_chunk} (max: {MAX_TOTAL_CHUNKS})"
)))?;
}
if line.starts_with("data:") {
let data_part = line.trim_start_matches("data:").trim();
if data_part == "[DONE]" {
break;
}
match serde_json::from_str::<ChatCompletionChunk>(data_part) {
Ok(chunk) => {
yield chunk;
},
Err(e) => {
let error_msg = create_safe_error_message(
&format!("Failed to parse streaming chunk: {e}. Data: {data_part}"),
"Streaming chunk parse error"
);
#[cfg(feature = "tracing")]
tracing::error!("Streaming parse error: {}", error_msg);
let _ = error_msg; continue;
}
}
} else if line.starts_with(":") {
continue;
} else {
match serde_json::from_str::<ChatCompletionChunk>(&line) {
Ok(chunk) => {
yield chunk;
},
Err(_) => continue,
}
}
}
};
Box::pin(stream)
}
pub async fn simple_completion(&self, model: &str, user_message: &str) -> Result<String> {
let request = ChatCompletionRequest {
model: model.to_string(),
messages: vec![Message::text(ChatRole::User, user_message)],
..Default::default()
};
let response = self.chat_completion(request).await?;
let choice = response.choices.first().ok_or_else(|| Error::ApiError {
code: 500,
message: "API returned no choices".into(),
metadata: None,
})?;
match &choice.message.content {
MessageContent::Text(content) => Ok(content.clone()),
MessageContent::Parts(_) => Err(Error::ConfigError(
"Unexpected multimodal content in simple completion response".into(),
)),
}
}
}