use std::collections::HashMap;
use cupel::{
ContextBudget, ContextItemBuilder, ContextKind, DiagnosticTraceCollector, ExclusionReason,
GreedySlice, Pipeline, PriorityScorer, RecencyScorer, TraceDetailLevel, UShapedPlacer,
};
use cupel_testing::SelectionReportAssertions;
fn make_pipeline() -> Pipeline {
Pipeline::builder()
.scorer(Box::new(RecencyScorer))
.slicer(Box::new(GreedySlice))
.placer(Box::new(UShapedPlacer))
.build()
.expect("pipeline build failed")
}
fn make_priority_pipeline() -> Pipeline {
Pipeline::builder()
.scorer(Box::new(PriorityScorer))
.slicer(Box::new(GreedySlice))
.placer(Box::new(UShapedPlacer))
.build()
.expect("priority pipeline build failed")
}
fn budget(max_tokens: i64) -> ContextBudget {
ContextBudget::new(max_tokens, max_tokens, 0, HashMap::new(), 0.0)
.expect("budget construction failed")
}
fn run(
pipeline: &Pipeline,
items: &[cupel::ContextItem],
budget: &ContextBudget,
) -> cupel::SelectionReport {
let mut collector = DiagnosticTraceCollector::new(TraceDetailLevel::Item);
pipeline
.run_traced(items, budget, &mut collector)
.expect("pipeline run failed");
collector.into_report()
}
#[test]
fn include_item_with_kind_passes() {
let pipeline = make_pipeline();
let kind = ContextKind::new("Message").unwrap();
let items = vec![
ContextItemBuilder::new("hello world", 5)
.kind(kind.clone())
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(100));
report.should().include_item_with_kind(kind);
}
#[test]
#[should_panic(expected = "include_item_with_kind(Document) failed")]
fn include_item_with_kind_panics() {
let pipeline = make_pipeline();
let msg_kind = ContextKind::new("Message").unwrap();
let items = vec![
ContextItemBuilder::new("hello world", 5)
.kind(msg_kind)
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(100));
report
.should()
.include_item_with_kind(ContextKind::new("Document").unwrap());
}
#[test]
fn include_item_matching_passes() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("special-content", 5)
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(100));
report
.should()
.include_item_matching(|i| i.item.content() == "special-content");
}
#[test]
#[should_panic(expected = "include_item_matching failed")]
fn include_item_matching_panics() {
let pipeline = make_pipeline();
let items = vec![ContextItemBuilder::new("hello", 5).build().unwrap()];
let report = run(&pipeline, &items, &budget(100));
report
.should()
.include_item_matching(|i| i.item.content() == "not-here");
}
#[test]
fn include_exact_n_items_with_kind_passes() {
let pipeline = make_pipeline();
let kind = ContextKind::new("Message").unwrap();
let items = vec![
ContextItemBuilder::new("msg1", 5)
.kind(kind.clone())
.build()
.unwrap(),
ContextItemBuilder::new("msg2", 5)
.kind(kind.clone())
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(100));
report.should().include_exact_n_items_with_kind(kind, 2);
}
#[test]
#[should_panic(expected = "include_exact_n_items_with_kind(Message, 5) failed")]
fn include_exact_n_items_with_kind_panics() {
let pipeline = make_pipeline();
let kind = ContextKind::new("Message").unwrap();
let items = vec![
ContextItemBuilder::new("msg1", 5)
.kind(kind.clone())
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(100));
report.should().include_exact_n_items_with_kind(kind, 5);
}
#[test]
fn exclude_item_with_reason_passes() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("fits", 10).build().unwrap(),
ContextItemBuilder::new("too-big", 1000).build().unwrap(),
];
let report = run(&pipeline, &items, &budget(20));
report
.should()
.exclude_item_with_reason(ExclusionReason::BudgetExceeded {
item_tokens: 0,
available_tokens: 0,
});
}
#[test]
#[should_panic(expected = "exclude_item_with_reason(Deduplicated")]
fn exclude_item_with_reason_panics() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("fits", 5).build().unwrap(),
ContextItemBuilder::new("too-big", 1000).build().unwrap(),
];
let report = run(&pipeline, &items, &budget(20));
report
.should()
.exclude_item_with_reason(ExclusionReason::Deduplicated {
deduplicated_against: String::new(),
});
}
#[test]
fn exclude_item_matching_with_reason_passes() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("small", 5).build().unwrap(),
ContextItemBuilder::new("giant", 9999).build().unwrap(),
];
let report = run(&pipeline, &items, &budget(20));
report.should().exclude_item_matching_with_reason(
|e| e.item.content() == "giant",
ExclusionReason::BudgetExceeded {
item_tokens: 0,
available_tokens: 0,
},
);
}
#[test]
#[should_panic(expected = "exclude_item_matching_with_reason(reason=Deduplicated")]
fn exclude_item_matching_with_reason_panics() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("small", 5).build().unwrap(),
ContextItemBuilder::new("giant", 9999).build().unwrap(),
];
let report = run(&pipeline, &items, &budget(20));
report.should().exclude_item_matching_with_reason(
|e| e.item.content() == "giant",
ExclusionReason::Deduplicated {
deduplicated_against: String::new(),
},
);
}
#[test]
fn have_excluded_item_with_budget_details_passes() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("small", 5).build().unwrap(),
ContextItemBuilder::new("giant", 500).build().unwrap(),
];
let report = run(&pipeline, &items, &budget(10));
let excluded = &report.excluded;
let budget_item = excluded
.iter()
.find(|e| matches!(e.reason, ExclusionReason::BudgetExceeded { .. }))
.expect("expected a BudgetExceeded exclusion");
let (actual_it, actual_at) = if let ExclusionReason::BudgetExceeded {
item_tokens,
available_tokens,
} = budget_item.reason
{
(item_tokens, available_tokens)
} else {
panic!("unexpected reason");
};
report.should().have_excluded_item_with_budget_details(
|e| e.item.content() == "giant",
actual_it,
actual_at,
);
}
#[test]
#[should_panic(expected = "have_excluded_item_with_budget_details failed")]
fn have_excluded_item_with_budget_details_panics() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("small", 5).build().unwrap(),
ContextItemBuilder::new("giant", 500).build().unwrap(),
];
let report = run(&pipeline, &items, &budget(10));
report.should().have_excluded_item_with_budget_details(
|e| e.item.content() == "giant",
999_999, 999_999, );
}
#[test]
fn have_no_exclusions_for_kind_passes() {
let pipeline = make_pipeline();
let msg_kind = ContextKind::new("Message").unwrap();
let doc_kind = ContextKind::new("Document").unwrap();
let items = vec![
ContextItemBuilder::new("msg", 5)
.kind(msg_kind.clone())
.build()
.unwrap(),
ContextItemBuilder::new("big-doc", 9999)
.kind(doc_kind)
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(20));
report.should().have_no_exclusions_for_kind(msg_kind);
}
#[test]
#[should_panic(expected = "have_no_exclusions_for_kind(Document) failed")]
fn have_no_exclusions_for_kind_panics() {
let pipeline = make_pipeline();
let doc_kind = ContextKind::new("Document").unwrap();
let items = vec![
ContextItemBuilder::new("msg", 5).build().unwrap(),
ContextItemBuilder::new("big-doc", 9999)
.kind(doc_kind.clone())
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(20));
report.should().have_no_exclusions_for_kind(doc_kind);
}
#[test]
fn have_at_least_n_exclusions_passes() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("small", 5).build().unwrap(),
ContextItemBuilder::new("too-big", 9999).build().unwrap(),
];
let report = run(&pipeline, &items, &budget(20));
report.should().have_at_least_n_exclusions(1);
}
#[test]
#[should_panic(expected = "have_at_least_n_exclusions(999) failed")]
fn have_at_least_n_exclusions_panics() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("small", 5).build().unwrap(),
ContextItemBuilder::new("too-big", 9999).build().unwrap(),
];
let report = run(&pipeline, &items, &budget(20));
report.should().have_at_least_n_exclusions(999);
}
#[test]
fn excluded_items_are_sorted_by_score_descending_passes() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("small", 5).build().unwrap(),
ContextItemBuilder::new("big1", 9000).build().unwrap(),
ContextItemBuilder::new("big2", 8000).build().unwrap(),
];
let report = run(&pipeline, &items, &budget(20));
report
.should()
.have_at_least_n_exclusions(1)
.excluded_items_are_sorted_by_score_descending();
}
#[test]
fn excluded_items_are_sorted_by_score_descending_vacuous_pass_on_zero_or_one() {
let pipeline = make_pipeline();
let items = vec![ContextItemBuilder::new("fits", 5).build().unwrap()];
let report = run(&pipeline, &items, &budget(100));
report
.should()
.excluded_items_are_sorted_by_score_descending();
}
#[test]
fn have_budget_utilization_above_passes() {
let pipeline = make_pipeline();
let items = vec![
ContextItemBuilder::new("a", 50).build().unwrap(),
ContextItemBuilder::new("b", 50).build().unwrap(),
];
let b = budget(110);
let report = run(&pipeline, &items, &b);
report.should().have_budget_utilization_above(0.5, &b);
}
#[test]
#[should_panic(expected = "have_budget_utilization_above(0.9999) failed")]
fn have_budget_utilization_above_panics() {
let pipeline = make_pipeline();
let items = vec![ContextItemBuilder::new("tiny", 5).build().unwrap()];
let b = budget(10000);
let report = run(&pipeline, &items, &b);
report.should().have_budget_utilization_above(0.9999, &b);
}
#[test]
fn have_kind_coverage_count_passes() {
let pipeline = make_pipeline();
let msg_kind = ContextKind::new("Message").unwrap();
let doc_kind = ContextKind::new("Document").unwrap();
let items = vec![
ContextItemBuilder::new("msg", 5)
.kind(msg_kind)
.build()
.unwrap(),
ContextItemBuilder::new("doc", 5)
.kind(doc_kind)
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(100));
report.should().have_kind_coverage_count(2);
}
#[test]
#[should_panic(expected = "have_kind_coverage_count(99) failed")]
fn have_kind_coverage_count_panics() {
let pipeline = make_pipeline();
let items = vec![ContextItemBuilder::new("a", 5).build().unwrap()];
let report = run(&pipeline, &items, &budget(100));
report.should().have_kind_coverage_count(99);
}
#[test]
fn place_item_at_edge_passes() {
let pipeline = make_priority_pipeline();
let items = vec![
ContextItemBuilder::new("first", 10)
.priority(30)
.build()
.unwrap(),
ContextItemBuilder::new("second", 10)
.priority(20)
.build()
.unwrap(),
ContextItemBuilder::new("third", 10)
.priority(10)
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(100));
report
.should()
.place_item_at_edge(|i| i.item.content() == "first");
}
#[test]
#[should_panic(expected = "place_item_at_edge failed")]
fn place_item_at_edge_panics() {
let pipeline = make_priority_pipeline();
let items = vec![
ContextItemBuilder::new("first", 10)
.priority(40)
.build()
.unwrap(), ContextItemBuilder::new("second", 10)
.priority(30)
.build()
.unwrap(), ContextItemBuilder::new("third", 10)
.priority(20)
.build()
.unwrap(), ContextItemBuilder::new("fourth", 10)
.priority(10)
.build()
.unwrap(), ];
let report = run(&pipeline, &items, &budget(100));
report
.should()
.place_item_at_edge(|i| i.item.content() == "third");
}
#[test]
fn place_top_n_scored_at_edges_n_zero_passes() {
let pipeline = make_pipeline();
let items = vec![ContextItemBuilder::new("a", 5).build().unwrap()];
let report = run(&pipeline, &items, &budget(100));
report.should().place_top_n_scored_at_edges(0);
}
#[test]
#[should_panic(expected = "place_top_n_scored_at_edges(99) failed: n=99 exceeds Included count=")]
fn place_top_n_scored_at_edges_n_exceeds_count_panics() {
let pipeline = make_pipeline();
let items = vec![ContextItemBuilder::new("a", 5).build().unwrap()];
let report = run(&pipeline, &items, &budget(100));
report.should().place_top_n_scored_at_edges(99);
}
#[test]
fn place_top_n_scored_at_edges_n2_passes() {
let pipeline = make_priority_pipeline();
let items = vec![
ContextItemBuilder::new("item0", 10)
.priority(40)
.build()
.unwrap(),
ContextItemBuilder::new("item1", 10)
.priority(30)
.build()
.unwrap(),
ContextItemBuilder::new("item2", 10)
.priority(20)
.build()
.unwrap(),
ContextItemBuilder::new("item3", 10)
.priority(10)
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(100));
report.should().place_top_n_scored_at_edges(2);
}
#[test]
fn chained_assertions_pass() {
let pipeline = make_pipeline();
let msg_kind = ContextKind::new("Message").unwrap();
let doc_kind = ContextKind::new("Document").unwrap();
let items = vec![
ContextItemBuilder::new("msg", 10)
.kind(msg_kind.clone())
.build()
.unwrap(),
ContextItemBuilder::new("doc", 10)
.kind(doc_kind)
.build()
.unwrap(),
ContextItemBuilder::new("oversized", 9999)
.kind(msg_kind.clone())
.build()
.unwrap(),
];
let report = run(&pipeline, &items, &budget(100));
report
.should()
.include_item_with_kind(msg_kind)
.have_at_least_n_exclusions(1)
.excluded_items_are_sorted_by_score_descending();
}