use crate::{CommunitySummary, GraphRAGResult, Triple};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextConfig {
pub max_length: usize,
pub include_communities: bool,
pub include_triples: bool,
pub triple_format: TripleFormat,
pub score_weighted: bool,
}
impl Default for ContextConfig {
fn default() -> Self {
Self {
max_length: 8000,
include_communities: true,
include_triples: true,
triple_format: TripleFormat::NaturalLanguage,
score_weighted: true,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum TripleFormat {
NaturalLanguage,
Structured,
Turtle,
JsonLd,
}
pub struct ContextBuilder {
config: ContextConfig,
}
impl Default for ContextBuilder {
fn default() -> Self {
Self::new(ContextConfig::default())
}
}
impl ContextBuilder {
pub fn new(config: ContextConfig) -> Self {
Self { config }
}
pub fn build(
&self,
query: &str,
triples: &[Triple],
communities: &[CommunitySummary],
) -> GraphRAGResult<String> {
let mut context = String::new();
let mut remaining_length = self.config.max_length;
let query_section = format!("## Query\n{}\n\n", query);
if query_section.len() < remaining_length {
context.push_str(&query_section);
remaining_length -= query_section.len();
}
if self.config.include_communities && !communities.is_empty() {
let community_section = self.format_communities(communities, remaining_length / 3);
if community_section.len() < remaining_length {
context.push_str(&community_section);
remaining_length -= community_section.len();
}
}
if self.config.include_triples && !triples.is_empty() {
let triples_section = self.format_triples(triples, remaining_length);
context.push_str(&triples_section);
}
Ok(context)
}
fn format_communities(&self, communities: &[CommunitySummary], max_length: usize) -> String {
let mut result = String::from("## Knowledge Graph Communities\n\n");
for community in communities {
let entry = format!(
"### {}\n{}\n**Entities:** {}\n\n",
community.id,
community.summary,
community
.entities
.iter()
.take(5)
.cloned()
.collect::<Vec<_>>()
.join(", ")
);
if result.len() + entry.len() > max_length {
break;
}
result.push_str(&entry);
}
result
}
fn format_triples(&self, triples: &[Triple], max_length: usize) -> String {
let mut result = String::from("## Knowledge Graph Facts\n\n");
for triple in triples {
let entry = match self.config.triple_format {
TripleFormat::NaturalLanguage => self.triple_to_natural_language(triple),
TripleFormat::Structured => self.triple_to_structured(triple),
TripleFormat::Turtle => self.triple_to_turtle(triple),
TripleFormat::JsonLd => self.triple_to_jsonld(triple),
};
if result.len() + entry.len() > max_length {
break;
}
result.push_str(&entry);
result.push('\n');
}
result
}
fn triple_to_natural_language(&self, triple: &Triple) -> String {
let subject = self.extract_local_name(&triple.subject);
let predicate = self.predicate_to_phrase(&triple.predicate);
let object = self.extract_local_name(&triple.object);
format!("- {} {} {}", subject, predicate, object)
}
fn triple_to_structured(&self, triple: &Triple) -> String {
let subject = self.extract_local_name(&triple.subject);
let predicate = self.extract_local_name(&triple.predicate);
let object = self.extract_local_name(&triple.object);
format!("- {} → {} → {}", subject, predicate, object)
}
fn triple_to_turtle(&self, triple: &Triple) -> String {
format!(
"<{}> <{}> <{}> .",
triple.subject, triple.predicate, triple.object
)
}
fn triple_to_jsonld(&self, triple: &Triple) -> String {
let subject = self.extract_local_name(&triple.subject);
let predicate = self.extract_local_name(&triple.predicate);
let object = self.extract_local_name(&triple.object);
format!(
"{{ \"@id\": \"{}\", \"{}\": \"{}\" }}",
subject, predicate, object
)
}
fn extract_local_name(&self, uri: &str) -> String {
uri.rsplit('#')
.next()
.filter(|s| s != &uri) .or_else(|| uri.rsplit('/').next())
.unwrap_or(uri)
.to_string()
}
fn predicate_to_phrase(&self, predicate: &str) -> String {
let local = self.extract_local_name(predicate);
match local.as_str() {
"type" | "rdf:type" => "is a".to_string(),
"label" | "rdfs:label" => "is labeled".to_string(),
"subClassOf" => "is a subclass of".to_string(),
"partOf" => "is part of".to_string(),
"hasPart" => "has part".to_string(),
"relatedTo" => "is related to".to_string(),
"sameAs" => "is the same as".to_string(),
"knows" => "knows".to_string(),
"worksFor" => "works for".to_string(),
"locatedIn" => "is located in".to_string(),
_ => {
let mut result = String::new();
for (i, c) in local.chars().enumerate() {
if i > 0 && c.is_uppercase() {
result.push(' ');
}
result.push(c.to_lowercase().next().unwrap_or(c));
}
result
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_building() {
let builder = ContextBuilder::default();
let triples = vec![
Triple::new(
"http://example.org/Battery1",
"http://example.org/hasStatus",
"http://example.org/Critical",
),
Triple::new(
"http://example.org/Battery1",
"http://example.org/temperature",
"85",
),
];
let communities = vec![CommunitySummary {
id: "community_0".to_string(),
summary: "Battery monitoring entities".to_string(),
entities: vec!["Battery1".to_string(), "Sensor1".to_string()],
representative_triples: vec![],
level: 0,
modularity: 0.5,
}];
let context = builder
.build("What is the battery status?", &triples, &communities)
.expect("should succeed");
assert!(context.contains("Query"));
assert!(context.contains("Battery1"));
}
#[test]
fn test_predicate_to_phrase() {
let builder = ContextBuilder::default();
assert_eq!(
builder.predicate_to_phrase("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
"is a"
);
assert_eq!(
builder.predicate_to_phrase("http://example.org/partOf"),
"is part of"
);
assert_eq!(
builder.predicate_to_phrase("http://example.org/hasTemperature"),
"has temperature"
);
}
}