use anyhow::{Context, Result};
use async_trait::async_trait;
use futures::stream::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
#[async_trait]
pub trait LLMClientTrait: Send + Sync {
async fn call(&self, prompt: &str) -> Result<String>;
}
pub type DynLLMClient = Arc<dyn LLMClientTrait>;
const DEFAULT_TIMEOUT_SECS: u64 = 60;
#[derive(Debug, Clone, Default, Serialize)]
pub struct ProviderPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub order: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allow_fallbacks: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub require_parameters: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data_collection: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub only: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ignore: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub quantizations: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sort: Option<String>,
}
impl ProviderPreferences {
pub fn new() -> Self {
Self::default()
}
pub fn with_order(mut self, order: Vec<String>) -> Self {
self.order = Some(order);
self
}
pub fn with_allow_fallbacks(mut self, allow: bool) -> Self {
self.allow_fallbacks = Some(allow);
self
}
pub fn with_require_parameters(mut self, require: bool) -> Self {
self.require_parameters = Some(require);
self
}
pub fn with_data_collection(mut self, policy: impl Into<String>) -> Self {
self.data_collection = Some(policy.into());
self
}
pub fn with_only(mut self, providers: Vec<String>) -> Self {
self.only = Some(providers);
self
}
pub fn with_ignore(mut self, providers: Vec<String>) -> Self {
self.ignore = Some(providers);
self
}
pub fn with_quantizations(mut self, quantizations: Vec<String>) -> Self {
self.quantizations = Some(quantizations);
self
}
pub fn with_sort(mut self, sort: impl Into<String>) -> Self {
self.sort = Some(sort.into());
self
}
}
#[derive(Clone)]
pub struct LLMClient {
pub api_key: String,
pub base_url: String,
pub model: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub timeout: Duration,
pub provider: Option<ProviderPreferences>,
}
impl std::fmt::Debug for LLMClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LLMClient")
.field("api_key", &"[REDACTED]")
.field("base_url", &self.base_url)
.field("model", &self.model)
.field("max_tokens", &self.max_tokens)
.field("temperature", &self.temperature)
.field("timeout", &self.timeout)
.field("provider", &self.provider)
.finish()
}
}
impl LLMClient {
pub fn openai(api_key: String, model: String) -> Self {
Self {
api_key,
base_url: "https://api.openai.com/v1".to_string(),
model,
max_tokens: None,
temperature: None,
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
provider: None,
}
}
pub fn anthropic(api_key: String, model: String) -> Self {
Self {
api_key,
base_url: "https://api.anthropic.com/v1".to_string(),
model,
max_tokens: None,
temperature: None,
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
provider: None,
}
}
pub fn custom(api_key: String, base_url: String, model: String) -> Self {
Self {
api_key,
base_url,
model,
max_tokens: None,
temperature: None,
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
provider: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_provider(mut self, provider: ProviderPreferences) -> Self {
self.provider = Some(provider);
self
}
pub async fn call_direct(&self, prompt: &str) -> Result<String> {
if std::env::var("BAML_DEBUG").is_ok() {
eprintln!("[BAML DEBUG] Prompt:\n{}", prompt);
}
let request_body = ChatCompletionRequest {
model: self.model.clone(),
messages: vec![
Message {
role: "user".to_string(),
content: prompt.to_string(),
}
],
max_tokens: self.max_tokens,
temperature: self.temperature,
stream: false,
provider: self.provider.clone(),
};
let client = reqwest::Client::builder()
.timeout(self.timeout)
.build()
.context("Failed to build HTTP client")?;
let response = client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.context("Failed to send request to LLM API")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("LLM API error ({}): {}", status, error_text);
}
let response_body: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse LLM API response")?;
if std::env::var("BAML_DEBUG").is_ok() {
if let Some(choice) = response_body.choices.first() {
eprintln!("[BAML DEBUG] LLM Response:\n{}", choice.message.content);
}
}
response_body
.choices
.first()
.and_then(|choice| Some(choice.message.content.clone()))
.ok_or_else(|| anyhow::anyhow!("No response from LLM"))
}
pub async fn call_stream(
&self,
prompt: &str,
) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
let request_body = ChatCompletionRequest {
model: self.model.clone(),
messages: vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
}],
max_tokens: self.max_tokens,
temperature: self.temperature,
stream: true,
provider: self.provider.clone(),
};
let client = reqwest::Client::builder()
.timeout(self.timeout)
.build()
.context("Failed to build HTTP client")?;
let response = client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.context("Failed to send streaming request to LLM API")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("LLM API error ({}): {}", status, error_text);
}
let stream = response.bytes_stream().map(move |chunk_result| {
let chunk = chunk_result.context("Failed to read stream chunk")?;
let text = String::from_utf8_lossy(&chunk);
let mut content = String::new();
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data.trim() == "[DONE]" {
continue;
}
if let Ok(sse_event) = serde_json::from_str::<StreamChatCompletionChunk>(data) {
if let Some(choice) = sse_event.choices.first() {
if let Some(delta_content) = &choice.delta.content {
content.push_str(delta_content);
}
}
}
}
}
Ok(content)
});
Ok(Box::pin(stream))
}
}
#[async_trait]
impl LLMClientTrait for LLMClient {
async fn call(&self, prompt: &str) -> Result<String> {
self.call_direct(prompt).await
}
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
provider: Option<ProviderPreferences>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: Message,
}
#[derive(Debug, Deserialize)]
struct StreamChatCompletionChunk {
choices: Vec<StreamChoice>,
}
#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: Delta,
}
#[derive(Debug, Deserialize)]
struct Delta {
content: Option<String>,
}
pub struct MockLLMClient {
responses: HashMap<String, String>,
}
impl MockLLMClient {
pub fn new() -> Self {
Self {
responses: HashMap::new(),
}
}
pub fn add_response(&mut self, pattern: &str, response: &str) {
self.responses.insert(pattern.to_string(), response.to_string());
}
pub async fn call_direct(&self, prompt: &str) -> Result<String> {
for (pattern, response) in &self.responses {
if prompt.contains(pattern) {
return Ok(response.clone());
}
}
Ok(r#"{"name": "Mock Response", "age": 25}"#.to_string())
}
}
#[async_trait]
impl LLMClientTrait for MockLLMClient {
async fn call(&self, prompt: &str) -> Result<String> {
self.call_direct(prompt).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_client() {
let mut client = MockLLMClient::new();
client.add_response("Extract person", r#"{"name": "John", "age": 30}"#);
let response = client.call("Extract person info from text").await.unwrap();
assert_eq!(response, r#"{"name": "John", "age": 30}"#);
}
#[test]
fn test_client_configuration() {
let client = LLMClient::openai("test-key".to_string(), "gpt-4".to_string());
assert_eq!(client.model, "gpt-4");
assert_eq!(client.base_url, "https://api.openai.com/v1");
}
#[test]
fn test_api_key_redacted_in_debug() {
let secret_key = "sk-super-secret-api-key-12345";
let client = LLMClient::openai(secret_key.to_string(), "gpt-4".to_string());
let debug_output = format!("{:?}", client);
assert!(!debug_output.contains(secret_key), "API key should not appear in debug output");
assert!(debug_output.contains("[REDACTED]"), "Debug output should show [REDACTED]");
assert!(debug_output.contains("gpt-4"), "Model should appear in debug output");
}
#[test]
fn test_default_timeout() {
let client = LLMClient::openai("test-key".to_string(), "gpt-4".to_string());
assert_eq!(client.timeout, Duration::from_secs(60));
}
#[test]
fn test_custom_timeout() {
let client = LLMClient::openai("test-key".to_string(), "gpt-4".to_string())
.with_timeout(Duration::from_secs(120));
assert_eq!(client.timeout, Duration::from_secs(120));
}
#[test]
fn test_timeout_in_debug_output() {
let client = LLMClient::openai("test-key".to_string(), "gpt-4".to_string());
let debug_output = format!("{:?}", client);
assert!(debug_output.contains("timeout"), "Debug output should include timeout");
}
}