use crate::fact::MemoryTier;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
content: content.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedFact {
pub text: String,
#[serde(default)]
pub entities: Vec<ExtractedEntity>,
#[serde(default)]
pub relationships: Vec<ExtractedRelationship>,
#[serde(default = "default_confidence")]
pub confidence: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub category: Option<String>,
}
fn default_confidence() -> f64 {
1.0
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(from = "ExtractedEntityInput")]
pub struct ExtractedEntity {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub entity_type: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum ExtractedEntityInput {
Name(String),
Full {
name: String,
#[serde(default)]
entity_type: Option<String>,
},
}
impl From<ExtractedEntityInput> for ExtractedEntity {
fn from(input: ExtractedEntityInput) -> Self {
match input {
ExtractedEntityInput::Name(name) => Self {
name,
entity_type: None,
},
ExtractedEntityInput::Full { name, entity_type } => Self { name, entity_type },
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedRelationship {
pub source: String,
pub relation: String,
pub target: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ExtractionResult {
pub facts: Vec<ExtractedFact>,
}
#[derive(Debug, Clone)]
pub struct ExtractionConfig {
pub custom_prompt: Option<String>,
pub skip_categories: Vec<String>,
pub rules: Vec<ExtractionRule>,
pub dedup_threshold: f32,
pub default_tier: MemoryTier,
}
impl Default for ExtractionConfig {
fn default() -> Self {
Self {
custom_prompt: None,
skip_categories: Vec::new(),
rules: Vec::new(),
dedup_threshold: 0.92,
default_tier: MemoryTier::Conversation,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractionRule {
pub category: String,
#[serde(default = "default_priority")]
pub priority: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl: Option<String>,
#[serde(default)]
pub pii: bool,
}
fn default_priority() -> f64 {
1.0
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConflictVerdict {
Duplicate,
Contradicts,
Refines,
NoConflict,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extraction_config_defaults() {
let cfg = ExtractionConfig::default();
assert!((cfg.dedup_threshold - 0.92).abs() < f32::EPSILON);
assert_eq!(cfg.default_tier, MemoryTier::Conversation);
assert!(cfg.skip_categories.is_empty());
}
#[test]
fn message_constructors() {
let m = Message::user("hello");
assert_eq!(m.role, "user");
let m = Message::assistant("hi");
assert_eq!(m.role, "assistant");
}
#[test]
fn extracted_fact_deserializes_with_defaults() {
let json = r#"{"text": "User likes pizza"}"#;
let fact: ExtractedFact = serde_json::from_str(json).unwrap();
assert_eq!(fact.confidence, 1.0);
assert!(fact.entities.is_empty());
}
#[test]
fn extracted_entity_accepts_bare_string() {
let json = r#""alice""#;
let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
assert_eq!(entity.name, "alice");
assert!(entity.entity_type.is_none());
}
#[test]
fn extracted_entity_accepts_full_struct() {
let json = r#"{"name": "Bangalore", "entity_type": "place"}"#;
let entity: ExtractedEntity = serde_json::from_str(json).unwrap();
assert_eq!(entity.name, "Bangalore");
assert_eq!(entity.entity_type.as_deref(), Some("place"));
}
#[test]
fn extracted_fact_accepts_mixed_entity_shapes() {
let json = r#"{
"text": "The user is allergic to peanuts",
"entities": ["user", {"name": "peanuts", "entity_type": "thing"}],
"confidence": 0.95,
"category": "health"
}"#;
let fact: ExtractedFact = serde_json::from_str(json).unwrap();
assert_eq!(fact.entities.len(), 2);
assert_eq!(fact.entities[0].name, "user");
assert!(fact.entities[0].entity_type.is_none());
assert_eq!(fact.entities[1].name, "peanuts");
assert_eq!(fact.entities[1].entity_type.as_deref(), Some("thing"));
}
}