use serde::de::DeserializeOwned;
use std::borrow::Cow;
use std::sync::Arc;
use tracing::{debug, instrument};
use super::config::LlmConfig;
use super::error::{LlmError, LlmResult};
use super::executor::LlmExecutor;
use super::fallback::FallbackChain;
use crate::throttle::ConcurrencyController;
#[derive(Clone)]
pub struct LlmClient {
executor: LlmExecutor,
}
impl std::fmt::Debug for LlmClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlmClient")
.field("model", &self.executor.config().model)
.field("endpoint", &self.executor.config().endpoint)
.field(
"concurrency",
&self.executor.throttle().map(|c| format!("{:?}", c)),
)
.field("fallback_enabled", &self.executor.fallback().is_some())
.finish()
}
}
impl LlmClient {
pub fn new(config: LlmConfig) -> Self {
Self {
executor: LlmExecutor::new(config),
}
}
pub fn with_defaults() -> Self {
Self::new(LlmConfig::default())
}
pub fn for_model(model: impl Into<String>) -> Self {
Self::new(LlmConfig::new(model))
}
pub fn with_concurrency(mut self, controller: ConcurrencyController) -> Self {
self.executor = self.executor.with_throttle(controller);
self
}
pub fn with_shared_concurrency(mut self, controller: Arc<ConcurrencyController>) -> Self {
self.executor = self.executor.with_shared_throttle(controller);
self
}
pub fn with_fallback(mut self, chain: FallbackChain) -> Self {
self.executor = self.executor.with_fallback(chain);
self
}
pub fn with_shared_fallback(mut self, chain: Arc<FallbackChain>) -> Self {
self.executor = self.executor.with_shared_fallback(chain);
self
}
pub fn config(&self) -> &LlmConfig {
self.executor.config()
}
pub fn concurrency(&self) -> Option<&ConcurrencyController> {
self.executor.throttle()
}
pub fn fallback(&self) -> Option<&FallbackChain> {
self.executor.fallback()
}
pub fn executor(&self) -> &LlmExecutor {
&self.executor
}
#[instrument(skip(self, system, user), fields(model = %self.executor.config().model))]
pub async fn complete(&self, system: &str, user: &str) -> LlmResult<String> {
debug!(
system_len = system.len(),
user_len = user.len(),
"Starting LLM completion"
);
self.executor.complete(system, user).await
}
pub async fn complete_with_max_tokens(
&self,
system: &str,
user: &str,
max_tokens: u16,
) -> LlmResult<String> {
debug!(
system_len = system.len(),
user_len = user.len(),
max_tokens = max_tokens,
"Starting LLM completion with max tokens"
);
self.executor
.complete_with_max_tokens(system, user, max_tokens)
.await
}
pub async fn complete_json<T: DeserializeOwned>(
&self,
system: &str,
user: &str,
) -> LlmResult<T> {
let response = self.complete(system, user).await?;
self.parse_json(&response)
}
pub async fn complete_json_with_max_tokens<T: DeserializeOwned>(
&self,
system: &str,
user: &str,
max_tokens: u16,
) -> LlmResult<T> {
let response = self
.complete_with_max_tokens(system, user, max_tokens)
.await?;
self.parse_json(&response)
}
fn parse_json<T: DeserializeOwned>(&self, text: &str) -> LlmResult<T> {
let json_text = self.extract_json(text);
serde_json::from_str(&json_text).map_err(|e| {
LlmError::Parse(format!("Failed to parse JSON: {}. Response: {}", e, text))
})
}
fn extract_json<'a>(&self, text: &'a str) -> Cow<'a, str> {
let text = text.trim();
if text.starts_with("```") {
if let Some(start) = text.find('\n') {
let rest = &text[start + 1..];
if let Some(end) = rest.find("```") {
return Cow::Borrowed(rest[..end].trim());
}
}
}
if text.starts_with('[') || text.starts_with('{') {
let open = text.chars().next().unwrap();
let close = if open == '[' { ']' } else { '}' };
let mut depth = 0;
for (i, ch) in text.char_indices() {
match ch {
c if c == open => depth += 1,
c if c == close => {
depth -= 1;
if depth == 0 {
return Cow::Borrowed(&text[..=i]);
}
}
_ => {}
}
}
}
Cow::Borrowed(text)
}
}
impl Default for LlmClient {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_json_plain() {
let client = LlmClient::with_defaults();
let json = client.extract_json(r#"{"key": "value"}"#);
assert_eq!(json, r#"{"key": "value"}"#);
}
#[test]
fn test_extract_json_code_block() {
let client = LlmClient::with_defaults();
let json = client.extract_json(
r#"```json
{"key": "value"}
```"#,
);
assert_eq!(json, r#"{"key": "value"}"#);
}
#[test]
fn test_extract_json_array() {
let client = LlmClient::with_defaults();
let json = client.extract_json(r#"[1, 2, 3]"#);
assert_eq!(json, r#"[1, 2, 3]"#);
}
#[test]
fn test_extract_json_nested() {
let client = LlmClient::with_defaults();
let json = client.extract_json(r#"{"outer": {"inner": 1}}"#);
assert_eq!(json, r#"{"outer": {"inner": 1}}"#);
}
#[test]
fn test_client_creation() {
let client = LlmClient::for_model("gpt-4o");
assert_eq!(client.config().model, "gpt-4o");
}
#[test]
fn test_client_with_concurrency() {
use crate::throttle::ConcurrencyConfig;
let controller = ConcurrencyController::new(ConcurrencyConfig::conservative());
let client = LlmClient::for_model("gpt-4o-mini").with_concurrency(controller);
assert!(client.concurrency().is_some());
}
}