use crate::config::Config;
use anyhow::Result;
use serde::de::DeserializeOwned;
pub use octolib::llm::{
AiProvider, ChatCompletionParams, Message, MessageBuilder, ProviderFactory, ProviderResponse,
StructuredOutputRequest, TokenUsage,
};
pub struct LlmClient {
provider: Box<dyn AiProvider>,
model: String,
temperature: f32,
max_tokens: usize,
}
impl LlmClient {
pub fn from_config(config: &Config) -> Result<Self> {
let (provider, model) = ProviderFactory::get_provider_for_model(&config.llm.model)?;
Ok(Self {
provider,
model,
temperature: config.llm.temperature,
max_tokens: config.llm.max_tokens,
})
}
pub fn with_model(config: &Config, model_str: &str) -> Result<Self> {
let (provider, model) = ProviderFactory::get_provider_for_model(model_str)?;
Ok(Self {
provider,
model,
temperature: config.llm.temperature,
max_tokens: config.llm.max_tokens,
})
}
pub async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let params = ChatCompletionParams::new(
&messages,
&self.model,
self.temperature,
1.0, 50, self.max_tokens as u32, );
let response = self.provider.chat_completion(params).await?;
if let Some(usage) = &response.exchange.usage {
tracing::debug!(
"LLM tokens: input={}, output={}, total={}",
usage.prompt_tokens,
usage.output_tokens,
usage.total_tokens
);
if let Some(cost) = usage.cost {
tracing::debug!("LLM cost: ${:.6}", cost);
}
}
Ok(response.content)
}
pub async fn chat_completion_structured<T: DeserializeOwned>(
&self,
messages: Vec<Message>,
) -> Result<T> {
if !self.provider.supports_structured_output(&self.model) {
return Err(anyhow::anyhow!(
"Provider does not support structured output for model: {}",
self.model
));
}
let structured_request = StructuredOutputRequest::json();
let params = ChatCompletionParams::new(
&messages,
&self.model,
self.temperature,
1.0, 50, self.max_tokens as u32, )
.with_structured_output(structured_request);
let response = self.provider.chat_completion(params).await?;
if let Some(usage) = &response.exchange.usage {
tracing::debug!(
"LLM tokens (structured): input={}, output={}, total={}",
usage.prompt_tokens,
usage.output_tokens,
usage.total_tokens
);
if let Some(cost) = usage.cost {
tracing::debug!("LLM cost: ${:.6}", cost);
}
}
if let Some(structured) = response.structured_output {
let result: T = serde_json::from_value(structured)?;
Ok(result)
} else {
let result: T = serde_json::from_str(&response.content)?;
Ok(result)
}
}
pub async fn chat_completion_with_temperature(
&self,
messages: Vec<Message>,
temperature: f32,
) -> Result<String> {
let params = ChatCompletionParams::new(
&messages,
&self.model,
temperature,
1.0, 50, self.max_tokens as u32, );
let response = self.provider.chat_completion(params).await?;
if let Some(usage) = &response.exchange.usage {
tracing::debug!(
"LLM tokens: input={}, output={}, total={}",
usage.prompt_tokens,
usage.output_tokens,
usage.total_tokens
);
if let Some(cost) = usage.cost {
tracing::debug!("LLM cost: ${:.6}", cost);
}
}
Ok(response.content)
}
pub fn model(&self) -> &str {
&self.model
}
pub fn supports_structured_output(&self) -> bool {
self.provider.supports_structured_output(&self.model)
}
pub async fn chat_completion_json(&self, messages: Vec<Message>) -> Result<serde_json::Value> {
let supports_structured = self.provider.supports_structured_output(&self.model);
tracing::debug!(
"Provider {} supports structured output for model {}: {}",
self.provider.name(),
self.model,
supports_structured
);
if supports_structured {
let structured_request = StructuredOutputRequest::json();
let params = ChatCompletionParams::new(
&messages,
&self.model,
self.temperature,
1.0,
50,
self.max_tokens as u32,
)
.with_structured_output(structured_request);
let response = self.provider.chat_completion(params).await?;
if let Some(usage) = &response.exchange.usage {
tracing::debug!(
"LLM tokens (structured): input={}, output={}, total={}",
usage.prompt_tokens,
usage.output_tokens,
usage.total_tokens
);
if let Some(cost) = usage.cost {
tracing::debug!("LLM cost: ${:.6}", cost);
}
}
tracing::debug!(
"Response has structured_output: {}",
response.structured_output.is_some()
);
tracing::debug!("Response content length: {}", response.content.len());
tracing::debug!(
"Response content preview: {}",
response.content.chars().take(200).collect::<String>()
);
if let Some(structured) = response.structured_output {
tracing::debug!("Using structured output from provider");
return Ok(structured);
}
tracing::debug!("No structured output, falling back to content parsing");
} else {
tracing::debug!("Provider does not support structured output, using markdown fallback");
}
let content = self.chat_completion(messages).await?;
tracing::debug!("Raw content length: {}", content.len());
tracing::debug!(
"Raw content preview: {}",
content.chars().take(200).collect::<String>()
);
let json = Self::strip_json_from_markdown(&content);
tracing::debug!(
"Parsed JSON has error field: {}",
json.get("error").is_some()
);
Ok(json)
}
fn strip_json_from_markdown(content: &str) -> serde_json::Value {
if let Ok(parsed) = serde_json::from_str(content.trim()) {
return parsed;
}
let marker = "```json";
let end_marker = "```";
if let Some(start) = content.find(marker) {
let after_marker = &content[start + marker.len()..];
if let Some(end) = after_marker.find(end_marker) {
let json_content = &after_marker[..end];
if let Ok(parsed) = serde_json::from_str(json_content.trim()) {
return parsed;
}
}
}
let mut in_code_block = false;
let mut code_start = 0;
for (line_num, line) in content.lines().enumerate() {
let trimmed = line.trim();
if trimmed.starts_with("```") {
if !in_code_block {
in_code_block = true;
code_start = content
.lines()
.take(line_num + 1)
.map(|l| l.len() + 1)
.sum();
} else {
let line_start = content.lines().take(line_num).map(|l| l.len() + 1).sum();
let code_content = &content[code_start..line_start];
if let Ok(parsed) = serde_json::from_str(code_content.trim()) {
return parsed;
}
break;
}
}
}
if let Some(start) = content.find('{') {
if let Ok(parsed) = serde_json::from_str(&content[start..]) {
return parsed;
}
}
if let Some(start) = content.find('[') {
if let Ok(parsed) = serde_json::from_str(&content[start..]) {
return parsed;
}
}
serde_json::json!({
"error": "Failed to parse JSON from response",
"raw_content": content
})
}
}