1use kameo::prelude::*;
6
7use super::events::{
8 CompressStrategy, NotifyCompressedForReply, RequestAppend, RequestClear,
9 RequestCompress, RequestCompressed, RequestEstimateTokens, RequestExtend, RequestFindByRole,
10 RequestImportIncremental, RequestGet, RequestImmutable, RequestIncremental, RequestInsert,
11 RequestIsEmpty, RequestLen, RequestMessages, RequestPop, RequestRemove, RequestRetain,
12 RequestSend, RequestSendStream, RequestSubscribeCompressed, RequestUnsubscribeCompressed,
13 RequestExportIncremental, RequestExportAll, RequestUpdate,
14};
15use super::stream::AgentSendStream;
16use super::types::ContextBackend;
17use crate::error::AgentError;
18
19type CompressSubscriberReply<M> = (Vec<M>, Vec<M>);
20type CompressSubscriberRecipient<M> = ReplyRecipient<NotifyCompressedForReply<M>, CompressSubscriberReply<M>>;
21
22use crate::message::ContextMessage;
23use crate::readonly::ReadOnly;
24
25#[derive(Actor)]
44pub struct AgentContext<B: ContextBackend> {
45 backend: B,
46 immutable: ReadOnly<B::Message>,
47 compressed: Vec<B::Message>,
48 incremental: Vec<B::Message>,
49 on_compressed: Option<CompressSubscriberRecipient<B::Message>>,
50}
51
52impl<B: ContextBackend> AgentContext<B> {
53 pub fn new(backend: B, immutable: Vec<B::Message>) -> Self {
58 Self {
59 backend,
60 immutable: ReadOnly::from(immutable),
61 compressed: Vec::new(),
62 incremental: Vec::new(),
63 on_compressed: None,
64 }
65 }
66
67 fn default_summary_prompt() -> String {
68 "请将以下对话历史压缩为简洁摘要,保留关键信息、决策和上下文。输出一条 system 消息。"
69 .to_string()
70 }
71
72
73 async fn notify_compressed_subscriber(
75 &self,
76 summary: Vec<B::Message>,
77 kept: Vec<B::Message>,
78 ) -> Result<(Vec<B::Message>, Vec<B::Message>), AgentError> {
79 if let Some(subscriber) = &self.on_compressed {
80 subscriber
81 .ask(NotifyCompressedForReply {
82 summary,
83 kept,
84 })
85 .send()
86 .await
87 .map_err(|e| AgentError::Context(e.to_string()))
88 } else {
89 Ok((summary, kept))
90 }
91 }
92
93 async fn compress_if_full(&mut self, opts: &B::Opts) -> Result<(), AgentError> {
95 let common = opts.as_ref();
96 let mut all: Vec<B::Message> = self
97 .immutable
98 .iter()
99 .chain(self.compressed.iter())
100 .chain(self.incremental.iter())
101 .cloned()
102 .collect();
103 if let Some(ref scratch) = common.scratch {
104 all.push(self.backend.system_message(scratch.clone()));
105 }
106 let tokens = self
107 .backend
108 .estimate_tokens(&all)
109 .await
110 .unwrap_or(usize::MAX);
111 if tokens < common.context_window {
112 return Ok(());
113 }
114 if !common.auto_compress {
115 return Err(AgentError::Context("上下文已满且未启用自动压缩".into()));
116 }
117 let total = self.incremental.len();
118 let keep = total / 2;
119 if total <= keep {
120 return Ok(());
121 }
122 let split = total - keep;
123 let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
124 if to_summarize.is_empty() {
125 return Ok(());
126 }
127 let mut summary_messages =
128 vec![self.backend.system_message(Self::default_summary_prompt())];
129 summary_messages.append(&mut self.compressed);
130 summary_messages.extend(to_summarize);
131 let response = self.backend.send(&summary_messages, opts).await?;
132 let raw_msgs = self
133 .backend
134 .extract_messages(std::slice::from_ref(&response))?;
135 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
136 let summary: Vec<B::Message> = request_msgs
137 .into_iter()
138 .map(|msg| self.backend.to_system_message(msg))
139 .collect();
140 let kept: Vec<B::Message> = self.incremental.drain(..).collect();
141 let (final_summary, final_kept) =
142 self.notify_compressed_subscriber(summary, kept).await?;
143 self.compressed = final_summary;
144 self.incremental = final_kept;
145 Ok(())
146 }
147}
148
149impl<B: ContextBackend> Message<RequestAppend<B::Message>> for AgentContext<B> {
154 type Reply = ();
155
156 async fn handle(
157 &mut self,
158 msg: RequestAppend<B::Message>,
159 _ctx: &mut Context<Self, Self::Reply>,
160 ) -> Self::Reply {
161 self.incremental.push(msg.message);
162 }
163}
164
165impl<B: ContextBackend> Message<RequestExtend<B::Message>> for AgentContext<B> {
166 type Reply = ();
167
168 async fn handle(
169 &mut self,
170 msg: RequestExtend<B::Message>,
171 _ctx: &mut Context<Self, Self::Reply>,
172 ) -> Self::Reply {
173 for m in msg.messages {
174 self.incremental.push(m);
175 }
176 }
177}
178
179impl<B: ContextBackend> Message<RequestUpdate<B::Message>> for AgentContext<B> {
180 type Reply = Result<(), AgentError>;
181
182 async fn handle(
183 &mut self,
184 msg: RequestUpdate<B::Message>,
185 _ctx: &mut Context<Self, Self::Reply>,
186 ) -> Self::Reply {
187 if msg.index >= self.incremental.len() {
188 return Err(AgentError::Context("索引越界".into()));
189 }
190 self.incremental[msg.index] = msg.message;
191 Ok(())
192 }
193}
194
195impl<B: ContextBackend> Message<RequestInsert<B::Message>> for AgentContext<B> {
196 type Reply = Result<(), AgentError>;
197
198 async fn handle(
199 &mut self,
200 msg: RequestInsert<B::Message>,
201 _ctx: &mut Context<Self, Self::Reply>,
202 ) -> Self::Reply {
203 if msg.index > self.incremental.len() {
204 return Err(AgentError::Context("索引越界".into()));
205 }
206 self.incremental.insert(msg.index, msg.message);
207 Ok(())
208 }
209}
210
211impl<B: ContextBackend> Message<RequestRemove> for AgentContext<B> {
212 type Reply = Result<(), AgentError>;
213
214 async fn handle(
215 &mut self,
216 msg: RequestRemove,
217 _ctx: &mut Context<Self, Self::Reply>,
218 ) -> Self::Reply {
219 if msg.index >= self.incremental.len() {
220 return Err(AgentError::Context("索引越界".into()));
221 }
222 self.incremental.remove(msg.index);
223 Ok(())
224 }
225}
226
227impl<B: ContextBackend> Message<RequestPop> for AgentContext<B> {
228 type Reply = Option<B::Message>;
229
230 async fn handle(
231 &mut self,
232 _msg: RequestPop,
233 _ctx: &mut Context<Self, Self::Reply>,
234 ) -> Self::Reply {
235 self.incremental.pop()
236 }
237}
238
239impl<B: ContextBackend> Message<RequestRetain> for AgentContext<B> {
240 type Reply = ();
241
242 async fn handle(
243 &mut self,
244 msg: RequestRetain,
245 _ctx: &mut Context<Self, Self::Reply>,
246 ) -> Self::Reply {
247 let mut removed = Vec::new();
248 let role = msg.role;
249 self.incremental.retain(|m| {
250 if m.role() == role {
251 true
252 } else {
253 removed.push(m.clone());
254 false
255 }
256 });
257 }
258}
259
260impl<B: ContextBackend> Message<RequestClear> for AgentContext<B> {
261 type Reply = ();
262
263 async fn handle(
264 &mut self,
265 _msg: RequestClear,
266 _ctx: &mut Context<Self, Self::Reply>,
267 ) -> Self::Reply {
268 self.incremental.clear();
269 }
270}
271
272impl<B: ContextBackend> Message<RequestLen> for AgentContext<B> {
277 type Reply = usize;
278
279 async fn handle(
280 &mut self,
281 _msg: RequestLen,
282 _ctx: &mut Context<Self, Self::Reply>,
283 ) -> Self::Reply {
284 self.immutable.len() + self.compressed.len() + self.incremental.len()
285 }
286}
287
288impl<B: ContextBackend> Message<RequestIsEmpty> for AgentContext<B> {
289 type Reply = bool;
290
291 async fn handle(
292 &mut self,
293 _msg: RequestIsEmpty,
294 _ctx: &mut Context<Self, Self::Reply>,
295 ) -> Self::Reply {
296 self.immutable.is_empty() && self.compressed.is_empty() && self.incremental.is_empty()
297 }
298}
299
300impl<B: ContextBackend> Message<RequestGet> for AgentContext<B> {
301 type Reply = Option<B::Message>;
302
303 async fn handle(
304 &mut self,
305 msg: RequestGet,
306 _ctx: &mut Context<Self, Self::Reply>,
307 ) -> Self::Reply {
308 let idx = msg.0;
309 let imm_len = self.immutable.len();
310 let comp_len = self.compressed.len();
311 if idx < imm_len {
312 Some(self.immutable[idx].clone())
313 } else if idx < imm_len + comp_len {
314 Some(self.compressed[idx - imm_len].clone())
315 } else if idx < imm_len + comp_len + self.incremental.len() {
316 Some(self.incremental[idx - imm_len - comp_len].clone())
317 } else {
318 None
319 }
320 }
321}
322
323impl<B: ContextBackend> Message<RequestMessages> for AgentContext<B> {
324 type Reply = Vec<B::Message>;
325
326 async fn handle(
327 &mut self,
328 _msg: RequestMessages,
329 _ctx: &mut Context<Self, Self::Reply>,
330 ) -> Self::Reply {
331 self.immutable
332 .iter()
333 .chain(self.compressed.iter())
334 .chain(self.incremental.iter())
335 .cloned()
336 .collect()
337 }
338}
339
340impl<B: ContextBackend> Message<RequestImmutable> for AgentContext<B> {
341 type Reply = Vec<B::Message>;
342
343 async fn handle(
344 &mut self,
345 _msg: RequestImmutable,
346 _ctx: &mut Context<Self, Self::Reply>,
347 ) -> Self::Reply {
348 self.immutable.to_vec()
349 }
350}
351
352impl<B: ContextBackend> Message<RequestCompressed> for AgentContext<B> {
353 type Reply = Vec<B::Message>;
354
355 async fn handle(
356 &mut self,
357 _msg: RequestCompressed,
358 _ctx: &mut Context<Self, Self::Reply>,
359 ) -> Self::Reply {
360 self.compressed.clone()
361 }
362}
363
364impl<B: ContextBackend> Message<RequestIncremental> for AgentContext<B> {
365 type Reply = Vec<B::Message>;
366
367 async fn handle(
368 &mut self,
369 _msg: RequestIncremental,
370 _ctx: &mut Context<Self, Self::Reply>,
371 ) -> Self::Reply {
372 self.incremental.clone()
373 }
374}
375
376impl<B: ContextBackend> Message<RequestFindByRole> for AgentContext<B> {
377 type Reply = Vec<B::Message>;
378
379 async fn handle(
380 &mut self,
381 msg: RequestFindByRole,
382 _ctx: &mut Context<Self, Self::Reply>,
383 ) -> Self::Reply {
384 self.immutable
385 .iter()
386 .chain(self.compressed.iter())
387 .chain(self.incremental.iter())
388 .filter(|m| m.role() == msg.0)
389 .cloned()
390 .collect()
391 }
392}
393
394impl<B: ContextBackend> Message<RequestSend<B::Opts>> for AgentContext<B> {
399 type Reply = Result<B::Response, AgentError>;
400
401 async fn handle(
402 &mut self,
403 msg: RequestSend<B::Opts>,
404 _ctx: &mut Context<Self, Self::Reply>,
405 ) -> Self::Reply {
406 self.compress_if_full(&msg.opts).await?;
407 let scratch = msg.opts.as_ref().scratch.clone();
408 let mut all_messages: Vec<B::Message> = self
409 .immutable
410 .iter()
411 .chain(self.compressed.iter())
412 .chain(self.incremental.iter())
413 .cloned()
414 .collect();
415 if let Some(content) = scratch {
416 all_messages.push(self.backend.system_message(content));
417 }
418 let response = self.backend.send(&all_messages, &msg.opts).await?;
419 let raw_msgs = self
420 .backend
421 .extract_messages(std::slice::from_ref(&response))?;
422 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
423 for msg in &request_msgs {
424 self.incremental.push(msg.clone());
425 }
426 Ok(response)
427 }
428}
429
430impl<B: ContextBackend + Clone> Message<RequestSendStream<B::Opts>> for AgentContext<B> {
431 type Reply = Result<AgentSendStream<B>, AgentError>;
432
433 async fn handle(
434 &mut self,
435 msg: RequestSendStream<B::Opts>,
436 _ctx: &mut Context<Self, Self::Reply>,
437 ) -> Self::Reply {
438 self.compress_if_full(&msg.opts).await?;
439 let scratch = msg.opts.as_ref().scratch.clone();
440 let mut all_messages: Vec<B::Message> = self
441 .immutable
442 .iter()
443 .chain(self.compressed.iter())
444 .chain(self.incremental.iter())
445 .cloned()
446 .collect();
447 if let Some(content) = scratch {
448 all_messages.push(self.backend.system_message(content));
449 }
450 let stream = self.backend.send_stream(all_messages, msg.opts);
451 Ok(AgentSendStream::new(stream))
452 }
453}
454
455impl<B: ContextBackend> Message<RequestCompress<B::Opts>> for AgentContext<B> {
460 type Reply = Result<(), AgentError>;
461
462 async fn handle(
463 &mut self,
464 msg: RequestCompress<B::Opts>,
465 _ctx: &mut Context<Self, Self::Reply>,
466 ) -> Self::Reply {
467 match msg.strategy {
468 CompressStrategy::Summarize { keep, prompt } => {
469 let total = self.incremental.len();
470 if total > keep {
471 let split = total - keep;
472 let to_summarize: Vec<B::Message> = self.incremental.drain(..split).collect();
473 if !to_summarize.is_empty() {
474 let summary_prompt = prompt.unwrap_or_else(Self::default_summary_prompt);
475 let mut summary_messages =
476 vec![self.backend.system_message(summary_prompt)];
477 summary_messages.append(&mut self.compressed);
478 summary_messages.extend(to_summarize);
479 let response = self.backend.send(&summary_messages, &msg.opts).await?;
480 let raw_msgs = self
481 .backend
482 .extract_messages(std::slice::from_ref(&response))?;
483 let request_msgs = self.backend.to_request_messages(raw_msgs)?;
484 let summary: Vec<B::Message> = request_msgs
485 .into_iter()
486 .map(|msg| self.backend.to_system_message(msg))
487 .collect();
488 let kept: Vec<B::Message> = self.incremental.drain(..).collect();
489 let (final_summary, final_kept) =
490 self.notify_compressed_subscriber(summary, kept).await?;
491 self.compressed = final_summary;
492 self.incremental = final_kept;
493 }
494 }
495 Ok(())
496 }
497 }
498 }
499}
500
501
502impl<B: ContextBackend> Message<RequestSubscribeCompressed<B::Message>> for AgentContext<B> {
503 type Reply = ();
504
505 async fn handle(
506 &mut self,
507 msg: RequestSubscribeCompressed<B::Message>,
508 _ctx: &mut Context<Self, Self::Reply>,
509 ) -> Self::Reply {
510 self.on_compressed = Some(msg.recipient);
511 }
512}
513
514impl<B: ContextBackend> Message<RequestUnsubscribeCompressed> for AgentContext<B> {
515 type Reply = ();
516
517 async fn handle(
518 &mut self,
519 _msg: RequestUnsubscribeCompressed,
520 _ctx: &mut Context<Self, Self::Reply>,
521 ) -> Self::Reply {
522 self.on_compressed = None;
523 }
524}
525
526impl<B: ContextBackend> Message<RequestEstimateTokens> for AgentContext<B> {
531 type Reply = usize;
532
533 async fn handle(
534 &mut self,
535 _msg: RequestEstimateTokens,
536 _ctx: &mut Context<Self, Self::Reply>,
537 ) -> Self::Reply {
538 let all: Vec<B::Message> = self
539 .immutable
540 .iter()
541 .chain(self.compressed.iter())
542 .chain(self.incremental.iter())
543 .cloned()
544 .collect();
545 self.backend.estimate_tokens(&all).await.unwrap_or(0)
546 }
547}
548
549impl<B: ContextBackend> Message<RequestExportAll> for AgentContext<B> {
550 type Reply = Result<String, AgentError>;
551
552 async fn handle(
553 &mut self,
554 _msg: RequestExportAll,
555 _ctx: &mut Context<Self, Self::Reply>,
556 ) -> Self::Reply {
557 let lines: Vec<String> = self
558 .immutable
559 .iter()
560 .chain(self.compressed.iter())
561 .chain(self.incremental.iter())
562 .map(|m| self.backend.message_to_json(m))
563 .collect::<Result<_, _>>()?;
564 Ok(lines.join("\n"))
565 }
566}
567
568impl<B: ContextBackend> Message<RequestExportIncremental> for AgentContext<B> {
569 type Reply = Result<String, AgentError>;
570
571 async fn handle(
572 &mut self,
573 _msg: RequestExportIncremental,
574 _ctx: &mut Context<Self, Self::Reply>,
575 ) -> Self::Reply {
576 let lines: Vec<String> = self
577 .incremental
578 .iter()
579 .map(|m| self.backend.message_to_json(m))
580 .collect::<Result<_, _>>()?;
581 Ok(lines.join("\n"))
582 }
583}
584
585impl<B: ContextBackend> Message<RequestImportIncremental> for AgentContext<B> {
586 type Reply = Result<(), AgentError>;
587
588 async fn handle(
589 &mut self,
590 msg: RequestImportIncremental,
591 _ctx: &mut Context<Self, Self::Reply>,
592 ) -> Self::Reply {
593 let mut messages = Vec::new();
594 for line in msg.json.lines() {
595 let line = line.trim();
596 if line.is_empty() {
597 continue;
598 }
599 messages.push(self.backend.message_from_json(line)?);
600 }
601 self.incremental.clear();
602 self.incremental = messages.clone();
603 Ok(())
604 }
605}