use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::config::LlmConfig;
use crate::error::{MemeError, Result};
use crate::llm::json::extract_json_from_text;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub(crate) enum Role {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Message {
pub role: Role,
pub content: String,
}
impl Message {
#[must_use]
pub(crate) fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: content.into(),
}
}
#[must_use]
pub(crate) fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: content.into(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct ChatOptions {
pub temperature: f32,
pub json_mode: bool,
}
impl Default for ChatOptions {
fn default() -> Self {
Self {
temperature: 0.1,
json_mode: true,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct LlmClient {
http: reqwest::Client,
base_url: String,
api_key: String,
model: String,
max_retries: u32,
}
impl LlmClient {
pub(crate) fn new(http: reqwest::Client, config: &LlmConfig) -> Result<Self> {
let api_key = config
.api_key
.clone()
.ok_or_else(|| MemeError::Config("LLM API key is required".to_owned()))?;
Ok(Self {
http,
base_url: config.base_url.trim_end_matches('/').to_owned(),
api_key,
model: config.model.clone(),
max_retries: config.max_retries,
})
}
#[tracing::instrument(skip(self, messages, opts), fields(model = %self.model))]
pub(crate) async fn chat_structured<T: serde::de::DeserializeOwned>(
&self,
messages: &[Message],
opts: &ChatOptions,
) -> Result<T> {
let mut last_err = None;
for attempt in 0..self.max_retries {
match self.call_api(messages, opts).await {
Ok(content) => match serde_json::from_str::<T>(&content) {
Ok(parsed) => return Ok(parsed),
Err(e) => {
tracing::warn!(attempt = attempt + 1, error = %e, "JSON parse failed");
if let Ok(v) = extract_json_from_text(&content)
&& let Ok(parsed) = serde_json::from_value::<T>(v)
{
return Ok(parsed);
}
last_err = Some(MemeError::JsonParse(format!("{e}")));
}
},
Err(e) => {
if !e.is_retryable() {
return Err(e);
}
tracing::warn!(attempt = attempt + 1, error = %e, "LLM API call failed");
last_err = Some(e);
}
}
if attempt + 1 < self.max_retries {
let wait = 2u64.saturating_pow(attempt).min(30);
tokio::time::sleep(Duration::from_secs(wait)).await;
}
}
Err(last_err.unwrap_or_else(|| MemeError::llm("all retries exhausted")))
}
async fn call_api(&self, messages: &[Message], opts: &ChatOptions) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let mut body = serde_json::json!({
"model": self.model,
"messages": messages,
"temperature": opts.temperature,
"stream": false,
});
if opts.json_mode
&& let Some(obj) = body.as_object_mut()
{
obj.insert(
"response_format".to_owned(),
serde_json::json!({"type": "json_object"}),
);
}
let resp = self
.http
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
return Err(MemeError::llm_with_status(
status.as_u16(),
format!("API returned {status}: {text}"),
));
}
let data: serde_json::Value = resp.json().await?;
data.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("message"))
.and_then(|m| m.get("content"))
.and_then(serde_json::Value::as_str)
.map(String::from)
.ok_or_else(|| MemeError::llm("missing content in API response"))
}
}