use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, error};
use super::base::{LlmResponse, Message, ResponseFormat, Tool, ToolCall};
use crate::error::{NeomemxError, Result};
use crate::llm::utils::extract_json;
#[derive(Debug, Clone)]
pub struct OpenAICompatConfig {
pub base_url: String,
pub api_key: String,
pub model: String,
pub temperature: f32,
pub max_tokens: u32,
pub top_p: f32,
pub provider_name: &'static str,
}
pub struct OpenAICompatClient {
config: OpenAICompatConfig,
client: Client,
skip_sampling_params: bool,
}
#[derive(Debug, Serialize)]
pub(crate) struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ChatCompletionResponse {
pub choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct Choice {
pub message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ResponseMessage {
pub content: Option<Content>,
#[serde(default)]
pub tool_calls: Option<Vec<ApiToolCall>>,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub(crate) enum Content {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Deserialize)]
pub(crate) struct ContentPart {
#[serde(rename = "type")]
pub kind: String,
pub text: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ApiToolCall {
pub function: ApiFunction,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ApiFunction {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ErrorResponse {
pub error: ApiError,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ApiError {
pub message: String,
#[serde(rename = "type")]
#[allow(dead_code)]
pub error_type: Option<String>,
}
impl OpenAICompatClient {
pub fn new(config: OpenAICompatConfig) -> Self {
let client = Client::builder()
.pool_max_idle_per_host(8)
.pool_idle_timeout(std::time::Duration::from_secs(90))
.tcp_keepalive(std::time::Duration::from_secs(60))
.no_proxy()
.build()
.unwrap_or_else(|_| Client::new());
Self {
config,
client,
skip_sampling_params: false,
}
}
pub fn with_skip_sampling_params(mut self, skip: bool) -> Self {
self.skip_sampling_params = skip;
self
}
pub async fn chat_completion(
&self,
messages: Vec<Message>,
response_format: Option<ResponseFormat>,
tools: Option<Vec<Tool>>,
tool_choice: Option<String>,
) -> Result<LlmResponse> {
let request = ChatCompletionRequest {
model: self.config.model.clone(),
messages,
temperature: if self.skip_sampling_params {
None
} else {
Some(self.config.temperature)
},
max_tokens: if self.skip_sampling_params {
None
} else {
Some(self.config.max_tokens)
},
top_p: if self.skip_sampling_params {
None
} else {
Some(self.config.top_p)
},
response_format,
tools,
tool_choice,
};
let url = format!("{}/chat/completions", self.config.base_url);
debug!("Sending request to {}: {}", self.config.provider_name, url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
let error: ErrorResponse = serde_json::from_str(&body).unwrap_or(ErrorResponse {
error: ApiError {
message: body.clone(),
error_type: None,
},
});
error!(
"{} API error: {}",
self.config.provider_name, error.error.message
);
return Err(NeomemxError::LlmError(format!(
"{}: {}",
self.config.provider_name, error.error.message
)));
}
let completion: ChatCompletionResponse = serde_json::from_str(&body).map_err(|e| {
NeomemxError::LlmError(format!(
"Failed to parse {} response: {}",
self.config.provider_name, e
))
})?;
let choice = completion.choices.into_iter().next().ok_or_else(|| {
NeomemxError::LlmError(format!(
"No choices in {} response",
self.config.provider_name
))
})?;
let content_text = choice
.message
.content
.map(|c| match c {
Content::Text(t) => t,
Content::Parts(parts) => parts
.into_iter()
.filter_map(|p| p.text)
.collect::<Vec<String>>()
.join(""),
})
.unwrap_or_default();
if let Some(api_tool_calls) = choice.message.tool_calls {
if !api_tool_calls.is_empty() {
let tool_calls: Vec<ToolCall> = api_tool_calls
.iter()
.filter_map(|tc| Self::parse_tool_call(tc).ok())
.collect();
return Ok(LlmResponse::WithToolCalls {
content: Some(content_text.clone()).filter(|s| !s.is_empty()),
tool_calls,
});
}
}
Ok(LlmResponse::Text(content_text))
}
fn parse_tool_call(api_call: &ApiToolCall) -> Result<ToolCall> {
let arguments: HashMap<String, serde_json::Value> =
extract_json(&api_call.function.arguments)?;
Ok(ToolCall {
name: api_call.function.name.clone(),
arguments,
})
}
}