use crate::Result;
use rig::providers::openai::Client;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMConfig {
pub api_base_url: String,
pub api_key: String,
pub model_efficient: String,
pub temperature: f32,
pub max_tokens: usize,
}
impl Default for LLMConfig {
fn default() -> Self {
Self {
api_base_url: std::env::var("LLM_API_BASE_URL")
.unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
api_key: std::env::var("LLM_API_KEY")
.unwrap_or_else(|_| "".to_string()),
model_efficient: std::env::var("LLM_MODEL")
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string()),
temperature: 0.1,
max_tokens: 4096,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryExtractionResponse {
pub facts: Vec<ExtractedFactRaw>,
pub decisions: Vec<ExtractedDecisionRaw>,
pub entities: Vec<ExtractedEntityRaw>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedFactRaw {
pub content: String,
#[serde(default)]
pub subject: Option<String>,
pub confidence: f32,
#[serde(default)]
pub importance: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedDecisionRaw {
pub decision: String,
pub context: String,
pub rationale: Option<String>,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedEntityRaw {
pub name: String,
pub entity_type: String,
pub description: Option<String>,
pub confidence: f32,
}
#[async_trait::async_trait]
pub trait LLMClient: Send + Sync {
async fn complete(&self, prompt: &str) -> Result<String>;
async fn complete_with_system(&self, system: &str, prompt: &str) -> Result<String>;
async fn extract_memories(&self, prompt: &str) -> Result<MemoryExtractionResponse>;
async fn extract_structured_facts(&self, prompt: &str) -> Result<crate::llm::extractor_types::StructuredFactExtraction>;
async fn extract_detailed_facts(&self, prompt: &str) -> Result<crate::llm::extractor_types::DetailedFactExtraction>;
fn model_name(&self) -> &str;
fn config(&self) -> &LLMConfig;
}
pub struct LLMClientImpl {
client: Client,
config: LLMConfig,
}
impl LLMClientImpl {
pub fn new(config: LLMConfig) -> Result<Self> {
let client = Client::builder()
.api_key(&config.api_key)
.base_url(&config.api_base_url)
.build()
.map_err(|e| crate::Error::Llm(format!("Failed to build OpenAI client: {:?}", e)))?;
Ok(Self { client, config })
}
pub fn default_config() -> LLMConfig {
LLMConfig::default()
}
pub async fn create_agent(&self, system_prompt: &str) -> Result<rig::agent::Agent<rig::providers::openai::CompletionModel>> {
use rig::client::CompletionClient;
let agent = self.client.clone()
.completions_api() .agent(&self.config.model_efficient)
.preamble(system_prompt)
.build();
Ok(agent)
}
pub async fn complete(&self, prompt: &str) -> Result<String> {
use rig::completion::Prompt;
tracing::info!("LLM call started [model: {}]", self.config.model_efficient);
tracing::debug!("Prompt length: {} chars", prompt.len());
let start = std::time::Instant::now();
let agent = self.create_agent("You are a helpful assistant.").await?;
let response = agent
.prompt(prompt)
.await
.map_err(|e| crate::Error::Llm(format!("LLM completion failed: {}", e)))?;
let elapsed = start.elapsed();
tracing::info!("LLM call completed [elapsed: {:.2}s, response: {} chars]", elapsed.as_secs_f64(), response.len());
Ok(response)
}
pub async fn complete_with_system(&self, system: &str, prompt: &str) -> Result<String> {
use rig::completion::Prompt;
tracing::info!("LLM call started (with system) [model: {}]", self.config.model_efficient);
tracing::debug!("System: {}..., Prompt length: {} chars",
&system.chars().take(50).collect::<String>(), prompt.len());
let start = std::time::Instant::now();
let agent = self.create_agent(system).await?;
let response = agent
.prompt(prompt)
.await
.map_err(|e| crate::Error::Llm(format!("LLM completion failed: {}", e)))?;
let elapsed = start.elapsed();
tracing::info!("LLM call completed [elapsed: {:.2}s, response: {} chars]", elapsed.as_secs_f64(), response.len());
Ok(response)
}
pub async fn extract_memories(&self, prompt: &str) -> Result<MemoryExtractionResponse> {
let response: String = self.complete(prompt).await?;
let json_str = Self::extract_json_from_response_static(&response);
if let Ok(extracted) = serde_json::from_str::<MemoryExtractionResponse>(json_str) {
tracing::debug!("Successfully parsed MemoryExtractionResponse");
return Ok(extracted);
}
if let Ok(facts) = serde_json::from_str::<Vec<ExtractedFactRaw>>(json_str) {
tracing::debug!("Parsed as facts array, found {} facts", facts.len());
return Ok(MemoryExtractionResponse {
facts,
decisions: Vec::new(),
entities: Vec::new(),
});
}
if let Ok(decisions) = serde_json::from_str::<Vec<ExtractedDecisionRaw>>(json_str) {
tracing::debug!("Parsed as decisions array, found {} decisions", decisions.len());
return Ok(MemoryExtractionResponse {
facts: Vec::new(),
decisions,
entities: Vec::new(),
});
}
if let Ok(entities) = serde_json::from_str::<Vec<ExtractedEntityRaw>>(json_str) {
tracing::debug!("Parsed as entities array, found {} entities", entities.len());
return Ok(MemoryExtractionResponse {
facts: Vec::new(),
decisions: Vec::new(),
entities,
});
}
eprintln!("[DEBUG] Failed to parse JSON, returning empty extraction");
Ok(MemoryExtractionResponse {
facts: Vec::new(),
decisions: Vec::new(),
entities: Vec::new(),
})
}
pub fn client(&self) -> &Client {
&self.client
}
pub fn extract_json_from_response_static(response: &str) -> &str {
let trimmed = response.trim();
if trimmed.starts_with("```json") {
if let Some(json_start) = trimmed.find('\n') {
let rest = &trimmed[json_start + 1..];
if let Some(end) = rest.find("```") {
return rest[..end].trim();
}
return rest.trim();
}
} else if trimmed.starts_with("```") {
if let Some(json_start) = trimmed.find('\n') {
let rest = &trimmed[json_start + 1..];
if let Some(end) = rest.find("```") {
return rest[..end].trim();
}
return rest.trim();
}
}
if let Some(start) = trimmed.find('{') {
let mut depth = 0;
for (i, c) in trimmed[start..].char_indices() {
match c {
'{' => depth += 1,
'}' => {
depth -= 1;
if depth == 0 {
return &trimmed[start..start + i + 1];
}
}
_ => {}
}
}
}
trimmed
}
}
#[async_trait::async_trait]
impl LLMClient for LLMClientImpl {
async fn complete(&self, prompt: &str) -> Result<String> {
use rig::completion::Prompt;
tracing::info!("LLM call started [model: {}]", self.config.model_efficient);
tracing::debug!("Prompt length: {} chars", prompt.len());
let start = std::time::Instant::now();
let agent = self.create_agent("You are a helpful assistant.").await?;
let response = agent
.prompt(prompt)
.await
.map_err(|e| crate::Error::Llm(format!("LLM completion failed: {}", e)))?;
let elapsed = start.elapsed();
tracing::info!("LLM call completed [elapsed: {:.2}s, response: {} chars]", elapsed.as_secs_f64(), response.len());
Ok(response)
}
async fn complete_with_system(&self, system: &str, prompt: &str) -> Result<String> {
use rig::completion::Prompt;
tracing::info!("LLM call started (with system) [model: {}]", self.config.model_efficient);
tracing::debug!("System: {}..., Prompt length: {} chars",
&system.chars().take(50).collect::<String>(), prompt.len());
let start = std::time::Instant::now();
let agent = self.create_agent(system).await?;
let response = agent
.prompt(prompt)
.await
.map_err(|e| crate::Error::Llm(format!("LLM completion failed: {}", e)))?;
let elapsed = start.elapsed();
tracing::info!("LLM call completed [elapsed: {:.2}s, response: {} chars]", elapsed.as_secs_f64(), response.len());
Ok(response)
}
async fn extract_memories(&self, prompt: &str) -> Result<MemoryExtractionResponse> {
let response: String = self.complete(prompt).await?;
let json_str = Self::extract_json_from_response_static(&response);
if let Ok(extracted) = serde_json::from_str::<MemoryExtractionResponse>(json_str) {
return Ok(extracted);
}
if let Ok(facts) = serde_json::from_str::<Vec<ExtractedFactRaw>>(json_str) {
return Ok(MemoryExtractionResponse {
facts,
decisions: Vec::new(),
entities: Vec::new(),
});
}
if let Ok(decisions) = serde_json::from_str::<Vec<ExtractedDecisionRaw>>(json_str) {
return Ok(MemoryExtractionResponse {
facts: Vec::new(),
decisions,
entities: Vec::new(),
});
}
if let Ok(entities) = serde_json::from_str::<Vec<ExtractedEntityRaw>>(json_str) {
return Ok(MemoryExtractionResponse {
facts: Vec::new(),
decisions: Vec::new(),
entities,
});
}
Ok(MemoryExtractionResponse {
facts: Vec::new(),
decisions: Vec::new(),
entities: Vec::new(),
})
}
async fn extract_structured_facts(&self, prompt: &str) -> Result<crate::llm::extractor_types::StructuredFactExtraction> {
let extraction_prompt = format!(
r#"Extract factual information from the text below.
## Instructions
1. Identify all factual statements that can be verified
2. Focus on concrete facts, not opinions or speculations
3. Each fact should be a single, atomic statement
## Output Format
Return ONLY a valid JSON object with this exact structure:
{{
"facts": ["fact 1", "fact 2", "fact 3"]
}}
If no facts are found, return: {{"facts": []}}
## Text to Analyze
{}
## Response (JSON only)"#,
prompt
);
let response = self.complete(&extraction_prompt).await?;
let json_str = Self::extract_json_from_response_static(&response);
match serde_json::from_str::<crate::llm::extractor_types::StructuredFactExtraction>(json_str) {
Ok(facts) => Ok(facts),
Err(e) => {
tracing::warn!("Failed to parse structured facts: {}. Response: {}", e, json_str);
Ok(crate::llm::extractor_types::StructuredFactExtraction {
facts: vec![],
})
}
}
}
async fn extract_detailed_facts(&self, prompt: &str) -> Result<crate::llm::extractor_types::DetailedFactExtraction> {
let extraction_prompt = format!(
r#"Extract detailed factual information from the text below.
## Instructions
1. Identify all factual statements that can be verified
2. For each fact, determine:
- content: The factual statement
- importance: A score from 0.0 to 1.0 (how important/relevant is this fact)
- category: One of "personal", "work", "preference", "event", "knowledge", "other"
- entities: List of named entities mentioned (people, places, organizations, etc.)
- source_role: Either "user", "assistant", or "system" (who stated this fact)
## Output Format
Return ONLY a valid JSON object with this exact structure:
{{
"facts": [
{{
"content": "The factual statement",
"importance": 0.8,
"category": "personal",
"entities": ["John", "New York"],
"source_role": "user"
}}
]
}}
If no facts are found, return: {{"facts": []}}
## Text to Analyze
{}
## Response (JSON only)"#,
prompt
);
let response = self.complete(&extraction_prompt).await?;
let json_str = Self::extract_json_from_response_static(&response);
match serde_json::from_str::<crate::llm::extractor_types::DetailedFactExtraction>(json_str) {
Ok(facts) => Ok(facts),
Err(e) => {
tracing::warn!("Failed to parse detailed facts: {}. Response: {}", e, json_str);
Ok(crate::llm::extractor_types::DetailedFactExtraction {
facts: vec![],
})
}
}
}
fn model_name(&self) -> &str {
&self.config.model_efficient
}
fn config(&self) -> &LLMConfig {
&self.config
}
}