use rucora_core::provider::types::{ChatMessage, Role};
pub fn group_messages_by_api_round(messages: &[ChatMessage]) -> Vec<Vec<ChatMessage>> {
let mut groups: Vec<Vec<ChatMessage>> = Vec::new();
let mut current_group: Vec<ChatMessage> = Vec::new();
let mut last_role: Option<Role> = None;
for msg in messages {
if msg.role == Role::User && !current_group.is_empty() {
groups.push(current_group);
current_group = vec![msg.clone()];
last_role = Some(msg.role.clone());
continue;
}
if msg.role == Role::Assistant
&& last_role == Some(Role::Assistant)
&& !current_group.is_empty()
{
if should_start_new_group(¤t_group, msg) {
groups.push(current_group);
current_group = vec![msg.clone()];
last_role = Some(msg.role.clone());
continue;
}
}
current_group.push(msg.clone());
last_role = Some(msg.role.clone());
}
if !current_group.is_empty() {
groups.push(current_group);
}
groups
}
fn should_start_new_group(current_group: &[ChatMessage], msg: &ChatMessage) -> bool {
if let Some(last) = current_group.last()
&& last.role == Role::Assistant
&& msg.role == Role::Assistant
{
return true;
}
false
}
pub fn select_groups_to_compact(
groups: &[Vec<ChatMessage>],
preserve_count: usize,
) -> Vec<Vec<ChatMessage>> {
if groups.len() <= preserve_count {
return Vec::new();
}
let groups_to_compact = groups.len() - preserve_count;
groups[..groups_to_compact].to_vec()
}
pub fn groups_to_text(groups: &[Vec<ChatMessage>]) -> String {
let mut parts: Vec<String> = Vec::new();
for (i, group) in groups.iter().enumerate() {
let mut group_text = format!("=== 轮次 {} ===\n", i + 1);
for msg in group {
let role = match msg.role {
Role::User => "用户",
Role::Assistant => "助手",
Role::System => "系统",
Role::Tool => "工具",
};
group_text.push_str(&format!("[{}]: {}\n", role, msg.content));
}
parts.push(group_text);
}
parts.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_group_messages() {
let messages = vec![
ChatMessage::user("你好"),
ChatMessage::assistant("你好!有什么可以帮助你的吗?"),
ChatMessage::user("帮我写个函数"),
ChatMessage::assistant("好的,我来帮你写。"),
];
let groups = group_messages_by_api_round(&messages);
assert_eq!(groups.len(), 2);
}
#[test]
fn test_select_groups_to_compact() {
let groups = vec![
vec![ChatMessage::user("消息 1")],
vec![ChatMessage::user("消息 2")],
vec![ChatMessage::user("消息 3")],
vec![ChatMessage::user("消息 4")],
];
let to_compact = select_groups_to_compact(&groups, 2);
assert_eq!(to_compact.len(), 2);
}
#[test]
fn test_groups_to_text() {
let groups = vec![vec![
ChatMessage::user("你好"),
ChatMessage::assistant("你好!"),
]];
let text = groups_to_text(&groups);
assert!(text.contains("轮次 1"));
assert!(text.contains("用户"));
assert!(text.contains("助手"));
}
}