use std::collections::{HashMap, HashSet};
use regex::Regex;
use serde::Deserialize;
use crate::content::angles::{EvidenceItem, EvidenceType};
use crate::error::LlmError;
use crate::llm::{GenerationParams, LlmProvider};
#[cfg(test)]
mod tests;
const MAX_CITATION_CHARS: usize = 120;
const MIN_CONFIDENCE_FLOOR: f64 = 0.1;
#[derive(Debug, Clone)]
pub struct NeighborContent {
pub node_id: i64,
pub note_title: String,
pub heading_path: Option<String>,
pub snippet: String,
}
#[derive(Debug, Clone)]
pub struct CandidateDataPoint {
pub text: String,
pub source_node_id: i64,
pub source_note_title: String,
}
pub fn pre_filter_data_points(neighbors: &[NeighborContent]) -> Vec<CandidateDataPoint> {
let patterns = [
r"\d+(?:\.\d+)?%",
r"\$[\d,]+(?:\.\d+)?",
r"\d+(?:\.\d+)?x\b",
r"\d{4}-\d{2}-\d{2}",
r"\d+\s+(?:users|customers|downloads|revenue|sales|companies|teams|projects)",
];
let combined = patterns.join("|");
let re = Regex::new(&combined).expect("pre-filter regex is valid");
let mut candidates = Vec::new();
for neighbor in neighbors {
for mat in re.find_iter(&neighbor.snippet) {
let start = mat.start().saturating_sub(20);
let end = (mat.end() + 30).min(neighbor.snippet.len());
let context = neighbor.snippet[start..end].trim().to_string();
candidates.push(CandidateDataPoint {
text: context,
source_node_id: neighbor.node_id,
source_note_title: neighbor.note_title.clone(),
});
}
}
candidates
}
#[derive(Debug, Deserialize)]
struct RawEvidenceItem {
evidence_type: String,
citation_text: String,
source_node_id: i64,
#[serde(default = "default_confidence")]
confidence: f64,
}
fn default_confidence() -> f64 {
0.5
}
pub async fn extract_evidence(
provider: &dyn LlmProvider,
topic: &str,
neighbors: &[NeighborContent],
candidates: &[CandidateDataPoint],
) -> Result<Vec<EvidenceItem>, LlmError> {
let system = build_extraction_prompt(topic, neighbors, candidates);
let user_message = "Extract evidence items as a JSON array.".to_string();
let params = GenerationParams {
max_tokens: 500,
temperature: 0.3,
..Default::default()
};
let resp = provider.complete(&system, &user_message, ¶ms).await?;
tracing::debug!(
raw_response = %resp.text,
"Raw LLM response for evidence extraction"
);
parse_evidence_response(&resp.text, neighbors)
}
fn build_extraction_prompt(
topic: &str,
neighbors: &[NeighborContent],
candidates: &[CandidateDataPoint],
) -> String {
let mut prompt = format!(
"You are an evidence mining engine. Given a topic and related note snippets, \
extract evidence items that could support social media content angles.\n\n\
Topic: {topic}\n\n\
Related notes:\n"
);
for (i, n) in neighbors.iter().enumerate() {
prompt.push_str(&format!(
"[{}] (node_id: {}, title: \"{}\") \"{}\"\n",
i + 1,
n.node_id,
n.note_title,
n.snippet
));
}
if !candidates.is_empty() {
prompt.push_str("\nCandidate data points found by scanning:\n");
for c in candidates {
prompt.push_str(&format!(
"- \"{}\" from \"{}\"\n",
c.text, c.source_note_title
));
}
}
prompt.push_str(
"\nExtract evidence as a JSON array. Each item:\n\
{\n\
\x20 \"evidence_type\": \"contradiction\" | \"data_point\" | \"aha_moment\",\n\
\x20 \"citation_text\": \"exact quote or close paraphrase, max 120 chars\",\n\
\x20 \"source_node_id\": <integer from the list above>,\n\
\x20 \"confidence\": <0.0-1.0>\n\
}\n\n\
Rules:\n\
- Only reference node_ids from the list above.\n\
- citation_text must be grounded in the source snippet.\n\
- For contradictions, identify opposing claims across different notes.\n\
- For aha_moments, identify non-obvious connections.\n\
- For data_points, confirm the candidate data points are relevant to the topic.\n\
- Return [] if no meaningful evidence found.",
);
prompt
}
pub fn parse_evidence_response(
text: &str,
neighbors: &[NeighborContent],
) -> Result<Vec<EvidenceItem>, LlmError> {
let trimmed = text.trim();
if let Ok(raw_items) = serde_json::from_str::<Vec<RawEvidenceItem>>(trimmed) {
return Ok(convert_raw_evidence(raw_items, neighbors));
}
if let Some(json_str) = extract_json_from_code_block(trimmed) {
if let Ok(raw_items) = serde_json::from_str::<Vec<RawEvidenceItem>>(json_str) {
return Ok(convert_raw_evidence(raw_items, neighbors));
}
}
if let Some(start) = trimmed.find('[') {
if let Some(end) = trimmed.rfind(']') {
let slice = &trimmed[start..=end];
if let Ok(raw_items) = serde_json::from_str::<Vec<RawEvidenceItem>>(slice) {
return Ok(convert_raw_evidence(raw_items, neighbors));
}
}
}
tracing::warn!(
raw_response = %text,
"Could not parse evidence extraction response as JSON"
);
Ok(vec![])
}
fn extract_json_from_code_block(text: &str) -> Option<&str> {
let start_marker = "```json";
let end_marker = "```";
let start = text.find(start_marker)?;
let json_start = start + start_marker.len();
let rest = &text[json_start..];
let end = rest.find(end_marker)?;
Some(rest[..end].trim())
}
fn convert_raw_evidence(
raw: Vec<RawEvidenceItem>,
neighbors: &[NeighborContent],
) -> Vec<EvidenceItem> {
let title_map: HashMap<i64, &NeighborContent> =
neighbors.iter().map(|n| (n.node_id, n)).collect();
raw.into_iter()
.filter_map(|r| {
let evidence_type = match r.evidence_type.as_str() {
"contradiction" => EvidenceType::Contradiction,
"data_point" => EvidenceType::DataPoint,
"aha_moment" => EvidenceType::AhaMoment,
_ => return None,
};
let neighbor = title_map.get(&r.source_node_id);
let source_note_title = neighbor.map(|n| n.note_title.clone()).unwrap_or_default();
let source_heading_path = neighbor.and_then(|n| n.heading_path.clone());
Some(EvidenceItem {
evidence_type,
citation_text: r.citation_text,
source_node_id: r.source_node_id,
source_note_title,
source_heading_path,
confidence: r.confidence,
})
})
.collect()
}
pub fn validate_evidence(
evidence: Vec<EvidenceItem>,
accepted_node_ids: &HashSet<i64>,
) -> Vec<EvidenceItem> {
let mut items: Vec<EvidenceItem> = evidence
.into_iter()
.filter(|e| accepted_node_ids.contains(&e.source_node_id))
.filter(|e| e.confidence >= MIN_CONFIDENCE_FLOOR)
.map(|mut e| {
if e.citation_text.len() > MAX_CITATION_CHARS {
let truncated = truncate_at_char_boundary(&e.citation_text, MAX_CITATION_CHARS - 3);
e.citation_text = format!("{truncated}...");
}
e
})
.collect();
let mut seen: HashMap<(EvidenceType, i64), usize> = HashMap::new();
let mut deduped: Vec<EvidenceItem> = Vec::new();
for item in items.drain(..) {
let key = (item.evidence_type, item.source_node_id);
if let Some(&idx) = seen.get(&key) {
if item.confidence > deduped[idx].confidence {
deduped[idx] = item;
}
} else {
seen.insert(key, deduped.len());
deduped.push(item);
}
}
deduped
}
fn truncate_at_char_boundary(s: &str, max_len: usize) -> &str {
if s.len() <= max_len {
return s;
}
let mut end = max_len;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}