use std::sync::Arc;
use async_trait::async_trait;
use entelix_core::ir::{CacheControl, ContentPart, Message, Role};
use entelix_core::{ExecutionContext, Result};
use entelix_runnable::Runnable;
use crate::chunker::Chunker;
use crate::document::Document;
pub const CONTEXTUAL_CHUNKER_DEFAULT_INSTRUCTION: &str = "\
You are an expert assistant that produces a short standalone context for a chunk extracted \
from a document. The context will be prepended to the chunk to improve retrieval accuracy. \
Reply with one or two sentences (50-100 tokens) describing how this chunk relates to the \
overall document. Do not echo the chunk content — produce only the contextual summary.";
const CHUNKER_NAME: &str = "contextual";
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[non_exhaustive]
pub enum FailurePolicy {
#[default]
KeepOriginal,
Skip,
Abort,
}
pub struct ContextualChunkerBuilder<M> {
model: Arc<M>,
instruction: String,
cache_control: Option<CacheControl>,
failure_policy: FailurePolicy,
}
impl<M> ContextualChunkerBuilder<M>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
#[must_use]
pub fn with_instruction(mut self, instruction: impl Into<String>) -> Self {
self.instruction = instruction.into();
self
}
#[must_use]
pub const fn with_cache_control(mut self, cache: CacheControl) -> Self {
self.cache_control = Some(cache);
self
}
#[must_use]
pub const fn with_failure_policy(mut self, policy: FailurePolicy) -> Self {
self.failure_policy = policy;
self
}
#[must_use]
pub fn build(self) -> ContextualChunker<M> {
ContextualChunker {
model: self.model,
instruction: Arc::from(self.instruction),
cache_control: self.cache_control,
failure_policy: self.failure_policy,
}
}
}
pub struct ContextualChunker<M> {
model: Arc<M>,
instruction: Arc<str>,
cache_control: Option<CacheControl>,
failure_policy: FailurePolicy,
}
impl<M> ContextualChunker<M>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
#[must_use]
pub fn builder(model: Arc<M>) -> ContextualChunkerBuilder<M> {
ContextualChunkerBuilder {
model,
instruction: CONTEXTUAL_CHUNKER_DEFAULT_INSTRUCTION.to_owned(),
cache_control: None,
failure_policy: FailurePolicy::default(),
}
}
fn build_prompt(&self, parent_content: &str, chunk_content: &str) -> Vec<Message> {
let parent_text = format!("<document>\n{parent_content}\n</document>");
let parent_part = self.cache_control.map_or_else(
|| ContentPart::text(parent_text.clone()),
|cache| ContentPart::Text {
text: parent_text.clone(),
cache_control: Some(cache),
provider_echoes: Vec::new(),
},
);
let user = Message::new(
Role::User,
vec![
ContentPart::text(self.instruction.to_string()),
parent_part,
ContentPart::text(format!("<chunk>\n{chunk_content}\n</chunk>")),
],
);
vec![user]
}
#[must_use]
pub const fn failure_policy(&self) -> FailurePolicy {
self.failure_policy
}
}
impl<M> Clone for ContextualChunker<M> {
fn clone(&self) -> Self {
Self {
model: Arc::clone(&self.model),
instruction: Arc::clone(&self.instruction),
cache_control: self.cache_control,
failure_policy: self.failure_policy,
}
}
}
impl<M> std::fmt::Debug for ContextualChunker<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ContextualChunker")
.field("failure_policy", &self.failure_policy)
.field("cache_control", &self.cache_control.is_some())
.finish_non_exhaustive()
}
}
#[async_trait]
impl<M> Chunker for ContextualChunker<M>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
fn name(&self) -> &'static str {
CHUNKER_NAME
}
async fn process(
&self,
chunks: Vec<Document>,
ctx: &ExecutionContext,
) -> Result<Vec<Document>> {
let mut out = Vec::with_capacity(chunks.len());
for mut chunk in chunks {
if ctx.is_cancelled() {
return Err(entelix_core::Error::Cancelled);
}
let prompt = self.build_prompt(&chunk.content, &chunk.content);
let outcome = self.model.invoke(prompt, ctx).await;
match outcome {
Ok(reply) => {
let prefix = extract_text(&reply);
if !prefix.is_empty() {
chunk.content = format!("{prefix}\n\n{}", chunk.content);
}
if let Some(lineage) = chunk.lineage.as_mut() {
lineage.push_chunker(CHUNKER_NAME);
}
out.push(chunk);
}
Err(err) => match self.failure_policy {
FailurePolicy::KeepOriginal => {
if let Some(lineage) = chunk.lineage.as_mut() {
lineage.push_chunker(CHUNKER_NAME);
}
out.push(chunk);
}
FailurePolicy::Skip => {
}
FailurePolicy::Abort => return Err(err),
},
}
}
Ok(out)
}
}
fn extract_text(message: &Message) -> String {
let mut buf = String::new();
for part in &message.content {
if let ContentPart::Text { text, .. } = part {
if !buf.is_empty() {
buf.push_str("\n\n");
}
buf.push_str(text);
}
}
buf.trim().to_owned()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::document::{Lineage, Source};
use entelix_memory::Namespace;
use std::sync::Mutex;
fn ns() -> Namespace {
Namespace::new(entelix_core::TenantId::new("acme"))
}
fn chunk_with_content(content: &str, idx: u32) -> Document {
let parent = Document::root("doc", "<parent>", Source::now("test://", "test"), ns());
let lineage = Lineage::from_split(parent.id.clone(), idx, 3, "test-splitter");
parent.child(content, lineage)
}
struct ScriptedModel {
script: Mutex<Vec<Result<String>>>,
}
impl ScriptedModel {
fn new(script: Vec<Result<String>>) -> Self {
Self {
script: Mutex::new(script.into_iter().rev().collect()),
}
}
}
#[async_trait]
impl Runnable<Vec<Message>, Message> for ScriptedModel {
async fn invoke(&self, _input: Vec<Message>, _ctx: &ExecutionContext) -> Result<Message> {
let next = self
.script
.lock()
.unwrap()
.pop()
.expect("ScriptedModel exhausted");
next.map(|text| Message::new(Role::Assistant, vec![ContentPart::text(text)]))
}
}
#[tokio::test]
async fn empty_input_produces_empty_output() {
let model = Arc::new(ScriptedModel::new(vec![]));
let chunker = ContextualChunker::builder(model).build();
let out = chunker
.process(Vec::new(), &ExecutionContext::new())
.await
.unwrap();
assert!(out.is_empty());
}
#[tokio::test]
async fn happy_path_prepends_contextual_prefix_and_records_lineage() {
let model = Arc::new(ScriptedModel::new(vec![
Ok("This chunk explains the alpha case.".into()),
Ok("This chunk covers the beta path.".into()),
]));
let chunker = ContextualChunker::builder(model).build();
let chunks = vec![
chunk_with_content("alpha body", 0),
chunk_with_content("beta body", 1),
];
let out = chunker
.process(chunks, &ExecutionContext::new())
.await
.unwrap();
assert_eq!(out.len(), 2);
assert!(
out[0]
.content
.starts_with("This chunk explains the alpha case."),
"prefix prepended: {:?}",
out[0].content
);
assert!(out[0].content.ends_with("alpha body"));
for chunk in &out {
let chain = &chunk.lineage.as_ref().unwrap().chunker_chain;
assert_eq!(chain.len(), 1);
assert_eq!(chain[0], CHUNKER_NAME);
}
}
#[tokio::test]
async fn failure_policy_keep_original_passes_through_unmodified_content() {
let model = Arc::new(ScriptedModel::new(vec![
Ok("alpha context.".into()),
Err(entelix_core::Error::provider_http(503, "transient")),
Ok("gamma context.".into()),
]));
let chunker = ContextualChunker::builder(model).build();
let chunks = vec![
chunk_with_content("alpha body", 0),
chunk_with_content("beta body", 1),
chunk_with_content("gamma body", 2),
];
let out = chunker
.process(chunks, &ExecutionContext::new())
.await
.unwrap();
assert_eq!(out.len(), 3);
assert!(out[0].content.starts_with("alpha context."));
assert_eq!(
out[1].content, "beta body",
"failed chunk passes through with original content"
);
assert_eq!(
out[1].lineage.as_ref().unwrap().chunker_chain,
vec![CHUNKER_NAME.to_owned()]
);
assert!(out[2].content.starts_with("gamma context."));
}
#[tokio::test]
async fn failure_policy_skip_drops_failed_chunks() {
let model = Arc::new(ScriptedModel::new(vec![
Ok("alpha context.".into()),
Err(entelix_core::Error::provider_http(503, "transient")),
Ok("gamma context.".into()),
]));
let chunker = ContextualChunker::builder(model)
.with_failure_policy(FailurePolicy::Skip)
.build();
let chunks = vec![
chunk_with_content("alpha body", 0),
chunk_with_content("beta body", 1),
chunk_with_content("gamma body", 2),
];
let out = chunker
.process(chunks, &ExecutionContext::new())
.await
.unwrap();
assert_eq!(out.len(), 2, "failed chunk dropped");
assert!(out[0].content.starts_with("alpha context."));
assert!(out[1].content.starts_with("gamma context."));
}
#[tokio::test]
async fn failure_policy_abort_returns_first_error() {
let model = Arc::new(ScriptedModel::new(vec![
Ok("alpha context.".into()),
Err(entelix_core::Error::provider_http(503, "transient")),
]));
let chunker = ContextualChunker::builder(model)
.with_failure_policy(FailurePolicy::Abort)
.build();
let chunks = vec![
chunk_with_content("alpha body", 0),
chunk_with_content("beta body", 1),
chunk_with_content("gamma body", 2),
];
let err = chunker
.process(chunks, &ExecutionContext::new())
.await
.unwrap_err();
assert!(matches!(
err,
entelix_core::Error::Provider {
kind: entelix_core::ProviderErrorKind::Http(503),
..
}
));
}
#[tokio::test]
async fn cache_control_attached_when_configured() {
let model = Arc::new(ScriptedModel::new(vec![Ok("ok".into())]));
let chunker = ContextualChunker::builder(model)
.with_cache_control(CacheControl::one_hour())
.build();
let prompt = chunker.build_prompt("parent body", "chunk body");
let parent_part = &prompt[0].content[1];
match parent_part {
ContentPart::Text { cache_control, .. } => {
assert!(
cache_control.is_some(),
"cache_control stamped on parent part"
);
}
_ => panic!("parent part must be Text"),
}
}
#[tokio::test]
async fn cancellation_short_circuits_between_chunks() {
let model = Arc::new(ScriptedModel::new(vec![Ok("alpha context.".into())]));
let chunker = ContextualChunker::builder(model).build();
let token = entelix_core::cancellation::CancellationToken::new();
let ctx = ExecutionContext::with_cancellation(token.clone());
token.cancel();
let err = chunker
.process(vec![chunk_with_content("alpha", 0)], &ctx)
.await
.unwrap_err();
assert!(matches!(err, entelix_core::Error::Cancelled));
}
}