use serde::{Deserialize, Serialize};
#[cfg(feature = "freshness")]
use ainl_contracts::ContextFreshness;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SegmentKind {
SystemPrompt,
OlderTurn,
RecentTurn,
ToolDefinitions,
ToolResult,
UserPrompt,
AnchoredSummaryRecall,
MemoryBlock,
}
impl SegmentKind {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::SystemPrompt => "system_prompt",
Self::OlderTurn => "older_turn",
Self::RecentTurn => "recent_turn",
Self::ToolDefinitions => "tool_definitions",
Self::ToolResult => "tool_result",
Self::UserPrompt => "user_prompt",
Self::AnchoredSummaryRecall => "anchored_summary_recall",
Self::MemoryBlock => "memory_block",
}
}
#[must_use]
pub fn is_always_keep(self) -> bool {
matches!(self, Self::SystemPrompt | Self::UserPrompt)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Segment {
pub kind: SegmentKind,
pub role: Role,
pub content: String,
pub age_index: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(default = "one")]
pub base_importance: f32,
#[cfg(feature = "freshness")]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub freshness: Option<ContextFreshness>,
}
const fn one() -> f32 {
1.0
}
impl Segment {
#[must_use]
pub fn user_prompt(content: impl Into<String>) -> Self {
Self {
kind: SegmentKind::UserPrompt,
role: Role::User,
content: content.into(),
age_index: 0,
tool_name: None,
base_importance: 2.0,
#[cfg(feature = "freshness")]
freshness: None,
}
}
#[must_use]
pub fn system_prompt(content: impl Into<String>) -> Self {
Self {
kind: SegmentKind::SystemPrompt,
role: Role::System,
content: content.into(),
age_index: u32::MAX, tool_name: None,
base_importance: 1.5,
#[cfg(feature = "freshness")]
freshness: None,
}
}
#[must_use]
pub fn recent_turn(role: Role, content: impl Into<String>, age_index: u32) -> Self {
Self {
kind: SegmentKind::RecentTurn,
role,
content: content.into(),
age_index,
tool_name: None,
base_importance: 1.0,
#[cfg(feature = "freshness")]
freshness: None,
}
}
#[must_use]
pub fn older_turn(role: Role, content: impl Into<String>, age_index: u32) -> Self {
Self {
kind: SegmentKind::OlderTurn,
role,
content: content.into(),
age_index,
tool_name: None,
base_importance: 0.7,
#[cfg(feature = "freshness")]
freshness: None,
}
}
#[must_use]
pub fn tool_result(tool_name: impl Into<String>, content: impl Into<String>, age_index: u32) -> Self {
Self {
kind: SegmentKind::ToolResult,
role: Role::Tool,
content: content.into(),
age_index,
tool_name: Some(tool_name.into()),
base_importance: 0.8,
#[cfg(feature = "freshness")]
freshness: None,
}
}
#[must_use]
pub fn tool_definitions(content: impl Into<String>) -> Self {
Self {
kind: SegmentKind::ToolDefinitions,
role: Role::System,
content: content.into(),
age_index: u32::MAX,
tool_name: None,
base_importance: 1.2,
#[cfg(feature = "freshness")]
freshness: None,
}
}
#[must_use]
pub fn memory_block(label: impl Into<String>, content: impl Into<String>) -> Self {
Self {
kind: SegmentKind::MemoryBlock,
role: Role::System,
content: content.into(),
age_index: 0,
tool_name: Some(label.into()),
base_importance: 1.0,
#[cfg(feature = "freshness")]
freshness: None,
}
}
#[must_use]
pub fn token_estimate(&self) -> usize {
ainl_compression::tokenize_estimate(&self.content)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn always_keep_classification() {
assert!(SegmentKind::SystemPrompt.is_always_keep());
assert!(SegmentKind::UserPrompt.is_always_keep());
assert!(!SegmentKind::OlderTurn.is_always_keep());
assert!(!SegmentKind::ToolResult.is_always_keep());
}
#[test]
fn segment_token_estimate_nonzero() {
let s = Segment::user_prompt("Hello world this is a test");
assert!(s.token_estimate() > 0);
}
#[test]
fn segment_kind_label_stable() {
assert_eq!(SegmentKind::SystemPrompt.as_str(), "system_prompt");
assert_eq!(SegmentKind::ToolResult.as_str(), "tool_result");
}
}