use crate::fact::MemoryTier;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
content: content.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedFact {
pub text: String,
#[serde(default)]
pub entities: Vec<ExtractedEntity>,
#[serde(default)]
pub relationships: Vec<ExtractedRelationship>,
#[serde(default = "default_confidence")]
pub confidence: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub category: Option<String>,
}
fn default_confidence() -> f64 {
1.0
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedEntity {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub entity_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedRelationship {
pub source: String,
pub relation: String,
pub target: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ExtractionResult {
pub facts: Vec<ExtractedFact>,
}
#[derive(Debug, Clone)]
pub struct ExtractionConfig {
pub custom_prompt: Option<String>,
pub skip_categories: Vec<String>,
pub rules: Vec<ExtractionRule>,
pub dedup_threshold: f32,
pub default_tier: MemoryTier,
}
impl Default for ExtractionConfig {
fn default() -> Self {
Self {
custom_prompt: None,
skip_categories: Vec::new(),
rules: Vec::new(),
dedup_threshold: 0.92,
default_tier: MemoryTier::Conversation,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractionRule {
pub category: String,
#[serde(default = "default_priority")]
pub priority: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl: Option<String>,
#[serde(default)]
pub pii: bool,
}
fn default_priority() -> f64 {
1.0
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConflictVerdict {
Duplicate,
Contradicts,
Refines,
NoConflict,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extraction_config_defaults() {
let cfg = ExtractionConfig::default();
assert!((cfg.dedup_threshold - 0.92).abs() < f32::EPSILON);
assert_eq!(cfg.default_tier, MemoryTier::Conversation);
assert!(cfg.skip_categories.is_empty());
}
#[test]
fn message_constructors() {
let m = Message::user("hello");
assert_eq!(m.role, "user");
let m = Message::assistant("hi");
assert_eq!(m.role, "assistant");
}
#[test]
fn extracted_fact_deserializes_with_defaults() {
let json = r#"{"text": "User likes pizza"}"#;
let fact: ExtractedFact = serde_json::from_str(json).unwrap();
assert_eq!(fact.confidence, 1.0);
assert!(fact.entities.is_empty());
}
}