use std::collections::HashMap;
use crate::CupelError;
use crate::model::{ContextBudget, ContextItem, OverflowStrategy, ScoredItem};
use crate::placer::Placer;
type PlaceResult = Result<(Vec<ContextItem>, Vec<(ContextItem, f64)>), CupelError>;
pub(crate) fn place_items(
pinned: &[ContextItem],
sliced: &[ContextItem],
sorted_scored: &[ScoredItem],
budget: &ContextBudget,
overflow_strategy: OverflowStrategy,
placer: &dyn Placer,
) -> PlaceResult {
let mut score_map: HashMap<&str, f64> = HashMap::with_capacity(sorted_scored.len());
for si in sorted_scored {
score_map.entry(si.item.content()).or_insert(si.score);
}
let mut merged: Vec<ScoredItem> = Vec::with_capacity(pinned.len() + sliced.len());
for item in pinned {
merged.push(ScoredItem {
item: item.clone(),
score: 1.0,
});
}
for item in sliced {
let score = score_map.get(item.content()).copied().unwrap_or(0.0);
merged.push(ScoredItem {
item: item.clone(),
score,
});
}
let merged_tokens: i64 = merged.iter().map(|si| si.item.tokens()).sum();
let truncated: Vec<ScoredItem>;
if merged_tokens > budget.target_tokens() {
let (kept, dropped) = handle_overflow(merged, budget.target_tokens(), overflow_strategy)?;
merged = kept;
truncated = dropped;
} else {
truncated = vec![];
}
let result = placer.place(&merged);
let truncated_with_scores: Vec<(ContextItem, f64)> = truncated
.into_iter()
.map(|si| (si.item, si.score))
.collect();
Ok((result, truncated_with_scores))
}
fn handle_overflow(
mut merged: Vec<ScoredItem>,
target_tokens: i64,
strategy: OverflowStrategy,
) -> Result<(Vec<ScoredItem>, Vec<ScoredItem>), CupelError> {
match strategy {
OverflowStrategy::Throw => {
let merged_tokens: i64 = merged.iter().map(|si| si.item.tokens()).sum();
Err(CupelError::Overflow {
merged_tokens,
target_tokens,
})
}
OverflowStrategy::Truncate => {
merged.sort_by(|a, b| match (a.item.pinned(), b.item.pinned()) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => b.score.total_cmp(&a.score),
});
let mut kept = Vec::new();
let mut dropped = Vec::new();
let mut current_tokens: i64 = 0;
for si in merged {
let fits = si.item.pinned() || current_tokens + si.item.tokens() <= target_tokens;
if fits {
current_tokens += si.item.tokens();
kept.push(si);
} else {
dropped.push(si);
}
}
Ok((kept, dropped))
}
OverflowStrategy::Proceed => {
Ok((merged, vec![]))
}
}
}