use crate::core::config::ProgramConfig;
use crate::core::logger::SQLiteLogger;
use crate::analytics::AnalyticsEngine;
use serde_json::Value;
use std::collections::HashMap;
use std::fs;
use tera::{Context, Tera};
use jsonschema::JSONSchema;
use reqwest::Client;
use std::env;
use redis::Client as RedisClient;
use redis::AsyncCommands;
use std::time::Instant;
use futures_util::StreamExt;
use anyhow::{Result, Context as AnyhowContext};
#[derive(serde::Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
temperature: f32,
max_tokens: u32,
response_format: ResponseFormat,
stream: bool,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(serde::Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
format_type: String,
}
#[derive(serde::Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
usage: Option<Usage>,
}
#[derive(serde::Deserialize, serde::Serialize)]
struct Choice {
message: ChatMessage,
delta: Option<ChatMessage>,
}
#[derive(serde::Deserialize, serde::Serialize)]
struct Usage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
pub struct LLMProgram {
pub program_path: String,
pub config: ProgramConfig,
tera: Tera,
client: Client,
api_key: String,
base_url: String,
redis_client: Option<RedisClient>,
enable_cache: bool,
logger: SQLiteLogger,
analytics_engine: AnalyticsEngine,
}
impl LLMProgram {
pub fn new(program_path: &str) -> Result<Self> {
Self::new_with_options(program_path, None, None, true, "redis://localhost:6379")
}
pub fn new_with_options(
program_path: &str,
api_key: Option<String>,
base_url: Option<String>,
enable_cache: bool,
redis_url: &str,
) -> Result<Self> {
let config = Self::load_config(program_path)?;
let mut tera = Tera::default();
tera.add_raw_template("template", &config.template)
.map_err(|e| anyhow::anyhow!("Failed to add template to Tera: {}", e))?;
let api_key = api_key
.or_else(|| env::var("OPENAI_API_KEY").ok())
.unwrap_or_else(|| "YOUR_API_KEY_HERE".to_string());
let base_url = base_url
.unwrap_or_else(|| "https://api.openai.com/v1/chat/completions".to_string());
let redis_client = if enable_cache {
Some(RedisClient::open(redis_url)
.map_err(|e| anyhow::anyhow!("Failed to connect to Redis: {}", e))?)
} else {
None
};
let db_path = config.database.path.clone().unwrap_or_else(|| {
let mut path = program_path.to_string();
if let Some(dot_idx) = path.rfind('.') {
path.truncate(dot_idx);
}
format!("{}.db", path)
});
let logger = SQLiteLogger::new(&db_path)
.context("Failed to create logger")?;
let analytics_engine = AnalyticsEngine::new("llmprogram_analytics.db")
.context("Failed to create analytics engine")?;
Ok(LLMProgram {
program_path: program_path.to_string(),
config,
tera,
client: Client::new(),
api_key,
base_url,
redis_client,
enable_cache,
logger,
analytics_engine,
})
}
fn load_config(program_path: &str) -> Result<ProgramConfig> {
let content = fs::read_to_string(program_path)
.context("Failed to read program file")?;
let config: ProgramConfig = serde_yaml::from_str(&content)
.context("Failed to parse YAML configuration")?;
Ok(config)
}
pub fn validate_input(&self, input: &Value) -> Result<(), Vec<String>> {
let schema = JSONSchema::compile(&self.config.input_schema)
.map_err(|e| vec![format!("Schema compilation error: {:?}", e)])?;
let result: Result<(), Vec<String>> = match schema.validate(input) {
Ok(_) => Ok(()),
Err(errors) => {
let error_messages: Vec<String> = errors
.into_iter()
.map(|e| format!("Validation error: {}", e))
.collect();
Err(error_messages)
}
};
result
}
pub fn validate_output(&self, output: &Value) -> Result<(), Vec<String>> {
let schema = JSONSchema::compile(&self.config.output_schema)
.map_err(|e| vec![format!("Schema compilation error: {:?}", e)])?;
let result: Result<(), Vec<String>> = match schema.validate(output) {
Ok(_) => Ok(()),
Err(errors) => {
let error_messages: Vec<String> = errors
.into_iter()
.map(|e| format!("Validation error: {}", e))
.collect();
Err(error_messages)
}
};
result
}
pub fn render_template(&self, context: &Context) -> Result<String> {
let rendered = self.tera.render("template", context)
.map_err(|e| anyhow::anyhow!("Template rendering failed: {}", e))?;
Ok(rendered)
}
fn generate_cache_key(&self, user_prompt: &str, _inputs: &HashMap<String, Value>) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
user_prompt.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
async fn get_from_cache(&self, cache_key: &str) -> Result<Option<Value>> {
if !self.enable_cache || self.redis_client.is_none() {
return Ok(None);
}
let mut conn = self.redis_client.as_ref().unwrap().get_async_connection().await
.map_err(|e| anyhow::anyhow!("Failed to get Redis connection: {}", e))?;
let cached: Option<String> = conn.get(cache_key).await
.map_err(|e| anyhow::anyhow!("Failed to get value from Redis: {}", e))?;
if let Some(cached_str) = cached {
let cached_value: Value = serde_json::from_str(&cached_str)
.context("Failed to parse cached value as JSON")?;
Ok(Some(cached_value))
} else {
Ok(None)
}
}
async fn save_to_cache(&self, cache_key: &str, value: &Value) -> Result<()> {
if !self.enable_cache || self.redis_client.is_none() {
return Ok(());
}
let mut conn = self.redis_client.as_ref().unwrap().get_async_connection().await
.map_err(|e| anyhow::anyhow!("Failed to get Redis connection: {}", e))?;
let value_str = serde_json::to_string(value)
.context("Failed to serialize value to JSON")?;
let _: () = conn.set_ex(cache_key, value_str, 3600).await
.map_err(|e| anyhow::anyhow!("Failed to set value in Redis: {}", e))?;
Ok(())
}
pub async fn run(&self, inputs: &HashMap<String, Value>) -> Result<Value> {
let start_time = Instant::now();
let input_value = serde_json::to_value(inputs)
.context("Failed to convert inputs to JSON value")?;
self.validate_input(&input_value)
.map_err(|errors| anyhow::anyhow!("Input validation failed: {:?}", errors))?;
let mut context = Context::new();
for (key, value) in inputs {
context.insert(key, &value);
}
let user_prompt = self.render_template(&context)?;
let cache_key = self.generate_cache_key(&user_prompt, inputs);
let cache_hit = if let Some(cached_result) = self.get_from_cache(&cache_key).await? {
let execution_time = start_time.elapsed().as_secs_f64();
let execution_time_ms = (execution_time * 1000.0) as u32;
let response_metadata = serde_json::json!({
"cache_hit": true,
"cache_source": "redis"
});
self.logger.log_execution(
&input_value,
&cached_result,
&user_prompt,
&serde_json::to_string(&cached_result)
.context("Failed to serialize cached result")?,
&self.config.version,
&response_metadata,
execution_time,
).context("Failed to log cached execution")?;
let program_name = std::path::Path::new(&self.program_path)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
self.analytics_engine.track_llm_call(
program_name,
&self.config.model.name,
None,
None,
None,
execution_time_ms,
true,
"unknown",
).context("Failed to track LLM call analytics")?;
self.analytics_engine.track_program_usage(
program_name,
execution_time_ms,
true,
None,
"unknown",
&serde_json::to_string(&input_value)
.context("Failed to serialize input params")?,
).context("Failed to track program usage analytics")?;
return Ok(cached_result);
} else {
false
};
let request = ChatCompletionRequest {
model: self.config.model.name.clone(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: self.config.system_prompt.clone(),
},
ChatMessage {
role: "user".to_string(),
content: user_prompt.clone(),
},
],
temperature: self.config.model.temperature,
max_tokens: self.config.model.max_tokens,
response_format: ResponseFormat {
format_type: self.config.model.response_format.clone(),
},
stream: false,
};
let response = self.client
.post(&self.base_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send request to OpenAI API")?
.json::<ChatCompletionResponse>()
.await
.context("Failed to parse response from OpenAI API")?;
let content = &response.choices[0].message.content;
let response_json: Value = serde_json::from_str(content)
.context("Failed to parse response content as JSON")?;
self.validate_output(&response_json)
.map_err(|errors| anyhow::anyhow!("Output validation failed: {:?}", errors))?;
self.save_to_cache(&cache_key, &response_json).await?;
let execution_time = start_time.elapsed().as_secs_f64();
let execution_time_ms = (execution_time * 1000.0) as u32;
let response_metadata = serde_json::json!({
"cache_hit": cache_hit,
"usage": response.usage
});
self.logger.log_execution(
&input_value,
&response_json,
&user_prompt,
content,
&self.config.version,
&response_metadata,
execution_time,
).context("Failed to log execution")?;
let program_name = std::path::Path::new(&self.program_path)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
if let Some(usage) = &response.usage {
self.analytics_engine.track_llm_call(
program_name,
&self.config.model.name,
Some(usage.prompt_tokens),
Some(usage.completion_tokens),
Some(usage.total_tokens),
execution_time_ms,
false,
"unknown",
).context("Failed to track LLM call analytics")?;
let cost_estimate = (usage.prompt_tokens as f64 / 1000.0) * 0.03 +
(usage.completion_tokens as f64 / 1000.0) * 0.06;
self.analytics_engine.track_token_usage(
program_name,
&self.config.model.name,
usage.prompt_tokens,
usage.completion_tokens,
usage.total_tokens,
"unknown",
cost_estimate,
).context("Failed to track token usage analytics")?;
}
self.analytics_engine.track_program_usage(
program_name,
execution_time_ms,
true,
None,
"unknown",
&serde_json::to_string(&input_value)
.context("Failed to serialize input params")?,
).context("Failed to track program usage analytics")?;
Ok(response_json)
}
pub async fn stream(
&self,
inputs: &HashMap<String, Value>,
) -> Result<impl futures_util::Stream<Item = Result<Value>>> {
let input_value = serde_json::to_value(inputs)
.context("Failed to convert inputs to JSON value")?;
self.validate_input(&input_value)
.map_err(|errors| anyhow::anyhow!("Input validation failed: {:?}", errors))?;
let mut context = Context::new();
for (key, value) in inputs {
context.insert(key, &value);
}
let user_prompt = self.render_template(&context)?;
let request = ChatCompletionRequest {
model: self.config.model.name.clone(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: self.config.system_prompt.clone(),
},
ChatMessage {
role: "user".to_string(),
content: user_prompt.clone(),
},
],
temperature: self.config.model.temperature,
max_tokens: self.config.model.max_tokens,
response_format: ResponseFormat {
format_type: self.config.model.response_format.clone(),
},
stream: true,
};
let response = self.client
.post(&self.base_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send request to OpenAI API")?;
let stream = response.bytes_stream()
.map(|result| {
match result {
Ok(bytes) => {
match serde_json::from_slice::<ChatCompletionResponse>(&bytes) {
Ok(response) => {
if let Some(delta) = response.choices.first().and_then(|c| c.delta.as_ref()) {
Ok(serde_json::json!({
"type": "content",
"data": delta.content
}))
} else {
Err(anyhow::anyhow!("Invalid response format"))
}
}
Err(_) => {
Ok(serde_json::json!({
"type": "raw",
"data": String::from_utf8_lossy(&bytes)
}))
}
}
}
Err(e) => Err(anyhow::anyhow!(e)),
}
});
Ok(stream)
}
pub async fn batch_process(
&self,
inputs_list: &[HashMap<String, Value>],
) -> Result<Vec<Value>> {
use futures_util::stream::{self, StreamExt};
let results: Vec<Result<Value>> = stream::iter(inputs_list)
.map(|inputs| self.run(inputs))
.buffer_unordered(4)
.collect()
.await;
let mut final_results = Vec::new();
for result in results {
final_results.push(result?);
}
Ok(final_results)
}
}