1use serde::{Deserialize, Serialize};
8
9#[cfg(feature = "freshness")]
10use ainl_contracts::ContextFreshness;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum Role {
16 System,
18 User,
20 Assistant,
22 Tool,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case")]
29pub enum SegmentKind {
30 SystemPrompt,
32 OlderTurn,
34 RecentTurn,
36 ToolDefinitions,
38 ToolResult,
40 UserPrompt,
42 AnchoredSummaryRecall,
44 MemoryBlock,
46}
47
48impl SegmentKind {
49 #[must_use]
51 pub fn as_str(self) -> &'static str {
52 match self {
53 Self::SystemPrompt => "system_prompt",
54 Self::OlderTurn => "older_turn",
55 Self::RecentTurn => "recent_turn",
56 Self::ToolDefinitions => "tool_definitions",
57 Self::ToolResult => "tool_result",
58 Self::UserPrompt => "user_prompt",
59 Self::AnchoredSummaryRecall => "anchored_summary_recall",
60 Self::MemoryBlock => "memory_block",
61 }
62 }
63
64 #[must_use]
66 pub fn is_always_keep(self) -> bool {
67 matches!(self, Self::SystemPrompt | Self::UserPrompt)
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct Segment {
74 pub kind: SegmentKind,
76 pub role: Role,
78 pub content: String,
80 pub age_index: u32,
82 #[serde(default, skip_serializing_if = "Option::is_none")]
84 pub tool_name: Option<String>,
85 #[serde(default = "one")]
88 pub base_importance: f32,
89 #[cfg(feature = "freshness")]
93 #[serde(default, skip_serializing_if = "Option::is_none")]
94 pub freshness: Option<ContextFreshness>,
95}
96
97const fn one() -> f32 {
98 1.0
99}
100
101impl Segment {
102 #[must_use]
104 pub fn user_prompt(content: impl Into<String>) -> Self {
105 Self {
106 kind: SegmentKind::UserPrompt,
107 role: Role::User,
108 content: content.into(),
109 age_index: 0,
110 tool_name: None,
111 base_importance: 2.0,
112 #[cfg(feature = "freshness")]
113 freshness: None,
114 }
115 }
116
117 #[must_use]
119 pub fn system_prompt(content: impl Into<String>) -> Self {
120 Self {
121 kind: SegmentKind::SystemPrompt,
122 role: Role::System,
123 content: content.into(),
124 age_index: u32::MAX, tool_name: None,
126 base_importance: 1.5,
127 #[cfg(feature = "freshness")]
128 freshness: None,
129 }
130 }
131
132 #[must_use]
134 pub fn recent_turn(role: Role, content: impl Into<String>, age_index: u32) -> Self {
135 Self {
136 kind: SegmentKind::RecentTurn,
137 role,
138 content: content.into(),
139 age_index,
140 tool_name: None,
141 base_importance: 1.0,
142 #[cfg(feature = "freshness")]
143 freshness: None,
144 }
145 }
146
147 #[must_use]
149 pub fn older_turn(role: Role, content: impl Into<String>, age_index: u32) -> Self {
150 Self {
151 kind: SegmentKind::OlderTurn,
152 role,
153 content: content.into(),
154 age_index,
155 tool_name: None,
156 base_importance: 0.7,
157 #[cfg(feature = "freshness")]
158 freshness: None,
159 }
160 }
161
162 #[must_use]
164 pub fn tool_result(
165 tool_name: impl Into<String>,
166 content: impl Into<String>,
167 age_index: u32,
168 ) -> Self {
169 Self {
170 kind: SegmentKind::ToolResult,
171 role: Role::Tool,
172 content: content.into(),
173 age_index,
174 tool_name: Some(tool_name.into()),
175 base_importance: 0.8,
176 #[cfg(feature = "freshness")]
177 freshness: None,
178 }
179 }
180
181 #[must_use]
183 pub fn tool_definitions(content: impl Into<String>) -> Self {
184 Self {
185 kind: SegmentKind::ToolDefinitions,
186 role: Role::System,
187 content: content.into(),
188 age_index: u32::MAX,
189 tool_name: None,
190 base_importance: 1.2,
191 #[cfg(feature = "freshness")]
192 freshness: None,
193 }
194 }
195
196 #[must_use]
198 pub fn memory_block(label: impl Into<String>, content: impl Into<String>) -> Self {
199 Self {
200 kind: SegmentKind::MemoryBlock,
201 role: Role::System,
202 content: content.into(),
203 age_index: 0,
204 tool_name: Some(label.into()),
205 base_importance: 1.0,
206 #[cfg(feature = "freshness")]
207 freshness: None,
208 }
209 }
210
211 #[must_use]
213 pub fn token_estimate(&self) -> usize {
214 ainl_compression::tokenize_estimate(&self.content)
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn always_keep_classification() {
224 assert!(SegmentKind::SystemPrompt.is_always_keep());
225 assert!(SegmentKind::UserPrompt.is_always_keep());
226 assert!(!SegmentKind::OlderTurn.is_always_keep());
227 assert!(!SegmentKind::ToolResult.is_always_keep());
228 }
229
230 #[test]
231 fn segment_token_estimate_nonzero() {
232 let s = Segment::user_prompt("Hello world this is a test");
233 assert!(s.token_estimate() > 0);
234 }
235
236 #[test]
237 fn segment_kind_label_stable() {
238 assert_eq!(SegmentKind::SystemPrompt.as_str(), "system_prompt");
239 assert_eq!(SegmentKind::ToolResult.as_str(), "tool_result");
240 }
241}