use async_trait::async_trait;
use tracing::info;
use super::base::{LlmBase, LlmResponse, Message, ResponseFormat, Tool};
use super::openai_compat::{OpenAICompatClient, OpenAICompatConfig};
use crate::config::OpenAILlmConfig;
use crate::error::{NeomemxError, Result};
pub struct OpenAILlm {
client: OpenAICompatClient,
}
impl OpenAILlm {
pub fn new(config: OpenAILlmConfig) -> Result<Self> {
let api_key = config.get_api_key().ok_or_else(|| {
NeomemxError::LlmError(
"OpenAI API key not found. Set OPENAI_API_KEY or provide in config.".to_string(),
)
})?;
info!("Creating OpenAI LLM with model: {}", config.model);
let is_reasoning = Self::is_reasoning_model(&config.model);
Ok(Self {
client: OpenAICompatClient::new(OpenAICompatConfig {
base_url: config.base_url,
api_key,
model: config.model,
temperature: config.temperature,
max_tokens: config.max_tokens,
top_p: config.top_p,
provider_name: "OpenAI",
})
.with_skip_sampling_params(is_reasoning),
})
}
fn is_reasoning_model(model: &str) -> bool {
let model = model.to_lowercase();
["o1", "o1-preview", "o3-mini", "o3", "gpt-5"]
.iter()
.any(|m| model.contains(m))
}
}
#[async_trait]
impl LlmBase for OpenAILlm {
async fn generate_response(
&self,
messages: Vec<Message>,
response_format: Option<ResponseFormat>,
tools: Option<Vec<Tool>>,
tool_choice: Option<String>,
) -> Result<LlmResponse> {
self.client.chat_completion(messages, response_format, tools, tool_choice).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_reasoning_model() {
assert!(OpenAILlm::is_reasoning_model("o1-preview"));
assert!(OpenAILlm::is_reasoning_model("o1"));
assert!(!OpenAILlm::is_reasoning_model("gpt-4"));
}
}