use std::sync::Arc;
use async_trait::async_trait;
use serde::Deserialize;
use super::types::{EntityKind, ExtractedEntities, ExtractedEntity, ExtractedTopic};
use super::EntityExtractor;
#[path = "llm_prompt.rs"]
mod prompt;
use prompt::build_system_prompt;
const EXTRACTION_MAX_OUTPUT_TOKENS: u32 = 8192;
#[derive(Clone, Debug)]
pub struct ChatPrompt {
pub system: String,
pub user: String,
pub temperature: f32,
pub kind: &'static str,
pub max_tokens: Option<u32>,
}
#[async_trait]
pub trait ChatProvider: Send + Sync {
fn name(&self) -> &str;
async fn chat_for_json(&self, prompt: &ChatPrompt) -> anyhow::Result<String>;
}
#[derive(Clone, Debug)]
pub struct LlmExtractorConfig {
pub model: String,
pub allowed_kinds: Vec<EntityKind>,
pub strict_kinds: bool,
pub emit_topics: bool,
pub output_language: Option<String>,
}
impl Default for LlmExtractorConfig {
fn default() -> Self {
Self {
model: "qwen2.5:0.5b".to_string(),
allowed_kinds: vec![
EntityKind::Person,
EntityKind::Organization,
EntityKind::Location,
EntityKind::Event,
EntityKind::Product,
EntityKind::Datetime,
EntityKind::Technology,
EntityKind::Artifact,
EntityKind::Quantity,
],
strict_kinds: false,
emit_topics: false,
output_language: None,
}
}
}
pub struct LlmEntityExtractor {
cfg: LlmExtractorConfig,
provider: Arc<dyn ChatProvider>,
}
impl LlmEntityExtractor {
pub fn new(cfg: LlmExtractorConfig, provider: Arc<dyn ChatProvider>) -> Self {
Self { cfg, provider }
}
fn build_prompt(&self, text: &str) -> ChatPrompt {
ChatPrompt {
system: build_system_prompt(self.cfg.emit_topics, self.cfg.output_language.as_deref()),
user: format!("Text:\n{text}\n\nReturn JSON only."),
temperature: 0.0,
kind: "memory_tree::extract",
max_tokens: Some(EXTRACTION_MAX_OUTPUT_TOKENS),
}
}
}
#[async_trait]
impl EntityExtractor for LlmEntityExtractor {
fn name(&self) -> &'static str {
"llm"
}
async fn extract(&self, text: &str) -> anyhow::Result<ExtractedEntities> {
const MAX_ATTEMPTS: u32 = 3;
for _ in 0..MAX_ATTEMPTS {
match self.try_extract(text).await {
AttemptOutcome::Done(extracted) => return Ok(extracted),
AttemptOutcome::Permanent => return Ok(ExtractedEntities::default()),
AttemptOutcome::Retryable => continue,
}
}
Ok(ExtractedEntities::default())
}
}
enum AttemptOutcome {
Done(ExtractedEntities),
Retryable,
Permanent,
}
impl LlmEntityExtractor {
async fn try_extract(&self, text: &str) -> AttemptOutcome {
let prompt = self.build_prompt(text);
let raw = match self.provider.chat_for_json(&prompt).await {
Ok(v) => v,
Err(e) => {
if is_non_retryable(&e) {
return AttemptOutcome::Permanent;
}
return AttemptOutcome::Retryable;
}
};
let parsed: LlmExtractionOutput = match serde_json::from_str(&raw) {
Ok(v) => v,
Err(e) if e.is_eof() => {
return AttemptOutcome::Retryable;
}
Err(_e) => {
return AttemptOutcome::Done(ExtractedEntities::default());
}
};
AttemptOutcome::Done(parsed.into_extracted_entities(text, &self.cfg))
}
}
fn is_non_retryable(err: &anyhow::Error) -> bool {
let lower = format!("{err:#}").to_lowercase();
lower.contains("402")
|| lower.contains("payment required")
|| lower.contains("requires more credits")
|| lower.contains("insufficient")
|| lower.contains("monthly_request_count")
|| lower.contains("monthly request")
|| lower.contains("quota")
|| lower.contains("401")
|| lower.contains("403")
|| lower.contains("unauthorized")
|| lower.contains("forbidden")
|| lower.contains("invalid api key")
|| lower.contains("incorrect api key")
}
#[derive(Debug, Deserialize)]
struct LlmExtractionOutput {
#[serde(default)]
entities: Vec<LlmEntity>,
#[serde(default)]
topics: Vec<String>,
#[serde(default)]
importance: Option<f32>,
#[serde(default)]
importance_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct LlmEntity {
kind: String,
text: String,
}
impl LlmExtractionOutput {
fn into_extracted_entities(
self,
source_text: &str,
cfg: &LlmExtractorConfig,
) -> ExtractedEntities {
let mut entities = Vec::with_capacity(self.entities.len());
use std::collections::HashMap;
let mut cursors: HashMap<String, (usize /*byte*/, u32 /*char*/)> = HashMap::new();
for raw in self.entities {
let surface = raw.text.trim();
if surface.is_empty() {
continue;
}
let kind = match parse_kind(&raw.kind) {
Some(k) => {
if cfg.allowed_kinds.contains(&k) {
k
} else if cfg.strict_kinds {
continue;
} else {
EntityKind::Misc
}
}
None => {
if cfg.strict_kinds {
continue;
}
EntityKind::Misc
}
};
let (byte_from, char_from) = cursors.get(surface).copied().unwrap_or((0, 0));
let (span_start, span_end, byte_after) =
match find_char_span_from(source_text, surface, byte_from, char_from) {
Some(s) => s,
None => {
continue;
}
};
cursors.insert(surface.to_string(), (byte_after, span_end));
entities.push(ExtractedEntity {
kind,
text: surface.to_string(),
span_start,
span_end,
score: 0.85, });
}
let llm_importance = self.importance.map(|v| v.clamp(0.0, 1.0));
let topics = self
.topics
.into_iter()
.filter_map(|raw| {
let label = raw.trim().to_string();
if label.is_empty() {
None
} else {
Some(ExtractedTopic { label, score: 0.85 })
}
})
.collect();
ExtractedEntities {
entities,
topics,
llm_importance,
llm_importance_reason: self.importance_reason,
}
}
}
fn parse_kind(s: &str) -> Option<EntityKind> {
match s.trim().to_lowercase().as_str() {
"person" | "people" => Some(EntityKind::Person),
"organization" | "organisation" | "org" => Some(EntityKind::Organization),
"location" | "place" | "loc" => Some(EntityKind::Location),
"event" => Some(EntityKind::Event),
"product" => Some(EntityKind::Product),
"datetime" | "date" | "time" | "timestamp" => Some(EntityKind::Datetime),
"technology" | "tech" | "tool" | "framework" | "library" | "language" | "service" => {
Some(EntityKind::Technology)
}
"artifact" | "reference" | "ref" | "pr" | "ticket" | "file" | "commit" => {
Some(EntityKind::Artifact)
}
"quantity" | "amount" | "metric" | "number" | "money" => Some(EntityKind::Quantity),
"misc" | "miscellaneous" | "other" => Some(EntityKind::Misc),
_ => None,
}
}
#[allow(dead_code)]
fn find_char_span(haystack: &str, needle: &str) -> Option<(u32, u32)> {
find_char_span_from(haystack, needle, 0, 0).map(|(s, e, _)| (s, e))
}
fn find_char_span_from(
haystack: &str,
needle: &str,
byte_from: usize,
char_from: u32,
) -> Option<(u32, u32, usize)> {
if needle.is_empty() || byte_from > haystack.len() {
return None;
}
if !haystack.is_char_boundary(byte_from) {
return None;
}
let rel = haystack[byte_from..].find(needle)?;
let byte_start = byte_from + rel;
let byte_end = byte_start + needle.len();
let char_start = char_from + haystack[byte_from..byte_start].chars().count() as u32;
let char_end = char_start + needle.chars().count() as u32;
Some((char_start, char_end, byte_end))
}
#[allow(dead_code)]
fn truncate_for_log(s: &str, max_chars: usize) -> String {
if s.chars().count() <= max_chars {
return s.to_string();
}
let truncated: String = s.chars().take(max_chars).collect();
format!("{truncated}…")
}
#[cfg(test)]
#[path = "llm_tests.rs"]
mod tests;